/******************************************************************************/
/*                                                                            */
/*      ctk_HMM_edit.cpp                                           */
/*                                                                            */
/*      Support for HMM editting                                       */
/*                                                                            */
/*      Author: Jon Barker, Sheffield University                              */
/*                                                                            */
/*      CTK VERSION 1.3.5  Apr 22, 2007                              */
/*  */
/******************************************************************************/

#include "ctk-config.h"

#include <cmath>

#include <string>
#include <numeric>
 
#include "ctk_local.hh"
#include "ctk_error.hh"

#include "ctk_HMM.hh"
#include "ctk_function_classes.hh"

#include "ctk_HMM_edit.hh"


// HMM Edit 
// Construct new HMMs by transforming an existing HMMs

typedef vector<vector<HMMFloat> > TransMatrix;


/******************************************************************************/
/*									      */
/*	CLASS NAME: HMMEditNullEdit	                    	       	      */
/*									      */
/******************************************************************************/

HMMEditNullEdit::HMMEditNullEdit():HMMEdit(){}

// Return a copy of the HMM with no transformation
HMM *HMMEditNullEdit::edit(const HMM &hmm) const {

  const vector<HMMState*> &states = hmm.get_states();
  const Transitions &trans = hmm.getTrans();
  const string &name = hmm.getName();
  int vec_size = hmm.get_vec_size();

  HMM *new_hmmp = new HMM(states, trans, vec_size, name);

  return new_hmmp;
}

/******************************************************************************/
/*									      */
/*	CLASS NAME: HMMEditPruneTransitions	               	       	      */
/*									      */
/******************************************************************************/

HMMEditPruneTransitions::HMMEditPruneTransitions(float threshold):HMMEdit(), threshold_(threshold) {}

HMM *HMMEditPruneTransitions::edit(const HMM &hmm) const {

  Transitions new_trans = hmm.getTrans();

  float prob;
  for (int i=0; i<new_trans.size(); ++i) {
    for (int j=0; j<new_trans.size(); ++j) {
      prob=new_trans.get_transition_prob(i,j);
      if (prob<threshold_)
	new_trans.add_transition(i,j, 0.0);
    }
  }

  // Rationalise transition matrix - delete states that don't lead to the exit state
  // and return the indices of any states removed
  vector<unsigned int> dead_states=new_trans.rationalise();

  // Construct new list of state including all the non-dead states from the untransformed HMM
  vector<HMMState*> new_states;
  const vector<HMMState*> &states = hmm.get_states();
  for (unsigned int i=0; i<states.size(); i++) {
   if (find(dead_states.begin(), dead_states.end(), i)==dead_states.end())
     new_states.push_back(states[i]);
  }
  
  HMM *new_HMM=new HMM(new_states, new_trans, hmm.get_vec_size(), hmm.getName());
  
  return new_HMM;
}

/******************************************************************************/
/*									      */
/*	CLASS NAME: HMMEditMaxDuration	                    	       	      */
/*									      */
/******************************************************************************/

HMMEditMaxDuration::HMMEditMaxDuration():HMMEdit(){}

// Impose Max Duration through HMM topology
HMM *HMMEditMaxDuration::edit(const HMM &hmm) const {

  vector<HMMState*> new_states;
  vector<int> durations; // Successive state durations

  const vector<HMMState*> &states = hmm.get_states();
  for (vector<HMMState*>::const_iterator hmmspp=states.begin(), hmmspp_end=states.end(); hmmspp!=hmmspp_end; ++hmmspp) {

    int duration = (*hmmspp)->get_max_duration();
    
    for (int i=0; i<max(1,duration); ++i)
      new_states.push_back(*hmmspp);

    durations.push_back(duration);
  }

  for_each(states.begin(), states.end(), bind2nd(mem_fun(&HMMState::set_max_duration), 0));
    
  Transitions new_trans=construct_max_duration_trans(hmm.getTrans(), durations);
	   
  return (new HMM(new_states, new_trans, hmm.get_vec_size(), hmm.getName()));
}

/******************************************************************************/
/*									      */
/*	CLASS NAME: HMMEditWarpTransitions	               	       	      */
/*									      */
/******************************************************************************/

HMMEditWarpTransitions::HMMEditWarpTransitions(float threshold):HMMEdit(), threshold_(threshold) {}

