/******************************************************************************/
/*                                                                            */
/*      ctk_HMM_builder.cpp                                                           */
/*                                                                            */
/*      Support for HTK HMM File reading                                              */
/*                                                                            */
/*      Author: Jon Barker, Sheffield University                              */
/*                                                                            */
/*      CTK VERSION 1.3.5  Apr 22, 2007                              */
/*  */
/******************************************************************************/

#include "ctk-config.h"

#include <cmath>
#include <algorithm>
#include <numeric>

#include "ctk_local.hh"
#include "ctk_error.hh"

#include "ctk_ro_file.hh"

#include "ctk_HMM.hh"
#include "ctk_HMM_builder.hh"


/******************************************************************************/

static const char *HMM_FLOAT_FORMAT = "%g";

/******************************************************************************/
/*									      */
/*	CLASS NAME: HMMObjectBuilder  	                 	       	      */
/*									      */
/******************************************************************************/

int HMMObjectBuilder::vec_size = 0;

HMMObjectBuilder::HMMObjectBuilder(SetOfHMMs *an_hmmset, int /*=0*/):hmmset(an_hmmset){
  sprintf(read_word_format_string_,"%%%ds",MAX_STRING_SIZE);
}

string HMMObjectBuilder::get_token(FILE *fp) {
  string token;
  char c;
  int x;
  
  fscanf(fp,"%c",&c);
  while (c!='~' && c!='<') {
    x=fscanf(fp,"%c",&c);
    if (x<=0) return token;
  }
  
  token+=c;
  if (c=='~') {
    fscanf(fp,"%c",&c);
    token+=c;
  } else {
    while (c!='>') {
      x=fscanf(fp,"%c",&c);
      if (x<=0) return token;
      token+=c;
    }
  }

  // strip off any _X  qualifiers   e.g.  <MFCC_D_A> => <MFCC>
  string::iterator pos;
  if ((pos=find(token.begin(), token.end(), '_'))!=token.end()) {
    cerr << "Token: " << token;
    token.erase(pos, token.end());
    token +='>';
    cerr << " => " << token << "\n";
  }
  return token;
}


void HMMObjectBuilder::set_finished() {
  finished=true;
}

void HMMObjectBuilder::error_in_HMM_file(FILE *fp) {
  char buffer[MAX_STRING_SIZE+1];
  fscanf(fp, read_word_format_string(), buffer);
  cerr << "HMM::read_file unknown keyword " << buffer << endl;  
  throw(CTKError(__FILE__, __LINE__));
}


// Set the static vec size - all future vectors read must match this size
bool  HMMObjectBuilder::read_vec_size(FILE *fp) {
  bool status = (fscanf(fp, "%d", &vec_size)==1);
  return status;
}

bool HMMObjectBuilder::check_size_consistency(int this_size) const {
  if (vec_size==0) vec_size=this_size;
  return (vec_size == this_size);
}

bool HMMObjectBuilder::interpret_IGNORE(FILE *){
  return true;
};

bool HMMObjectBuilder::interpret_StreamInfo(FILE *fp){
  fscanf(fp, "%*d%*d");
  return true;
};

bool HMMObjectBuilder::interpret_VecSize(FILE *fp){
  return read_vec_size(fp);
};

/******************************************************************************/
/*									      */
/*	CLASS NAME: HMMBuilder  	                 	       	      */
/*									      */
/******************************************************************************/

HMMBuilder::HMMBuilder(SetOfHMMs *an_hmmset):HMMObjectBuilder(an_hmmset) {

  build_op_map();

}

HMM *HMMBuilder::getHMMObject(const string &filename,  const string &name) {
  FILE *fp;
  if ((fp=fopen(filename.c_str(),"r"))==0) {
    cerr << "HMM FILE ERROR: Cannot open HMM file: " << filename << endl;
    throw(CTKError(__FILE__, __LINE__));
  }

  do_read_file(fp);

  HMM *hmmp =new HMM(states, pTrans, get_vec_size(), name, filename);

  fclose(fp);
  
  return hmmp;
}


