/******************************************************************************/
/*									      */
/*	ctk_decoder.hh   			             		      */
/*									      */
/*	Class declarations for ctk_decoder.cpp	              	      */
/*									      */
/*	Author: Jon Barker, Sheffield University			      */
/*									      */
/*      CTK VERSION 1.3.5  Apr 22, 2007			     	      */
/*									      */
/******************************************************************************/

#ifndef CTK_DECODER_HH
#define CTK_DECODER_HH

#include <sys/types.h>
#include <regex.h> /* Provides regular expression matching */

#include "ctk_local.hh"

#include <list>
#include <vector>
#include <string>

#include "ctk_HMM.hh"
#include "ctk_HMM_types.hh"
#include "ctk_group.hh"

class HMM;
class HMMState;



const Integer FILTER_NBEST_N_CONST = 50;                    // Size of NBest list to process when employing filter
const Integer MIN_BACK_TRACE_THRESHOLD_CONST = -20000;    // Minimum likelihood difference threshold for NBest traceback
const Integer MIN_BACK_TRACE_THRESHOLD_CONST2 = -1000;    // Minimum likelihood difference threshold for NBest traceback
const Float NBEST_THRESHOLD_STEP_SIZE = 5.0;

/******************************************************************************/
/*									      */
/*	CLASS NAME: GroupHypothesis	      	         	              */
/*									      */
/******************************************************************************/

// Stores a group hypothesis

class GroupHypothesis {

private:
  
  vector<Integer> groups; // The list of group numbers that make up this hypothesis
  string unique_name;     // The group mask stored as a string of 0's and 1's
  string masked_name;     // 0's,1's replaced with X when group ends - used to find decoders to merge
  int slot;               // The token slot that this hypothesis employs in the decoder
  
public:

  GroupHypothesis();
  explicit GroupHypothesis(const string &a_group_record);
  
  GroupHypothesis(const GroupHypothesis &hyp);
  GroupHypothesis &operator=(const GroupHypothesis &gh);

  
  void make_discrete_mask(const vector<Integer> &ingroup_mask, const vector<Float> &inmask, vector<Float> &outmask, bool inverted);
  void make_soft_mask(const vector<Integer> &ingroup_mask, const vector<Float> &inmask, vector<Float> &outmask, bool inverted);

  void make_masked_name(const Group &group);
  void make_masked_name_all_groups_dead();
  
  void set_group_record(const string &group_record_string);

  //  void add_group_background(const Group &group) {add_group(group.number(), false);}
  //  void add_group_speech(const Group &group) {add_group(group.number(), true);}
  void add_group_background(const Group &group) {add_group(group, false);}
  void add_group_speech(const Group &group) {add_group(group, true);}

  int get_slot() const {return slot;}
  void set_slot(int a_slot){slot=a_slot;}
  
  string get_masked_name() const {return masked_name;}
  
  friend ostream& operator<< (ostream& out, const GroupHypothesis& x);

  friend bool comp_group_hypotheses(const GroupHypothesis &d1, const GroupHypothesis &d2);

  // Return true if the given group is present in this hypothesis
  bool has_group(int group_number) { return unique_name[group_number]=='1';}
  
private:

  //  void add_group(Integer group_number, Boolean mask_flag);
  void add_group(const Group &group, Boolean mask_flag);


};

//
// Stores decoder data
//

class RecoHypothesis;
class Decoder;


/******************************************************************************/
/*									      */
/*	CLASS NAME: Token  	              	         	              */
/*									      */
/******************************************************************************/

class WordRecord;  
class EmittingNodeBase;


class Token {
  
private:

  static HMMFloat MAXVALIDPROB;
  
  friend bool compareToken(const Token &t1, const Token &t2); 
  
  WordRecord *wrp;          //   Word record for the current token
  HMMFloat score;           //   Hypothesis score
  string group_record;      //   CASA Group inclusion record for token
  const string *labelp;  // pointer to output label - NULL if not ready to output
  vector<int> state_trace;   // Record of state numbers
  
public:
  
  Token();
  
  // Copy constructor for token
  Token(const Token &atoken);  
  
  // Assignment operator for token
  Token &operator=(const Token &atoken);

  // Ammended assignment 
  void ammend(const Token &atoken, HMMFloat score, const string *labelp);

  void addState(int state_number) {
    state_trace.push_back(state_number);
  }

  int getNStates() const {return state_trace.size();}
  vector<int> getStateTrace() const {return state_trace;}
  
  void reset() {deactivate(); group_record="_\0"; wrp=NULL;}
  
  HMMFloat getScore() const {return score;}
  void setScore(HMMFloat ascore){score=ascore;}
  
