/******************************************************************************/
/*									      */
/*	ctk_HMM_MS_decoder.cpp	 	       			              */
/*									      */
/*	Block for MS HMM decoding  (see also - ctk_HMM_decoder.cpp)           */
/*									      */
/*	Author: Jon Barker, Sheffield University			      */
/*									      */
/*      CTK VERSION 1.3.5  Apr 22, 2007		         	      */
/*									      */
/******************************************************************************/
 
#include "ctk-config.h"

#include <cmath>
#include <vector>
#include <algorithm>
#include <numeric>

#include "ctk_local.hh"
#include "ctk_error.hh"

#include "ctk_function_classes.hh"
#include "ctk_param.hh"
#include "ctk_socket.hh"
#include "ctk_data_descriptor.hh"

#include "ctk_reco.hh"
#include "ctk_decoder.hh"
#include "ctk_HMM_decoder.hh"
#include "ctk_HMM_MS_decoder.hh"



/******************************************************************************/
/*                                                                            */
/*       CLASS NAME: BaseMS_HMMDecoderBlock                                   */
/*                                                                            */
/******************************************************************************/

static const char *VALID_NORMALISE_MODES[] = {"MODE_1", "MODE_2", "\0"};
static const char *VALID_PROB_CALCS[] = {"MULTISOURCE", "MISSING_DATA", "\0"};
static const char *PARAM_DEFAULT_NORMALISE_MODE       = "MODE_1";
static const char *PARAM_DEFAULT_PROB_CALC       = "MULTISOURCE";
const Float   PARAM_DEFAULT_ONE_ZERO_ROUNDING        = 0.0;      // round mask to 1 or 0 if within this tolerance
const Float   PARAM_DEFAULT_MD_WEIGHT  		    = 1.0;      // Multiplier applied to the MD bounds component

BaseMS_HMMDecoderBlock::BaseMS_HMMDecoderBlock(const string &a_name, const string &a_type):CTKObject(a_name),BaseMD_HMMDecoderBlock(a_name, a_type) {

  use_soft_multisource=false;
  
  unset_parameter_hidden("NOISE_GRAMMAR_FILE"); 
  unset_parameter_hidden("NOISE_HMM_FILE"); 

  unset_parameter_hidden("DISPLAY_GROUPS"); 
  unset_parameter_hidden("MASK_OUTPUT_FILENAME"); 
  unset_parameter_hidden("VERBOSITY_LEVEL");

  // USE_BOUNDS switch current has no effect for multisource, so it is hidden
  set_parameter_hidden("USE_BOUNDS"); 

  // Set up MD_WEIGHT parameter
  MD_WEIGHT_param = new ParamFloat("MD_WEIGHT", PARAM_DEFAULT_MD_WEIGHT);
  MD_WEIGHT_param->set_helptext("This is a tuning parameter that sets the balance between groups being accepted/rejected from the mask (must be greater than 0.0).");
  MD_WEIGHT_param->install_validator(new Validator(Validator::VLOWER, numeric_limits<float>::epsilon()));
  parameters->register_parameter(MD_WEIGHT_param);

  // Set up normalisation mode parameter
  NORMALISE_MODE_param= new ParamEnumerated("NORMALISE_MODE", VALID_NORMALISE_MODES, PARAM_DEFAULT_NORMALISE_MODE);
  NORMALISE_MODE_param->set_helptext("The normalisation mode - see documentation for details.");
  parameters->register_parameter(NORMALISE_MODE_param);

  // Set up delta mode parameter
  PROB_CALC_param= new ParamEnumerated("PROB_CALC", VALID_PROB_CALCS, PARAM_DEFAULT_PROB_CALC);
  PROB_CALC_param->set_helptext("Probabilty computation to employ - see documentation for details.");
  parameters->register_parameter(PROB_CALC_param);
  
  // Set up the group file name parameter
  GROUP_FILE_NAME_param= new ParamString("GROUP_FILE_NAME");
  GROUP_FILE_NAME_param->set_helptext("Name of the file containing the group summary data.");
  parameters->register_parameter(GROUP_FILE_NAME_param);

  // Set up temporal window size parameter
  TEMPORAL_WINDOW_SIZE_param= new ParamInt("TEMPORAL_WINDOW_SIZE", 0);
  TEMPORAL_WINDOW_SIZE_param->set_helptext("Length in frames of the temporal window used when applying sequential grouping constraints");
  TEMPORAL_WINDOW_SIZE_param->set_unit("frames");
  TEMPORAL_WINDOW_SIZE_param->install_validator(new Validator(Validator::VLOWER, 0.0));
  parameters->register_parameter(TEMPORAL_WINDOW_SIZE_param);

  SPEECH_GROUP_BIAS_param= new ParamFloat("SPEECH_GROUP_BIAS", 0.0);
  SPEECH_GROUP_BIAS_param->set_helptext("The bias that is given to include groups that are a priori labelled as speech (i.e. cheating)");
  parameters->register_parameter(SPEECH_GROUP_BIAS_param);
  
  // Inputs for Data, Mask and Bounds
  make_input_sockets(7);
  input_sockets->set_description("in1", "data");
  input_sockets->set_description("in2", "mask");
  input_sockets->set_description("in3", "group");
  input_sockets->set_description("in4", "lower bounds");
  input_sockets->set_optional("in4");
  input_sockets->set_description("in5", "upper bounds");
  input_sockets->set_description("in6", "groups starting");
  input_sockets->set_optional("in6");
  input_sockets->set_description("in7", "groups ending");
  input_sockets->set_optional("in7");
  
  // Output socket for state likelihoods
  make_output_sockets(1);
  output_sockets->set_description("out1", "num active segregation hypotheses");
}