HMM *HMMBuilder::getHMMObject(FILE *fp,  const string &name) {

  do_read_file(fp);
  
  HMM *hmmp =new HMM(states, pTrans, get_vec_size(), name);

  return hmmp;
}

void HMMBuilder::do_read_file(FILE *fp) {
  states.resize(0);
  num_states=0;
  current_state=0;
  pTrans.clear();
  
  if (read_file(this, fp, op_translator)==false)
    error_in_HMM_file(fp);

  // If the transition matrix was defined via a macro name ...
  string this_macro_name = get_macro_name();
  if (!this_macro_name.empty()) {
    pTrans=get_hmmset()->lookup_transition_macro(this_macro_name);
    if (pTrans.size()==0) {
      cerr << "HMM FILE ERROR: Definition for transition macro \"" << this_macro_name << "\" not found." << endl;
    }
  }

}

void HMMBuilder::build_op_map() {
  op_translator["<BEGINHMM>"]	= &HMMBuilder::interpret_IGNORE;
  op_translator["<NULLD>"]	= &HMMBuilder::interpret_IGNORE;
  op_translator["<DIAGC>"]	= &HMMBuilder::interpret_IGNORE;
  op_translator["<NUMSTATES>"]	= &HMMBuilder::interpret_NumStates;
  op_translator["<STREAMINFO>"]	= &HMMBuilder::interpret_StreamInfo;
  op_translator["<VECSIZE>"]	= &HMMBuilder::interpret_VecSize;
  op_translator["<STATE>"]	= &HMMBuilder::interpret_State;
  op_translator["<TRANSP>"]	= &HMMBuilder::interpret_TransP;
  op_translator["<ENDHMM>"]	= &HMMBuilder::interpret_EndHMM;
  op_translator["<USE>"]        = &HMMBuilder::interpret_Use;

  op_translator["~T"]           = &HMMBuilder::interpret_macro_t;

  // Feature vector data type
  op_translator["<DISCRETE>"]	= &HMMBuilder::interpret_IGNORE;
  op_translator["<LPC>"]	= &HMMBuilder::interpret_IGNORE;
  op_translator["<LPCEPSTRA>"]	= &HMMBuilder::interpret_IGNORE;
  op_translator["<MFCC>"]	= &HMMBuilder::interpret_IGNORE;
  op_translator["<MELSPEC>"]	= &HMMBuilder::interpret_IGNORE;
  op_translator["<LPREFC>"]	= &HMMBuilder::interpret_IGNORE;
  op_translator["<LPDELCEP>"]	= &HMMBuilder::interpret_IGNORE;
  op_translator["<FBANK>"]	= &HMMBuilder::interpret_IGNORE;
  op_translator["<USER>"]	= &HMMBuilder::interpret_IGNORE; 
}


bool HMMBuilder::interpret_macro_t(FILE *fp) {
  char buffer[MAX_STRING_SIZE+1];
  fscanf(fp, read_word_format_string(), buffer);
  set_macro_name(buffer);

  return true;
}

bool HMMBuilder::interpret_Use(FILE *fp){
 
  char buffer[MAX_STRING_SIZE+1];
  fscanf(fp, read_word_format_string(), buffer);

  string macro_filename(buffer);
  
  // Remove quotes and spaces
  macro_filename.erase(remove(macro_filename.begin(), macro_filename.end(), '\"'), macro_filename.end());
  macro_filename.erase(remove(macro_filename.begin(), macro_filename.end(), ' '), macro_filename.end());
  
 
  if (get_hmmset()->load_macro_file(macro_filename)==false) {
    cerr << "HMM FILE ERROR: Problem reading macro file: " << macro_filename << endl;
    throw(CTKError(__FILE__, __LINE__));
  }
  
  return true;
}

bool HMMBuilder::interpret_NumStates(FILE *fp){
  fscanf(fp, "%d", &num_states);

  if (num_states<3) {
    cerr << "HMM FILE ERROR: All HMMs must have at least 3 states" << endl;
    return false;
  }
  
  // Don't need to worry about the start and end non-emitting states
  num_states-=2;
  states.resize(num_states);

  return true;
};


