#include <iostream>
#include <cstdio>
#include <distribution/mpi/mpi_control.h>
#include <distribution/mpi/mpi_platypus_types.h>

namespace Platypus
{

  MPIControl::~MPIControl()
  {
    delete pa_;
    delete modelHandler_;
    delete exitHandler_;
    delete stateHandler_;
    delete serializer_;
  }
  
  MPIControl::MPIControl(const ProgramInterface& program,
			 MPIStatistics& statistics,
			 AnswerSetPrinterBase& printer, 
			 size_t requested_answersets,
			 bool suppressed)
    
    :   program_(&program)
    ,   serializer_(new MPISerializer(program))
    ,   statistics_(&statistics)
    ,	printer_(&printer)
    ,   stateHandler_(new MPIControlStateHandler())
    ,   exitHandler_(new MPIControlExitHandler())
    ,   modelHandler_(new MPIControlModelHandler())
    ,   pa_(new PartialAssignment(program))
    ,	requestedAnswerSets_(requested_answersets)
    ,	receivedAnswerSets_(0)
    ,   globalAnswersFound_(0)
    ,   suppressed_(suppressed)
    ,   requests_(0)
    ,	shutdown_(false)
    ,   workers_(0)
    ,   answersFoundByWorker_(0)
    ,   start_time_(0)
  {}
  
  void MPIControl::setup()
  {
    start_time_ = MPI::Wtime();
    workers_ = MPI::COMM_WORLD.Get_size() - 1;
    stateHandler_->setup(workers_);
    exitHandler_->setup(workers_);
    modelHandler_->setup(workers_, requestedAnswerSets_);

    //set the buffer size for the two message buffers
    size_t bufferSize = serializer_->bytesRequired();
    paBuffer_.resize(bufferSize);
    dcBuffer_.resize(bufferSize);
    
    //bootstrap by putting a choice into the queue
    DelegatableChoice* dc = new DelegatableChoice();
    choices_.push(dc);
    
  }

  void MPIControl::initiateTermination(){
    
    char dummy = 0;

    //debug timing stuff
    //double startInit = 0;
    //double messageSent = 0;
    //double finishInit = 0;

    //send terminate message to all workers
    //startInit = MPI::Wtime();
    //std::cout << "Termination initiation started by master at " << startInit - start_time_ << " seconds into run." << std::endl; 
    for(unsigned i = 0; i < workers_; i++)
      {
	int messageDest = i+1;
	MPI::COMM_WORLD.Send(&dummy, 
			     1, 
			     MPI::CHAR, 
			     messageDest, 
			     TERMINATE_TAG);

	//messageSent = MPI::Wtime();
	//std::cout << "MASTER SENT TERMINATE_TAG TO " << messageDest << " " << messageSent - startInit << " seconds after termination initiation began." << std::endl;
	statistics_->incMessagesSent(MASTER);
      }
    //finishInit = MPI::Wtime();
    //std::cout << "Termination initiation completed by master at " << finishInit - start_time_ << " seconds into run." << std::endl;
    //std::cout << "Total time for master to initiate termination: " << finishInit - startInit << std::endl;
  }
  
  void MPIControl::completeTermination(){
    
    unsigned terminationsReceived = 0;
    std::vector<unsigned long> buffer(NUM_STATS,0);

    //more debug timing stuff
    //double startComp = 0;
    //double termRecv = 0;
    //double finishComp = 0;

    //receive terminate confirmations from workers
    //startComp = MPI::Wtime();
    //std::cout << "completeTermination() started by master at " << startComp - start_time_ << " seconds into run." << std::endl;
    while(terminationsReceived < workers_)
      {

	//TRY THIS AS A BLOCKING PROBE !!!!!
	//non-blocking probe for the first available message
	if(MPI::COMM_WORLD.Iprobe(MPI::ANY_SOURCE, TERMINATE_CONFIRMATION_TAG, localControllerStatus_))
	  {
	    int source = localControllerStatus_.Get_source();
	    
	    MPI::COMM_WORLD.Recv(&buffer[0], 
				 NUM_STATS, 
				 MPI::UNSIGNED_LONG, 
				 source,
				 TERMINATE_CONFIRMATION_TAG);

	    //termRecv = MPI::Wtime();
	    //std::cout << "MASTER RECEIVED TERMINATE_CONFIRMATION_TAG FROM " << source << " " << termRecv - startComp <<  " seconds after termination completion began." << std::endl;
	    statistics_->incMessagesReceived(MASTER);
	    
	    //gather stats from individual workers
	    statistics_->incExpanderInits(source, buffer[EXPANDER_INITS]);
	    statistics_->incConflicts(source, buffer[CONFLICTS]);
	    statistics_->incModels(source, buffer[MODELS]);
	    statistics_->incBacktracks(source, buffer[BACKTRACKS]);
	    statistics_->incThreadDelegations(source, buffer[DELEGATIONS]);
	    statistics_->incMessagesSent(source, buffer[MESSAGES_SENT]);
	    statistics_->incMessagesReceived(source, buffer[MESSAGES_RECEIVED]);
	    statistics_->incDroppedRequests(source, buffer[DROPPED_REQUESTS]);

	    terminationsReceived++;
	  }
      }
    //finishComp = MPI::Wtime();
    //std::cout << "Termination completion completed by master at " << finishComp - start_time_ << " seconds into run." << std::endl;
    //std::cout << "Total time required for completeTermination(): " << finishComp - startComp << std::endl;
  }

