/******************************************************************************/
/*									      */
/*	ctk_correlator.cpp   	    	       			              */
/*									      */
/*	Block for applying auto or cross-correlation to a frame               */
/*									      */
/*	Author: Sue Harding, Sheffield University			      */
/*	based on autocorrel by Jon Barker, Sheffield University		      */
/*									      */
/*      CTK VERSION 1.3.5  Apr 22, 2007		         	      */
/*									      */
/******************************************************************************/
 
#include "ctk_correlator.hh"
#include <cmath>
#include <vector>

#include <functional>
#include <numeric>
#include <algorithm>

#include "ctk_local.hh"

#include "ctk_error.hh"
#include "ctk_dsp.hh"

#include "ctk_param.hh"
#include "ctk_socket.hh"
#include "ctk_data_descriptor.hh"



void correlate_STAF(vector<Float> &buffer0, vector<Float> &buffer1, Float *odp, int window_size, int max_lag);
void correlate_modified_STAF(vector<Float> &buffer0, vector<Float> &buffer1, Float *odp,  int window_size, int max_lag);
void corr_normalise_lag(Float *fp, int max_lag);

  
/******************************************************************************/
/*                                                                            */
/*       CLASS NAME: CorrelatorBlock                                          */
/*                                                                            */
/******************************************************************************/

const char *CorrelatorBlock::PARAM_VERSIONS[] = {"STAF", "MODIFIED_STAF", "\0"};
const char *CorrelatorBlock::PARAM_DEFAULT_VERSION       = "STAF";

const string CorrelatorBlock::type_name = "Correlator";
const string CorrelatorBlock::help_text = "Auto- or cross-correlation";

CorrelatorBlock::CorrelatorBlock(string a_name):CTKObject(a_name),Block(a_name, type_name) {
  make_input_sockets(2);
  input_sockets->set_description("in1", "1st input");
  input_sockets->set_description("in2", "optional 2nd input - performs autocorrelation on 1st input if left unconnected");
  input_sockets->set_optional("in2");
  make_output_sockets(1);

  // Set up AXIS0 parameter
  AXIS0_param= new ParamString("AXIS_IN1");
  AXIS0_param->set_helptext("The name of the axis from input in1 on which to operate. <p> If left unset then the axis named 'TIME' is used. If this does not exist then the outer axis is used.");
  parameters->register_parameter(AXIS0_param);

  // Set up AXIS1 parameter
  AXIS1_param= new ParamString("AXIS_IN2");
  AXIS1_param->set_helptext("The name of the axis from input in2 on which to operate. <p> If left unset then the axis named 'TIME' is used. If this does not exist then the outer axis is used.");
  parameters->register_parameter(AXIS1_param);


  // Set up NORMALISE parameter
  NORMALISE_param= new ParamBool("NORMALISE", false);
  NORMALISE_param->set_helptext("If set ON then the correlation is linearly scaled to have a value of 1.0 at zero lag.");
  parameters->register_parameter(NORMALISE_param);

  WINDOW_SIZE_param = new ParamInt("WINDOW_SIZE");
  WINDOW_SIZE_param->set_helptext("Window size to use for computing correlation.");
  WINDOW_SIZE_param->install_validator(new Validator(Validator::VLOWER, 1.0));
  parameters->register_parameter(WINDOW_SIZE_param);
  
  MAX_LAG_param = new ParamInt("MAX_LAG");
  MAX_LAG_param->set_helptext("The maximum lag over which to compute correlation.");
  MAX_LAG_param->install_validator(new Validator(Validator::VLOWER, 0.0));
  parameters->register_parameter(MAX_LAG_param);

  // Set up VERSION parameter
  VERSION_param= new ParamEnumerated("VERSION", PARAM_VERSIONS, PARAM_DEFAULT_VERSION);
  VERSION_param->set_helptext("There are two variannts STAF - short time autocorrelation function; Modified STAF. The modified STAF required WINDOW_SIZE multiplication at every lag, and requires frames of data of at least WINDOW_SIZE+MAX_LAG points.");
  parameters->register_parameter(VERSION_param);

}
 
/******************************************************************************/

