/******************************************************************************/
/*                                                                            */
/*      ctk_HMM.hh                                                            */
/*                                                                            */
/*      Class declarations for ctk_HMM.cpp                                    */
/*                                                                            */
/*      Author: Jon Barker, Sheffield University                              */
/*                                                                            */
/*      CTK VERSION 1.3.5  Apr 22, 2007                              */
/*                                                                            */
/******************************************************************************/

#ifndef CTK_HMM_HH
#define CTK_HMM_HH



#include <cmath>

#include "ctk_local.hh"

#include <string>
#include <vector>
#include <list>
#include <map>

#include "boost/smart_ptr.hpp"

#include "ctk_ro_file.hh"
#include "ctk_HMM_types.hh"
#include "ctk_feature_vector.hh"

const HMMFloat CTK_SMALL_NUMBER_HMMFLOAT = 10e-10;

const HMMFloat LOG_HALF = log(0.5);

const int MAX_STRING_SIZE=255;

class SetOfHMMs;
class HMM;
class HMMState;
class HMMMixture;


void read_hmm_label_file(const string &a_file_name, map<string, string> &name_label_map);

void read_dictionary(const string &a_file_name, map<string, list<string> > &dictionary);

using boost::shared_ptr;


/////////////////////////////////////////////////////////////////////////

// Maths stuff - including stuff for log and exp approximations

// Static union used for direct manipulation of 4-byte floating point types
// (employed in table-based log approximations, logT8 and logT10)
static union {
  float d;
  struct {
#if defined CTK_BIG_ENDIAN
    unsigned char c1;
    unsigned char c2;
    unsigned char c3;
    unsigned char c4;
#elif defined CTK_LITTLE_ENDIAN
    unsigned char c4;
    unsigned char c3;
    unsigned char c2;
    unsigned char c1;
#endif
  } n;
  struct {
#if defined CTK_BIG_ENDIAN
    unsigned short s1;
    unsigned short s2;
#elif defined CTK_LITTLE_ENDIAN
    unsigned short s2;
    unsigned short s1;    
#else
#error You must define CTK_BIG_ENDIAN or CTK_LITTLE_ENDIAN, there is no default. (Intel/VAX is LE, Sun/IBM/HP are BE)
#endif
  } s;
  unsigned int i;
} ieee_flt;

// Static union used for direct manipulation of 8-byte double floating point types
// (employed in exp approximation)
static union {
  double d;
  struct {
#if defined CTK_BIG_ENDIAN
    int s1;
    int s2;
#elif defined CTK_LITTLE_ENDIAN
    int s2;
    int s1;    
#else
#error You must define CTK_BIG_ENDIAN or CTK_LITTLE_ENDIAN, there is no default. (Intel/VAX is LE, Sun/IBM/HP are BE)
#endif
  } s;
} ieee_dbl;


///  exp approximations 

#define EXP_A_FLOAT (128/M_LN2) 
#define EXP_A_DOUBLE (1048576/M_LN2) 

#define EXP_C_FLOAT 7
#define EXP_C_DOUBLE 45799

#define EXP_FLOAT(y) (ieee_flt.s.s1=(int)(EXP_A_FLOAT*(y) + (16256-EXP_C_FLOAT)), ieee_flt.d)
#define EXP_DOUBLE(y) (ieee_dbl.s.s1=(int)(EXP_A_DOUBLE*(y) + (1072693248-EXP_C_DOUBLE)), ieee_dbl.d)


// log add function based on log1p library function
template <class T> 
T log_add(T x1, T x2) {
  // y = log(exp(x1)+exp(x2));
  // but with out having to evaluate either exp(x1) or exp(x2) which can be very big
  T diff = x2-x1;
  
  if (x2-x1>0) {
    if (diff>80)
      return x2;
    else
      return x2+log1p(exp(-diff));
  } else {
    if (diff < -80)
      return x1;
    else
      return x1+log1p(exp(diff));
  }
}


// Handy template methods

template <class T>
void remove_duplicates(vector<T> &array) {
  // Remove duplicate from a list
  sort(array.begin(), array.end());
  typename vector<T>::iterator new_end = unique(array.begin(), array.end());
  array.resize(new_end-array.begin());
}

template <class T>
void remove_duplicates(list<T> &array) {
  // Remove duplicate from a list
  sort(array.begin(), array.end());
  typename vector<T>::iterator new_end = unique(array.begin(), array.end());
  array.resize(new_end-array.begin());
}

