#include <iostream>
#include <cppunit/TestCase.h>
#include <cppunit/TestCaller.h>
#include <cppunit/TestResult.h>
#include <cppunit/TestSuite.h>
#include <cppunit/extensions/HelperMacros.h>
#include <portablethreads/config.h>

#define PT_ARCH_COMMON_INCLUDED

#include <portablethreads/arch/free-high-bits-muxer.h>
#include <portablethreads/arch/x-byte-pointer-to-int-compression.h>


using namespace std;
using namespace PortableThreads;
using namespace PortableThreads::LockFree;

#ifdef _MSC_VER
#	pragma warning(disable:4311) // pointer to int
#	pragma warning(disable:4312) // int to pointer
#endif

template<typename T>
struct bits
{
	bits(T i)
		:	int_(i)
	{}
	T int_;
};

template<typename T>
std::ostream& operator<<(std::ostream& os, bits<T> bits)
{
	
	T bit = (static_cast<T>(1) << ((sizeof(T)*8)-1));
	for(unsigned i = 0; i < sizeof(T)*8; ++i)
	{
		os << ((bits.int_ & bit) == 0 ? '0' : '1');
		bit >>= 1;
	}
	os << " = " << bits.int_;
	return os;
}

template<typename T, unsigned POINTER_BITS, unsigned HW_POINTER_BITS, unsigned ALIGNMENT_BITS>
class BitTwiddlingTest : public CppUnit::TestFixture
{
public:
	typedef PortableThreads::LockFree::Private::PTPointerCAS PTPointerCAS;
	typedef PortableThreads::LockFree::Private::token_t token_t;
	typedef PortableThreads::LockFree::Private::FreeHighBits<T, POINTER_BITS, HW_POINTER_BITS, ALIGNMENT_BITS> Muxer;
	typedef typename Muxer::int_t int_t;
	BitTwiddlingTest()
	{
		CPPUNIT_ASSERT(sizeof(int_t) >= 8);
	}

	template<typename U>
	int_t inflate(U* p)
	{
		return PortableThreads::LockFree::Private::pt_inflate_pointer<U>(p);
	}

	template<typename U>
	U* deflate(int_t p)
	{
		return PortableThreads::LockFree::Private::pt_deflate_pointer<U>(p);
	}