bool HMMBuilder::interpret_State(FILE *fp){

  if (num_states==0) {
    cerr << "HMM FILE ERROR: HMM needs a <NumStates> before the first <State> definition." << endl;
    return false;
  }
  
  fscanf(fp, "%d", &current_state);

  current_state-=2;

  if (current_state<0 || current_state>=num_states) {
    cerr << "HMM FILE ERROR: State number out of range." << endl;
    return false;
  }

  HMMStateBuilder state_builder(get_hmmset());
  
  HMMState *sp = state_builder.getHMMStateObject(fp);
   
  // If the mixture was defined via a macro name delete it and find macro in table
  string this_macro_name = state_builder.get_macro_name();
  if (!this_macro_name.empty()) {
    delete sp;
    sp=get_hmmset()->lookup_state_macro(this_macro_name);
    if (sp==NULL) {
      cerr << "HMM FILE ERROR: Definition for state macro \"" << this_macro_name << "\" not found." << endl;
      return false;
    }
  }

  if (states[current_state]!=NULL) {
    cerr << "HMM FILE ERROR: State has multiple definitions." << endl;
    return false;
  }
  
  states[current_state]=sp;
  
  return true;
};


bool HMMBuilder::interpret_TransP(FILE *fp){

  HMMTransitionBuilder transition_builder(get_hmmset());

  pTrans = transition_builder.getHMMTransitionObject(fp);

  return true;
}

bool HMMBuilder::interpret_EndHMM(FILE *){
  set_finished();
  
  for (int i=0; i<num_states; ++i) {
    if (states[i]==NULL) {
      cerr << "HMM FILE ERROR: Missing state definition." << endl;
      throw(CTKError(__FILE__, __LINE__));
    }
  }
  
  return true;
}



/******************************************************************************/
/*									      */
/*	CLASS NAME: HMMStateBuilder   	                 	       	      */
/*									      */
/******************************************************************************/

HMMStateBuilder::HMMStateBuilder(SetOfHMMs *an_hmmset, bool in_macro_def/*=false*/):HMMObjectBuilder(an_hmmset), max_duration(0), voicing(0.0) {

  build_op_map(in_macro_def);

}


HMMState *HMMStateBuilder::getHMMStateObject(FILE *fp) {

  num_mixes=0;
  current_mixture=-1;
  
  if (read_file(this, fp, op_translator)==false) {
    // This may have failed because the NumMixture and Mixture tags are missing - but these can be legally
    // omitted for models with only one mixture. So try reading a single mixture...
    if (current_mixture==-1) {
      get_single_mixture(fp);
    }
  }

  check_mixtures();
  HMMState *hmm_state =new HMMState(mixture, mix_weights, voicing, max_duration);
  
  return hmm_state;
}



void HMMStateBuilder::build_op_map(bool macro_def) {
  op_translator["<VOICING>"]	= &HMMStateBuilder::interpret_Voicing;
  op_translator["<MAXDURATION>"]	= &HMMStateBuilder::interpret_MaxDuration;
  op_translator["<MIXTURE>"]	= &HMMStateBuilder::interpret_Mixture;
  op_translator["<NUMMIXES>"]	= &HMMStateBuilder::interpret_NumMixes;
  if (macro_def==false) 
    op_translator["~S"]           = &HMMStateBuilder::interpret_macro_s;
}

bool HMMStateBuilder::interpret_macro_s(FILE *fp) {
  char buffer[MAX_STRING_SIZE+1];
  fscanf(fp, read_word_format_string(), buffer);
  set_macro_name(buffer);

  current_mixture=1; // Otherwise state will think no mixtures have been read
  return true;
}

bool HMMStateBuilder::interpret_Voicing(FILE *fp){
  if (fscanf(fp, HMM_FLOAT_FORMAT, &voicing)==1) return true;
  else {
    cerr << "HMM FILE ERROR: missing voicing value." << endl;
    return false;
  }
};

bool HMMStateBuilder::interpret_MaxDuration(FILE *fp){
  if (fscanf(fp, "%d", &max_duration)==1) return true;
  else {
    cerr << "HMM FILE ERROR: missing max duration value." << endl;
    return false;
  }
};