  void addGroup(Boolean flag);
  
  void activate();
  void deactivate();

  void addProb(HMMFloat x){score+=x;}

  void extendWordRecord(int frame, Float word_creation_penalty);
  void extendWordRecord(int frame, Float word_creation_penalty, const list<Token> &exit_tokens);

  //  void setHMM(const HMM *anhmm) {hmm=anhmm;}
  const string* getLabelp() const;

  void setLabelp(const string *alabelp) {labelp=alabelp;}
  
  inline bool isValid() const {return score>MAXVALIDPROB;}

  const WordRecord* getWRP() const {return wrp;}
  string getGroupRecord() const {return group_record;}
  void clearGroupRecord() {group_record.clear();}
  //
  static void SetMaxValidProb(HMMFloat prob){MAXVALIDPROB=prob;}
  static void ResetMaxValidProb(){MAXVALIDPROB = -numeric_limits<float>::max();}
  
  friend ostream& operator<< (ostream& out, const Token& x);

};

bool compareToken(const Token &t1, const Token &t2); 

/******************************************************************************/
/*									      */
/*	CLASS NAME: Modified Token                 	         	              */
/*									      */
/******************************************************************************/

struct UpdatedToken {
  const Token *tp_;
  HMMFloat newprob_;
  const string *labelp_;
  
  UpdatedToken(const Token *tp, HMMFloat newprob, const string *labelp):tp_(tp), newprob_(newprob), labelp_(labelp){};
  UpdatedToken(const Token *tp, HMMFloat newprob):tp_(tp), newprob_(newprob), labelp_(NULL){};
};

/******************************************************************************/
/*									      */
/*	CLASS NAME: WordRecord                 	         	              */
/*									      */
/******************************************************************************/

class WordRecord {

  static list<WordRecord *> record_list;

private:
  int frame;                      // The frame at which this word ended
  Token best_token;               // Highest scoring token arriving at the end of the word
  list<Token> all_tokens;   // All the tokens that arrived at the end of the word
  
public:

  WordRecord(int aframe, Token best, const list<Token> &all);

  WordRecord(int aframe, Token best);
  
  //  void display(ostream &outfile, bool with_segmentation=0) const;
    

  const WordRecord *getPrevWordRecord() const;
  
  vector<RecoHypothesis*> backTrace(vector<RecoHypothesis*> &hyps, HMMFloat threshold, list<string> labelsofar, HMMFloat scoresofar, HMMFloat relative_scoresofar, list<int> boundsofar, list<vector<int> > statesofar, string groups) const;

  static int getRecordListSize() { return record_list.size();}
  static void initialiseWordRecordList() {record_list.resize(0);}
  static void destroyAllWordRecords();

  friend ostream& operator<< (ostream& out, const WordRecord& x);

private:

  const string *getLabelp() const;
  
  // These constructor are declared private and not provided with definitions
  WordRecord(const WordRecord &wr);
  WordRecord &operator= (const WordRecord &wr);
  
};




/******************************************************************************/
/*									      */
/*	CLASS NAME: Node  	              	         	              */
/*									      */
/******************************************************************************/

class Transition;
class PushTransition;
class PullTransition;
class NonEmittingNode;

// Node in the constructed HMM network
class Node {
  friend class Transition;
  friend class PushTransition;
  friend class PullTransition;

public:

  static void ResetMaxObservedProb(){MAXOBSERVEDPROB = -numeric_limits<float>::max();}
  static HMMFloat GetMaxObservedProb(){return MAXOBSERVEDPROB;}
  static void SetPruning(bool x){PRUNING=x;}
  static bool Pruning(){return PRUNING;}
  
private:

  static HMMFloat MAXOBSERVEDPROB;
  static bool PRUNING;

  static int new_node_number;            // Number to assign to newly constructed nodes
  
  list<PushTransition*> push_transitions_received;  // Push transitions that push token to here

  int node_number;            // A unique node number held for debug purposes
  int state_num;   // Index of the HMM state within the HMM model or -1 for non-emitting nodes
  int node_id;     // A user settable ID number

  
protected:
  

  list<PushTransition*> push_transitions;     // To nodes
  list<PullTransition*> pull_transitions_supplied;  // Pull transitions that pull token from here

  static int token_slot;                 // Current token slot

  
  vector<Token> token;             // Current tokens - 1 for each parallel decoder token slot

  bool active;  // stores whether node has active token
  
  list<const UpdatedToken*> updated_token_pointers;  // Incoming token pointers
  list<Token> new_tokens; // Constructed incoming tokens - needed only for NBest lists
  
  Token token_copy;  // Tempory storage for incoming token
  
  bool record_state_path;  // Set to true if this node should be added to token's state history. 

