/******************************************************************************/
/*									      */
/*	ctk_binaural.cpp	    		      			      */
/*									      */
/*	Block for binaural (localisation) processing                          */
/*									      */
/*	Author: Sue Harding, Sheffield University			      */
/*									      */
/*      CTK VERSION 1.3.5  Apr 22, 2007		         	      */
/*									      */
/******************************************************************************/

#include "ctk-config.h"

#include "ctk_binaural.hh"
 
#include <cmath>
#include <vector>
#include <numeric>
#include <algorithm>

#include "ctk_local.hh"
#include "ctk_error.hh"

#include "ctk_socket.hh"
#include "ctk_param.hh"
#include "ctk_data_descriptor.hh"


/******************************************************************************/
/*                                                                            */
/*       CLASS NAME: SkeletoniseBlock                                         */
/*                                                                            */
/******************************************************************************/

const string SkeletoniseBlock::type_name = "Skeletonise";
const string SkeletoniseBlock::help_text = "Skeletonise by sharpening peaks along one dimension";

SkeletoniseBlock::SkeletoniseBlock(const string &a_name):CTKObject(a_name),Block(a_name, type_name) {

  make_input_sockets(1);
  make_output_sockets(1);
  

  // Set up CHANNEL_AXIS parameter
  CHANNEL_AXIS_param= new ParamString("CHANNEL_AXIS");
  CHANNEL_AXIS_param->set_helptext("The name of the axis along which Gaussians vary.<p>If not specified, FREQUENCY axis will be used if it exists; otherwise the inner axis will be used.<p><p>Not used for 1-D data.");
  parameters->register_parameter(CHANNEL_AXIS_param);

  // Set up PEAK_AXIS parameter
  PEAK_AXIS_param= new ParamString("PEAK_AXIS");
  PEAK_AXIS_param->set_helptext("The name of the axis along which peaks will be found.<p>If not specified, the outer axis will be used.");
  parameters->register_parameter(PEAK_AXIS_param);

  // Set up LOW_CHAN_STDDEV parameter
  LOW_CHAN_STDDEV_param= new ParamFloat("LOW_CHAN_STDDEV");
  LOW_CHAN_STDDEV_param->install_validator(new Validator(Validator::VLOWER, 0.0));
  LOW_CHAN_STDDEV_param->set_helptext("The standard deviation of the Gaussian applied to the lowest (or only) channel.");
  parameters->register_parameter(LOW_CHAN_STDDEV_param);

  // Set up HIGH_CHAN_STDDEV parameter
  HIGH_CHAN_STDDEV_param= new ParamFloat("HIGH_CHAN_STDDEV");
  HIGH_CHAN_STDDEV_param->install_validator(new Validator(Validator::VLOWER, 0.0));
  HIGH_CHAN_STDDEV_param->set_helptext("The standard deviation of the Gaussian applied to the highest channel.<p>Not used for 1-D data.");
  parameters->register_parameter(HIGH_CHAN_STDDEV_param);

  // Set up GAUSS_LIMIT parameter
  GAUSS_LIMIT_param= new ParamFloat("GAUSS_LIMIT");
  GAUSS_LIMIT_param->install_validator(new Validator(Validator::VLOWER, 0.0));
  GAUSS_LIMIT_param->set_helptext("The number of standard deviations from the mean of the Gaussian used during convolution.");
  parameters->register_parameter(GAUSS_LIMIT_param);

}

/////////////////////////////////////////////////////////////////////

Block* SkeletoniseBlock::clone(const string &n) const{
  Block *ablock = new SkeletoniseBlock(n.empty()?getname():n);
  return copy_this_block_to(ablock);
}

/////////////////////////////////////////////////////////////////////