  void MPIControl::requestDCFromDelegatedWorkers(unsigned source)
  {

    //more debug timing crap
    //double startRequest = 0;
    //double finishRequest = 0;

    //startRequest = MPI::Wtime();
    //std::cout << "requestDCFromDelegatedWorkers() started " << startRequest - start_time_ << " after run began." << std::endl;

    vector<unsigned>* states = stateHandler_->getStates();
    for(unsigned i = 0; i < (*states).size(); i++)
      {
	if((*states)[i] == DELEGATED)
	  {
	    //the mpi id is one more than the index into the vector
	    unsigned worker = i+1;
	    char dummy = 0;
	    localControllerRequest_ = MPI::COMM_WORLD.Isend(&dummy,
							    1,
							    MPI::CHAR,
							    worker,
							    DC_NEEDED_TAG);

	    
	    while(!localControllerRequest_.Test())
	      {
		if(MPI::COMM_WORLD.Iprobe(worker, DC_REQUEST_TAG))
		  {
		    //index into data structure is one less than mpi id
		    stateHandler_->setState(worker-1, FILED);
		    localControllerRequest_.Cancel();
		  }
		if(!suppressed_)
		  {
		    MPI::Status localStatus;
		    if(MPI::COMM_WORLD.Iprobe(worker, ANSWER_MESSAGE_TAG, localStatus))
		      {
			unsigned count = localStatus.Get_count(MPI::UNSIGNED_LONG);
			handleAnswerSet(worker, count);
		      }
		  }
	      }
	    statistics_->incWorkRequestsFromMaster(worker);
	    statistics_->incMessagesSent(MASTER);
	    
	  }
      }
    //finishRequest = MPI::Wtime();
    //std::cout << "Request for DC from delegated workers ended " << finishRequest - start_time_ << " after run began." << std::endl;
    //std::cout << "Total time for requestDCFromDelegatedWorkers(): " << finishRequest - startRequest << std::endl;

  }

  void MPIControl::handleAnswerSet(unsigned source, unsigned count)
  {

    if(count != paBuffer_.size())
      paBuffer_.resize(count);

    MPI::COMM_WORLD.Recv(&paBuffer_[0], 
			 paBuffer_.size(), 
			 MPI::UNSIGNED_LONG, 
			 source, 
			 ANSWER_MESSAGE_TAG);
    
    
    statistics_->incMessagesReceived(MASTER);
    statistics_->incAnswers();
    modelHandler_->incLocalModels();
    
    if(!suppressed_)
      {
	serializer_->deserialize(*pa_, paBuffer_);
	print(*pa_);
      }
  }

  void MPIControl::handleDelegatableChoice(unsigned source, unsigned count)
  {
    
    //unsigned index = source - 1;
    //stateHandler_->setState(index, DELEGATED);          
    
    dcBuffer_.resize(count);
    //std::cout << "count in handleDelegatableChoice(unsigned source, unsigned count) from worker " << source << ": " << count << std::endl;

    dcBuffer_.resize(count);
    MPI::COMM_WORLD.Recv(&dcBuffer_[0],                                                                                   
			 //dcBuffer_.size(),
			 count,
			 MPI::UNSIGNED_LONG,                                                                              
			 source,                                                                                          
			 DC_TAG_FROM_SLAVE);                                                                              
    statistics_->incMessagesReceived(MASTER);                                                                             
    statistics_->incWorkDelegationsToMaster(source);                                                                      
        
    if(filedWorkers_.size())                                                                                              
      {                                                                                                                   
	unsigned workerCandidate = filedWorkers_.front();                                                                 
	MPI::COMM_WORLD.Send(&dcBuffer_[0],                                                                               
			     dcBuffer_.size(),                                                                            
			     MPI::UNSIGNED_LONG,                                                                          
			     workerCandidate,                                                                             
			     DC_TAG_FROM_CONTROLLER);                                                                     
	statistics_->incMessagesSent(MASTER);                                                                             
	filedWorkers_.pop();                                                                                              
      
	//index into data structure is one less than mpi id                                                               
	stateHandler_->setState(workerCandidate-1, DELEGATED);                                                            
	exitHandler_->decRequests();                                                                                      
      }                                                                                                                   
    else                                                                                                                  
      {                                                                                                                   
	DelegatableChoice* newDc = new DelegatableChoice();                                                               
	serializer_->deserialize(*newDc, dcBuffer_);                                                                      
	choices_.push(newDc);                                                                                             
	statistics_->checkMaxQueueSize(choices_.size());                                                                  
      }                                              
                                                                     
  }
  