bool HMMStateBuilder::interpret_NumMixes(FILE *fp){

  if (num_mixes!=0) {
    cerr << "HMM FILE ERROR: State with more than one <NumMixes> tag" << endl;
    return false;
  }
  
  fscanf(fp, "%d", &num_mixes);

  mixture.clear();
  mixture.resize(num_mixes, 0);
  mix_weights.resize(num_mixes);

  return true;
}

bool HMMStateBuilder::interpret_Mixture(FILE *fp){
  HMMFloat weight;
  
  if (num_mixes==0) {
    cerr << "HMM FILE ERROR: HMM needs a <NumMixes> before the first <Mixture> definition." << endl;
    return false;
  }

  fscanf(fp, "%d", &current_mixture);
  fscanf(fp, HMM_FLOAT_FORMAT, &weight);  

  if (current_mixture<1 || current_mixture>num_mixes) {
    cerr << "HMM FILE ERROR: Mixture number out of range." << endl;
    return false;
  }

  current_mixture-=1;

  mix_weights[current_mixture]=log(weight);

  return get_mixture(fp);
  
}
 

bool HMMStateBuilder::get_single_mixture(FILE *fp) {
  // Used for single mixture HMMs where the <NumMixtures> and <Mixture> keywords can be omitted

  num_mixes=1;
  current_mixture=0; 

  mix_weights.resize(1);
  mixture.resize(1);
  mix_weights[0]=log(1.0);
  
  return get_mixture(fp);
}


bool HMMStateBuilder::get_mixture(FILE *fp){ 

  HMMMixtureBuilder mixture_builder(get_hmmset());
  HMMMixture *mp = mixture_builder.getHMMMixtureObject(fp);
 
  // If the mixture was defined via a macro name delete it and find macro in table ...
  string this_macro_name = mixture_builder.get_macro_name();
  if (!this_macro_name.empty()) {
    delete mp;
    mp=get_hmmset()->lookup_mixture_macro(this_macro_name);
    if (mp==NULL) {
      cerr << "HMM FILE ERROR: Definition for mixture macro \"" << this_macro_name << "\"not found" << endl;
      return false;
    }
  } else {
    // ... else check the mixture definition is complete
    if (mp==NULL) {
      cerr << "Incomplete mixture definition\n";
      return false;
    }

  }

  if (mixture[current_mixture]!=NULL) {
    cerr << "HMM FILE ERROR: Mixture  has multiple definitions." << endl;
    return false;
  }
  
  mixture[current_mixture]=mp;

  return true;
}


// Check all mixtures are present
void HMMStateBuilder::check_mixtures() const {
  
  /*  for (vector<HMMMixture*>::const_iterator mp=mixture.begin(), mp_end=mixture.end(); mp!=mp_end; ++mp) {
    if ((*mp)==NULL) {
      cerr << "HMM FILE ERROR: Missing mixture definition." << endl;
      throw(CTKError(__FILE__, __LINE__));
    }
  }
  */
}


/******************************************************************************/
/*									      */
/*	CLASS NAME: HMMMixtureBuilder   	                 	       	      */
/*									      */
/******************************************************************************/

HMMMixtureBuilder::HMMMixtureBuilder(SetOfHMMs *an_hmmset, bool in_macro_def/*=false*/):HMMObjectBuilder(an_hmmset) {

  build_op_map(in_macro_def);

}


HMMMixture *HMMMixtureBuilder::getHMMMixtureObject(FILE *fp) {

  read_mean=read_variance=false;
  read_file(this, fp, op_translator);

  // If this mixture has not been defined by a macro and either its mean or variance are missing then return an error
  if (!(read_mean && read_variance) && (get_macro_name().size()==0)) {
    if (!read_mean)
      cerr << "HMM FILE ERROR: HMM mixture with missing Mean values." << endl;
    if (!read_variance)
      cerr << "HMM FILE ERROR: HMM mixture with missing Variance values." << endl;
    return NULL;
  }

  return get_hmmset()->get_mixture_prototype()->clone(mu, ivar);

}