void BaseMS_HMMDecoderBlock::build_output_data_descriptors() {
  
  // Make a default 'sample data' descriptor for the 'num active segregation hypotheses' output
  (*output_sockets)[0]->set_data_descriptor(new DataDescriptor());
}

void BaseMS_HMMDecoderBlock::reset() {

  BaseMD_HMMDecoderBlock::reset();

  groups.resize(0);
  group_label_to_number_map.clear();
  group_number_to_label_map.clear();
  groups_active.clear();

  group_summary_map.clear();
  
  if (GROUP_FILE_NAME_param->get_set_flag() ) 
    read_group_summary_file(GROUP_FILE_NAME_param->get_value(), group_summary_map);

  
  groups_active.push_back(false);  // Put a false in groups_active[0], as group numbers start at 1
  
  // The special 0 and -1 group labels retain their value
  group_label_to_number_map[0]=0;
  group_label_to_number_map[-1]=-1;
  group_number_to_label_map[0]=0;
  group_number_to_label_map[-1]=-1;

  // The "-1" group is the special "always include" group, must pass it onto the decoder without changing the label
  
  next_group_number = 1;

  // Check the inputs all have the same shape
  if (input_shape_check()==false) {
    Integer x1=(*input_sockets)[0]->get_data_descriptor()->get_storage();
    Integer x2=(*input_sockets)[1]->get_data_descriptor()->get_storage();
    Integer x3=(*input_sockets)[2]->get_data_descriptor()->get_storage();
    cerr << "BaseMD_HMMDecoderBlock:: Inputs have unequal widths (data=" << x1 << ", mask=" << x2 << ", groups=" << x3;
    if ((*input_sockets)[3]->connected())
      cerr << ", Lower Bounds=" << (*input_sockets)[3]->get_data_descriptor()->get_storage();
    cerr << ", Upper Bounds=" << (*input_sockets)[4]->get_data_descriptor()->get_storage();
    if ((*input_sockets)[5]->connected())
      cerr << ", Group starts=" << (*input_sockets)[5]->get_data_descriptor()->get_storage();
    if ((*input_sockets)[6]->connected())
      cerr << ", Group ends=" << (*input_sockets)[6]->get_data_descriptor()->get_storage();
    cerr << ")" << endl;
    throw(CTKError(__FILE__, __LINE__));
  }

  frame_no=0;

  mean_likelihood_scale_factor=1.0;
  group_hypothesis_merging_parameter=1.0;

  // Interpretation of MD_WEIGHT_param depends on the NORMALISE_MODE
  if (NORMALISE_MODE_param->get_value()==string("MODE_1"))
    mean_likelihood_scale_factor=MD_WEIGHT_param->get_value();
  else
    group_hypothesis_merging_parameter=MD_WEIGHT_param->get_value();

  speech_group_bias=1.0;
  if (SPEECH_GROUP_BIAS_param->get_set_flag())
    speech_group_bias=SPEECH_GROUP_BIAS_param->get_value();
}
 