template <class T>
void multi_erase(vector<T> &array, vector<unsigned int> positions) {
  // indices of elements to be deleted must be in descending order
  sort(positions.begin(), positions.end(), greater<unsigned int>());
  for (unsigned int i=0; i<positions.size(); ++i) {
    array.erase(array.begin()+positions[i]);
  }
}

template <class T>
void replace_list(vector<T*> &old_list, vector<T*> &new_list) {
  // Replaces an old list of pointers to thing with a new lists
  // and deletes any things in the old lists that are no longer used in the new list
  // Note: assumes lists are already sorted 
  vector<T*> dead_things;
  sort(old_list.begin(), old_list.end());
  sort(new_list.begin(), new_list.end());
  set_difference(old_list.begin(), old_list.end(), new_list.begin(), new_list.end(), back_inserter(dead_things));
  //  cerr << "Deleting " << dead_things.size() << " things\n";
  sequence_delete(dead_things.begin(), dead_things.end());
  old_list=new_list;
}

/******************************************************************************/
/*									      */
/*	CLASS NAME: HMMMixture		                 	       	      */
/*									      */
/******************************************************************************/
  
class HMMMixture  {

  static const Float ERF_TABLE_MAX;
  static const UInteger16 ERF_TABLE_RESOLUTION;

private: 

  static HMMFloat *make_erf_table();
  static float *make_log_table_exponent();
  static float *make_log_table_mantissa_8bit();
  static float *make_log_table_mantissa_10bit();
  

  static HMMFloat *erf_table;
  static HMMFloat *log_table_exponent;
  static HMMFloat *log_table_mantissa_8bit;
  static HMMFloat *log_table_mantissa_10bit;
 
  int mix_ID;   // Unique mixture ID number

  vector<HMMFloat> mu;          // Means
  vector<HMMFloat> ivar;         // Reciprocal Variances;

  // The following expressions of mu and ivar appear as terms in probability calculations
  // and can be precomputed at the expense of some storage
  
  vector<HMMFloat> ivarsm05;    //  -0.5 * ivar
  vector<HMMFloat> sqrtivar; 	// sqrt(ivar)
  vector<HMMFloat> sqrtivar05; 	// sqrt(0.5*ivar)
  vector<HMMFloat> likcon; 	// -0.5 log(2 PI var)  - constant occurring in likelihood computation
  vector<HMMFloat> lik0; 	// likelihood of x=0
  vector<HMMFloat> erfmu; 	// erf(-mu * vv2)

  
  shared_ptr<FeatureVector> feature_vector;
  shared_ptr<MaskVector> mask_vector;

  shared_ptr<HMMMixture> parallel_mixture;  // Mixture can be attached to a parallel mixture for PMC
  
  bool use_marginals;
  bool use_delta_marginals;
  bool use_deltas;
  
  // These intermediate variable depend on the data but not on the mask
  mutable vector<HMMFloat> likelihoods; // Present data components
  mutable vector<HMMFloat> marginals;   // Bounded marginals - i.e. missing data components
  
  mutable bool data_has_changed; 
  mutable bool mask_has_changed;
  mutable bool parallel_mixture_has_changed;

  mutable bool likelihoods_valid;
  mutable bool marginals_valid;
  
  //
  //
  //
  
public:
  HMMMixture(const vector<HMMFloat> &amu, const vector<HMMFloat> &aivar);
  
  virtual ~HMMMixture();

  virtual HMMMixture *clone(const vector<HMMFloat> amu, const vector<HMMFloat> &ivar) const = 0;
  virtual HMMFloat get_prob() const=0;

  void set_missing_data_mask(shared_ptr<MaskVector> mask_vector);
  void set_observed_data(shared_ptr<FeatureVector> feature_vector);
  void set_parallel_mixture(shared_ptr<HMMMixture> parallel_mixture);
  
  void set_use_deltas(bool x) {use_deltas=x;}
  void set_use_marginals(bool x) {use_marginals=x;}
  void set_use_delta_marginals(bool x) {use_delta_marginals=x;}
  
  const vector<HMMFloat> &get_mu() const {return mu;}
  const vector<HMMFloat> &get_ivar() const {return ivar;}

  int get_ID() const {return mix_ID;}
  
  void calculate_full_likelihood() const;
  void calculate_full_marginal() const;