void CorrelatorBlock::reset() {
  Block::reset();
  
  // Check for correct data format (vector frames)
  if (inputs_are_all_sample_data) {
    cerr << "Error in network at block: " << getname() << endl;
    cerr << "Attempting to apply Correl to sample data." << endl;
    cerr << "Correl can only be applied to vector frame data." << endl;
    throw(CTKError(__FILE__, __LINE__));
    }

  // Check for 1 or 2 input sockets
  Integer no_inputs = 1 + (*input_sockets)[1]->connected();
  // next bit should be taken care of by max and min params to make_input_sockets
  if (no_inputs < 1 || no_inputs > 2) {
    cerr << "Error in network at block: " << getname() << endl;
    cerr << "Invalid no. of input sockets - must be 1 or 2" << endl;
    throw(CTKError(__FILE__, __LINE__));
    }

  string other_axis0,other_axis1;
  string inner_axis0,outer_axis0;
  string inner_axis1,outer_axis1;

  // Get axis to use for input in1
  inner_axis0=((*input_sockets)[0]->get_data_descriptor()->get_inner_dimension())->get_name();
  outer_axis0=((*input_sockets)[0]->get_data_descriptor()->get_outer_dimension())->get_name();
  if (AXIS0_param->get_set_flag()) 
    {
    axis_to_use0=AXIS0_param->get_value();
    if (axis_to_use0!=inner_axis0 && axis_to_use0 != outer_axis0)
      {
      cerr << "Error in network at block: " << getname() << endl;
      cerr << "Axis " << axis_to_use0 << " not found in data from input in1" << endl;
      throw(CTKError(__FILE__, __LINE__));
      }
    }
  else
    // Get the axis from the first input
    // if axis name isn't specified, use dimension time if it exists
    if (inner_axis0=="TIME")
      {
      axis_to_use0="TIME";
      other_axis0=outer_axis0;
      }
    else
      if (outer_axis0=="TIME")
        {
        axis_to_use0="TIME";
        other_axis0=inner_axis0;
        }
      else
	{
        axis_to_use0=outer_axis0;
	other_axis0=inner_axis0;
	}
  // Get the no. of dimensions and the index no. of the x axis for the 
  // correlation - first input
  n_dims[0]=(*input_sockets)[0]->get_data_descriptor()->get_n_dimensions();
  x_dim[0]=(*input_sockets)[0]->get_data_descriptor()->get_dimension(axis_to_use0)->get_index();


  // Give a warning if 2nd axis specified and only one input
  if (no_inputs == 1)
    if (AXIS1_param->get_set_flag()) 
      // Parameter will be ignored if only one input
      {
      cerr << "Warning at block: " << getname() << endl;
      cerr << "Ignoring axis parameter " << axis_to_use1 << " for input in2 - only one input" << endl;
      }

  if (no_inputs == 2)
    {
    // Get axis to use for input in2
    inner_axis1=((*input_sockets)[1]->get_data_descriptor()->get_inner_dimension())->get_name();
    outer_axis1=((*input_sockets)[1]->get_data_descriptor()->get_outer_dimension())->get_name();
    if (AXIS1_param->get_set_flag()) 
      {
      axis_to_use1=AXIS1_param->get_value();
      if (axis_to_use1!=inner_axis1 && axis_to_use1 != outer_axis1)
        {
        cerr << "Error in network at block: " << getname() << endl;
        cerr << "Axis " << axis_to_use1 << " not found in data from input in2" << endl;
        throw(CTKError(__FILE__, __LINE__));
        }
      }
    else
      // Get the axis from the second input
      // if axis name isn't specified, use dimension time if it exists
      if (inner_axis1=="TIME")
        {
        axis_to_use1="TIME";
        other_axis1=outer_axis1;
        }
      else
        if (outer_axis1=="TIME")
          {
          axis_to_use1="TIME";
          other_axis1=inner_axis1;
          }
        else
	  {
          axis_to_use1=outer_axis1;
	  other_axis1=inner_axis1;
	  }
    // Get the no. of dimensions and the index no. of the x axis for the 
    // correlation - second input
    n_dims[1]=(*input_sockets)[0]->get_data_descriptor()->get_n_dimensions();
    x_dim[1]=(*input_sockets)[1]->get_data_descriptor()->get_dimension(axis_to_use1)->get_index();

    // Check axes match
    if (n_dims[0]!=n_dims[1])
      {
      cerr << "Error in network at block: " << getname() << endl;
      cerr << n_dims[0] << "-dimensional input in1 is incompatible with " << n_dims[1] << "-dimensional input in2" << endl;
      throw(CTKError(__FILE__, __LINE__));
      }
    if (axis_to_use0 != axis_to_use1)
      {
      cerr << "Warning at block: " << getname() << endl;
      cerr << "Correlation axes " << axis_to_use0 << " and " << axis_to_use1 << " differ" << endl;
      }
    if (other_axis0 != other_axis1)
      {
      cerr << "Warning at block: " << getname() << endl;
      cerr << "Input axes " << other_axis0 << " and " << other_axis1 << " differ" << endl;
      }
    } // if (no_inputs == 2)
  
  // Get the input data dimensions 
  for (Integer i=0; i < no_inputs; i++)
    {
    dim_sizes[i]=(*input_sockets)[i]->get_data_descriptor()->get_dimension_sizes();
    strides[i]=(*input_sockets)[i]->get_data_descriptor()->get_strides();

    if (n_dims[i]>2) {
      cerr << "Error in network at block: " << getname() << endl;
      cerr << "Attempting to apply correlation to " << n_dims[i] << "-dimensional frames" << endl;
      cerr << "Block can only handle 1 or 2 dimensional data at present" << endl;
      throw(CTKError(__FILE__, __LINE__));
      }

    } // end for (checking input parameters)

  // Find the dimension size and strides for each axis for each input
  // If there is only one input, second input values will be copies of first input
  // If there is only one dimension, y values will be equal to 1
  for (Integer i=0; i<2; i++)
    {
    Integer j=i;
    if (no_inputs==1)	// duplicate first input if necessary
      j=0;
    Integer x=1, y=0;	// assume first dim is y
    if (x_dim[j]==1) 
      { x=0; y=1;} // first dim is x
    dims_x[i]=dim_sizes[j][x];
    stride_x[i]=strides[j][x];
    if (n_dims[j] > 1)
      {
      dims_y[i]=dim_sizes[j][y];
      stride_y[i]=strides[j][y];
      }
    else	// only one dimension - no y
      {
      dims_y[i]=1;
      stride_y[i]=0;
      }
    }

  min_dims_y=dims_y[0];
  
  // If two inputs, check that the format of the two inputs is the same
  if ( no_inputs == 2 )
    {
    if (dims_y[0]!=dims_y[1])
      {
      cerr << "Warning at block: " << getname() << endl;
      cerr << "Inputs have different y dimension lengths: " << dims_y[0] << " and " << dims_y[1] << "; using smallest" << endl;
      if (dims_y[1] < min_dims_y)
        min_dims_y=dims_y[1];	// save the minimum
      }
    } // end (if no_inputs==2)

  normalise=NORMALISE_param->get_value();

  // Get the dimension size of the axis to use for in1
  int N0=((*input_sockets)[0]->get_data_descriptor()->get_dimension(axis_to_use0))->size();
  int N1;

  // Get the size of the 2nd input in2 if it exists
  if ( no_inputs == 2 )
    N1=((*input_sockets)[1]->get_data_descriptor()->get_dimension(axis_to_use1))->size();
  else
    N1=N0;	// Only one input - duplicate size

  // Find minimum x dimension size
  min_dims_x = N0;
  if (min_dims_x > N1)
    min_dims_x = N1;

  // Display parameters used
  cerr << "Information from block: " << getname() << endl;
  cerr << "  Version = " << VERSION_param->get_value() << endl;
  cerr << "  Input 1:" << endl;
  cerr << "    Correlating on axis " << axis_to_use0 << endl;
  cerr << "    Dimension size = " << N0 << endl;
  if (no_inputs > 1)
    {
    cerr << "  Input 2:" << endl;
    cerr << "    Correlating on axis " << axis_to_use1 << endl;
    cerr << "    Dimension size = " << N1 << endl;
    }

  // Get default window size and lag for STAF
  if (VERSION_param->get_value()==string("STAF")) {
    if (WINDOW_SIZE_param->get_set_flag()==false)
      // Use the minimum dimension size as default window size
      window_size = min_dims_x;
    else 
      window_size=WINDOW_SIZE_param->get_value();

    if (MAX_LAG_param->get_set_flag()==false)
      max_lag=window_size;
    else 
      max_lag=MAX_LAG_param->get_value();

    //cerr << "  maximum lag = " << max_lag << endl;
    //cerr << "  window size =  " << window_size << endl;

    if (window_size>min_dims_x) {
      cerr << "Error in network at block: " << getname() << endl;
      cerr << "Window size " << window_size << " is greater than minimum data width " << min_dims_x << endl;
      throw(CTKError(__FILE__, __LINE__));
      }
    }

  // Get default window size and lag for modified STAF
  if (VERSION_param->get_value()==string("MODIFIED_STAF")) {
    if (WINDOW_SIZE_param->get_set_flag()==false)
      // Use half minimum dimension size as default for correlation
      window_size = min_dims_x/2;
    else 
      window_size=WINDOW_SIZE_param->get_value();

    if (MAX_LAG_param->get_set_flag()==false)
      // Default is minimum dimension size minus window size
      max_lag=min_dims_x-window_size;
    else 
      max_lag=MAX_LAG_param->get_value();
    
    //cerr << "  maximum lag = " << max_lag << endl;
    //cerr << "  window size =  " << window_size << endl;

    if (window_size+max_lag>min_dims_x) {
      cerr << "Error in network at block: " << getname() << endl;
      cerr << "For Modified STAF,  window size + max lag (" << window_size << "+" << max_lag << ")" << endl;
      cerr << "cannot be greater than smallest data width (" << min_dims_x << ")" << endl;
      throw(CTKError(__FILE__, __LINE__));
      }    
    }
  
  if (max_lag>window_size) {
    cerr << "Error in network at block: " << getname() << endl;
    cerr << "Cannot compute correlation for lags greater than specified window size" << endl;
    throw(CTKError(__FILE__, __LINE__));
  }   
    

    
}

 
/******************************************************************************/