void SkeletoniseBlock::reset() {
  Block::reset();

  // Get some data descriptor info
  const DataDescriptor *ddp = (*input_sockets)[0]->get_data_descriptor();

  n_dims=ddp->get_n_dimensions();	// no. of dimensions
  dim_sizes=ddp->get_dimension_sizes();	// size of each dimension
  strides=ddp->get_strides();		// length of each stride
  out_storage=ddp->get_storage();	// size of output



  // Check validity of input data
  if (n_dims>2 || n_dims < 1) {
    cerr << "Error in network at block: " << getname() << endl;
    cerr << "Attempting to skeletonise " << n_dims << "-dimensional frames." <<
 endl;
    cerr << "Block can only handle 1 or 2 dimensional data at present." << endl;
    throw(CTKError(__FILE__, __LINE__));
  }


  // Check the parameters

  string inner_axis_name=(ddp->get_inner_dimension ())->get_name();
  string outer_axis_name=(ddp->get_outer_dimension ())->get_name();

  // First sort out the axes: channel and peak axes must be different
  // Channel will default to FREQUENCY dimension if appropriate, or to
  // the other axis if peak axis is set, or else to the inner dimension
  // Peak will default to the other axis if channel axis is set, or else to
  // the outer dimension

  // Get the axis along which to find the peaks
  if (PEAK_AXIS_param->get_set_flag()) 	// peak axis set
    {
    peak_axis_name=PEAK_AXIS_param->get_value();
    if ((*input_sockets)[0]->get_data_descriptor()->get_dimension(peak_axis_name)== NULL) 
      {
      cerr << "Error in network at block: " << getname() << endl;
      cerr << "Cannot find axis named: " << peak_axis_name << endl;
      throw(CTKError(__FILE__, __LINE__));
      }
    }
  else	// peak axis not set
    peak_axis_name="";

  // Get the axis along which Gaussians vary
  if (CHANNEL_AXIS_param->get_set_flag()) 	// channel axis set
    if (n_dims==1)
      {
      cerr << "Warning at block: " << getname() << endl;
      cerr << "Channel axis is not applicable to 1-D data" << endl;
      channel_axis_name="";
      }
    else	// n_dims>1
      {
      channel_axis_name=CHANNEL_AXIS_param->get_value();
      // Check the channel axis is valid
      if (ddp->get_dimension(channel_axis_name)== NULL) 
        {
        cerr << "Error in network at block: " << getname() << endl;
        cerr << "Cannot find axis named: " << channel_axis_name << endl;
        throw(CTKError(__FILE__, __LINE__));
        }
      }
  else		// channel axis not set
    {
    channel_axis_name = "";		// show channel axis isn't set
    if (n_dims>1)
      {
      // Set channel to default unless this conflicts with peak axis
      if (peak_axis_name != PARAM_DEFAULT_CHANNEL_AXIS)
        {
        channel_axis_name=PARAM_DEFAULT_CHANNEL_AXIS;
        if ((*input_sockets)[0]->get_data_descriptor()->get_dimension(channel_axis_name)== NULL) 
          channel_axis_name = "";	// default is invalid
        }
      if (channel_axis_name == "")	// channel axis still isn't set
        if (peak_axis_name == inner_axis_name)
          channel_axis_name = outer_axis_name;
        else
          channel_axis_name = inner_axis_name;
      }
    }

  // Sort out the default peak axis
  if (peak_axis_name == "")	// peak axis isn't set
    if (channel_axis_name == inner_axis_name)
      peak_axis_name = outer_axis_name;
    else
      peak_axis_name = inner_axis_name;	
  else		// peak axis is set
    // Check axes are different
    if (channel_axis_name == peak_axis_name)
      {
      cerr << "Error in network at block: " << getname() << endl;
      cerr << "Channel and peak axes cannot be the same: " << channel_axis_name << endl;
      throw(CTKError(__FILE__, __LINE__));
      }

  // Get the standard deviations of the Gaussians for the lowest and highest channels
  if (LOW_CHAN_STDDEV_param->get_set_flag()) 
    low_chan_stddev=LOW_CHAN_STDDEV_param->get_value();
  else low_chan_stddev=0;
  if (low_chan_stddev <= 0)
    low_chan_stddev = PARAM_DEFAULT_LOW_CHAN_STDDEV;
  // HIGH_CHAN_STDDEV isn't used if data is 1-D
  if (n_dims>1)
    {
    if (HIGH_CHAN_STDDEV_param->get_set_flag()) 
      high_chan_stddev=HIGH_CHAN_STDDEV_param->get_value();
    else high_chan_stddev=0;
    if (high_chan_stddev <= 0)
      high_chan_stddev = PARAM_DEFAULT_HIGH_CHAN_STDDEV;
    }
  else
    {
    if (HIGH_CHAN_STDDEV_param->get_set_flag()) 
      {
      cerr << "Warning at block: " << getname() << endl;
      cerr << "High channel std. dev. is not used for 1-D data" << endl;
      }
    // Use low_chan_stddev for both values (only one is used)
    high_chan_stddev = low_chan_stddev;
    }


  // Precompute dimensions and strides to use for 2-D case

  // Get the index of the peak dimension
  op_dim=ddp->get_dimension(peak_axis_name)->get_index() - 1; // dimension no. starts from 1, array from 0

  if (n_dims>=2) 
    {
    Integer x=op_dim, y;
    // Get the index of the channel dimension
    y=ddp->get_dimension(channel_axis_name)->get_index() - 1; // dimension no. starts from 1, array from 0
    // Save the x and y values for optimisation
    dims_x=dim_sizes[x];
    dims_y=dim_sizes[y];
    stride_x=strides[x];
    stride_y=strides[y];
    }


  // Get the no. of standard deviations from the mean to be used when convolving with Gaussians
  if (GAUSS_LIMIT_param->get_set_flag()) 
    gauss_limit=GAUSS_LIMIT_param->get_value();
  else gauss_limit=0;
  if (gauss_limit <= 0)
    gauss_limit = PARAM_DEFAULT_GAUSS_LIMIT;


  // Calculate sigma for gaussian for each channel
  if (n_dims == 1)
    {
    n_chans = 1;	// only one channel
    chan_stddev.push_back(low_chan_stddev);
    }
  else
    {
    n_chans=dims_y;	// channel is on y axis
    // Calculate linearly spaced sigma values - one per channel
    for (Integer i=0;  i < n_chans; i++)
      chan_stddev.push_back(low_chan_stddev + i * (high_chan_stddev - low_chan_stddev) / (n_chans - 1));	// add next value
    }
    

}