HMM *HMMEditWarpTransitions::edit(const HMM &hmm) const {

  Transitions new_trans = hmm.getTrans();

  float prob;
  for (int i=0; i<new_trans.size(); ++i) {
    for (int j=0; j<new_trans.size(); ++j) {
      if (i==j) continue;
      prob=new_trans.get_transition_prob(i,j);
      if (prob<threshold_) {
	new_trans.add_transition(i,j, 0.0);
      } else {
	float one_minus_self=1.0-new_trans.get_transition_prob(i,i);
	new_trans.add_transition(i, j, one_minus_self);
      }
    }
  }
      
  HMM *new_HMM=new HMM(hmm.get_states(), new_trans, hmm.get_vec_size(), hmm.getName());
  
  return new_HMM;
}

/******************************************************************************/
/*									      */
/*	CLASS NAME: HMMEditMixtureSeparation	               	       	      */
/*									      */
/******************************************************************************/

HMMEditMixtureSeparation::HMMEditMixtureSeparation(int nmixes):HMMEdit(), nmixes_(nmixes) {}

HMM *HMMEditMixtureSeparation::edit(const HMM &hmm) const {

  const vector<HMMState*> &states = hmm.get_states();
  
  vector<HMMState*> new_states;
  vector<HMMFloat> all_weights;
  vector<int> new_states_per_old_state; // Number of states in transformed HMM used to represent each state in old HMM

  for (vector<HMMState*>::const_iterator hmmspp=states.begin(), hmmspp_end=states.end(); hmmspp!=hmmspp_end; ++hmmspp) {

    // Form new states from the mixtures of each state in the original hmm
    int num_mixes=(*hmmspp)->get_num_mixes();
    vector<HMMState*> new_states_for_this_state;
    vector<HMMFloat> weights_per_state;
    float voicing = (*hmmspp)->get_voicing();
    int duration = (*hmmspp)->get_max_duration();
    
    for (int i=0; i<num_mixes; ++i) {
      vector<HMMMixture*> mixes;
      vector<HMMFloat> weights;
      vector<int> nearest_neighbours=get_nearest_neighbours(*hmmspp, i, nmixes_);

      for (vector<int>::iterator nnp=nearest_neighbours.begin(), nnp_end=nearest_neighbours.end(); nnp<nnp_end; ++nnp) {
	mixes.push_back((*hmmspp)->get_mixture(*nnp));
	weights.push_back((*hmmspp)->get_mixture_weight(*nnp));
      }
      // Normalise weight vector to sum to 1.0 and store as logProb
      HMMFloat logsum = log_normalise(weights);

      HMMState *new_state= new HMMState(mixes, weights, voicing, duration );

      if (find_pointer(new_states_for_this_state.begin(), new_states_for_this_state.end(), new_state)!=new_states_for_this_state.end()) {
	delete new_state;
      } else {
	new_states.push_back(new_state);
	new_states_for_this_state.push_back(new_state);
	weights_per_state.push_back(logsum);
      }
    }
    // Normalise across sets of weights to sum to one
    log_normalise(weights_per_state);
    all_weights.insert(all_weights.end(), weights_per_state.begin(), weights_per_state.end());
    new_states_per_old_state.push_back(new_states_for_this_state.size());
  }
  
  Transitions new_trans=construct_temporal_continuity_trans(hmm.getTrans(), new_states_per_old_state, all_weights);
  
  return (new HMM(new_states, new_trans, hmm.get_vec_size(), hmm.getName()));
}


vector<int> get_nearest_neighbours(HMMState *state, int mix_index, int K) {
  // Find the indices of the K nearest neighbour mixtures to the mix_index mixture 
  vector<int> nearest_neighbours;
  
  HMMMixture *mix1=state->get_mixture(mix_index);
  vector<float> distances;
  for (int j=0, j_end=state->get_num_mixes(); j<j_end; ++j) {
    HMMMixture *mix2 = state->get_mixture(j);
    float dist_MAH=distance_MAH(*mix1, *mix2);
    distances.push_back(dist_MAH);
  }

  vector<float> sorted_distances(distances);
  
  sort(sorted_distances.begin(), sorted_distances.end());

  float threshold=sorted_distances[K-1];
  
  for (int j=0, j_end=state->get_num_mixes(); j<j_end; ++j) {
    if (distances[j]<=threshold)
      nearest_neighbours.push_back(j);
  }

  return nearest_neighbours;
}

///////////////////////////////////////////////