void CorrelatorBlock::build_output_data_descriptors() {
  
  const DataDescriptor *iddp0 = (*input_sockets)[0]->get_data_descriptor();
  
  DataDescriptor *oddp = new DataDescriptor(*iddp0);

  // Remove the TIME dimension (or dimension specified by parameter AXIS)
  // and replace it with an inner dimension called LAG
  DimensionDescriptor correl_dim = *oddp->get_dimension(axis_to_use0);
  // By making LAG the inner dimension the writing of the output array is simplified
  oddp->remove_dimension(axis_to_use0);

  // The output correlation axis is twice the maximum lag plus 1 (for zero lag)
  CTKVector lag_axis;
  lag_axis.resize(2*max_lag+1);	// Make sure new axis is correct size
  
  // Add the scale for the lag axis based on the TIME dimension
  // such that the central value is 0
  CTKVector time_axis=correl_dim.get_axis();
  Float mintime=time_axis[0];	// First time
  Float maxtime=time_axis[max_lag-1];	// Last time (at max lag)
  
  // First do negative lag (excluding 0)
  copy(time_axis.begin(), time_axis.begin()+max_lag-1, lag_axis.begin());
 
 
  for (CTKVector::iterator lagp=lag_axis.begin(); lagp!=lag_axis.begin()+max_lag; lagp++)
    *lagp-=(mintime+maxtime);	// make sure central lag value will be 0
    
  // Then do zero lag
  lag_axis[max_lag]=0;
  // Then do positive lag (excluding zero)
  copy(time_axis.begin(), time_axis.begin()+max_lag, lag_axis.begin()+max_lag+1);
  oddp->add_inner_dimension("LAG", lag_axis);
  out_storage=oddp->get_storage();
  
  // Pass the descriptor to the output channel
  (*output_sockets)[0]->set_data_descriptor(oddp);

}


 
/******************************************************************************/

