/******************************************************************************/
/*									      */
/*	ctk_reducer.cpp	    		      			              */
/*									      */
/*	Block for reducing N dimensional frames down to N-1 dimensional frames*/
/*      in various different ways                             */
/*									      */
/*	Author: Jon Barker, Sheffield University			      */
/*									      */
/*      CTK VERSION 1.3.5  Apr 22, 2007		         	      */
/*									      */
/******************************************************************************/

#include "ctk-config.h"

#include "ctk_reducer.hh"
 
#include <cmath>
#include <vector>
#include <numeric>
#include <algorithm>

#include "ctk_local.hh"
#include "ctk_error.hh"

#include "ctk_socket.hh"
#include "ctk_param.hh"
#include "ctk_data_descriptor.hh"


/******************************************************************************/
/*                                                                            */
/*       CLASS NAME: ReducerBlockAbstract                                         */
/*                                                                            */
/******************************************************************************/


ReducerBlockAbstract::ReducerBlockAbstract(const string &a_name, const string &a_type):CTKObject(a_name),Block(a_name, a_type) {


  // Set up AXIS parameter
  AXIS_param= new ParamString("AXIS");
  AXIS_param->set_helptext("The name of the AXIS on which to operate. <p> If the parameter is unset then the outer axis is used.");
  parameters->register_parameter(AXIS_param);
  
}
 

void ReducerBlockAbstract::reset() {
  Block::reset();
  
  n_dims=(*input_sockets)[0]->get_data_descriptor()->get_n_dimensions();
  dim_sizes=(*input_sockets)[0]->get_data_descriptor()->get_dimension_sizes();
  strides=(*input_sockets)[0]->get_data_descriptor()->get_strides();
  
  
  // Check validity of input data
  if (n_dims>2) {
    cerr << "Error in network at block: " << getname() << endl;
    cerr << "Attempting to average over " << n_dims << "-dimensional frames." << endl;
    cerr << "Block can only handle 1 or 2 dimensional data at present." << endl;
    throw(CTKError(__FILE__, __LINE__));
  }

}

void ReducerBlockAbstract::build_output_data_descriptors() {
  const DataDescriptor *idd = (*input_sockets)[0]->get_data_descriptor();
  
  DataDescriptor *dd = new DataDescriptor(*idd);


  int ndim;
  
  // If remove the dimension specified by AXIS parameter - if AXIS paremeter is unset
  // then default behaviour is to remove the outer dimension
  if ((ndim=dd->get_n_dimensions())>0) {  // Note, do nothing for 0 dimensional sample data
    if (AXIS_param->get_set_flag()==false || (AXIS_param->get_value()==string(""))) {
      dd->remove_outer_dimension();
      op_dim=1;  // Operate on the outer (i.e. 1st) dimension 
    } else {
      if (dd->remove_dimension(AXIS_param->get_value())==false) {
	cerr << "Error in network at block: " << getname() << endl;
	cerr << "Attempting to average over an axis named: " << AXIS_param->get_value() << "." << endl;
	cerr << "The input data contains no axis of this name." << endl;
	throw(CTKError(__FILE__, __LINE__));
      }
      op_dim=idd->get_dimension(AXIS_param->get_value())->get_index(); // Get the index of the named dimension
    }

    if (ndim==2) {
      // Precompute for optimized 2-D case
      Integer x=1, y=0;
      if (op_dim==1) { x=0; y=1;}
      dims_x=dim_sizes[x];
      dims_y=dim_sizes[y];
      stride_x=strides[x];
      stride_y=strides[y];
    }
  }
  

  
  out_storage=dd->get_storage();
  
  // Pass a copy of the descriptor to all the output socket
  for (vector<Socket*>::iterator socketp=output_sockets->begin(); socketp<output_sockets->end(); ++socketp) {
    (*socketp)->set_data_descriptor(new DataDescriptor(*dd));
  }
  delete dd;
  
}


/******************************************************************************/
/*                                                                            */
/*       CLASS NAME: AveragerBlock                                             */
/*                                                                            */
/******************************************************************************/

const string AveragerBlock::type_name = "Averager";
const string AveragerBlock::help_text = AVERAGER_BLOCK_HELP_TEXT;

AveragerBlock::AveragerBlock(const string &a_name):CTKObject(a_name),ReducerBlockAbstract(a_name, type_name) {

  make_input_sockets(1);
  make_output_sockets(1);
  
}
 
Block* AveragerBlock::clone(const string &n) const{
  Block *ablock = new AveragerBlock(n.empty()?getname():n);
  return copy_this_block_to(ablock);
}