void BaseMS_HMMDecoderBlock::compute() {

  CTKVector *data;
  CTKVector *mask;
  CTKVector *group_labels;
  CTKVector *lower_bounds=NULL;
  CTKVector *upper_bounds;
  CTKVector *group_starting_labels=NULL;
  CTKVector *group_ending_labels=NULL;
  
  (*input_sockets)[0]->get_vector(data);
  (*input_sockets)[1]->get_vector(mask);
  (*input_sockets)[2]->get_vector(group_labels);
  if ((*input_sockets)[3]->connected())
    (*input_sockets)[3]->get_vector(lower_bounds);
  (*input_sockets)[4]->get_vector(upper_bounds);
  if ((*input_sockets)[5]->connected())
    (*input_sockets)[5]->get_vector(group_starting_labels);
  if ((*input_sockets)[6]->connected())
    (*input_sockets)[6]->get_vector(group_ending_labels);

  Boolean verbose=(get_verbosity_level()>0);
  
  if (verbose) cerr << frame_no++ << ": ";
  
  // Find groups that are starting, ending or restarting groups  and updata group map
  vector<Integer> group_starting_numbers, group_ending_numbers;

  if ((*input_sockets)[5]->connected() && (*input_sockets)[6]->connected()) {
    // Start and end markers are supplied - but they need remapping before they can be used

    // First strip out any '0.0's  and remove any duplicates
    tidy_label_list(group_starting_labels);
    tidy_label_list(group_ending_labels);

    /*    
    if (group_starting_labels->size()>0) {
      cerr << "\n groups starting labels: ";
      for (int i=0; i<group_starting_labels->size(); ++i)
	cerr << (*group_starting_labels)[i] << " ";
      cerr << "\n";
    }

    if (group_ending_labels->size()>0) {
      cerr << "\n group ending labels: ";
      for (int i=0; i<group_ending_labels->size(); ++i)
	cerr << (*group_ending_labels)[i] << " ";
      cerr << "\n";
    }
    */
    
    // Update group label map 
    update_group_map(*group_starting_labels);

    // Remap start and end group labels
    map_group_labels(*group_starting_labels, group_starting_numbers);
    map_group_labels(*group_ending_labels, group_ending_numbers);
    
  } else {
    // Start and end markers are not supplied - find starts and ends labels
    start_end_detection(*group_labels, group_starting_numbers, group_ending_numbers, group_label_to_number_map, group_number_to_label_map);
  }

  // Start counters for ending groups
  for (vector<Integer>::iterator genp=group_ending_numbers.begin(), genp_end=group_ending_numbers.end(); genp!=genp_end; ++genp) {
    group_merge_counters.push_back( pair<int,int>(*genp, TEMPORAL_WINDOW_SIZE_param->get_value()));
  }

  // Map frame of group labels into frame of group numbers
  // The original group labels can be arbitrary integers, the group numbers are consecutive integers
  // starting at 1
  
  vector<Integer> *group_numbers = new vector<Integer>(0);

  map_group_labels(*group_labels, *group_numbers);

  // Update group statistics according to the incoming mask and group data. 
  update_group_stats(*group_numbers, *mask, groups);

  // For each group counter, merge group if counter reaches 0, else decrement counter
  list<pair<int, int> >::iterator gmcp=group_merge_counters.begin();
  list<pair<int, int> >::iterator gmcp_end=group_merge_counters.end();
  while (gmcp!=gmcp_end) {
    if (gmcp->second==0) {
      int group_number=gmcp->first;
      if (GROUP_FILE_NAME_param->get_set_flag() ) {
	// Look up stored group details if group detail file has been loaded
	int group_label=group_number_to_label_map[group_number];
	if (group_summary_map[group_label].source()==1)
	  group_hypothesis_merging_parameter=speech_group_bias;
	else
	  group_hypothesis_merging_parameter=1.0/speech_group_bias;
      }				  

      merge_group(group_number, group_hypothesis_merging_parameter);

      gmcp=group_merge_counters.erase(gmcp);
    } else {
      --(gmcp->second);
      ++gmcp;
    }
  }
  
  for (unsigned int i=0; i<group_starting_numbers.size(); ++i) {
    decoder->spawn_group_hypotheses(groups[group_starting_numbers[i]]);  
  }
  
  process_frame_multisource(*data, *mask, lower_bounds, upper_bounds, group_numbers, mean_likelihood_scale_factor, use_soft_multisource);

  delete data;
  delete mask;
  delete group_labels;
  delete lower_bounds;
  delete upper_bounds;
  delete group_starting_labels;
  delete group_ending_labels;
  
  delete group_numbers;
}


