/******************************************************************************/
/*									      */
/*	ctk_reco.cpp	 	       			              */
/*									      */
/*	Block for HMM decoding                                                */
/*									      */
/*	Author: Jon Barker, Sheffield University			      */
/*									      */
/*      CTK VERSION 1.3.5  Apr 22, 2007		         	      */
/*									      */
/******************************************************************************/

#include "ctk-config.h"

#include "ctk_reco.hh"

#include <vector>
#include <sstream>

#include "ctk_local.hh"


/******************************************************************************/
/*									      */
/*	CLASS NAME: RecoHypothesis		              	       	      */
/*									      */
/******************************************************************************/

RecoHypothesis::RecoHypothesis(const list<string> &a_solution, HMMFloat a_score, list<int> the_bounds, list<vector<int> > the_states, string groups):solution(a_solution), score(a_score), bounds(the_bounds), states(the_states), group_record(groups) {}

const list<string> &RecoHypothesis::get_solution_raw() const {return solution;}

const list<string> RecoHypothesis::get_solution_source() const {
  // Return just source one tokens where parallel sources are employed
  // as just return unprocessed tokens as in get_solution_raw
  list<string> solution_source;
  for (list<string>::const_iterator sp=solution.begin(); sp!=solution.end(); ++sp) {
    string name1, name2;
    if (parse_model_name(*sp, name1, name2)) {
      if (name1.size()>0)
	solution_source.push_back(name1);
    } else {
      solution_source.push_back(*sp);
    }
  }
	  
  return solution_source;
}

// Returns true if the hypothesis matches the regular expression
Boolean RecoHypothesis::matches_filter(regex_t *filter) const {

  string solution_string = list_string_to_string(solution);
  
  return (regexec(filter, solution_string.c_str(), 0, NULL, 0)==0);
}

string RecoHypothesis::get_group_record() const {return group_record;}

void RecoHypothesis::write_all(FILE *file) {
  string solution_string=list_string_to_string(solution);
  fprintf(file,"<hypothesis> %s </hypothesis>\n",solution_string.c_str());
  fprintf(file,"<score> %8.5e </score>\n", score);
  write_model_sequence(file);
}

void RecoHypothesis::write_all_multisource(FILE *file, map<Integer, Integer> &number_to_label_map ) {
  string solution_string=list_string_to_string(solution);
  fprintf(file,"<hypothesis> %s </hypothesis>\n",solution_string.c_str());
  fprintf(file,"<score> %8.5e </score>\n", score);

  fprintf(file,"<groups> "); 
  for (unsigned int i=1; i<group_record.size(); ++i)
    if (group_record[i]=='1') 
      fprintf(file,"%d ",number_to_label_map[i]);
  fprintf(file,"</groups>\n");

  write_model_sequence(file);
}


void RecoHypothesis::write_boundaries(FILE *file) const {
  fprintf(file,"<boundaries>");
  for (list<int>::const_iterator bp=bounds.begin(); bp!=bounds.end(); ++bp)
    fprintf(file,"%d ",*bp);
  fprintf(file,"</boundaries>\n");
}

void RecoHypothesis::write_TIMIT_format(FILE *file, int samples_per_frame) const {

  int start_sample=0, end_sample;

  list<int>::const_iterator bp=bounds.begin();
  list<string>::const_iterator sp=solution.begin();
  
  for (unsigned int i=0; i<solution.size(); ++i, ++bp, ++sp) {
    fprintf(file,"%d %d %s\n", start_sample, end_sample=*bp*samples_per_frame, sp->c_str());
    start_sample=end_sample;
  }
  
}

void RecoHypothesis::write_model_sequence(FILE *file) const {

  int start1=1, end1;
  int start2=1, end2;
  list<vector<int> >::const_iterator sp=states.begin();
  list<int>::const_iterator bp=bounds.begin();
  list<string>::const_iterator solp=solution.begin(), solp_end=solution.end();
  
  for (;solp!=solp_end; ++solp) {
    string name1, name2;
    if (parse_model_name(*solp, name1, name2)) {
      if (name1.size()!=0) {
	fprintf(file,"<model name=\"1:%s\" start=\"%d\" end=\"%d\">\n", name1.c_str(), start1, end1=*bp);
    fprintf(file,"</model>\n");
	start1=end1+1;
      } 
      if (name2.size()!=0) {
	fprintf(file,"<model name=\"2:%s\" start=\"%d\" end=\"%d\">\n", name2.c_str(), start2, end2=*bp);
    fprintf(file,"</model>\n");
	start2=end2+1;
      }
    } else {
      fprintf(file,"<model name=\"%s\" start=\"%d\" end=\"%d\">\n", solp->c_str(), start1, end1=*bp);
      if (states.size()!=0) {
	for (vector<int>::const_iterator xp=sp->begin(); xp!=sp->end(); ++xp)
	  fprintf(file,"%d ",*xp);
	fprintf(file,"\n");
	++sp;
      }
      fprintf(file,"</model>\n");
      start1=end1+1;
    }

    ++bp;
  }
  
}