Block* CorrelatorBlock::clone(const string &n) const{
  Block *ablock = new CorrelatorBlock(n.empty()?getname():n);
  return copy_this_block_to(ablock);
}

 
/******************************************************************************/

void CorrelatorBlock::compute() {
  
  CTKVector *invec0;
  CTKVector *invec1;
  CTKVector *outvec=NULL;
  
  Integer no_inputs = 1 + (*input_sockets)[1]->connected();
  (*input_sockets)[0]->get_vector(invec0);
  if ( no_inputs > 1 )  // cross correlation
    (*input_sockets)[1]->get_vector(invec1);
  else
    invec1 = invec0;    // autocorrelation
  
  // Buffers to hold copies of current vector
  CTKVector buffer0(min_dims_x);
  CTKVector buffer1(min_dims_x);

  //////////////////////////////////////////////////////////
  if (n_dims[0]==1) {
    // 1-D case

    // include room for -ve lag in output vector
    outvec = new CTKVector(max_lag*2+1);
    
    // Copy invec0 to buffer
    copy(invec0->begin(), invec0->begin()+min_dims_x, buffer0.begin());
    // Copy invec1 to buffer - will be the same as invec0 if only one input 
    copy(invec1->begin(), invec1->begin()+min_dims_x, buffer1.begin());

    if (VERSION_param->get_value()==string("STAF"))
      correlate_STAF(buffer0, buffer1, &(*outvec)[0], window_size, max_lag);
    else if (VERSION_param->get_value()==string("MODIFIED_STAF"))
      correlate_modified_STAF(buffer0, buffer1, &(*outvec)[0], window_size, max_lag);

    if (normalise) corr_normalise_lag(&(*outvec)[0], max_lag);
    
  //////////////////////////////////////////////////////////
  } else if (n_dims[0]==2)  {
    // The 2-D case. 


    // Make output buffer with dimension equal to the maximum lag * 2 + 1 
    //(to include -ve lag)
    outvec = new CTKVector((max_lag*2+1)*min_dims_y);
    vector<int> dims=(*input_sockets)[0]->get_data_descriptor()->get_dimension_sizes();
    
    Float *odp=&((*outvec)[0]);
    
    // Each time vector is first copied into a contiguous 1-D vector 
    // Once this is done the inner_products can be calculated using simpler iterators.
    // - Trying to operate on the data while leaving it in place turns out to be much slower.

    for (Integer chan=0; chan!=min_dims_y; chan++) {

      // Copy invec0
      for (vector<Float>::iterator bp=buffer0.begin(), ip=invec0->begin()+chan*stride_y[0]; bp!=buffer0.begin()+min_dims_x; ++bp, ip+=stride_x[0]) 
        *bp=*ip;
      // Copy invec1 - will be the same as invec0 if only one input
      for (vector<Float>::iterator bp=buffer1.begin(), ip=invec1->begin()+chan*stride_y[1]; bp!=buffer1.begin()+min_dims_x; ++bp, ip+=stride_x[1]) 
        *bp=*ip;

      if (VERSION_param->get_value()==string("STAF"))
	correlate_STAF(buffer0, buffer1, odp, window_size, max_lag);
      else if (VERSION_param->get_value()==string("MODIFIED_STAF"))
	correlate_modified_STAF(buffer0, buffer1, odp, window_size, max_lag);

      if (normalise) corr_normalise_lag(odp, max_lag);

      odp+=2*max_lag+1;	// move along output buffer
      
    }
    
  } else { // n_dims[0] > 2
    // The general N-D case.    
    // JON - ADD STUFF HERE
  }

  (*output_sockets)[0]->put_vector(outvec);  

  delete invec0;
  if ( no_inputs > 1 )  // cross correlation
    delete invec1;
}