  int get_size() const {
    int n = feature_vector->size();
    if (feature_vector->has_deltas() && !use_deltas) n=n/2;
    return n;
  }
  
  vector<HMMFloat> get_expected_values(const vector<bool> &mask) const;
 

  vector<HMMFloat>::const_iterator get_likelihoods_begin() const {return likelihoods.begin();}
  vector<HMMFloat>::const_iterator get_likelihoods_end() const {return likelihoods.end();}
  vector<HMMFloat>::const_iterator get_marginals_begin() const {return marginals.begin();}
  vector<HMMFloat>::const_iterator get_marginals_end() const {return marginals.end();}

  friend ostream& operator<< (ostream& out, const HMMMixture& x);

protected:
  
  const shared_ptr<MaskVector> get_mask() const  {return mask_vector;}
  
  const shared_ptr<HMMMixture> get_parallel_mixture() const {return parallel_mixture;}
  const shared_ptr<FeatureVector> get_feature_vector() const {return feature_vector;}
  
  bool get_use_deltas() const {return use_deltas;}

  bool get_use_marginals() const {return use_marginals;}
  bool get_use_delta_marginals() const {return use_delta_marginals;}
  
  HMMMixture() {}
  
  void calculate_masked_likelihood() const;
  HMMFloat accumulate_masked_likelihood() const;
    
  void calculate_masked_marginal() const;
  HMMFloat accumulate_masked_marginal() const;
  
  bool get_data_has_changed() const {return data_has_changed;}
  bool get_mask_has_changed() const {return mask_has_changed;}
  bool get_parallel_mixture_has_changed() const {return parallel_mixture_has_changed;}
  void clear_data_has_changed() const {data_has_changed=false;}
  void clear_mask_has_changed() const {mask_has_changed=false;}
  void clear_parallel_mixture_has_changed() const {parallel_mixture_has_changed=false;}

private:
  
  void do_all_precomputation();
    
  // The likelihood and marginal calculations are inlined. The HMMFloat 'tmp_hmmfloat'
  // is declared as a private class member, rather than a local method variable, to
  // make sure there is no unecessary construction and destruction.

  inline HMMFloat calc_feature_likelihood(HMMFloat data_, HMMFloat mu_, HMMFloat ivarsm05_, HMMFloat likcon_) const {
    tmp_hmmfloat=data_-mu_; return (ivarsm05_*tmp_hmmfloat*tmp_hmmfloat)+likcon_;
  };

  //  Note: BUG_1
  //
  // The method calc_feature_marginal can be compiled with a deliberate bug to simulate
  // the behaviour of CTKv1.1.0. This is useful if running systems that have been tuned 
  // under CTKv1.1.0 and hence were put out of tune when the bug was corrected
  //

  inline HMMFloat calc_feature_marginal(HMMFloat lower_bound_, HMMFloat upper_bound_, HMMFloat mu_, HMMFloat sqrtivar05_, HMMFloat erfmu_, HMMFloat marg_norm_, HMMFloat lik_lower_bound_) const{
    if (upper_bound_<=lower_bound_)
      return lik_lower_bound_;
    else if (lower_bound_==0.0) {
#ifdef WITH_V1_1_0_BUGS
      return logT8(erft((upper_bound_ - mu_)* sqrtivar05_)  - erfmu_)+marg_norm_;  // BUGGY VERSION!
#else
      return logT8(erft((upper_bound_ - mu_)* sqrtivar05_)  - erfmu_)+marg_norm_ + LOG_HALF; // CORRECT VERSION
#endif
    } else {
#ifdef WITH_V1_1_0_BUGS
      return logT8(erft((upper_bound_ - mu_)* sqrtivar05_)  - erft((lower_bound_ - mu_)* sqrtivar05_))+marg_norm_;  // BUGGY VERSION!
#else
      return logT8(erft((upper_bound_ - mu_)* sqrtivar05_)  - erft((lower_bound_ - mu_)* sqrtivar05_))+marg_norm_+ LOG_HALF; // CORRECT VERSION
#endif
    }
  }