/////////////////////////////////////////////////////////////////////
void SkeletoniseBlock::build_output_data_descriptors() {

  // Reorganise the dimensions to make the channel axis the outer dimension
  // and the peak axis the next dimension in
  // This simplifies the writing of the output array

  // If there is only one dimension, only the peak axis is affected

  const DataDescriptor *iddp0 = (*input_sockets)[0]->get_data_descriptor();
  DataDescriptor *oddp = new DataDescriptor(*iddp0);

  DimensionDescriptor peak_dim = *oddp->get_dimension(peak_axis_name);
  CTKVector peak_axis=peak_dim.get_axis();

  // Remove the channel dimension then replace it in its new position
  oddp->remove_dimension(peak_axis_name);
  oddp->add_outer_dimension(peak_axis_name, peak_axis);

  if (n_dims > 1)	// no channel axis if only one dimension
    {
    DimensionDescriptor channel_dim = *oddp->get_dimension(channel_axis_name);
    CTKVector channel_axis=channel_dim.get_axis();
    // Remove the channel dimension then replace it in its new position
    oddp->remove_dimension(channel_axis_name);

    // peak axis is shifted in by channel
    oddp->add_outer_dimension(channel_axis_name, channel_axis);
    }

  // Pass the descriptor to the output channel
  (*output_sockets)[0]->set_data_descriptor(oddp);

}


/////////////////////////////////////////////////////////////////////
// Find all the peaks in a vector
bool SkeletoniseBlock::find_all_peaks(array_iterator<Float> ipbegin, array_iterator<Float> ipend, Float *peakheightp, Float *peakexistp) {

  // Returns boolean indicating whether peaks found
  // Passes across pointers to arrays of peak heights and peak existence

  bool peakfound=false;	// flag indicating whether any peaks found

  // Initialise output arrays
  Integer N=distance(ipbegin, ipend);	// find length of input
  Float *hp=peakheightp;
  Float *ep=peakexistp;
  for (Integer i=0; i < N; i++) 
    {
    *hp++=0;
    *ep++=0;
    }
  
  // Process peaks
  if (ipbegin==ipend || ipbegin+1==ipend)
    // only one value in input
    return false;


  // A peak is found if an element is greater than the adjacent elements
  Float prev_el=*ipbegin;
  Float curr_el=*(ipbegin+1);

  hp=peakheightp+1;	// Start checking from the 2nd element
  ep=peakexistp+1;      // (First peak cannot be at edge)
  for (array_iterator<Float> next_ip=ipbegin+2; next_ip!=ipend; ++next_ip, hp++, ep++) 
    {
    if ((curr_el > prev_el) && (curr_el > *next_ip)) 
      {
      peakfound=true;
      *hp=curr_el;	// save peak height
      *ep=1;		// mark peak position
      }
    // Move on to the next value
    prev_el=curr_el;
    curr_el=*next_ip;
    }

  return peakfound;
}