  //
  //
  //
  
protected:

  // Protected constructor
  Node(bool record_state_path_value):node_number(++new_node_number), node_id(0), record_state_path(record_state_path_value) {
    token.resize(1);
  };

public:
  virtual ~Node();

  void activate();
  void deactivate();

  void mergeTokens(const vector<vector<int> >&merge_slot);

  void adjustTokenScore(int slot, float adjustment) {token[slot].addProb(adjustment);}
  
  int getNodeNumber() const {return node_number;}
  int getHMMStateNumber() const {return state_num;}
  
  void setID(int id) {node_id=id;}
  int getID() const {return node_id;}
  
  static void setTokenSlot(int n) {token_slot=n;}
  static void resetNewNodeNumber() {new_node_number=0;}
  
  PushTransition *connectTo(Node *to_node, HMMFloat prob);
  
  // New group started - double up tokens and add group backtrace info
  void newGroup();

  void pushToken();      // push tokens to next state

  void realiseUpdatedTokens(); // Turns UpdatedToken pointers into a list of real tokens 
  void compareTokens(bool debug=false);    // Replace token with best of new_tokens 
  void updateToken();   // Move temp token (token_copy) to the token slot
  void clearUpdatedTokenPointers();
  
  const Token &getToken() const {return token[token_slot];} 


  inline bool hasValidToken() const {return token[token_slot].isValid();}
  void reset();
  
  virtual bool isEmitting() const = 0;
  
  void addPullTransitionSupplied(PullTransition *pullp){pull_transitions_supplied.push_back(pullp);}
  void addPushTransitionReceived(PushTransition *pushp){push_transitions_received.push_back(pushp);}

  list<PushTransition*>::const_iterator getPushTransitionsBegin() const {return push_transitions.begin();}
  list<PushTransition*>::const_iterator getPushTransitionsEnd() const {return push_transitions.end();}

  int nPullTransitionsSupplied() const {return pull_transitions_supplied.size();}
  int nPushTransitionsReceived() const {return push_transitions_received.size();}
  int nPushTransitions() const {return push_transitions.size();}

  virtual void reattachTransitions(NonEmittingNode &redirect_node);
  
  virtual bool recursiveNENLoopSearch(list<NonEmittingNode*> &) {return false;}


  virtual ostream &display(ostream& out) const;
  
  void removePushTransition(PushTransition *trans);

  void removePushTransitionReceived(PushTransition *trans);

  void removePullTransitionSupplied(PullTransition *trans);
  
  virtual void recordStateToToken(Token &token_copy) const =0;

protected:
  inline void receiveToken(const UpdatedToken *uptp){
    updated_token_pointers.push_back(uptp);
  }

  virtual void setFlag(){}   // These operations do nothing by default
  virtual void clearFlag(){} //
  
  
  
};


/******************************************************************************/
/*									      */
/*	STRUCT NAME: NodePair   	      	         	              */
/*									      */
/******************************************************************************/

struct NodePair {
  class NonEmittingNode *start;
  class NonEmittingNode *end;
private:
  bool empty_;
  
public:
  NodePair(NonEmittingNode *astart, NonEmittingNode *aend):start(astart), end(aend), empty_(false) {};
  
  NodePair():start(NULL), end(NULL), empty_(true){};

  bool empty(){return empty_;}
};

/******************************************************************************/
/*									      */
/*	CLASS NAME: NonEmittingNode  	      	         	              */
/*									      */
/******************************************************************************/

// End Node for collecting tokens from word endings and passing to word beginnings

class NonEmittingNode:public Node {

private:
  HMMFloat word_creation_penalty;   // Additional penalty added to likelihood when entering this state
  bool save_nbest_info;             // If set true then data for nbest lists is saved
  bool flag;          // Flag used in the network analysis algorithms
  bool word_end_node; // true if node represents the end of a word in the grammar
  //
  //
  //
protected:
  

public:

  list<PullTransition*> pull_transitions;   

  NonEmittingNode(bool word_end=true):Node(false),word_creation_penalty(0.0), save_nbest_info(false), flag(false), word_end_node(word_end) {}  // 
  NonEmittingNode(list<NonEmittingNode*> &node_loop); // Construct a node by merging a set of nodes

  list<PullTransition*>::const_iterator getPullTransitionsBegin() const {return pull_transitions.begin();}
  list<PullTransition*>::const_iterator getPullTransitionsEnd() const {return pull_transitions.end();}

  virtual ~NonEmittingNode();

  HMMFloat getWordCreationPenalty() const {return word_creation_penalty;}

  virtual void reattachTransitions(NonEmittingNode &redirect_node);