    /*
  inline HMMFloat calc_feature_marginal(HMMFloat lower_bound_, HMMFloat upper_bound_, HMMFloat mu_, HMMFloat sqrtivar05_, HMMFloat erfmu_, HMMFloat marg_norm_, HMMFloat lik_lower_bound_) const{
    cerr << lower_bound_ << " " << upper_bound_ << " " << mu_ << " " << sqrtivar05_ << " " << erfmu_ << " " << marg_norm_ << " " << lik_lower_bound_ <<  " = ";
    if (upper_bound_<=lower_bound_) {
      cerr << "A " << lik_lower_bound_ << "\n";
      return lik_lower_bound_;
    } else if (lower_bound_==0.0) {
      HMMFloat x= logT8(erft((upper_bound_ - mu_)* sqrtivar05_)  - erfmu_)+marg_norm_ + LOG_HALF;
      cerr << "B " << x << "\n";
      return x;
    } else {

      HMMFloat x= logT8(erft((upper_bound_ - mu_)* sqrtivar05_)  - erft((lower_bound_ - mu_)* sqrtivar05_))+marg_norm_+ LOG_HALF;
      cerr << "C " << x << "\n";
      return x;
    }
  }
  */
  

  mutable HMMFloat tmp_hmmfloat;

  inline HMMFloat erfs(HMMFloat x) const {
    return 2.0/(1.0+exp(-2.3236*x))-1.0;
  }
  
  
  /***** The table-based erf function approximation*****************************/
  
  inline HMMFloat erft(HMMFloat x) const {
    if (x<0.0F) {
      if (x<-ERF_TABLE_MAX) return -1.0F;
      return -erf_table[UInteger16(-x*ERF_TABLE_RESOLUTION)];
    } else {
      if (x>ERF_TABLE_MAX) return 1.0F;
      return erf_table[UInteger16(x*ERF_TABLE_RESOLUTION)];
    }
  }
  
  /***** The table-based log function approximation  *******************************/

  inline float logT10(float f) const {
    ieee_flt.d = f;
    ieee_flt.i<<=1;
    float x=log_table_exponent[ieee_flt.n.c1];
    ieee_flt.n.c1=0;
    ieee_flt.i<<=2;  
    return (x+log_table_mantissa_10bit[ieee_flt.s.s1]); 
  }

  inline float logT8(float f) const {
    ieee_flt.d = f;
    ieee_flt.i<<=1;
    return (log_table_exponent[ieee_flt.n.c1]+log_table_mantissa_8bit[ieee_flt.n.c2]); 
  }

  
public:
  // log add based on log and exp approximations
  inline float logT8_add(float x1, float x2) const {

    static const int MAX_LOG_ADD_DIFF = 20;
    
    float diff=x2-x1;
    if (diff>0.0) {
      if (diff > MAX_LOG_ADD_DIFF)
	return x2;
      else
	return x2+logT8(1.0F+EXP_DOUBLE(-diff));
    } else {
      if (diff < -MAX_LOG_ADD_DIFF)
	return x1;
      else
	return x1+logT8(1.0F+EXP_DOUBLE(diff));
    }
  }

  inline float logT10_add(float x1, float x2) const {

    static const int MAX_LOG_ADD_DIFF = 20;
    
    float diff=x2-x1;
    if (diff>0.0) {
      if (diff > MAX_LOG_ADD_DIFF)
	return x2;
      else
	return x2+logT10(1.0F+EXP_DOUBLE(-diff));
    } else {
      if (diff < -MAX_LOG_ADD_DIFF)
	return x1;
      else
	return x1+logT10(1.0F+EXP_DOUBLE(diff));
    }
  }

};

/******************************************************************************/
/*									      */
/*	CLASS NAME: HMMMixtureStandard		              	       	      */
/*									      */
/******************************************************************************/

class HMMMixtureStandard: public HMMMixture {

public:

  HMMMixtureStandard(){};
  HMMMixtureStandard(const vector<HMMFloat> &mu, const vector<HMMFloat> &ivar):HMMMixture(mu, ivar){};
  virtual ~HMMMixtureStandard() {};

  HMMMixture *clone(const vector<HMMFloat> amu, const vector<HMMFloat> &ivar) const {
    return new HMMMixtureStandard(amu, ivar);
  }

  HMMFloat get_prob() const;

private:
  
  mutable HMMFloat prob;   // The likelihood calculated given mean, var, mask and data

};

/******************************************************************************/
/*									      */
/*	CLASS NAME: HMMMixtureMD		              	       	      */
/*									      */
/******************************************************************************/

class HMMMixtureMD: public HMMMixture {

public:

  HMMMixtureMD(){};
  HMMMixtureMD(const vector<HMMFloat> &mu, const vector<HMMFloat> &ivar):HMMMixture(mu, ivar){};
  virtual ~HMMMixtureMD() {};

