#include "common.hpp"
#include "Utils/UnwrappedSequenceNumber.hpp"
#include <catch2/catch_test_macros.hpp>

SCENARIO("SCTP UnwrappedSequenceNumber", "[sctp]")
{
	using TestSequence = Utils::UnwrappedSequenceNumber<uint16_t>;

	SECTION("simple unwrapping")
	{
		TestSequence::Unwrapper unwrapper;

		TestSequence s0 = unwrapper.Unwrap(0);
		TestSequence s1 = unwrapper.Unwrap(1);
		TestSequence s2 = unwrapper.Unwrap(2);
		TestSequence s3 = unwrapper.Unwrap(3);

		REQUIRE(s0 < s1);
		REQUIRE(s0 < s2);
		REQUIRE(s0 < s3);
		REQUIRE(s1 < s2);
		REQUIRE(s1 < s3);
		REQUIRE(s2 < s3);

		REQUIRE(TestSequence::Difference(s1, s0) == 1);
		REQUIRE(TestSequence::Difference(s2, s0) == 2);
		REQUIRE(TestSequence::Difference(s3, s0) == 3);

		REQUIRE(s1 > s0);
		REQUIRE(s2 > s0);
		REQUIRE(s3 > s0);
		REQUIRE(s2 > s1);
		REQUIRE(s3 > s1);
		REQUIRE(s3 > s2);

		s0.Increment();
		REQUIRE(s0 == s1);

		s1.Increment();
		REQUIRE(s1 == s2);

		s2.Increment();
		REQUIRE(s2 == s3);

		REQUIRE(TestSequence::AddTo(s0, 2) == s3);
	}

	SECTION("mid value unwrapping")
	{
		TestSequence::Unwrapper unwrapper;

		TestSequence s0 = unwrapper.Unwrap(0x7FFE);
		TestSequence s1 = unwrapper.Unwrap(0x7FFF);
		TestSequence s2 = unwrapper.Unwrap(0x8000);
		TestSequence s3 = unwrapper.Unwrap(0x8001);

		REQUIRE(s0 < s1);
		REQUIRE(s0 < s2);
		REQUIRE(s0 < s3);
		REQUIRE(s1 < s2);
		REQUIRE(s1 < s3);
		REQUIRE(s2 < s3);

		REQUIRE(TestSequence::Difference(s1, s0) == 1);
		REQUIRE(TestSequence::Difference(s2, s0) == 2);
		REQUIRE(TestSequence::Difference(s3, s0) == 3);

		REQUIRE(s1 > s0);
		REQUIRE(s2 > s0);
		REQUIRE(s3 > s0);
		REQUIRE(s2 > s1);
		REQUIRE(s3 > s1);
		REQUIRE(s3 > s2);

		s0.Increment();
		REQUIRE(s0 == s1);

		s1.Increment();
		REQUIRE(s1 == s2);

		s2.Increment();
		REQUIRE(s2 == s3);

		REQUIRE(TestSequence::AddTo(s0, 2) == s3);
	}

	SECTION("wrapped unwrapping")
	{
		TestSequence::Unwrapper unwrapper;

		TestSequence s0 = unwrapper.Unwrap(0xFFFE);
		TestSequence s1 = unwrapper.Unwrap(0xFFFF);
		TestSequence s2 = unwrapper.Unwrap(0x0000);
		TestSequence s3 = unwrapper.Unwrap(0x0001);

		REQUIRE(s0 < s1);
		REQUIRE(s0 < s2);
		REQUIRE(s0 < s3);
		REQUIRE(s1 < s2);
		REQUIRE(s1 < s3);
		REQUIRE(s2 < s3);

		REQUIRE(TestSequence::Difference(s1, s0) == 1);
		REQUIRE(TestSequence::Difference(s2, s0) == 2);
		REQUIRE(TestSequence::Difference(s3, s0) == 3);

		REQUIRE(s1 > s0);
		REQUIRE(s2 > s0);
		REQUIRE(s3 > s0);
		REQUIRE(s2 > s1);
		REQUIRE(s3 > s1);
		REQUIRE(s3 > s2);

		s0.Increment();
		REQUIRE(s0 == s1);

		s1.Increment();
		REQUIRE(s1 == s2);

		s2.Increment();
		REQUIRE(s2 == s3);

		REQUIRE(TestSequence::AddTo(s0, 2) == s3);
	}

	SECTION("wrap around a few times")
	{
		TestSequence::Unwrapper unwrapper;

		const TestSequence s0 = unwrapper.Unwrap(0);
		TestSequence prev     = s0;

		for (uint32_t i{ 1 }; i < 65536 * 3; ++i)
		{
			const auto wrapped    = static_cast<uint16_t>(i);
			const TestSequence si = unwrapper.Unwrap(wrapped);

			REQUIRE(s0 < si);
			REQUIRE(prev < si);

			prev = si;
		}
	}

	SECTION("increment is same as wrapped")
	{
		TestSequence::Unwrapper unwrapper;

		TestSequence s0   = unwrapper.Unwrap(0);
		TestSequence prev = s0;

		for (uint32_t i{ 1 }; i < 65536 * 2; ++i)
		{
			const auto wrapped    = static_cast<uint16_t>(i);
			const TestSequence si = unwrapper.Unwrap(wrapped);

			s0.Increment();
			REQUIRE(s0 == si);

			prev = si;
		}
	}

	SECTION("unwrapping larger number is always larger")
	{
		TestSequence::Unwrapper unwrapper;

		for (uint32_t i{ 1 }; i < 65536 * 2; ++i)
		{
			const auto wrapped    = static_cast<uint16_t>(i);
			const TestSequence si = unwrapper.Unwrap(wrapped);

			REQUIRE(unwrapper.Unwrap(wrapped + 1) > si);
			REQUIRE(unwrapper.Unwrap(wrapped + 5) > si);
			REQUIRE(unwrapper.Unwrap(wrapped + 10) > si);
			REQUIRE(unwrapper.Unwrap(wrapped + 100) > si);
		}
	}

	SECTION("unwrapping smaller number is always smaller")
	{
		TestSequence::Unwrapper unwrapper;

		for (uint32_t i{ 1 }; i < 65536 * 2; ++i)
		{
			const auto wrapped    = static_cast<uint16_t>(i);
			const TestSequence si = unwrapper.Unwrap(wrapped);

			REQUIRE(unwrapper.Unwrap(wrapped - 1) < si);
			REQUIRE(unwrapper.Unwrap(wrapped - 5) < si);
			REQUIRE(unwrapper.Unwrap(wrapped - 10) < si);
			REQUIRE(unwrapper.Unwrap(wrapped - 100) < si);
		}
	}

	SECTION("difference is absolute")
	{
		TestSequence::Unwrapper unwrapper;
		const TestSequence thisValue  = unwrapper.Unwrap(10);
		const TestSequence otherValue = TestSequence::AddTo(thisValue, 100);

		REQUIRE(TestSequence::Difference(thisValue, otherValue) == 100);
		REQUIRE(TestSequence::Difference(otherValue, thisValue) == 100);

		const TestSequence minusValue = TestSequence::AddTo(thisValue, -100);

		REQUIRE(TestSequence::Difference(thisValue, minusValue) == 100);
		REQUIRE(TestSequence::Difference(minusValue, thisValue) == 100);
	}
}