  PullTransition *connectFrom(Node *from_node, HMMFloat prob);

  void pullToken();   // Pull tokens from previous states

  void extendWordRecord(int frame);

  void setWordCreationPenalty(HMMFloat penalty){word_creation_penalty=penalty;}

  virtual bool recursiveNENLoopSearch(list<NonEmittingNode*> &node_loop);

  bool isWordEndNode() const {return word_end_node;}
  
  void flagFromNodes();
  
  bool flagged(){return flag;}
  virtual void setFlag(){flag=1;}
  virtual void clearFlag(){flag=0;}

  virtual bool isEmitting() const {return false;}
 
  void activateNBestInfo(bool x) {save_nbest_info=x;}

  int nPullTransitions() const {return pull_transitions.size();}

  // bypasses the node and return true if it is redundant
  // bypassed nodes should then be deleted
  bool bypassIfRedundant();

  // Remove any transitions that refer a node back to itself
  void removeSelfTransitions();
  
  void removePullTransition(PullTransition *trans);
    
  friend ostream& operator<< (ostream& out, const NonEmittingNode& x);
  
  virtual void recordStateToToken(Token &) const {};

  ostream &display(ostream& out) const;

};

/******************************************************************************/
/*									      */
/*	CLASS NAME: EmittingNodeBase  	       	         	              */
/*									      */
/******************************************************************************/
//
// Base class for all type of emitting ndde
//

class EmittingNodeBase:public Node {

private:


public:
  EmittingNodeBase();
  
  virtual ~EmittingNodeBase(){}

  virtual void addEmissionProb()=0;

  virtual const HMMState *getHMMState() const = 0;

  virtual bool isEmitting() const {return true;}

  virtual int getStateNum() const = 0;

  void activateStatePathRecording(bool x);

  virtual void setLabel(const string &alabel)=0;
  virtual const string &getLabel() const = 0;

  virtual ostream &display(ostream & out) const = 0;
};


/******************************************************************************/
/*									      */
/*	CLASS NAME: EmittingNode  	       	         	              */
/*									      */
/******************************************************************************/

class EmittingNode:public EmittingNodeBase {

private:

  string label;      // Optional output label for this state
  bool has_label;
  const HMMState * const state; // HMM state to which this node is attached
  int state_num;
  //
  //
  //

public:
  EmittingNode(const HMMState *a_state, int a_state_num);

  virtual ~EmittingNode(){}

  const HMMState *getHMMState() const {return state;}
  int getStateNum() const {return state_num;}
  
  void addEmissionProb();

  void setLabel(const string &alabel);
  const string &getLabel() const;

  virtual ostream &display(ostream & out) const;

  virtual void recordStateToToken(Token &token_copy) const {
    token_copy.addState(state_num);
  }
};


/******************************************************************************/
/*									      */
/*	CLASS NAME: CompositeEmittingNode  	       	         	              */
/*									      */
/******************************************************************************/

class CompositeEmittingNode:public EmittingNodeBase {

private:

  string label1;      // Optional output label for this state
  bool has_label1;
  const HMMState * const state1; // HMM state to which this node is attached
  const int state_num1;
  
  string label2;      // Optional output label for this state
  bool has_label2;
  const HMMState * const state2; // HMM state to which this node is attached
  const int state_num2;
//
  //
  //

public:
  CompositeEmittingNode(const HMMState *stat1, int astate_num1, const HMMState *state2, int astate_num2);

  virtual ~CompositeEmittingNode(){}

  void setLabel(const string &){};
  virtual const string &getLabel() const {exit(-1); return label1;};

  void setLabel1(const string &alabel);
  void setLabel2(const string &alabel);

  const string &getLabel1() const;
  const string &getLabel2() const;

  int getStateNum() const {return 0;}
  const HMMState *getHMMState() const {return NULL;}

  int getStateNum1() const {return state_num1;}
  int getStateNum2() const {return state_num1;}

  void addEmissionProb();
  
  virtual ostream &display(ostream & out) const;
  friend ostream& operator<< (ostream& out, const CompositeEmittingNode& x);

  virtual void recordStateToToken(Token &token_copy) const {
    token_copy.addState(state_num1);
    token_copy.addState(state_num2);
  }
};

/******************************************************************************/
/*									      */
/*	CLASS NAME: Transition                 	         	              */
/*									      */
/******************************************************************************/

class Transition {

protected:

  Node *base_node;    // Node to which tranisition 'belongs'
  Node *node;         // Node to which tranistion pushes or from which it pulls
  const HMMFloat log_prob;  
  
protected:
  
  Transition(Node *abase_node, Node *anode, HMMFloat aprob);
  