  HMMMixture *clone(const vector<HMMFloat> amu, const vector<HMMFloat> &ivar) const {
    return new HMMMixtureMD(amu, ivar);
  }

  HMMFloat get_prob() const;

private:
  
  mutable HMMFloat prob;   // The likelihood calculated given mean, var, mask and data

};



/******************************************************************************/
/*									      */
/*	CLASS NAME: HMMMixtureMDSoft		              	       	      */
/*									      */
/******************************************************************************/

class HMMMixtureMDSoft: public HMMMixture {

public:

  HMMMixtureMDSoft(){};
  HMMMixtureMDSoft(const vector<HMMFloat> &mu, const vector<HMMFloat> &ivar):HMMMixture(mu, ivar){};
  virtual ~HMMMixtureMDSoft() {};

  HMMMixture *clone(const vector<HMMFloat> amu, const vector<HMMFloat> &ivar) const {
    return new HMMMixtureMDSoft(amu, ivar);
  }

  HMMFloat get_prob() const;

private:
  
  mutable HMMFloat prob;   // The likelihood calculated given mean, var, mask and data

};


/******************************************************************************/
/*									      */
/*	CLASS NAME: HMMMixturePMC		              	       	      */
/*									      */
/******************************************************************************/

class HMMMixturePMC: public HMMMixture {

public:

  HMMMixturePMC(){};
  HMMMixturePMC(const vector<HMMFloat> &mu, const vector<HMMFloat> &ivar):HMMMixture(mu, ivar){};
  virtual ~HMMMixturePMC() {};

  HMMMixture *clone(const vector<HMMFloat> amu, const vector<HMMFloat> &ivar) const {
    return new HMMMixturePMC(amu, ivar);
  }

  HMMFloat get_prob() const;

private:
  
  mutable HMMFloat prob;   // The likelihood calculated given mean, var, mask and data

};



/******************************************************************************/
/*									      */
/*	CLASS NAME: HMMMixtureMultisource	              	       	      */
/*									      */
/******************************************************************************/

class HMMMixtureMultisource: public HMMMixture {

public:

  HMMMixtureMultisource(){};
  HMMMixtureMultisource(const vector<HMMFloat> &mu, const vector<HMMFloat> &ivar):HMMMixture(mu, ivar){};
  virtual ~HMMMixtureMultisource() {};

  HMMMixture *clone(const vector<HMMFloat> amu, const vector<HMMFloat> &ivar) const {
    return new HMMMixtureMultisource(amu, ivar);
  }

  HMMFloat get_prob() const;

private:
  
  mutable HMMFloat prob;   // The likelihood calculated given mean, var, mask and data

};


/******************************************************************************/
/*									      */
/*	CLASS NAME: HMMMixtureMultisourceSoft	              	       	      */
/*									      */
/******************************************************************************/

class HMMMixtureMultisourceSoft: public HMMMixture {

public:

  HMMMixtureMultisourceSoft(){};
  HMMMixtureMultisourceSoft(const vector<HMMFloat> &mu, const vector<HMMFloat> &ivar):HMMMixture(mu, ivar){};
  virtual ~HMMMixtureMultisourceSoft() {};

  HMMMixture *clone(const vector<HMMFloat> amu, const vector<HMMFloat> &ivar) const {
    return new HMMMixtureMultisourceSoft(amu, ivar);
  }

  HMMFloat get_prob() const;

private:
  
  mutable HMMFloat prob;   // The likelihood calculated given mean, var, mask and data

};




/******************************************************************************/
/*									      */
/*	CLASS NAME: HMMState		                 	       	      */
/*									      */
/******************************************************************************/

// n.b. States with 0 mixtures are non-emitting states

class HMMState {
  
private:

  int state_ID;   // Unique State ID number
  int num_mixes;

  vector<HMMMixture*> mixture;
  vector<HMMFloat> mix_weights;
  
  float voicing;             // Optional voicing parameter  - not currently used
  
  int max_duration;          // Optional max duration parameter - limits how long token may stay in the state
                             // Duration limit may be turned off by setting max_duration=0 
  
  mutable HMMFloat prob;     // Computed state probability
  mutable int max_mixture;       // Store max likelihood mixture index

  //
  //
  //
  
public:
  HMMState();  // default constructor - constructs a non-emitting state 
  HMMState(const vector<HMMMixture*> &amixture, const vector<HMMFloat> &amix_weights, float voicing, int max_duration);