	void testMuxerWithHighZerosPointerIsZero()
	{
		int_t p = 0;

		int_t highZeroBits = Muxer::freeBits();
		highZeroBits = (static_cast<int_t>(1) << highZeroBits) - 1;
		highZeroBits <<= Muxer::hardwarePointerBits();
		CPPUNIT_ASSERT_EQUAL(static_cast<int_t>(0), p & static_cast<int_t>(highZeroBits));

		// easy count
		int_t count = 0;

		int_t mux = Muxer::multiplex(p, count);
		CPPUNIT_ASSERT_EQUAL(p, Muxer::value(mux));
		CPPUNIT_ASSERT_EQUAL(count, Muxer::count(mux));

		// reasonable count
		count = 43884;
		mux = Muxer::multiplex(p, count);
		CPPUNIT_ASSERT_EQUAL(p, Muxer::value(mux));
		CPPUNIT_ASSERT_EQUAL(count, Muxer::count(mux));

		// too big a count, must be adapted
		count = static_cast<int_t>(1) << Muxer::freeBits();
		mux = Muxer::multiplex(p, count);
		CPPUNIT_ASSERT_EQUAL(p, Muxer::value(mux));

		// cannot assert that count is still ok, because it will get changed to
		// fit into bits
	}
	void testMuxerWithHighZerosPointer()
	{
		int_t p = static_cast<int_t>(1) << 30;
		p ^= static_cast<int_t>(23939) << 3;
		/*
		cout << bits<int_t>(0) << endl;
		cout << bits<int_t>(1) << endl;
		cout << bits<int_t>(static_cast<int_t>(1) << 1) << endl;
		cout << bits<int_t>(static_cast<int_t>(1) << 2) << endl;
		*/
		CPPUNIT_ASSERT(p < (static_cast<int_t>(1) << Muxer::hardwarePointerBits()));

		int_t highZeroBits = Muxer::freeBits();
		highZeroBits = (static_cast<int_t>(1) << highZeroBits) - 1;
		highZeroBits <<= Muxer::hardwarePointerBits();
		CPPUNIT_ASSERT_EQUAL(static_cast<int_t>(0), p & static_cast<int_t>(highZeroBits));

		// easy count
		int_t count = 0;
		int_t mux = Muxer::multiplex(p, count);
		int_t demuxp = Muxer::value(mux);
		int_t demuxcount = Muxer::count(mux);
		CPPUNIT_ASSERT_EQUAL(p, demuxp);
		CPPUNIT_ASSERT_EQUAL(count, demuxcount);

		// reasonable count
		count = 43884;
		mux = Muxer::multiplex(p, count);
		CPPUNIT_ASSERT_EQUAL(p, Muxer::value(mux));
		CPPUNIT_ASSERT_EQUAL(count, Muxer::count(mux));

		// too big a count, must be adapted
		count = static_cast<int_t>(1) << Muxer::freeBits();
		mux = Muxer::multiplex(p, count);
		CPPUNIT_ASSERT_EQUAL(p, Muxer::value(mux));
	}
	void testMuxerWithHighOnesPointer()
	{
		int_t p = static_cast<int_t>(1) << 30;
		p ^= static_cast<int_t>(23939) << 3;
		/*
		cout << bits<int_t>(0) << endl;
		cout << bits<int_t>(1) << endl;
		cout << bits<int_t>(static_cast<int_t>(1) << 1) << endl;
		cout << bits<int_t>(static_cast<int_t>(1) << 2) << endl;
		*/
		CPPUNIT_ASSERT(p < (static_cast<int_t>(1) << Muxer::hardwarePointerBits()));

		int_t highZeroBits = Muxer::freeBits();
		highZeroBits = (static_cast<int_t>(1) << highZeroBits) - 1;
		highZeroBits <<= Muxer::hardwarePointerBits();

		p ^= highZeroBits;

		CPPUNIT_ASSERT_EQUAL(highZeroBits, p & static_cast<int_t>(highZeroBits));

		// easy count
		int_t count = 0;
		int_t mux = Muxer::multiplex(p, count);
		int_t demuxp = Muxer::value(mux);
		CPPUNIT_ASSERT_EQUAL(p, demuxp);
		CPPUNIT_ASSERT_EQUAL(count, Muxer::count(mux));

		// reasonable count
		count = 43884;
		mux = Muxer::multiplex(p, count);
		CPPUNIT_ASSERT_EQUAL(p, Muxer::value(mux));
		CPPUNIT_ASSERT_EQUAL(count, Muxer::count(mux));

		// too big a count, must be adapted
		count = static_cast<int_t>(1) << Muxer::freeBits();
		mux = Muxer::multiplex(p, count);
		CPPUNIT_ASSERT_EQUAL(p, Muxer::value(mux));
	}
	void testPointerCompression()
	{
		// NOTE: This method will execute differently depending 
		// on the size of a POINTER of the target platform!

		void* p = 0;
		int64 compressed = inflate(p);
		CPPUNIT_ASSERT_EQUAL(p, deflate<void>(compressed));

		p = reinterpret_cast<void*>(0xfffffff0);
		compressed = inflate(p);
		CPPUNIT_ASSERT_EQUAL(p, deflate<void>(compressed));

		int_t highZeroBits = Muxer::freeBits();
		highZeroBits = (static_cast<int_t>(1) << highZeroBits) - 1;
		highZeroBits <<= Muxer::hardwarePointerBits();

		p = reinterpret_cast<void*>(highZeroBits);
		compressed = inflate(p);
		CPPUNIT_ASSERT_EQUAL(p, deflate<void>(compressed));

		p = reinterpret_cast<void*>(highZeroBits ^ static_cast<int_t>(0xfffffff0));
		compressed = inflate(p);
		CPPUNIT_ASSERT_EQUAL(p, deflate<void>(compressed));
	}
	void testNativePointerCAS()
	{
		PTPointerCAS x(0);
		void* p = 0;
		
		token_t t;
		
		t = x.get();
		x.cas(reinterpret_cast<PTPointerCAS::int_t>(p), t);
		CPPUNIT_ASSERT_EQUAL(p, reinterpret_cast<void*>(x.get().value()));

		p = reinterpret_cast<void*>(0xffff3430);
		t = x.get();
		x.cas(reinterpret_cast<PTPointerCAS::int_t>(p), t);
		CPPUNIT_ASSERT_EQUAL(p, reinterpret_cast<void*>(x.get().value()));
	}
	CPPUNIT_TEST_SUITE( BitTwiddlingTest );
		CPPUNIT_TEST( testMuxerWithHighZerosPointerIsZero );
		CPPUNIT_TEST( testMuxerWithHighZerosPointer );
		CPPUNIT_TEST( testMuxerWithHighOnesPointer );
		CPPUNIT_TEST( testPointerCompression );
		CPPUNIT_TEST( testNativePointerCAS );
	CPPUNIT_TEST_SUITE_END();
};

typedef BitTwiddlingTest< ::PortableThreads::int64, 64, 48, 3> AMDTest;
typedef BitTwiddlingTest< ::PortableThreads::int64, 64, 44, 3> SunTest;

CPPUNIT_TEST_SUITE_REGISTRATION( AMDTest );
CPPUNIT_TEST_SUITE_REGISTRATION( SunTest ); 