  // Copy constructor for Transition
  Transition(const Transition &trans): base_node(trans.base_node), node(trans.node), log_prob(trans.log_prob) {};

  virtual ~Transition(); 
  
public:
  Node *getNode() const {return node;}
  Node *getBaseNode() const {return base_node;}
  HMMFloat getLogProb() const {return log_prob;}
  
  bool pointsTo(const Node *x) const {return x==node;}

  virtual void changeDistalNode(Node *anode) = 0;

  bool recursiveNENLoopSearch(list<NonEmittingNode*> &node_loop) {
    return node->recursiveNENLoopSearch(node_loop);
  }

  bool hasValidToken() const {return node->hasValidToken();}
  void deactivate() {node->deactivate();}
  
  void flagNode() {node->setFlag();}

  bool selfTransition() const {return base_node==node;}
  
  virtual ostream& display(ostream& out) const;
};

/******************************************************************************/
/*									      */
/*	CLASS NAME: PullTransition            	         	              */
/*									      */
/******************************************************************************/

// Transition for pulling token from previous state

class PullTransition: public Transition {

  string label;
  bool has_label;
public:
  
  PullTransition(Node *abase_node, Node *anode, HMMFloat aprob):Transition(abase_node, anode, aprob), label(), has_label(false) {}

  // Copy constructor for PullTransition
  PullTransition(const PullTransition &pull):Transition(pull), label(pull.label), has_label(pull.has_label) {
    pull.getNode()->addPullTransitionSupplied(this);
  };
  
  virtual ~PullTransition();

  virtual void changeProxalNode(NonEmittingNode *anode);
  virtual void changeDistalNode(Node *anode);

  void setLabel(const string &alabel){label=alabel; has_label=true;}
  string getLabel() const {return label;}
  
  UpdatedToken *pull() const;

};

/******************************************************************************/
/*									      */
/*	CLASS NAME: PushTransition            	         	              */
/*									      */
/******************************************************************************/

// Transition for pushing token to next state

class PushTransition: public Transition {

  
public:
  
  PushTransition(Node *abase_node, Node *anode, HMMFloat aprob):Transition(abase_node, anode, aprob){}

  // Copy constructor for PushTransition
  PushTransition(const PushTransition &push):Transition(push){
    push.getNode()->addPushTransitionReceived(this);
  };

  virtual ~PushTransition();
  
  virtual void changeProxalNode(Node *anode);
  virtual void changeDistalNode(Node *anode);

  void push(const Token *tokenp) const;

};


/******************************************************************************/
/*									      */
/*	CLASS NAME: PronunciationTreeNode                	         	              */
/*									      */
/******************************************************************************/

class PronunciationTree {

  // Position markers for parsing
  mutable bool visited;
  mutable list<class PronunciationTree>::const_iterator current_child;

  // Data
  
  string node_name;
  string label;
  list<class PronunciationTree> children;

public:

  PronunciationTree(): visited(), current_child(), node_name(), label(), children() {}
  PronunciationTree(const string &name):visited(), current_child(), node_name(name), label(), children() {}

  string getNodeName() const {return node_name;}
  string getLabel() const {return label;}
  
  bool isMyNodeName(const string &aname) {return node_name==aname;}
  void addSequenceToTree(const list<string> &sequence, const string &label);

  class PronunciationTree& descendPronunciationTree(const string &name); 
  class PronunciationTree& growPronunciationTree(const string &name);

  // Methods for parsing
  void resetTree() const;  // Must be called before parsing tree
  const PronunciationTree *parseTree(int &popped) const;

  friend ostream& operator<< (ostream& out, const PronunciationTree& tree);

};

/******************************************************************************/
/*									      */
/*	CLASS NAME: Decoder                	         	              */
/*									      */
/******************************************************************************/

class Decoder {

private:
  // group_recording_switch is declared static as a handy way of making it available to the nodes (where it is used) this is a hack and will prevent the switch operating in the expected way in scripts that use more than one decoder. XXXX
  static bool group_recording_switch; // whether or not group info is recorded

  typedef list<Node*>::const_iterator Nit;
  typedef list<EmittingNodeBase*>::const_iterator ENit;
  typedef list<NonEmittingNode*>::const_iterator NENit;
  typedef list<PushTransition*>::const_iterator PushTit;
  typedef list<PullTransition*>::const_iterator PullTit;

  HMMFloat pruning_beamwidth;
  
  list<EmittingNodeBase*> emitting_nodes;   // HMM state nodes
  list<NonEmittingNode*> link_nodes;    // Non-emitting HMM linking nodes

  NonEmittingNode *network_start_node;          // Starting state - Where the viable token starts 
  NonEmittingNode *network_end_node;            // Ending state - Where the hypothesis must end