  ~HMMState();

  Integer get_num_mixes() const;

  float get_voicing() const {return voicing;}
  
  int get_ID() const {return state_ID;}

  void calc_prob(bool max_approx_param) const;
  HMMFloat get_prob() const;

  bool emits() const;
  
  int get_max_mixture() const {return max_mixture;}

  // Get mixture by mixture index: mixtures are number 0 to num_mixtures-1
  HMMMixture* get_mixture(Integer num) const {return mixture[num];}

  HMMFloat get_mixture_weight(Integer num) const {return mix_weights[num];}
  
  // Find the max likelihood mixture and store index in max_mixture
  void store_max_mixture() const;

  int get_max_duration() const {return max_duration;}
  void set_max_duration(int x) {max_duration=x;}
  
  void set_prob(HMMFloat x) const {prob=x;}

  bool operator== (HMMState x);

  friend ostream& operator<< (ostream& out, const HMMState& x);

};


/******************************************************************************/
/*									      */
/*	CLASS NAME: Transitions		                      	       	      */
/*									      */
/******************************************************************************/

class Transitions {

  vector<vector<HMMFloat> > probabilities;

public:

  Transitions(int n=0); // Construct empty n by n transition matrix
  
  HMMFloat get_transition_prob(unsigned int from, unsigned int to) const;

  vector<HMMFloat> get_row(unsigned int row) const;
  
  CTKStatus add_transition(unsigned int from, unsigned int to, HMMFloat prob);
  void clear();
  void resize(int n);
  int size() const;
  
  // Normalise so that exits probs from each state sum to 1.0
  void normalise();
  
  // Rationalise - reduce transition matrix by removing states that don't lead to the exit state
  // and return the indices of any states removed
  vector<unsigned int> rationalise();


  
  CTKStatus validate() const;

  friend ostream& operator<< (ostream& out, const Transitions& x);

};

/******************************************************************************/
/*									      */
/*	CLASS NAME: HMM			                 	       	      */
/*									      */
/******************************************************************************/

class HMM: public ReadOnceFile {

private:

  vector<HMMState*> states;

  Transitions pTrans;
  string name;
  
  int num_states;

  int vec_size;
  //
  //
  //
  
public:
  
  HMM(const vector<HMMState*> &astates, const Transitions &atrans, int avec_size, const string &aname);

  HMM(const vector<HMMState*> &astates, const Transitions &atrans, int avec_size, const string &aname, const string &filename);

  virtual ~HMM();

  Integer get_num_states() const;
  Integer get_vec_size() const {return vec_size;}

  const vector<HMMState*> &get_states() const {return states;}
  // Get state by state index: states are number 0 to num_states-1
  HMMState* get_state(Integer num) const;
   
  const Transitions &getTrans() const {return pTrans;}
  const string &getName() const {return name;}
  
  friend ostream& operator<< (ostream& out, const HMM& x);
};


/******************************************************************************/
/*									      */
/*	CLASS NAME: SetOfHMMs		                 	       	      */
/*									      */
/******************************************************************************/
class SetOfHMMs: public ReadOnceFile {

  typedef vector<HMM*>::iterator HMMIt;
  typedef vector<HMM*>::const_iterator HMMConstIt;
  typedef vector<HMMMixture*>::iterator HMMMixtureIt;
  typedef vector<HMMState*>::iterator HMMStateIt;

  friend class HMMDecoderBlock;
  friend class DecoderInfo;

private:

  char read_word_format_string[MAX_STRING_SIZE+1];  // Format for "%(MAX_STRING_SIZE)s"
  
  map<string, HMMState *> HMM_state_macros;        
  map<string, HMMMixture *> HMM_mixture_macros;
  map<string, vector<HMMFloat> *> HMM_variance_macros;
  map<string, Transitions> HMM_transition_macros;

  map<string, string> name_label_map; 
  map<string, list<string> > dictionary; 

  vector<string> labels;
  vector<string> names;
  
  vector<string> macro_list;     // list of macro files read    
  
  vector<HMM *> HMM_list;
  vector<HMMState *> state_list;
  vector<HMMMixture *> mixture_list;      // The complete set of mixtures  e.g. 960x1

  HMMMixture *mixture_prototype;
  
  int total_num_states;  // Total number of states - NOTE: This is not necessarily the same as state_list.size() because tied states only have one entry in state_list but are counted separately in total_num_states
  