void HMMMixtureBuilder::build_op_map(bool macro_def) {
  op_translator["<MEAN>"]	= &HMMMixtureBuilder::interpret_Mean;
  op_translator["<VARIANCE>"]	= &HMMMixtureBuilder::interpret_Variance;
  op_translator["<GCONST>"]	= &HMMMixtureBuilder::interpret_GConst;
  if (macro_def==false) 
    op_translator["~M"]           = &HMMMixtureBuilder::interpret_macro_m;
}

bool HMMMixtureBuilder::interpret_macro_m(FILE *fp) {
  char buffer[MAX_STRING_SIZE+1];
  fscanf(fp, read_word_format_string(), buffer);
  set_macro_name(buffer);
  return true;
}

bool HMMMixtureBuilder::interpret_Mean(FILE *fp){
  int this_size;
  
  if (read_mean) {
    cerr << "HMM FILE ERROR: MixtureBuilder with multiple mean definitions." << endl;
    return false;
  }

  if (fscanf(fp, "%d", &this_size)!=1) {
    cerr << "HMM FILE ERROR: error in mean definition." << endl;
    return false;
  }

  if (check_size_consistency(this_size)==false) {
    cerr << "HMM FILE ERROR: mean vector size inconsistency." << endl;
    return false;
  }

  mu.resize(get_vec_size());
  vector<HMMFloat>::iterator mup=mu.begin();
  for (int i=0; i<get_vec_size(); ++i)
    if (fscanf(fp, HMM_FLOAT_FORMAT, &(*mup++))!=1)   {
      cerr << "HMM FILE ERROR: missing mean parameters." << endl;
      return false;
    }

  read_mean=true;
  return true;
}

bool HMMMixtureBuilder::interpret_Variance(FILE *fp){
  int this_size;

  if (read_variance) {
    cerr << "HMM FILE ERROR: MixtureBuilder with multiple variance definitions." << endl;
    return false;
  }
  
  if (fscanf(fp, "%d", &this_size)!=1) {
    cerr << "HMM FILE ERROR: error in variance definition." << endl;
    return false;
  }
  
  if (check_size_consistency(this_size)==false) {
    cerr << "HMM FILE ERROR: variance vector size inconsistency." << endl;
    return false;
  }

  ivar.resize(get_vec_size());
  vector<HMMFloat>::iterator ivarp=ivar.begin();

  for (int i=0; i<get_vec_size(); ++i) {
    if (fscanf(fp, HMM_FLOAT_FORMAT, &(*ivarp))!=1) {
      cerr << "HMM FILE ERROR: missing variance parameters." << endl;
      return false;
    }
    *ivarp=1.0/ *ivarp;    // variances are stored as inverse variances (1/variance)   
    ++ivarp;
  }
  
  read_variance=true;
  return true;

};


bool HMMMixtureBuilder::interpret_GConst(FILE *fp){
  HMMFloat dummy;
  if (fscanf(fp, HMM_FLOAT_FORMAT, &dummy)==1) return true;
  else {
    cerr << "HMM FILE ERROR: missing gConst value." << endl;
    return false;
  }
};

/******************************************************************************/
/*									      */
/*	CLASS NAME: HMMVarianceBuilder   	              	       	      */
/*									      */
/******************************************************************************/

HMMVarianceBuilder::HMMVarianceBuilder(SetOfHMMs *an_hmmset, bool in_macro_def/*=false*/):HMMObjectBuilder(an_hmmset) {

  build_op_map(in_macro_def);

}


vector<HMMFloat> *HMMVarianceBuilder::getHMMVarianceObject(FILE *fp) {

  read_variance=false;
  read_file(this, fp, op_translator);

  if (!read_variance) {
    cerr << "HMM FILE ERROR: HMM mixture with missing Variance values." << endl;
    throw(CTKError(__FILE__, __LINE__));
  }

  return new vector<HMMFloat>(ivar);

}

void HMMVarianceBuilder::build_op_map(bool) {
  op_translator["<VARIANCE>"]	= &HMMVarianceBuilder::interpret_Variance;
}