  const SetOfHMMs *hmms;                 // The set of HMMs from which the network may be built

  int current_frame;                           // The current frame being processed

  int nbest_list_size;                // Size of the Nbest list to compute
  bool state_path_recording_switch;   // whether or not state path recording is active

  unsigned int ntoken_slots;
  
  vector<GroupHypothesis> group_hypothesis;   // Set of hypotheses for the Multisource group masks

  vector<GroupHypothesis>::iterator current_group_hypothesis;              // Group hypothesis currently under consideration 

  //
  //
  //
  
public:

  Decoder(const SetOfHMMs *hmms, const string &prefix, const string &postfix, HMMFloat word_creation_penalty);

  Decoder(const SetOfHMMs *hmms,
	  HMMFloat word_creation_penalty,
	  const string &grammar_file,
	  const string &grammar_format        // Valid formats are, "SLF" (standard lattice) or "EBNF" (extended Backus-Naur)
);

  // Construct a new decoder by making the product of two existing decoders
  Decoder(const Decoder &decoder1, const Decoder &decoder2);
  
  ~Decoder();

  void activateStatePathRecording(bool x);  // Turn on or off state path recording
   
  // Set/query the size of the NBest list
  void setNBestListSize(int list_size);
  void setRequiresNBestTraceInfo(bool requires_nbest);
  int getNBestListSize() const {return nbest_list_size;}

  static void setRequiresGroupInfo(bool requires_group_info) {group_recording_switch=requires_group_info;}
  static bool requiresGroupInfo() {return group_recording_switch;}

  void setPruningBeamwidth(float width){
    pruning_beamwidth=width;
    Node::SetPruning(width>0.0);
  }
  
  void reset();  // Called before token passing begins. Puts network into good initial state.

  void nextFrame(); // Increase the current frame number
  int getFrame() const {return current_frame;}
  
  void passTokens(); // Step tokens forward one frame

  // Find the emitting node that has the token with the highest score so far
  const EmittingNodeBase *getTopEmittingNode() const;

  const EmittingNodeBase *getInstantTopEmittingNode() const;

  list<EmittingNodeBase*>::const_iterator getEmittingNodesBegin() const {return emitting_nodes.begin();}
  list<EmittingNodeBase*>::const_iterator getEmittingNodesEnd() const {return emitting_nodes.end();}
  
  list<NonEmittingNode*>::const_iterator getLinkNodesBegin() const {return link_nodes.begin();}
  list<NonEmittingNode*>::const_iterator getLinkNodesEnd() const {return link_nodes.end();}
  
  CTKStatus backTrace(vector<RecoHypothesis*> &hyps, HMMFloat threshold=-NBEST_THRESHOLD_STEP_SIZE);  // Perform backtrace and gather all hypotheses that score over a given threshold
  
  void displayLogProbs(ostream &outfile) const;

  int countValidTokens() const;
  
  NonEmittingNode *insertNodeAt(NonEmittingNode *at_node, bool is_optional, bool is_repeatable);
  NodePair insertNodePairAt(Node *at_node, bool is_optional, bool is_repeatable);
  NodePair insertNodePairBetween(NodePair &node_pair, bool is_optional, bool is_repeatable);
  
  NonEmittingNode *addHMMByNameAt(const string &hmm_name, NonEmittingNode *at_node);

  NonEmittingNode *addNEN(bool word_end=true);

  const HMM *getHMMByName(string hmm_label) const;
  
  // Add a pair of NEN's with an HMM sitting between. Return the NEN's. 
  NodePair addHMMByName(string hmm_label);
  
  CTKStatus addHMMByNameBetween(string hmm_name, NonEmittingNode *start_node, NonEmittingNode *end_node);

  void cleanUp();

  //
  //
  //

  
  // Find  a hypothesis that does not match regex
  RecoHypothesis *get_compliant_hyp(regex_t* filter);

  // A new potential group has started - spawn a complimentary set of hypotheses
  void spawn_group_hypotheses(const Group &new_group);
  
  // A group has end - merge complimentary set of hypotheses
  void merge_group_hypotheses(const Group &dead_group, float merging_parameter);

  // Make mask according to the current group hypothesis under consideration
  void make_discrete_mask_hypothesis(const vector<Integer> &ingroup_mask, const vector<Float> &inmask, vector<Float> &outmask, bool inverted);

  // As above, but for use with the *soft* multisource decoder
  void make_soft_mask_hypothesis(const vector<Integer> &ingroup_mask, const vector<Float> &inmask, vector<Float> &outmask, bool inverted);

  void reset_current_group_hypothesis();
  bool next_group_hypothesis();
  int get_num_group_hypotheses() const {return group_hypothesis.size();}