void BaseMS_HMMDecoderBlock::close() {

  // Merge all remaining active groups
  for (unsigned int group_number=0; group_number<groups_active.size(); ++group_number) {
    if (groups_active[group_number]) {
      if (GROUP_FILE_NAME_param->get_set_flag() ) {
	int group_label=group_number_to_label_map[group_number];
	if (group_summary_map[group_label].source()==1)
	  group_hypothesis_merging_parameter=speech_group_bias;
	else
	  group_hypothesis_merging_parameter=1.0/speech_group_bias;
      }
      merge_group(group_number, group_hypothesis_merging_parameter);
    }
  }
  
  BaseMD_HMMDecoderBlock::close();
}

void BaseMS_HMMDecoderBlock::write_hypothesis(RecoHypothesis *hyp, FILE *file) {
  hyp->write_all_multisource(file, group_number_to_label_map);
}

void BaseMS_HMMDecoderBlock::merge_group(int group_number, float) {
  decoder->merge_group_hypotheses(groups[group_number], group_hypothesis_merging_parameter);
  groups_active[group_number]=false;
}

// Make lists of groups that are starting, groups that ended in the last frame
// (This method also updates a group map that is used to map arbitrary integer group labels onto sequential
// integer group numbers)

void BaseMS_HMMDecoderBlock::start_end_detection(const vector<Float> &group, vector<Integer> &group_starting_numbers, vector<Integer> &group_ending_numbers, map<Integer, Integer> &a_group_label_to_number_map, map<Integer, Integer> &) {

  vector<Float> sorted = group;

  // All previously active groups may potentially have ended this frame
  vector<Boolean> group_has_ended=groups_active;

  Integer group_number;
  
  // Construct the list of group numbers appearing in this frame
  sort(sorted.begin(), sorted.end());
  vector<Float>::iterator new_end = unique(sorted.begin(), sorted.end());
  for (vector<Float>::iterator gp=sorted.begin(); gp!=new_end; ++gp) {
    int group_label = (int)*gp;
    if (group_label > 0) { // gp==0 => background - i.e. never part of the mask, gp==-1 => foreground i.e. always part of mask
      if ((group_number=a_group_label_to_number_map[group_label])==0) {
	// This is a new group ... assign the next group number and add it to the group starting list
	group_has_ended.push_back(true);
	group_starting_numbers.push_back(next_group_number);
	update_group_map(vector<Float>(1, group_label));
      }
      if (groups_active[group_number]==false) {
	// groups that are restarting
	groups_active[group_number]=true;
	group_starting_numbers.push_back(group_number);
      }
      // Group has continued to this frame so mark 'group_has_ended' as false
      group_has_ended[group_number]=false;
    }
  }
  
  Integer n=0;
  for (vector<int>::iterator gp=group_has_ended.begin(), gp_end=group_has_ended.end(); gp!=gp_end; ++gp, ++n) {
    // Groups ending are those for which 'group_has_ended' has remained 'true' 
    if (*gp) {
      group_ending_numbers.push_back(n);
      groups_active[n]=false;
    }
  }
  
}


void BaseMS_HMMDecoderBlock::read_group_summary_file(const string &filename, map<int, Group> &group_summary_map) {

  FILE *a_file;
  if ((a_file=fopen(filename.c_str(), "r"))==NULL) {
    cerr << "Cannot group summary file file: " << filename << endl;
    throw(CTKError(__FILE__, __LINE__));
  }

  Group group;

  while (group.input(a_file)==CTK_SUCCESS) {
    group_summary_map[group.number()]=Group(group);
  }
  
  fclose(a_file);

}

void BaseMS_HMMDecoderBlock::update_group_map(const vector<Float> &group_labels) {

  // Update group map - maps incoming labels onto new consecutively numbered labels
  for (vector<Float>::const_iterator glp =group_labels.begin(), glp_end=group_labels.end(); glp!=glp_end; ++glp) {
    if (*glp>0.0 && group_label_to_number_map[(int)*glp]==0) {
      group_label_to_number_map[(int)*glp]=next_group_number;
      group_number_to_label_map[next_group_number]=(int)*glp;
      ++next_group_number;
      groups_active.push_back(true);
    } 
  }
}