Transitions construct_temporal_continuity_trans(const Transitions &old_trans, const vector<int> &new_states_per_old_state, const vector<HMMFloat> &weights) {
// Construct a transition matrix for a HMM that has been transformed by expanding an old HMM
// such that each mixture becomes a new state

  vector<int> first_states;  // First of the new states corresponding to each old state i.

  int new_state_index=1;
  first_states.push_back(0);
  first_states.push_back(1);
  for (unsigned int i=0; i<new_states_per_old_state.size(); ++i) {
    new_state_index+=new_states_per_old_state[i]; 
    first_states.push_back(new_state_index);
  }
  first_states.push_back(new_state_index+1);

  Transitions new_trans(new_state_index+1);

  int old_max_state=new_states_per_old_state.size();
  for (int from=0; from<=old_max_state; ++from) {
    for (int to=1; to<=old_max_state+1; ++to) {
      HMMFloat prob=old_trans.get_transition_prob(from,to);
      if (prob==0.0) continue; // ignore 0 prob transitions
      
      if (from==to) {
	// Expand self transition
	for (int state=first_states[from]; state<first_states[from+1]; ++state) { 
	  if (new_trans.add_transition(state, state, prob)==CTK_FAILURE) {
	    cerr << "internal error\n"; exit(-1);
	  }
	}
      } else {
	
	// Expand non-self transition ....
	// ... from each of 1st set of parallel states to linking non emitting state
	for (int from_state=first_states[from]; from_state<first_states[from+1]; ++from_state) {
	   if (to==old_max_state+1) { 
	    for (int to_state=first_states[to]; to_state<first_states[to+1]; ++to_state) {
	      cerr << from_state << " -> " << to_state << "\n";
	      if (new_trans.add_transition(from_state, to_state, prob)==CTK_FAILURE) {
		cerr << "internal error\n"; exit(-1);
	      }
	    }
	  } else {	    
	    for (int to_state=first_states[to]; to_state<first_states[to+1]; ++to_state) {
	      //		cerr << "prob = " << prob << " weights = " << exp(weights[to_state-1]) << "\n";
	      //	      cerr << to_state-1 << " ";
	      cerr << from_state << " -> " << to_state << "\n";
	      if (new_trans.add_transition(from_state, to_state, prob*exp(weights[to_state-1]))==CTK_FAILURE) {
		cerr << "internal error\n"; exit(-1);
	      }
	    }
	  }
	}  
	
      }
	
    
    }
  }

  return new_trans;  
}

//
//
//


Transitions construct_max_duration_trans(const Transitions &old_trans, const vector<int> &durations) {

  int new_state_index=1;
  vector<int> first_states;  // First of the new states corresponding to each old state i.
  first_states.push_back(0);
  first_states.push_back(1);

  cerr << "0 1 ";
  for (unsigned int i=0; i<durations.size(); ++i) {
    new_state_index+=max(1, durations[i]);  
    first_states.push_back(new_state_index);
    cerr << new_state_index << " ";
  }
  first_states.push_back(new_state_index+1);
  cerr << new_state_index+1 << "\n";
  
  Transitions new_trans(new_state_index+1);  

  
  cerr << "New Trans: size = " << new_state_index+1 << "\n";
  
  int old_max_state=durations.size();
  for (int from=0; from<=old_max_state; ++from) {
    for (int to=1; to<=old_max_state+1; ++to) {
      HMMFloat prob=old_trans.get_transition_prob(from,to);
      if (prob==0.0) continue; // ignore 0 prob transitions

      int first_state=first_states[from];
      int last_state=first_states[from+1]-1;
      
      if (from==to) {
	if (durations[from-1]==0) {
	  // Duration=0 => use normal self transition
	  cerr << first_state << " 1-> " << first_state << "\n";
	  new_trans.add_transition(first_state,first_state,prob);
	} else {
	  // ... else expand self transition into a simple chain
	  for (int state=first_state; state<last_state; ++state) { 
	    cerr << state << " 2-> " << state+1 << "\n";
	    new_trans.add_transition(state, state+1, prob);
	  }
	}
      } else {
	// Expand non-self transitions to originate from every state in the from-state
	// chain and to point to first state in the destination-state chain
	int to_state = first_states[to];
	if (from==0 || durations[from-1]==0) {
	    cerr << first_state << " 3-> " << to_state << "\n";
	  new_trans.add_transition(first_state,to_state,prob);
	} else {
	  for (int state=first_state; state<=last_state; ++state) { 
	    cerr << state << " 4-> " << to_state << "\n";
	    new_trans.add_transition(state, to_state, prob);
	  }
	}
	
      }
    }
  }

  // Normalise to take care of transitions from final state of chain that no longer
  // sum to one now that there are no self transitions.
  new_trans.normalise();
  
  return new_trans;
}

//
// 
//
//

/* End of ctk_HMM_edit.cpp */