bool  RecoHypothesis::parse_model_name(const string &label, string &name1, string &name2) const {
  // Extracts labelA and labelB from string labels of the forms
  // 1:labelA
  // 2:labelB
  // +1:labelA+2:labelB
  // +2:labelA+1:labelB
  // Returns true if one of the above formats was recognised, else false.

  if (label[0]=='1' && label[1]==':') 
    name1=label.substr(2);
  else if (label[0]=='2' && label[1]==':') 
    name2=label.substr(2);
  else if (label[0]=='+' && label[1]=='1' && label[2]==':') {
    unsigned int pos2=label.find("+2:");
    if (pos2!=std::string::npos) {
      name1=label.substr(3,pos2-3);
      name2=label.substr(pos2+3);
    }
  } else if (label[0]=='+' && label[1]=='2' && label[2]==':') {
    unsigned int pos2=label.find("+1:");
    if (pos2!=std::string::npos) {
      name2=label.substr(3,pos2-3);
      name1=label.substr(pos2+3);
    }
  } else return false;

  return true;

}

ostream& operator<< (ostream& out, const RecoHypothesis &hyp) {
  out << list_string_to_string(hyp.solution);
  return out;
}

string list_string_to_string(const list<string> &vstring) {
  string output;
  list<string>::const_iterator sp=vstring.begin();
  list<string>::const_iterator sp_end=vstring.end();

  if (sp!=sp_end) {
    output+=*sp++;
    for (; sp!=sp_end; ++sp) {
      output+=" ";
      output+=*sp;
    }
  }

  return output;
}

list<string> string_to_list_string(const string &str) {
  list<string> words;
  string word;
  
  istringstream istr(str);
  while (istr>>word) words.push_back(word);
  return words;
}

/******************************************************************************/
/*									      */
/*	CLASS NAME: RecoStats   		              	       	      */
/*			       					      */
/******************************************************************************/

RecoStats::RecoStats() {
  commonConstructor();
  labels.resize(0);
};

RecoStats::RecoStats(const list<string> &label_list) {
  commonConstructor();

  set_labels(label_list);
}


RecoStats::~RecoStats() {}

// Code that is common to all constructors
void RecoStats::commonConstructor() {
  ntokens_ref=ntokens_test=0;
  hits=dels=ins=subs=0;
}

void RecoStats::calc_stats(const list<string> &test, const list<string> &ref) {
  
  ntokens_test=test.size();
  ntokens_ref=ref.size();

  vector< vector<GridPos> > grid;

  align_strings(test, ref, &grid);

  dels+=(grid[ntokens_test][ntokens_ref]).del;
  ins+=(grid[ntokens_test][ntokens_ref]).ins;
  subs+=(grid[ntokens_test][ntokens_ref]).subs;
  hits+=(grid[ntokens_test][ntokens_ref]).hits;

  Boolean go=true;

  int i=ntokens_test;
  int j=ntokens_ref;

  list<string>::const_iterator tp=test.end();
  list<string>::const_iterator rp=ref.end();
  --tp; --rp;
  
  int x,y;
  
  while (go) {
    switch ((grid[i][j]).dir) {
    case NILDIR:
      go=0;
      break;
    case DIAGDIR:
      x=label_num[*rp--]; j-=1;
      y=label_num[*tp--]; i-=1;
      confusions[x][y]+=1;
      break;
    case VERTDIR:
      x=label_num[*rp--]; j-=1;
      deletions[x]+=1;
      break;
    case HORDIR:
      x=label_num[*tp--]; i-=1;
      insertions[x]+=1;
      break;
    }
    if (go) {
      go = !((i==0) && (j==0));
    }
  }

  correctness=compute_correctness(hits, ntokens_ref);
  accuracy=compute_accuracy(hits, ins, dels, subs);
  
}


void RecoStats::print_stats(FILE *outfile) {

  fprintf(outfile, "Nin: %d Nout: %d (H=%d D=%d S=%d I=%d) Cor: %5.3f Acc %5.3f --", ntokens_ref, ntokens_test, hits, dels, subs, ins, correctness, accuracy);
  
}