/////////////////////////////////////////////////////////////



void correlate_STAF(vector<Float> &buffer0, vector<Float> &buffer1, Float *odp, int window_size, int max_lag) {

  Float *idp0=&buffer0[0];
  Float *idp1=&buffer1[0];
  // Do -ve lag first
  for (Integer lag=max_lag; lag>0; lag--) {
    // Function parameters are: inner_product(start1, end1, start2, initial_accumulator)
    *odp=inner_product(idp0, idp0+window_size-lag, idp1+lag, 0.0);
    ++odp;
    }
  // Do +ve lag next, including zero lag
  for (Integer lag=0; lag<=max_lag; lag++) {
    *odp=inner_product(idp1, idp1+window_size-lag, idp0+lag, 0.0);
    ++odp;
    }
  
}


void correlate_modified_STAF(vector<Float> &buffer0, vector<Float> &buffer1, Float *odp,  int window_size, int max_lag) {
// (As defined in Rabiner and Schafer, Digital Processing of Speech Signals)

// Input buffers must be at least as big as window_size+max_lag

  Float *idp1=&buffer1[0];
  Float *idp0=&buffer0[0];

  // Do the cross- or auto-correlation
  // (If autocorrelation, buffer0==buffer1)
  // Do -ve lag first
  for (Integer lag=max_lag; lag>0; lag--) {
    // Function parameters are: inner_product(start1, end1, start2, initial_accumulator)
    *odp=inner_product(idp0, idp0+window_size, idp1+lag, 0.0);
    ++odp;
    }
  // Do +ve lag next, including zero lag
  for (Integer lag=0; lag<=max_lag; lag++) {
    *odp=inner_product(idp1, idp1+window_size, idp0+lag, 0.0);
    ++odp;
    }
  
}



void corr_normalise_lag(Float *fp, int max_lag) {

  Float odp0lag=*fp;
  for (Integer lag=0; lag<max_lag*2+1; ++lag) 
      (*fp++)/=odp0lag;
}


/* End of ctk_correlator.cpp */
 