  friend ostream& operator<< (ostream& out, const Decoder& x);

private:

  void displayStatusSummary(ostream &out, bool NEN_details=false);
  
  // Gently adapt pruning beamwidth to aim at a given target percentage of pruned tokens
  void adaptPruningBeamwidth(float target_percent);
  
  // Build network to represent a simple loop grammar with optional fixed prefix and postfix
  CTKStatus constructLoopGrammar(const SetOfHMMs *these_hmms, const string &prefix, const string &postfix);
  
  // Build network according to Extended Backus-Naur Form grammar file
  void constructGrammarFromEBNFFile(const string &grammar_file);  

  // Build network according to Standard Lattice Format file
  void constructGrammarFromSLFFile(const string &grammar_file);  

  void initialise(); // Called at end of construction to put network into a form suitable for the Viterbi token passing algorithm
  
  void setWordCreationPenalty(HMMFloat penalty);
  
  void setStartNode(NonEmittingNode *nenp);
  void setEndNode(NonEmittingNode *nenp);

  const NonEmittingNode *getStartNode() const {return network_start_node;}
  const NonEmittingNode *getEndNode() const {return network_end_node;}
  
  // Add 1 or more HMMs in parrallel after a given NEN - set the `optional' and `repeatable' status of these HMMs
  NonEmittingNode *addHMMAt(const HMM *hmm, NonEmittingNode *at_node);
  
  NonEmittingNode *addHMMByNameAt(const vector<string> &hmms, NonEmittingNode *at_node);

  CTKStatus addHMMByNameBetween(const vector<string> &some_hmms, NonEmittingNode *start_node, NonEmittingNode *end_node);

  NonEmittingNode *addBaseHMMByName(const string &name, NonEmittingNode *start_node, NonEmittingNode *end_node, const string &label);
  
  void addBaseHMMBetween(const HMM *hmm, NonEmittingNode *start_node, NonEmittingNode *end_node, const string &label);  // Add ENs for an HMM between two existing NENs
  
  void removeNENLoops();   // Factorize NEN loops out of a  network
  void simplifyNENs();   // Remove redundant NEN nodes from network

  bool mergeNENs(list<NonEmittingNode*> &node_loop);  // Turn loop of NENs into a single equivalent NEN

  void orderNENs();     // Arrange NENs into a sequentially executable order

  NodePair addNodePair(bool is_optional, bool is_repeatable);
 
  
  void renumber_group_hypotheses(vector<GroupHypothesis> &gh);


  // Stuff for pronunciation trees

  void buildPronunciationTree(PronunciationTree &tree, const vector<string> &some_hmms);

  void addPronunciationTreeBetween(const PronunciationTree &tree, NonEmittingNode *start_node, NonEmittingNode *end_node);

  // Stuff for Prodcut HMMs

  
  string compositeNodeName(const Node *node1, const Node *node2);

  template <class T1, class T2, class T3>
  void multiplyOutNodes(T1 npbegin1, T1 npend1, T2 npbegin2, T2 npend2, T3 &nodelist, map<string, Node*> &namemap) {
    
    T1 np1;
    T2 np2;
    for (np1=npbegin1; np1!= npend1; ++np1) {
      for (np2=npbegin2; np2!= npend2; ++np2) {
	Node *np = newCompositeNode(*np1, *np2, nodelist);
	string name = compositeNodeName(*np1, *np2);
	namemap[name]=np;
      }
    }
  }

  Node *newCompositeNode(EmittingNodeBase *np1, EmittingNodeBase *np2, list<EmittingNodeBase*> &emit_node);
  Node *newCompositeNode(Node *np1, Node *np2, list<NonEmittingNode*> &non_emit_node);