void RecoStats::print_confusions(FILE *outfile) {

  fprintf(outfile,"\t");

  list<string>::const_iterator lp;
  list<string>::const_iterator lp_end=labels.end();
  for (lp=labels.begin(); lp!=lp_end; ++lp)
    fprintf(outfile,"%s\t",lp->c_str());
  fprintf(outfile,"DELS\tN\tCorr\tAcc\n");

  lp=labels.begin();
  for (int i=0; i<nlabels; ++i, ++lp) {
    int nout=0, nj;
    fprintf(outfile,"%s:\t",lp->c_str());
    for (int j=0; j<nlabels; ++j) {
      nout+=(nj=confusions[i][j]);
      fprintf(outfile,"%d\t",nj);
    }

    int D=deletions[i];
    int N=nout+D;
    fprintf(outfile,"%d\t",D);
    fprintf(outfile,"%d\t",N);  // Times occurring in test set
    if (N>0) {
      int H=confusions[i][i];  // Hits
      int S=nout-H;              // Substitutions
      fprintf(outfile,"%5.3f\t",compute_correctness(H, N)); // Correctness for this symbol
      fprintf(outfile,"%5.3f\t",compute_accuracy(H, insertions[i], D, S)); // Accuracy for this symbol
    }
    fprintf(outfile,"\n");
  }

  fprintf(outfile,"INS:\t");
  for (int i=0; i<nlabels; ++i)
    fprintf(outfile,"%d\t",insertions[i]);
  fprintf(outfile,"\n");
  
}


void RecoStats::operator+=(const RecoStats &stats) {

  if (labels.size()==0) {
    set_labels(stats.labels);
  }
  
  ntokens_ref+=stats.ntokens_ref;
  ntokens_test+=stats.ntokens_test;

  hits+=stats.hits;
  dels+=stats.dels;
  ins+=stats.ins;
  subs+=stats.subs;

  for (int i=0; i<nlabels; ++i) {
    deletions[i] += stats.deletions[i];
    insertions[i] += stats.insertions[i];
  }

  for (int i=0; i<nlabels; ++i) {
    for (int j=0; j<nlabels; ++j) {
      confusions[i][j]+=stats.confusions[i][j];
    }
  }
  
  correctness=compute_correctness(hits, ntokens_ref);
  accuracy=compute_accuracy(hits, ins, dels, subs);

}

/**** Private methods *****/

void RecoStats::align_strings(const list<string> &test_string, const list<string> &ref_string, vector<vector<GridPos> > *grid) {

  const int dtw_subpen=10;
  const int dtw_inspen=7;
  const int dtw_delpen=7;
  

  int test_size=test_string.size();
  int ref_size=ref_string.size();

  grid->resize(test_size+1);
  for (int i=0; i<=test_size; ++i)
    (*grid)[i].resize(ref_size+1);

  for (int i=1; i<=test_size; ++i) {
    (*grid)[i][0] = (*grid)[i-1][0];
    (*grid)[i][0].dir = HORDIR;
    (*grid)[i][0].score += dtw_inspen;
    (*grid)[i][0].ins += 1;
  }
  
  for (int i=1; i<=ref_size; ++i) {
    (*grid)[0][i] = (*grid)[0][i-1];
    (*grid)[0][i].dir = VERTDIR;
    (*grid)[0][i].score += dtw_delpen;
    (*grid)[0][i].del += 1;
  }

  int h, v, d;

  list<string>::const_iterator tp=test_string.begin();
  for (int i=1; i<test_size+1; ++i, ++tp) {
    list<string>::const_iterator rp=ref_string.begin();
    for (int j=1; j<ref_size+1; ++j, ++rp) {
      h = (*grid)[i-1][j].score + dtw_inspen;
      d = (*grid)[i-1][j-1].score;
      if (*tp != *rp)
	d += dtw_subpen;
      v = (*grid)[i][j-1].score + dtw_delpen;
      if ((d <= h) && (d <= v)) {
	(*grid)[i][j] = (*grid)[i-1][j-1];
	(*grid)[i][j].score = d;
	(*grid)[i][j].dir = DIAGDIR;
	if (*rp == *tp)
	  (*grid)[i][j].hits += 1;
	else
	  (*grid)[i][j].subs += 1;
      } else if (h<v) {
	(*grid)[i][j] = (*grid)[i-1][j];
	(*grid)[i][j].score = h;
	(*grid)[i][j].dir = HORDIR;
	(*grid)[i][j].ins += 1;
      } else {
	(*grid)[i][j] = (*grid)[i][j-1];
	(*grid)[i][j].score = v;
	(*grid)[i][j].dir = VERTDIR;
	(*grid)[i][j].del += 1;
      }
    }
  }  
  (*grid)[0][0].dir=NILDIR;
  
}

void RecoStats::set_labels(const list<string> &some_labels) {

  labels=some_labels;

  nlabels=labels.size();
  
  insertions.resize(nlabels);
  deletions.resize(nlabels);

  confusions.resize(nlabels);
  for (int i=0; i<nlabels; ++i)
    confusions[i].resize(nlabels);
  
  list<string>::const_iterator lp = labels.begin();
  for (int i=0; i<nlabels; ++i, ++lp) {
    label_num[*lp]=i;
  }

}

Float RecoStats::compute_correctness(int H, int N) {
  return (100.0*H)/N;
}

Float RecoStats::compute_accuracy(int H, int I, int D, int S) {
  return (100.0*(H-I))/(H+D+S);
}

  
 /* End of ctk_reco.cpp */
