/******************************************************************************/
/*									      */
/*	ctk_tracker.cpp	                  	      	      	              */
/*									      */
/*	   tracks frame colourings using one or more parameters               */
/*									      */
/*	Author: Martin Cooke, Sheffield University	                      */
/*									      */
/*     	CTK VERSION 1.3.5  Apr 22, 2007		      	      */
/*						 			      */
/******************************************************************************/
 
#include "ctk-config.h"

#include <cmath>
#include <iomanip>
#include <vector>
#include <map>
#include <algorithm>
 
#include "ctk_local.hh"
#include "ctk_error.hh"

#include "ctk_dsp.hh"
#include "ctk_auditory.hh"
#include "ctk_socket.hh"
#include "ctk_param.hh"
#include "ctk_data_descriptor.hh"

#include "ctk_tracker.hh"

/******************************************************************************/
/*                                                                            */
/*       CLASS NAME: TrackerBlock                                             */
/*                                                                            */
/******************************************************************************/

const string TrackerBlock::help_text = "Tracks colourings";
const string TrackerBlock::type_name = "Tracker";

TrackerBlock::TrackerBlock(const string &a_name):CTKObject(a_name),Block(a_name, type_name) {
 
  make_input_sockets(2);
  make_output_sockets(1);
  input_sockets->set_description("in1", "colouring");
  input_sockets->set_description("in2", "data");
  output_sockets->set_description("out1", "tracked colouring");

  // Set up CONT parameter
  CONT_param= new ParamFloat("Continuity_threshold", PARAM_DEFAULT_CONT);
  CONT_param->set_helptext("Ends track if no continuation within this amount");  
  CONT_param->install_validator(new Validator(Validator::VLOWER, 0.0));
  parameters->register_parameter(CONT_param);
}
 
Block* TrackerBlock::clone(const string &n) const{
  Block *ablock = new TrackerBlock(n.empty()?getname():n);
  return copy_this_block_to(ablock);
}

void TrackerBlock::reset(){
  Block::reset();

  // Check validity of input data
  if (inputs_are_all_sample_data) {
    cerr << "Error in network at block: " << getname() << endl;
    cerr << "Attempting to apply Tracker to sample data. " << endl;
    cerr << "Tracker can only be applied to frame data." << endl;
    throw(CTKError(__FILE__, __LINE__));
  }

  cont = CONT_param->get_value();

  vector_size = (*input_sockets)[0]->get_data_descriptor()->get_storage();
  // later, check for size consistency etc

  first_frame = true;
  nextcolour=0;

}


void dumpmap(Tmap m) {
  cerr << "Map has " << m.size() << "elements" << endl;
  for (Titerator i=m.begin(); i != m.end(); ++i)
    cerr << (*i).first << " " << endl;
} 
  

void TrackerBlock::compute() {
  
  CTKVector *colouring;
  (*input_sockets)[0]->get_vector(colouring);
  CTKVector *newcolouring = new CTKVector(*colouring);

  CTKVector *data;
  (*input_sockets)[1]->get_vector(data);

  Tmap m;
  vector<int> i2c;
  int ind=0;
  for (unsigned int i=0; i < colouring->size(); i++) {
    int colour = (int)(*colouring)[i];
    if (colour > 0) { // zero encoded background which is not tracked
      if (m.count(colour) == 0) { // first time for this colour
	m[colour].sum =(*data)[i];
	m[colour].index=ind++;
	m[colour].count=1;
	i2c.push_back(colour);
      } else {      
	m[colour].sum += (*data)[i];
	m[colour].count++;
      }
    }
  }

  CTKVector mean(i2c.size());
  CTKVector weight_this(i2c.size());
  vector<bool> contin(i2c.size());
  for (unsigned int i=0; i < i2c.size(); i++) {
    int colour = i2c[i];
    mean[i] = m[colour].sum/m[colour].count;
    weight_this[i] = m[colour].count;
    contin[i]=false;
  }
  
  int numtracks_this = mean.size();  

   
  if (!first_frame) {

    vector< vector<float> > dist;
    dist.resize(numtracks_this);
    for (int i=0; i < numtracks_this; i++) {
      dist[i].resize(numtracks_last);
      for (int j=0; j < numtracks_last; j++) {
	float d = fabs(mean[i]-mean_last[j]);
	if (d < cont) {
	  float wr = max(weight_this[i]/weight_last[j],weight_last[j]/weight_this[i]);
	  dist[i][j] = (weight_this[i]+weight_last[j])/(wr*(d+1.0));
	} else
	  dist[i][j] = 0.0;
      }
    }

//    for (int i=0; i < numtracks_this; i++) {
//      for (int j=0; j < numtracks_last; j++) {
//	cerr << (int) dist[i][j] << " ";
//      }
//      cerr << endl;
//    }
//    cerr << endl << endl ;;


    // find assignments, best first
    float best;
    int besti, bestj;
    do {
      best = -1.0;
      besti = 0; bestj = 0;
      for (int i=0; i < numtracks_this; i++) {
	for (int j=0; j < numtracks_last; j++) {
	  if (dist[i][j] > best) {
	    best = dist[i][j];
	    besti = i;
	    bestj = j;
	  }
	}
      }

      if (best > 0.000001) {
	// assign besti as extension of bestj
	// cerr << "extension of col " << i2c_last[bestj] << " ";
	replace(newcolouring->begin(),newcolouring->end(),i2c[besti],i2c_last[bestj]);
	i2c[besti]=i2c_last[bestj];
	// zero row besti and column bestj
	for (int j=0; j < numtracks_last; j++) dist[besti][j] = 0.0;
	for (int i=0; i < numtracks_this; i++) dist[i][bestj] = 0.0;
	// record what has happened to track
	contin[besti] = true;
      }
    } while (best > 0.000001);
	  
    // now start new tracks (new colours) for all i for which contin[i] is false
    for (int i=0; i < numtracks_this; i++) {
      if (!contin[i]) {
	replace(newcolouring->begin(),newcolouring->end(),i2c[i],nextcolour);
	//cerr << "new colour for track" << i2c[i-1] << " is " << nextcolour << endl;
	i2c[i]=nextcolour++;
      }
    }
    



  } else {
    first_frame = false;
    // determine when to start next colour
    // ensure that nextcolour for a group increases
    nextcolour = (int)(*max_element(i2c.begin(),i2c.end()))+1;

  }

  numtracks_last = numtracks_this;
  weight_last = weight_this;
  i2c_last = i2c;
  mean_last = mean;

  (*output_sockets)[0]->put_vector(newcolouring);
  delete data;
  delete colouring;
}


/* End of ctk_tracker.cpp */
 