void AveragerBlock::compute() {

  CTKVector *invec;
  CTKVector *outvec;
  
  if (inputs_are_all_sample_data) {
    // Multidimensional sample data case
    Float x1;
    (*input_sockets)[0]->get_sample(x1);
    (*output_sockets)[0]->put_sample(x1);
  } else {
    // Vector data case
    (*input_sockets)[0]->get_vector(invec);
    if (n_dims==1) {
      // 1-D case where the output is a Sample
      Integer N=dim_sizes[0];
      (*output_sockets)[0]->put_sample(accumulate(&(*invec)[0], &(*invec)[N], 0.0)/N);
    } else if (n_dims==2) {
      // 2-D case - specialized from the N-D case for the sake of optimization
      outvec= new CTKVector (out_storage);
      Float *odp=&((*outvec)[0]);
      for (Integer start=0, end=stride_x*dims_x; start<dims_y*stride_y; start+=stride_y, end+=stride_y) {
	*odp++=accumulate(array_iterator<Float>(&(*invec)[start], stride_x), array_iterator<Float>(&(*invec)[end], stride_x), 0.0)/dims_x;
      }
      
      (*output_sockets)[0]->put_vector(outvec);
      delete invec;
      
    } else { // n_dims > 2
      // The general N-D case.
      // JON - ADD STUFF HERE
      cerr << "Averager ndims>2 not yet implemented" << endl;
      throw(CTKError(__FILE__, __LINE__));
    }
  }

}

/******************************************************************************/
/*                                                                            */
/*       CLASS NAME: SumBlock                                             */
/*                                                                            */
/******************************************************************************/


//
// This shares 99% of its code with Averager ... need to decide on a satisfactory way of dealing with this -JON
//

const string SumBlock::type_name = "Sum";
const string SumBlock::help_text = SUM_BLOCK_HELP_TEXT;

SumBlock::SumBlock(const string &a_name):CTKObject(a_name),ReducerBlockAbstract(a_name, type_name) {

  make_input_sockets(1);
  make_output_sockets(1);
  
}
 
Block* SumBlock::clone(const string &n) const{
  Block *ablock = new SumBlock(n.empty()?getname():n);
  return copy_this_block_to(ablock);
}


void SumBlock::compute() {
  
  CTKVector *invec;
  CTKVector *outvec;
  
  if (inputs_are_all_sample_data) {
    // Multidimensional sample data case
    Float x1;
    (*input_sockets)[0]->get_sample(x1);
    (*output_sockets)[0]->put_sample(x1);
  } else {
    // Vector data case
    (*input_sockets)[0]->get_vector(invec);
    if (n_dims==1) {
      // 1-D case where the output is a Sample
      Integer N=dim_sizes[0];
      (*output_sockets)[0]->put_sample(accumulate(&(*invec)[0], &(*invec)[N], 0.0));
    } else if (n_dims==2) {
      // 2-D case - specialized from the N-D case for the sake of optimization
      outvec= new CTKVector (out_storage);
      Float *odp=&((*outvec)[0]);
      for (Integer start=0, end=stride_x*dims_x; start<dims_y*stride_y; start+=stride_y, end+=stride_y) {
	*odp++=accumulate(array_iterator<Float>(&(*invec)[start], stride_x), array_iterator<Float>(&(*invec)[end], stride_x), 0.0);
      }
      
      (*output_sockets)[0]->put_vector(outvec);
      delete invec;
      
    } else { // n_dims > 2
      // The general N-D case.
      // JON - ADD STUFF HERE
      cerr << "Sum ndims>2 not yet implemented" << endl;
      throw(CTKError(__FILE__, __LINE__));
    }
  }

}

/******************************************************************************/
/*                                                                            */
/*       CLASS NAME: ReducerBlock                                             */
/*                                                                            */
/******************************************************************************/

const string ReducerBlock::type_name = "Reducer";
const string ReducerBlock::help_text = REDUCER_BLOCK_HELP_TEXT;

ReducerBlock::ReducerBlock(const string &a_name):CTKObject(a_name),ReducerBlockAbstract(a_name, type_name) {

  make_input_sockets(2);
  make_output_sockets(1);
  
  input_sockets->set_description("in1", "data");
  input_sockets->set_description("in2", "index");
}
 
Block* ReducerBlock::clone(const string &n) const{
  Block *ablock = new ReducerBlock(n.empty()?getname():n);
  return copy_this_block_to(ablock);
}