  void MPIControl::handleDelegatableChoiceRequest(unsigned source)
  {

    unsigned index = source - 1;//index into state manager is one less than mpi id
    answersFoundByWorker_ = 0;                                                                                                           

    MPI::COMM_WORLD.Recv(&answersFoundByWorker_,                                                                                         
			 1 ,                                                                                                             
			 MPI::UNSIGNED_LONG,                                                                                             
			 source,                                                                                                         
			 DC_REQUEST_TAG);                                                                                                

    statistics_->incMessagesReceived(MASTER);                                                                                            
    statistics_->incWorkRequestsToMaster(source);                                                                                        

    //update the global number of answer sets found                                                                                      
    modelHandler_->setGlobalModels(index, answersFoundByWorker_);                                                                        

    if(choices_.size())                                                                                                                  
      {                                                                                                                                  
	//handle the bootstrap case of an empty
	stateHandler_->setState(index, DELEGATED);                                                                                       

	DelegatableChoice * sendDc = choices_.front();                                                                                   
	serializer_->serialize(dcBuffer_, *sendDc);                                                                                      
	MPI::COMM_WORLD.Send(&dcBuffer_[0],                                                                                              
			     dcBuffer_.size(),                                                                                           
			     MPI::UNSIGNED_LONG,                                                                                         
			     source,                                                                                                     
			     DC_TAG_FROM_CONTROLLER);                                                                                    

	statistics_->incMessagesSent(MASTER);                                                                                            
	statistics_->incWorkDelegationsFromMaster(source);                                                                               
	delete sendDc;                                                                                                                   
	choices_.pop();                                                                                                                  
      }                                                                                                                                  
    else                                                                                                                                 
      {                                                                                                                                  
	stateHandler_->setState(index, FILED);                                                                                       
	exitHandler_->incRequests();      
	statistics_->incWorkDenials(source);                                                                                             
	filedWorkers_.push(source);                                                                                                      
	statistics_->checkMaxFiledWorkers(filedWorkers_.size());                                                                         
       
	requestDCFromDelegatedWorkers(source);                                                                                           
      }
  }