  template <class T1, class T2>
  void multiplyOutPushTransitions(T1 npbegin1, T1 npend1, T2 npbegin2, T2 npend2, map<string, Node*> &namemap) {

    T1 np1;
    T2 np2;
    
    for (np1=npbegin1; np1!= npend1; ++np1) {
      for (np2=npbegin2; np2!= npend2; ++np2) {
	string from_namestr = compositeNodeName(*np1, *np2);
	Node *from_node= namemap[from_namestr];
	
	if (from_node!=NULL) {
	  PushTit pt1, pt2;
	  PushTit ptend1 = (*np1)->getPushTransitionsEnd();
	  PushTit ptend2 = (*np2)->getPushTransitionsEnd();
	  
	  float total_logprob= -numeric_limits<float>::max(); // large -ve number
	  for (pt1=(*np1)->getPushTransitionsBegin(); pt1!= ptend1; ++pt1) {
	    float prob1 = (*pt1)->getLogProb();
	    for (pt2=(*np2)->getPushTransitionsBegin(); pt2!= ptend2; ++pt2) {
	      float prob2=(*pt2)->getLogProb();
	      total_logprob=log_add(total_logprob, prob1+prob2);
	    }
	  }
	  
	  for (pt1=(*np1)->getPushTransitionsBegin(); pt1!= ptend1; ++pt1) {
	    float prob1=(*pt1)->getLogProb();
	    for (pt2=(*np2)->getPushTransitionsBegin(); pt2!= ptend2; ++pt2) {
	      float prob2 = (*pt2)->getLogProb();
	      string to_namestr=compositeNodeName((*pt1)->getNode(), (*pt2)->getNode());
	      Node *to_node=namemap[to_namestr];
	      if (to_node!=NULL) {
		//		 cout << "Connecting node " << from_node->getNodeNumber() << " to " << to_node->getNodeNumber() << " with prob " << prob1+prob2-total_logprob << "\n";
		from_node->connectTo(to_node, exp(prob1+prob2-total_logprob)); // JON!!!
		//		from_node->connectTo(to_node, exp(prob1+prob2)); // Testing without prob renormalisation
	      } else {
		cerr << "failed to find composite to_node: " << to_namestr << "\n";
	      }
	    }
	  }
	} else {
	  cerr << "failed to find composite from_node: " << from_namestr << "\n";
	}
      }
    }    
  }

  void multiplyOutPullTransitions(NENit npbegin1, NENit npend1, NENit npbegin2, NENit npend2, map<string, Node*> &namemap);

  inline NENit getNENit(NENit np1, ENit ) {return np1;}
  inline NENit getNENit(ENit , NENit np2) {return np2;}
  inline ENit getENit(ENit np1, NENit ) {return np1;}
  inline ENit getENit(NENit , ENit np2) {return np2;}

  template <class T1, class T2>
  void multiplyOutPullTransitions(T1 npbegin1, T1 npend1, T2 npbegin2, T2 npend2, map<string, Node*> &namemap) {
    
    T1 np1;
    T2 np2;
    
    ENit np;

    // One of npbegin1 and npbegin2 is emitting the other is non-emitting - return the emitting one
    ENit npemitting = getENit(npbegin1, npbegin2);

    for (np1=npbegin1; np1!= npend1; ++np1) {
      for (np2=npbegin2; np2!= npend2; ++np2) {

	string to_namestr = compositeNodeName(*np1, *np2);
	NonEmittingNode *to_node= dynamic_cast<NonEmittingNode*>(namemap[to_namestr]);
	
	if (to_node!=NULL) {
	  NENit nenit = getNENit(np1, np2);
	  ENit enit = getENit(np1, np2);
	  PullTit pt1;
	  PullTit ptend1 = (*nenit)->getPullTransitionsEnd();
	  
	  for (pt1=(*nenit)->getPullTransitionsBegin(); pt1!= ptend1; ++pt1) {
	    string from_namestr;
	    Node *np;
	    string label;
	    //	    if (npemitting==npbegin1) {  // Change for gcc3.4.x
	    if ((*npbegin1)->isEmitting()) {     //
	      from_namestr=compositeNodeName(*enit, np=(*pt1)->getNode());
	      EmittingNodeBase *enbp = dynamic_cast<EmittingNodeBase*>(np);
	      if (enbp!=NULL) {
		label=string("2:")+ enbp->getLabel(); 
		//		cerr << "label = " << label << "\n";
	      }
	    } else { // npbegin2 is emitting
	      from_namestr=compositeNodeName(np=(*pt1)->getNode(), *enit);
	      EmittingNodeBase *enbp = dynamic_cast<EmittingNodeBase*>(np);
	      if (enbp!=NULL) {
		label=string("1:")+ enbp->getLabel(); 
		//	cerr << "label = " << label << "\n";
	      }
	    }
	    Node *from_node=namemap[from_namestr];
	    if (from_node!=NULL) {
	      // cout << "Connecting node " << to_node->getNodeNumber() << " from " << from_node->getNodeNumber() << "\n";
	      PullTransition *ptp=to_node->connectFrom(from_node, 1.0);
	      if (label.size()>0)  // ignore empty labels
		ptp->setLabel(label);
	    } else {
	      cerr << "failed to find composite from_node: " << from_namestr << "\n";
	    } 
	  }
	} else {
	  cerr << "failed to find composite to_node: " << to_namestr << "\n";
	}
      }
    }    
  }

};



////////////////////////////////////////


inline Boolean compare_hypotheses(RecoHypothesis *hyp1, RecoHypothesis *hyp2 );

#endif

/* End of ctk_decoder.hh */