void ReducerBlock::reset() {
  ReducerBlockAbstract::reset();
  
  // Check validity of input data
  if (n_dims==0) {
    cerr << "Error in network at block: " << getname() << endl;
    cerr << "Attempting to apply Reducer to sample data." << endl;
    cerr << "Reducer can only be applied to frame data." << endl;
    throw(CTKError(__FILE__, __LINE__));
  }
}



void ReducerBlock::compute() {

  CTKVector *invec;
  CTKVector *outvec;
  Float chan;
  Integer ichan;
  
  (*input_sockets)[1]->get_sample(chan);

  ichan = (int)chan;
  
  if ((*input_sockets)[0]->get_data_descriptor()->is_sample_data()) {
    // Multidimensional sample data case
    Float x1;
    (*input_sockets)[0]->get_sample(x1);
    (*output_sockets)[0]->put_sample(x1);
  } else {
    // Vector data case
    (*input_sockets)[0]->get_vector(invec);
    if (n_dims==1) {
      // 1-D case where the output is a Sample
      if (ichan<0 || (unsigned)ichan>(*invec).size()) {
	cerr << "ReducerBlock: index error : index= " << ichan << endl;
	throw(CTKError(__FILE__, __LINE__));
      }
      (*output_sockets)[0]->put_sample((*invec)[ichan]);
    } else if (n_dims==2) {
      // 2-D case - specialized from the N-D case for the sake of optimization
      
      outvec= new CTKVector (out_storage);
      Float *odp=&((*outvec)[0]);

      if (ichan<0 || (unsigned)(stride_x*ichan+stride_y*(dims_y-1))>(*invec).size()) {
	cerr << "ReducerBlock: index error : index= " << ichan << endl;
	throw(CTKError(__FILE__, __LINE__));
      }
      
      for (Integer start=0, end=stride_x*dims_x; start<dims_y*stride_y; start+=stride_y, end+=stride_y) {
	*odp++=(*invec)[start+stride_x*ichan];
      }
      
      (*output_sockets)[0]->put_vector(outvec);
      
    } else { // n_dims > 2
      // The general N-D case.
      // JON - ADD STUFF HERE
      cerr << "Reducer ndims>2 not yet implemented" << endl;
      throw(CTKError(__FILE__, __LINE__));
    }

    delete invec;
  }

}


/******************************************************************************/
/*                                                                            */
/*       CLASS NAME: FindBlockAbstract                                        */
/*                                                                            */
/******************************************************************************/


FindBlockAbstract::FindBlockAbstract(const string &a_name, const string &a_type_name):CTKObject(a_name),ReducerBlockAbstract(a_name, a_type_name){}; 

void FindBlockAbstract::compute() {

  CTKVector *invec;
  CTKVector *outvals, *outchans, *outisthing;
  
  
  if ((*input_sockets)[0]->get_data_descriptor()->is_sample_data()) {
    // Multidimensional sample data case
    Float x1;
    (*input_sockets)[0]->get_sample(x1);
    (*output_sockets)[0]->put_sample(x1);
    (*output_sockets)[1]->put_sample(1.0);
    if ((*output_sockets).size()>2)
      (*output_sockets)[2]->put_sample(1.0);
  } else {
    // Vector data case
    bool is_thing;
    (*input_sockets)[0]->get_vector(invec);
    if (n_dims==1) {
      Integer N=dim_sizes[0];
      Integer pos;
      // 1-D case where the output is a Sample
      array_iterator<Float> xp=find_thing(array_iterator<Float>(&(*invec)[0],1), array_iterator<Float>(&(*invec)[N], 1), pos, is_thing);
      (*output_sockets)[0]->put_sample(*xp);
      (*output_sockets)[1]->put_sample(pos);
      if ((*output_sockets).size()>2)
	(*output_sockets)[2]->put_sample(is_thing);
    } else if (n_dims==2) {
      // 2-D case - specialized from the N-D case for the sake of optimization
      
      outvals= new CTKVector (out_storage);
      outchans= new CTKVector (out_storage);
      outisthing= new CTKVector (out_storage);
      Float *valp=&((*outvals)[0]);
      Float *chanp=&((*outchans)[0]);
      Float *outisthingp=&((*outisthing)[0]);
      Integer pos;
      bool is_thing;
      for (Integer start=0, end=stride_x*dims_x; start<dims_y*stride_y; start+=stride_y, end+=stride_y) {
	array_iterator<Float> x=find_thing(array_iterator<Float>(&(*invec)[start], stride_x), array_iterator<Float>(&(*invec)[end], stride_x), pos, is_thing);
	*valp++=*x;
	*chanp++=pos;
	*outisthingp++=is_thing;
      }
      
      (*output_sockets)[0]->put_vector(outvals);
      (*output_sockets)[1]->put_vector(outchans);
      if ((*output_sockets).size()>2)
	(*output_sockets)[2]->put_vector(outisthing);
      else delete outisthing;
      
    } else { // n_dims > 2
      // The general N-D case.
      // JON - ADD STUFF HERE
      cerr << "Peak ndims>2 not yet implemented" << endl;
      throw(CTKError(__FILE__, __LINE__));
    }

    delete invec;   

  }

}


