#include <vector>
#include <cppunit/TestCase.h>
#include <cppunit/TestCaller.h>
#include <cppunit/TestResult.h>
#include <cppunit/TestSuite.h>
#include <cppunit/extensions/HelperMacros.h>
#include <portablethreads/thread.h>
#include <portablethreads/time.h>
#include <portablethreads/condition.h>



using namespace std;
using namespace PortableThreads;
using namespace PortableThreads::LockFree;

namespace
{
	class ThreadSilentBase : public PThread
	{
		void unexpectedException() throw()
		{}
	};
	class ThreadBase : public ThreadSilentBase
	{
		void unexpectedException() throw()
		{
			CPPUNIT_ASSERT(false && "no exception excepted");
		}
	};
	class Thread0 : public ThreadBase
	{
	public:
		Thread0(volatile bool& flag, PTCondition& cond)
			:	condition_(&cond)
			,	flag_(&flag)
			,	passed_(false)
		{}
		bool passed() const { return passed_; }
	private:
		void threadMain()
		{
			while(!*flag_)
				give();
			condition_->wait();
			passed_ = true;
		}
	private:
		PTCondition* condition_;
		volatile bool* flag_;
		volatile bool passed_;
	};

	class Thread1 : public ThreadBase
	{
	public:
		Thread1(PTCondition& cond)
			:	condition_(&cond)
		{}
		
	private:
		void threadMain()
		{
			pt_milli_sleep(300);
			condition_->signal();
		}
	private:
		PTCondition* condition_;
	};

	class Thread2 : public ThreadBase
	{
	public:
		Thread2(volatile bool& flag, PTCondition& cond)
			:	condition_(&cond)
			,	flag_(&flag)
			,	passed1_(false)
			,	passed2_(false)
		{}
		bool passed1() const { return passed1_; }
		bool passed2() const { return passed2_; }
	private:
		void threadMain()
		{
			while(!*flag_)
				give();
			condition_->wait();
			passed1_ = true;
			while(*flag_)
				give();
			condition_->wait();
			passed2_ = true;
		}
	private:
		PTCondition* condition_;
		volatile bool* flag_;
		volatile bool passed1_;
		volatile bool passed2_;
	};

}

class ConditionTest : public CppUnit::TestFixture
{
public:
	void testImpossible()
	{
		try
		{
			PTCondition c(true, true);
			CPPUNIT_ASSERT(false && "Assumption: Autoresetting and broadcasted do not work together!");
		}
		catch(PTParameterError&)
		{
			return;
		}
		CPPUNIT_ASSERT(false && "Wrong exception thrown!");
	}
	void testDryTimedWait()
	{
		PTCondition c1, c2(false, true);
		CPPUNIT_ASSERT_EQUAL(false, c1.wait(0));
		CPPUNIT_ASSERT_EQUAL(false, c1.wait(0, 10));

		CPPUNIT_ASSERT_EQUAL(true, c2.wait(0));
		CPPUNIT_ASSERT_EQUAL(true, c2.wait(0, 10));

		c2.reset();
		CPPUNIT_ASSERT_EQUAL(false, c2.wait(0));
		CPPUNIT_ASSERT_EQUAL(false, c2.wait(0, 10));

	}
	void testTimedWait()
	{
		PTCondition c;
		Thread1 t(c);

		t.run();

		CPPUNIT_ASSERT_EQUAL(false, c.wait(0, 50));
		CPPUNIT_ASSERT_EQUAL(true, c.wait(1));

		t.join();

		CPPUNIT_ASSERT_EQUAL(false, c.wait(0, 50));
	}
	void testBroadcastAutoreset()
	{
		volatile bool flag = false;
		const unsigned t = 8;
		PTCondition c;

		vector<Thread0*> threads(t);
		for(unsigned i = 0; i < t; ++i)
		{
			threads[i] = new Thread0(flag, c);
			threads[i]->run();
		}

		// no one started yet
		for(unsigned i = 0; i < t; ++i)
		{
			CPPUNIT_ASSERT_EQUAL(false, threads[i]->passed());
		}

		// let them go..
		flag = true;

		pt_milli_sleep(100);

		// no one should be past the condition
		for(unsigned i = 0; i < t; ++i)
		{
			CPPUNIT_ASSERT_EQUAL(false, threads[i]->passed());
		}

		c.signal();
		
		for(unsigned i = 0; i < t; ++i)
		{
			threads[i]->join();
			CPPUNIT_ASSERT_EQUAL(true, threads[i]->passed());
			delete threads[i];
		}

	}
	void testBroadcastManualreset()
	{
		volatile bool flag = false;
		const unsigned t = 8;
		PTCondition c(false, true);

		vector<Thread2*> threads(t);
		for(unsigned i = 0; i < t; ++i)
		{
			threads[i] = new Thread2(flag, c);
			threads[i]->run();
		}

		// no one started yet
		for(unsigned i = 0; i < t; ++i)
		{
			CPPUNIT_ASSERT_EQUAL(false, threads[i]->passed1());
		}

		// let them go..
		flag = true;

		
		for(unsigned passed = 0; passed != t;)
		{
			passed = 0;
			for(unsigned i = 0; i < t; ++i)
			{
				passed += threads[i]->passed1();
			}
		}
		

		c.reset();

		// let them go..
		flag = false;

		for(unsigned i = 0; i < t; ++i)
		{
			CPPUNIT_ASSERT_EQUAL(false, threads[i]->passed2());
		}
		
		c.signal();

		for(unsigned i = 0; i < t; ++i)
		{
			threads[i]->join();
			CPPUNIT_ASSERT_EQUAL(true, threads[i]->passed2());
			delete threads[i];
		}

	}
	
	CPPUNIT_TEST_SUITE( ConditionTest );
		CPPUNIT_TEST( testImpossible );
		CPPUNIT_TEST( testDryTimedWait );
		CPPUNIT_TEST( testTimedWait );
		CPPUNIT_TEST( testBroadcastAutoreset );
		CPPUNIT_TEST( testBroadcastManualreset );
	CPPUNIT_TEST_SUITE_END();
};


CPPUNIT_TEST_SUITE_REGISTRATION( ConditionTest );