  /*
   * Algorithm:
   *
   * WHILE requests != number of workers AND requested number of answers not reached
   *       IF some message from worker to master exists
   *          IF message contains choice from worker
   *             receive choice message
   *             queue message locally
   *             set flag for worker who sent message to delegated
   *          ELSE IF message contains work request from worker
   *             receive choice request
   *             IF local queue is not empty
   *                decrement requests counter
   *                set worker flag to delegated
   *                remove choice from front of queue and send to worker
   *
   *             ELSE
   *                increment requests counter
   *                set worker flag to filed
   *                send worker message that request for choice was denied
   *             END IF
   *          ELSE IF message from worker contains an answer
   *             receive answer message
   *             increment answer counter
   *             IF answer printing not suppressed
   *                print answer
   *             END IF
   *          END IF
   *       END IF
   *
   * END WHILE
   */
  void MPIControl::start()
  {
    while(!stop())
      {

	//double finishProbe = 0;
	//double dealWithProbe = 0;

	assert(requests_ >= 0);
	
	//if(MPI::COMM_WORLD.Iprobe(MPI::ANY_SOURCE, MPI::ANY_TAG, localControllerStatus_))
	MPI::COMM_WORLD.Probe(MPI::ANY_SOURCE, MPI::ANY_TAG, localControllerStatus_);  
	//{
	//finishProbe = MPI::Wtime();
	//std::cout << "<<<<<: " << finishProbe - start_time_ << " seconds into run since master received a message." << std::endl;
	
	size_t source = localControllerStatus_.Get_source();
	size_t tag = localControllerStatus_.Get_tag();
	size_t count = localControllerStatus_.Get_count(MPI::UNSIGNED_LONG);
	
	if(tag == DC_TAG_FROM_SLAVE)
	  {
	    //std::cout << "Master received DC_TAG_FROM_SLAVE from: " << source << std::endl;
	    handleDelegatableChoice(source, count);
	  }	    
	else if(tag == DC_REQUEST_TAG)
	  {
	    //std::cout << "Master received DC_REQUEST_TAG from: " << source << std::endl;
	    handleDelegatableChoiceRequest(source);
	  }
	else if(tag == ANSWER_MESSAGE_TAG)
	  {
	    //std::cout << "Master received ANSWER_MESSAGE_TAG from: " << source << std::endl;
	    handleAnswerSet(source, count);
	  }
	//dealWithProbe = MPI::Wtime();
	//std::cout << "<<<<< Total time required by master to deal with incoming message: " << dealWithProbe - finishProbe << std::endl;
	//std::cout << "Total time required including time to receive message: " << dealWithProbe - startProbe << std::endl;
	//}
      }
    
    //send terminate messages to workers
    //initiateTermination();
    //wait for worker termination confirmations
    //completeTermination();
    //clean up stray messages
    //cleanup();
    //double startFinish = 0;
    //double endFinish = 0;

    //startFinish = MPI::Wtime();
    //std::cout << "Master called finish() " << startFinish - start_time_ << " seconds into run." << std::endl;
    
    finish();
    //endFinish = MPI::Wtime();
    //std::cout << "Master completed finish() " << endFinish - start_time_ << " seconds into run." << std::endl;
    //std::cout << "Total time for finish(): " << endFinish - startFinish << std::endl;

  }

  void MPIControl::finish()
  {
    //send terminate messages to workers
    initiateTermination();
    //wait for worker termination confirmations 
    completeTermination();
    //clean up stray messages 
    cleanup();    
  }

  bool MPIControl::stop()
  {
    return (enoughReceived() || exitHandler_->exit());
  }

  bool MPIControl::enoughReceived() const
  {
    return modelHandler_->enough();
  }

  void MPIControl::print(PartialAssignment& pa)
  {
    printer_->print(modelHandler_->localModels(), pa);
  }

  void MPIControl::cleanup() 
  {
    
    //double startClean = 0;
    //double endClean = 0;

    //startClean = MPI::Wtime();
    //std::cout << "Master called cleanup() " << startClean - start_time_ << " seconds into run." << std::endl;

    if(!enoughReceived() /*&& !suppressed_*/)
      {
	while(receivedAnswerSets_ != globalAnswersFound_)
	  {
	    if(MPI::COMM_WORLD.Iprobe(MPI::ANY_SOURCE, ANSWER_MESSAGE_TAG, localControllerStatus_))
	      {
		unsigned source = localControllerStatus_.Get_source();
		unsigned tag = localControllerStatus_.Get_tag();

		if(tag == ANSWER_MESSAGE_TAG)
		  {
		    MPI::COMM_WORLD.Recv(&paBuffer_[0],
					 paBuffer_.size(),
					 MPI::UNSIGNED_LONG,
					 source,
					 ANSWER_MESSAGE_TAG);
		    statistics_->incMessagesReceived(MASTER);
		    //modelHandler_->incLocalModels();
		    
		    if(!suppressed_)
		      {
			serializer_->deserialize(*pa_, paBuffer_);
			print(*pa_);
		      }         
		  }
		else if(tag == DC_TAG_FROM_SLAVE)
		  {
		    MPI::COMM_WORLD.Recv(&dcBuffer_[0],
					 dcBuffer_.size(),
					 MPI::UNSIGNED_LONG,
					 source,
					 DC_TAG_FROM_SLAVE);
		    statistics_->incMessagesReceived(MASTER);
		  }
		else if(tag == DC_REQUEST_TAG)
		  {
		    char dummy = 0;
		    MPI::COMM_WORLD.Recv(&dummy,
					 1,
					 MPI::CHAR,
					 source,
					 DC_REQUEST_TAG);
		    statistics_->incMessagesReceived(MASTER);
		  }
	      }
	  }
      }
    //endClean = MPI::Wtime();
    //std::cout << "Master finished cleanup() " << endClean - start_time_ << " seconds into run." << std::endl;
    //std::cout << "Totla time for cleanup(): " << endClean - startClean << std:: endl;

  }    

}