bool HMMVarianceBuilder::interpret_Variance(FILE *fp){
  int this_size;

  if (read_variance) {
    cerr << "HMM FILE ERROR: VarianceBuilder with multiple variance definitions." << endl;
    return false;
  }
  
  if (fscanf(fp, "%d", &this_size)!=1) {
    cerr << "HMM FILE ERROR: error in variance definition." << endl;
    return false;
  }
  
  ivar.resize(get_vec_size());
  vector<HMMFloat>::iterator ivarp=ivar.begin();

  for (int i=0; i<get_vec_size(); ++i) {
    if (fscanf(fp, HMM_FLOAT_FORMAT, &(*ivarp))!=1) {
      cerr << "HMM FILE ERROR: missing variance parameters." << endl;
      return false;
    }
    *ivarp=1.0/ *ivarp;    // variances are stored as inverse variances (1/variance)   
    ++ivarp;
  }
  
  read_variance=true;
  return true;

};

/******************************************************************************/
/*									      */
/*	CLASS NAME: HMMTransitionBuilder   	              	       	      */
/*									      */
/******************************************************************************/

HMMTransitionBuilder::HMMTransitionBuilder(SetOfHMMs *an_hmmset, bool ):HMMObjectBuilder(an_hmmset) {

}


Transitions HMMTransitionBuilder::getHMMTransitionObject(FILE *fp) {

  interpret_TransP(fp);
  
  return pTrans;
}

bool HMMTransitionBuilder::interpret_TransP(FILE *fp){
  int tp_size;
  fscanf(fp,"%d", &tp_size);

  pTrans.resize(tp_size);

  HMMFloat prob;
  for (int i=0; i<tp_size; ++i) {
    for (int j=0; j<tp_size; ++j) {
      fscanf(fp, HMM_FLOAT_FORMAT, &prob);
      pTrans.add_transition(i,j,prob);
    }
  }

  //  // Validate transition probabilities
  //  if (pTrans.validate()==CTK_FAILURE) {
  //    cerr << "HMM FILE ERROR: Inconsistent transition matrix." << endl;
  //   throw(CTKError(__FILE__, __LINE__));
  //  }

  return true;
};


/******************************************************************************/
/*									      */
/*	CLASS NAME: HMMGlobalBuilder   	              	       	      */
/*									      */
/******************************************************************************/

HMMGlobalBuilder::HMMGlobalBuilder(SetOfHMMs *an_hmmset, bool in_macro_def/*=false*/):HMMObjectBuilder(an_hmmset) {

  build_op_map(in_macro_def);

}


void HMMGlobalBuilder::getHMMGlobalObject(FILE *fp) {

  read_file(this, fp, op_translator);

  //  return new vector<HMMFloat>(ivar);

}

void HMMGlobalBuilder::build_op_map(bool) {
  op_translator["<STREAMINFO>"]	= &HMMGlobalBuilder::interpret_StreamInfo;
  op_translator["<VECSIZE>"]	= &HMMGlobalBuilder::interpret_VecSize;
  op_translator["<NULLD>"]	= &HMMGlobalBuilder::interpret_IGNORE;
  op_translator["<DIAGC>"]	= &HMMGlobalBuilder::interpret_IGNORE;
  // Feature vector data type
  op_translator["<DISCRETE>"]	= &HMMGlobalBuilder::interpret_IGNORE;
  op_translator["<LPC>"]	= &HMMGlobalBuilder::interpret_IGNORE;
  op_translator["<LPCEPSTRA>"]	= &HMMGlobalBuilder::interpret_IGNORE;
  op_translator["<MFCC>"]	= &HMMGlobalBuilder::interpret_IGNORE;
  op_translator["<MELSPEC>"]	= &HMMGlobalBuilder::interpret_IGNORE;
  op_translator["<LPREFC>"]	= &HMMGlobalBuilder::interpret_IGNORE;
  op_translator["<LPDELCEP>"]	= &HMMGlobalBuilder::interpret_IGNORE;
  op_translator["<FBANK>"]	= &HMMGlobalBuilder::interpret_IGNORE;
  op_translator["<USER>"]	= &HMMGlobalBuilder::interpret_IGNORE;
}