void BaseMS_HMMDecoderBlock::map_group_labels(const vector<Float> &group_labels, vector<int> &group_numbers) {
  // Map from groups labels to group numbers
  for (vector<Float>::const_iterator glp =group_labels.begin(), glp_end=group_labels.end(); glp!=glp_end; ++glp) {
    group_numbers.push_back(group_label_to_number_map[(int)*glp]);
  } 
}


// Update group statistics according to the incoming mask and group data. 
void BaseMS_HMMDecoderBlock::update_group_stats(const vector<Integer> &mapped_group, const CTKVector &mask, vector<Group> &groups) {
  CTKVector::const_iterator maskp = mask.begin();
  for ( vector<Integer>::const_iterator mgp=mapped_group.begin(), mgp_end=mapped_group.end(); mgp!=mgp_end; ++mgp, ++maskp) {

    if (*mgp>0) {  // Ignore the special 0 and -1 forced background and foreground groups
      // construct new groups if necessary
      for (int i=groups.size(); i<=*mgp; ++i) {
	groups.push_back(Group(i, group_number_to_label_map[i]));
	// cerr << "Constructing group: " << i << "; " << groups.size() << " groups now exist\n";
      }
      
      groups[*mgp].add_point(0, 0, *maskp);
    }
  }

  
}


void BaseMS_HMMDecoderBlock::tidy_label_list(vector<Float> *labels) {
  // removes labels <= 0  and duplicate labels
  vector<Float>::iterator new_end;
  
  new_end=remove_if(labels->begin(), labels->end(), bind2nd(less_equal<Float>(), 0.5));
  labels->erase(new_end, labels->end());
  sort(labels->begin(), labels->end());
  new_end=unique(labels->begin(), labels->end());
  labels->erase(new_end, labels->end());
}



/******************************************************************************/
/*                                                                            */
/*       CLASS NAME: MS_HMMDecoderBlock                                       */
/*                                                                            */
/******************************************************************************/

const string MS_HMMDecoderBlock::type_name = "HMMDecoderMultisource";
const string MS_HMMDecoderBlock::help_text = HMM_DECODER_MULTISOURCE_BLOCK_HELP_TEXT;

MS_HMMDecoderBlock::MS_HMMDecoderBlock(const string &a_name):CTKObject(a_name),BaseMS_HMMDecoderBlock(a_name, type_name) {

}


Block* MS_HMMDecoderBlock::clone(const string &n) const{
  Block *ablock = new MS_HMMDecoderBlock(n.empty()?getname():n);
  return copy_this_block_to(ablock);
}

HMMMixture * MS_HMMDecoderBlock::get_HMM_mixture_prototype() const {
  if (get_PROB_CALC_param()->get_value()==string("MULTISOURCE"))
    return new HMMMixtureMultisource();
  else // use standard missing data computation
    return new HMMMixtureMD(); 
}


/******************************************************************************/
/*                                                                            */
/*       CLASS NAME: MSSoft_HMMDecoderBlock                                   */
/*                                                                            */
/******************************************************************************/

const string MSSoft_HMMDecoderBlock::type_name = "HMMDecoderMultisourceSoft";
const string MSSoft_HMMDecoderBlock::help_text = HMM_DECODER_MULTISOURCE_SOFT_BLOCK_HELP_TEXT;

MSSoft_HMMDecoderBlock::MSSoft_HMMDecoderBlock(const string &a_name):CTKObject(a_name),BaseMS_HMMDecoderBlock(a_name, type_name) {
  set_use_soft_multisource();
  unset_parameter_hidden("PROB_CALC"); 
}


Block* MSSoft_HMMDecoderBlock::clone(const string &n) const{
  Block *ablock = new MSSoft_HMMDecoderBlock(n.empty()?getname():n);
  return copy_this_block_to(ablock);
}

HMMMixture* MSSoft_HMMDecoderBlock::get_HMM_mixture_prototype() const {
  return new HMMMixtureMultisourceSoft();
}





/* End of ctk_HMM_MS_decoder.cpp */