/////////////////////////////////////////////////////////////////////
void SkeletoniseBlock::conv_with_gauss(Float *peakheightp, Float *skeletonp, Integer inarraylen, Float sigma, Float gauss_limit) {

  // Convolve an array of peak heights with a Gaussian specified by sigma
  // Cut off Gaussian at limit specified

  Integer intlim=(Integer)ceil(2 * gauss_limit * sigma);	// the integer limit for the Gaussian
  // Find the maximum of the two array lengths (used for optimisation)
  Integer maxarraylen=max(inarraylen, 2*intlim+1);
  // Find the maximum length of the convolution
  Integer convlen=inarraylen + 2+intlim+1 - 1;

  Float *discrete_gauss=new Float[convlen];	// Create array for discrete Gaussian

  // Set up the discrete Gaussian
  Integer gt;
  for (Integer i=0; i< 2*intlim+1; i++)
    {
    gt = -intlim + i;
    discrete_gauss[i] = exp(-gt*gt/(2*sigma*sigma));
    }

  // Copy the peak array 
  Float *peakcopy=new Float [convlen];
  for (Integer i=0; i<inarraylen; i++)
    peakcopy[i]=*(peakheightp+i);

  // Zero-pad the two arrays
  for (Integer i=inarraylen; i<convlen; i++)
    peakcopy[i]=0.0;
  for (Integer i=2*intlim+1; i<convlen; i++)
    discrete_gauss[i]=0.0;
    
  // Do the convolution
  Float *sp=skeletonp;
  for (Integer si=intlim; si < intlim+inarraylen; si++)
     {
     *sp=0.0;	// initialise output
     for (Integer pi=0; pi < min(si, maxarraylen); pi++)	// don't do unnecessary elements
       *sp += discrete_gauss[pi]*peakcopy[si-pi];	// add next term to total
     sp++;
     }

}

/////////////////////////////////////////////////////////////////////

void SkeletoniseBlock::compute() {

  CTKVector *invec;
  CTKVector *peakheight;
  CTKVector *peakexist;
  CTKVector *currvec;
  CTKVector *skeleton;
  
  if (inputs_are_all_sample_data) 
    {
    // Sample data case - pass input to output
    // (This shouldn't happen - trapped in reset())
    Float x1;
    (*input_sockets)[0]->get_sample(x1);
    (*output_sockets)[0]->put_sample(x1);
    }
  else 

    {
    // Vector data case
    (*input_sockets)[0]->get_vector(invec); // get input vector
    // set up output vector and intermediate arrays
    peakheight = new CTKVector (out_storage);	// array of peak heights
    peakexist = new CTKVector (out_storage);	// array of ones at peak positions
    currvec = new CTKVector (out_storage);	// current vector (to deal with axis shifts)
    skeleton = new CTKVector (out_storage);	// array of convolved peaks

    Float *hp=&((*peakheight)[0]);
    Float *ep=&((*peakexist)[0]);
    Float *sp=&((*skeleton)[0]);

    if (n_dims==1) 
      {
      // 1-D case 
      Integer N=dim_sizes[0];	// (should be same as out_storage for 1-D)

      // Find the peaks in the vector
      bool peakfound=find_all_peaks(array_iterator<Float>(&(*invec)[0],1), array_iterator<Float>(&(*invec)[N], 1), hp, ep);
      // skeletonise by convolving with gaussian
      if (peakfound)
        conv_with_gauss(hp, sp, N, chan_stddev[0], gauss_limit);

      (*output_sockets)[0]->put_vector(skeleton);
      delete invec;
      delete peakheight;
      delete peakexist;
      }

    else if (n_dims==2) 
      {
      // 2-D case - specialized from the N-D case for the sake of optimization
      // Find the peaks in each vector of each frame
      Integer N=dim_sizes[op_dim];	// length of each channel

      // Process each channel
      for (Integer chan=0, instart=0, inend=stride_x*dims_x, skelstart=0; chan < n_chans; chan++, instart+=stride_y, inend+=stride_y, skelstart+=N) 
        {
        // Find the peaks in each channel
        bool peakfound=find_all_peaks(array_iterator<Float>(&(*invec)[instart],stride_x), array_iterator<Float>(&(*invec)[inend], stride_x), hp, ep);
        // skeletonise by convolving with gaussian
        if (peakfound)
          conv_with_gauss(hp, sp+skelstart, N, chan_stddev[chan], gauss_limit);
        }
      (*output_sockets)[0]->put_vector(skeleton);
      delete invec;
      delete peakheight;
      delete peakexist;
      }

      
    else 
      { // n_dims > 2
      // The general N-D case.
      // ADD STUFF HERE
      cerr << "Error in network at block: " << getname() << endl;
      cerr << "Skeletonise for more than 2 dimensions not yet implemented" << endl;
      throw(CTKError(__FILE__, __LINE__));
      }

  }

}

/////////////////////////////////////////////////////////////////////
/******************************************************************************/
 

/* End of ctk_binaural.cpp */
 