  string hmmpath;    // The path of the HMM file currently being read.

  CTKStatus error_status;   // Set of CTK_FAILURE if a problem is encountered during construction
  //
  //
  //
  
public:
  
  SetOfHMMs(){};
  SetOfHMMs(string filename, const string &owner_blockname, map<string, string> &a_name_label_map, map<string, list<string> > &dictionary, HMMMixture *amixture_prototype);

  friend ostream& operator<< (ostream& out, const SetOfHMMs& x);
  
  CTKStatus get_error_status() const {return error_status;}
  
  const HMMMixture *get_mixture_prototype() const {return mixture_prototype;}
  const HMM* get_HMM(int n) {return HMM_list[n];}
  const HMM* get_HMM_by_name(const string &name) const;
  HMM* get_HMM_by_name(const string &name);

  virtual ~SetOfHMMs();
  
  void display(ostream &outfile);

  Integer get_num_HMMs() const;

  Integer get_total_num_states() const;

  Integer get_vec_size() const;
  
  // Set missing data mask for all mixtures in the pool
  void set_missing_data_mask(shared_ptr<MaskVector> mask_vector);

  // Set observed data for all mixtures in the pool
  void set_observed_data(shared_ptr<FeatureVector> feature_vector);

  // Set parallel mixture for all mixtures in the pool
  void set_parallel_mixture(shared_ptr<HMMMixture> parallel_mixture);

  // Set use deltas for all mixtures
  void set_use_deltas(bool use_deltas);
  
  // Set use marginals (on static features) for all mixtures
  void set_use_marginals(bool use_marginals);

  // Set use marginals (on delta features) for all mixtures
  void set_use_delta_marginals(bool use_delta_marginals);

  // Insert the likelihoods for each HMM state into the vector supplied
  void construct_likelihood_vector(vector<Float> &likelihoods);

  // Insert the likelihoods supplied into the HMM states ready for decoding
  void set_likelihoods(const vector<Float> &likelihoods);

  void construct_winning_mixture_vector(vector<Float> &winning_mixture);

  void calc_prob(bool max_mixtures_param);

  bool load_macro_file(string macro_filename);
  
  HMMState *lookup_state_macro(string macro_name) {
    return HMM_state_macros[macro_name];
  }
  
  HMMMixture *lookup_mixture_macro(string macro_name) {
    return HMM_mixture_macros[macro_name];
  }

  Transitions lookup_transition_macro(string macro_name) {
    return HMM_transition_macros[macro_name];
  }

  // Return the HMM label given the HMM name
  string lookup_label(const string &name) const;

  // Return the sequence of physical HMM names defining the pronunciation of the grammar unit name
  list<string> lookup_pronunciation(const string &name) const;
  
  // Return the list of output labels
  const vector<string> &get_label_list() const;
  // Return the list of grammar leaf names 
  const vector<string> &get_name_list() const;
  
  // Apply a transformation to a set of HMMs
  void edit_HMMs(const class HMMEdit *edit);
  
private:
  
  // Construct a sorted list of unique output labels and grammar unit names
  void make_name_and_label_lists();

  void make_default_name_label_map_from_dictionary();
  void make_default_name_label_map_from_HMMs();
  void make_default_dictionary();

  // Make a list of all HMM used in the dictionary
  vector<HMM *> make_HMM_list_from_dictionary();
  
  void store_HMM(HMM *hmmp);

  bool read_macro(FILE *fp);
  
  void error_in_HMM_file(FILE *fp) const;
  
  int detect_HMM_file_type(FILE *ro_fp);

  // Replaces the HMM list - taking care of underlying state and mixture lists
  void replace_HMM_list(vector<HMM *> &new_HMM_list);
};


/******************************************************************************/
// Non class methods

// ----- Various distance metrics -----

// Return the Kullback-Leiber distance between two mixture components
float distance_KL(const HMMMixture &mix1, const HMMMixture &mix2);

// Return the Bhattacharyya distance between two mixture components
float distance_BHA(const HMMMixture &mix1, const HMMMixture &mix2);

// Return the Mahalanobis distance between two mixture components
float distance_MAH(const HMMMixture &mix1, const HMMMixture &mix2);

// Return the Euclidian distance between two mixture components
float distance_EUC(const HMMMixture &mix1, const HMMMixture &mix2);






#endif


/* End of ctk_HMM.hh */