/******************************************************************************/
/*                                                                            */
/*       CLASS NAME: PeakBlock                                                */
/*                                                                            */
/******************************************************************************/

const string PeakBlock::type_name = "Peak";
const string PeakBlock::help_text = PEAK_BLOCK_HELP_TEXT;

PeakBlock::PeakBlock(const string &a_name):CTKObject(a_name),FindBlockAbstract(a_name, type_name) {

  make_input_sockets(1);
  make_output_sockets(3);
  
  output_sockets->set_description("out1", "peak value");
  output_sockets->set_description("out2", "peak index");
  output_sockets->set_description("out3", "is peak?");
}
 
Block* PeakBlock::clone(const string &n) const{
  Block *ablock = new PeakBlock(n.empty()?getname():n);
  return copy_this_block_to(ablock);
}


array_iterator<Float> PeakBlock::find_thing(array_iterator<Float> pbegin, array_iterator<Float> pend, Integer &pos, bool &is_thing) {

  pos = 0;
  is_thing=false;
  if (pbegin==pend || pbegin+1==pend) return pbegin;

  
  Float x0=*pbegin;
  Float x1=*(pbegin+1);
  
  array_iterator<Float> xp=pbegin+1;

  pos=1;
  for (array_iterator<Float> yp=pbegin+2; yp!=pend; ++yp, ++xp, ++pos) {
    if ((x1 > x0) && (x1 > *yp)) {
      is_thing=true; break;
    } else {
      x0=x1;
      x1=*yp;
    }
  }

  if (!is_thing) pos=0;
  
  return (is_thing)?xp:pbegin;
}



/******************************************************************************/
/*                                                                            */
/*       CLASS NAME: FindMinBlock                                             */
/*                                                                            */
/******************************************************************************/

const string FindMinBlock::type_name = "FindMin";
const string FindMinBlock::help_text = FIND_MIN_BLOCK_HELP_TEXT;

FindMinBlock::FindMinBlock(const string &a_name):CTKObject(a_name),FindBlockAbstract(a_name, type_name) {

  make_input_sockets(1);
  make_output_sockets(2);
  
  output_sockets->set_description("out1", "Min value");
  output_sockets->set_description("out2", "Min value index");
}
 
Block* FindMinBlock::clone(const string &n) const{
  Block *ablock = new FindMinBlock(n.empty()?getname():n);
  return copy_this_block_to(ablock);
}


array_iterator<Float> FindMinBlock::find_thing(array_iterator<Float> pbegin, array_iterator<Float> pend, Integer &pos, bool &is_thing) {

  pos = 0;
  is_thing=true;

  array_iterator<Float> px = min_element(pbegin, pend);
  
  pos=px-pbegin;
  return px;
  
}


/******************************************************************************/
/*                                                                            */
/*       CLASS NAME: FindMaxBlock                                             */
/*                                                                            */
/******************************************************************************/

const string FindMaxBlock::type_name = "FindMax";
const string FindMaxBlock::help_text = FIND_MAX_BLOCK_HELP_TEXT;

FindMaxBlock::FindMaxBlock(const string &a_name):CTKObject(a_name),FindBlockAbstract(a_name, type_name) {

  make_input_sockets(1);
  make_output_sockets(2);
  
  output_sockets->set_description("out1", "Max value");
  output_sockets->set_description("out2", "Max value index");
}
 
Block* FindMaxBlock::clone(const string &n) const{
  Block *ablock = new FindMaxBlock(n.empty()?getname():n);
  return copy_this_block_to(ablock);
}

array_iterator<Float> FindMaxBlock::find_thing(array_iterator<Float> pbegin, array_iterator<Float> pend, Integer &pos, bool &is_thing) {

  pos = 0;
  is_thing=true;

  array_iterator<Float> px = max_element(pbegin, pend);
  
  pos=px-pbegin;
  return px;
  
}




///////////////


/* End of ctk_reducer.cpp */
 
