/*
This is the code of the model described in:

S. Saeb, C. Weber, J. Triesch (2011)
Learning the Optimal Control of Coordinated Eye and Head Movements.
PLoS Computational Biology, 7(11): e1002253. Doi: 10.1371/journal.pcbi.1002253

Some parameters vary and analysis procedures are not given.
Output is written to 4 files to be displayed by gnuplot via the commands written out during end of execution.

Complile like:
g++ eyehead.cpp

Then run:
./a.out
*/

#include <stdio.h>
#include <math.h>

int sign(double v)
{
return (v > 0) ? 1 : ((v < 0) ? -1 : 0);
}

int main()
{
  register short int i, i2, k, k2, jtaw;
  register float t, taw;
  bool found;
  FILE *file;

  //object positions
  //const int N_OBJ = 7;
  //double r_obj_range[N_OBJ] = {10, 20, 30, 40, 50, 60, 70};
  const int N_OBJ = 1;
  double r_obj_range[N_OBJ] = {30};
  
  //simulation parameters
  double dt = 0.01; //Temporal resolution
  double T = 2.0;    //Integration duration (see the paper)
  const int dur = (const int) (T/dt+1);

  //delay-line parameters
  const int n_delay = (const int) 1.25*T/dt;
  int A = 200;           //Height of the Gaussian profile
  int gausswidth, gausswidth1, gausswidth2, gaussmean;
  double sigma = 0.002;  //Width of the Gaussian profile

  double *templt = new double[(int)(4*T/dt)+1];
  double *w_eye = new double[n_delay];
  double *w_head = new double[n_delay];
  double **s = new double*[n_delay];
  for (i=0; i<n_delay; i++) s[i] = new double[dur];

  //learning parameters
  double *grad_eye = new double[n_delay];
  double *grad_head = new double[n_delay];
  double *grad_eye_new = new double[n_delay];
  double *grad_head_new = new double[n_delay];
  double *lr_eye = new double[n_delay];
  double *lr_head = new double[n_delay];
  double *dedw_eye = new double[n_delay];
  double *dedw_head = new double[n_delay];
  double *drdw_eye = new double[n_delay];
  double *drdw_head = new double[n_delay];
  double error_best, error;
  double *out_eye = new double[dur];
  double *out_head = new double[dur];
  double *r_eye = new double[dur];
  double *r_head = new double[dur];
  double *v_eye = new double[dur];
  double *v_head = new double[dur];
  double *r_vis = new double[dur];

  double *r_eye_best = new double[dur];
  double *r_head_best = new double[dur];
  double *v_eye_best = new double[dur];
  double *v_head_best = new double[dur];
  double *out_eye_best = new double[dur];
  double *out_head_best = new double[dur];

  //creating Gaussian profiles
  bool width1found = false;
  bool width2found = false;
  i=0;
  gausswidth1 = 0;
  gausswidth2 = (int)(4.0*T/dt);
  for (t=-2*T; t<=2*T; t+=dt) {

    templt[i] = A*exp(-t*t/(2*sigma*sigma));
    
    if ((templt[i]>0.1*A)&&(!width1found)) {
      gausswidth1 = i;
      width1found = true;
    }
    if ((templt[i]<=0.1*A)&&(width1found)&&(!width2found)) {
      gausswidth2 = i;
      width2found = true;
    }
    i++; 
  }
  gausswidth = gausswidth2 - gausswidth1 + 1;
  gaussmean = (int) (4.0*T/dt)/2.0;

  //eye plant parameters
  double k_eyeplant = 1.0/4.6;
  double T1_eye = 223e-3;
  double T2_eye = 14e-3;
  double T3_eye = 4e-3;
  double AA = T1_eye / (T1_eye*T1_eye + T2_eye*T3_eye - T3_eye*T1_eye - T1_eye*T2_eye);
  double BB = T2_eye / (T2_eye*T2_eye + T3_eye*T1_eye - T1_eye*T2_eye - T2_eye*T3_eye);
  double CC = T3_eye / (T3_eye*T3_eye + T1_eye*T2_eye - T2_eye*T3_eye - T3_eye*T1_eye);

  //head plant parameters
  double k_headplant = (4.0/6.54)*2.81;
  double T1_head = 9844e-3; 
  double T2_head = 156e-3;

  //Learning Parameters
  int n_iter = 400;      // was 300  //number of iterations
  double m_inc = 1.01;   //increase factor in learning rate
  double m_dec = 0.95;   //decrease factor in learning rate

  //please set these values according to the paper:
  double k_reg_eye = 0.016;
  double k_reg_head = 0.0016;

  double alpha = 0.0;  // learning momentum

  double *vpeak = new double[N_OBJ];
  double *duration = new double[N_OBJ];

  fprintf(stderr, "going into k2...N_OBJ=%d loop\n", N_OBJ);

  k2=0;

  for (k2=0; k2<N_OBJ; k2++) {

    fprintf(stderr, " k2=%d ", k2);

    //initialization
    for (k=0; k<n_delay; k++) {
      grad_eye[k] = 0.0;
      grad_head[k] = 0.0;
      lr_eye[k] = 0.01; 
      lr_head[k] = 0.01;
      w_eye[k] = 0.0;
      w_head[k] = 0.0; 
    }

    error_best = 1.7e+308;

    fprintf(stderr, "going into k...n_iter=%d loop\n", n_iter);

    for (k=0; k<=n_iter; k++) {

      fprintf(stderr, "%d ", k);

      //Initializing the Network

      for (i=0; i<n_delay; i++)
        for (i2=0; i2<dur; i2++)  
          s[i][i2] = 0.0;

      for (i=0; i<n_delay; i++)
        s[i][0] = templt[gaussmean-(int)(gausswidth/2)+i];

      for (i=0; i<dur; i++) {
        out_eye[i] = 0.0;
        out_head[i] = 0.0;
        r_eye[i] = 0.0;
        r_head[i] = 0.0;
        v_eye[i] = 0.0;
        v_head[i] = 0.0;
        r_vis[i] = 0.0;
      }
      r_vis[0] = r_obj_range[k2];

      for (i2=0; i2<n_delay; i2++) {
        out_eye[0] += w_eye[i2]*s[i2][0];
        out_head[0] += w_head[i2]*s[i2][0];
      }

      //initial positions and velocities
      r_eye[0] = k_eyeplant*out_eye[0]*( AA + BB + CC )*dt;
      r_head[0] = k_headplant*out_head[0]*( 1.0/(T1_head-T2_head) + 1.0/(T2_head-T1_head) )*dt;
      v_eye[0] = 0.0;   
      v_head[0] = 0.0;  

      i=0;

      //calculating positions
      for (t=dt; t<=T; t+=dt) {

        i++;

        for (i2=0; i2<n_delay; i2++) s[i2][i] = templt[i2+gaussmean-(int)(gausswidth/2.0)-i];

        for (i2=0; i2<n_delay; i2++) {
          out_eye[i] += w_eye[i2]*s[i2][i];
          out_head[i] += w_head[i2]*s[i2][i];
        }
            
        for (i2=0; i2<i; i2++) {
          r_eye[i] += k_eyeplant*out_eye[i2]*( AA*exp((-t+dt*i2)/T1_eye) + BB*exp((-t+dt*i2)/T2_eye) + CC*exp((-t+dt*i2)/T3_eye) )*dt;
          r_head[i] += k_headplant*out_head[i2]*( exp((-t+dt*i2)/T1_head)/(T1_head-T2_head) + exp((-t+dt*i2)/T2_head)/(T2_head-T1_head) )*dt;
        }

        r_vis[i] = r_obj_range[k2] - r_eye[i] - r_head[i];
      }

      //calculating velocities
      for (i2=1; i2<dur; i2++) {
        v_eye[i2] = r_eye[i2]-r_eye[i2-1]/dt;
        v_head[i2] = r_head[i2]-r_head[i2-1]/dt;
      }

      //calculating the cost
      error = 0.0;
      for (i2=0; i2<dur; i2++)
        error += fabs(r_vis[i2]);
      for (i2=0; i2<n_delay; i2++)
        error += k_reg_eye*pow(w_eye[i2], 4.0) + k_reg_head*pow(w_head[i2], 4.0);

      //saving the values if better than before  
      if (error<error_best) {
        error_best = error;
        for (i2=0; i2<dur; i2++) {
          r_eye_best[i2] = r_eye[i2];
          v_eye_best[i2] = v_eye[i2];
          r_head_best[i2] = r_head[i2];
          v_head_best[i2] = v_head[i2];
          out_eye_best[i2] = out_eye[i2];
          out_head_best[i2] = out_head[i2];
        }
      }

      //Learning
      i=-1;
      for (i2=0; i2<n_delay; i2++) {
        dedw_eye[i2] = 0.0;
        dedw_head[i2] = 0.0;
      }

      for (t=dt; t<=T; t+=dt) {     

        i++;

        for (i2=0; i2<n_delay; i2++) {
          drdw_eye[i2] = 0.0;
          drdw_head[i2] = 0.0;
        }
        
        for (taw=(t<0.75)?dt:(dt+t-0.75); taw<=t; taw+=dt) {
          jtaw = (int)(taw/dt)-1;
          
          for (i2=0; i2<n_delay; i2++) {
            if (s[i2][jtaw]!=0.0) {
              drdw_eye[i2] += k_eyeplant*s[i2][jtaw]*( AA*exp((taw-t)/T1_eye) + BB*exp((taw-t)/T2_eye) + CC*exp((taw-t)/T3_eye) )*dt;
              drdw_head[i2] += k_headplant*s[i2][jtaw]*( exp((taw-t)/T1_head)/(T1_head-T2_head) + exp((taw-t)/T2_head)/(T2_head-T1_head) )*dt;
            }
          }
        }

        for (i2=0; i2<n_delay; i2++) {
          if (drdw_eye[i2]!=0.0)
            dedw_eye[i2] += -sign(r_vis[i])*0.5*pow(fabs(r_vis[i]), -0.5)*drdw_eye[i2]*dt;
          if (drdw_head[i2]!=0.0)
            dedw_head[i2] += -sign(r_vis[i])*0.5*pow(fabs(r_vis[i]), -0.5)*drdw_head[i2]*dt;
        }
      }

      for (i2=0; i2<n_delay; i2++) {
        grad_eye_new[i2] = dedw_eye[i2];
        grad_head_new[i2] = dedw_head[i2];
        
        lr_eye[i2] = lr_eye[i2]*((grad_eye[i2]*grad_eye_new[i2]==0)?1:((grad_eye[i2]*grad_eye_new[i2]>0)?m_inc:m_dec));
        lr_head[i2] = lr_head[i2]*((grad_head[i2]*grad_head_new[i2]==0)?1:((grad_head[i2]*grad_head_new[i2]>0)?m_inc:m_dec));

        grad_eye[i2] = alpha*grad_eye[i2] + (1.0-alpha)*grad_eye_new[i2];
        grad_head[i2] = alpha*grad_head[i2] + (1.0-alpha)*grad_head_new[i2];

        w_eye[i2] -= lr_eye[i2]*(4.0*k_reg_eye*pow(w_eye[i2], 3.0) + grad_eye[i2]);
        w_head[i2] -= lr_head[i2]*(4.0*k_reg_head*pow(w_head[i2], 3.0) + grad_head[i2]);     
      }

      //claculating duration and peak velocity
      vpeak[k2] = 0.0;
      duration[k2] = 0.0;
      found = false;
      for (i2=0; i2<dur; i2++) {
        if (vpeak[k2]<v_eye_best[i2]) vpeak[k2] = v_eye_best[i2];
        if ((r_eye_best[i2]>0.99*r_obj_range[k2])&&(!found)) {
          duration[k2] = i2;
          found = true;
        }
      }


      if  (k % 100 == 0) {
        //writing the results to disk
        fprintf(stderr, "\nWriting files. It is k=%d, k2=%d, obj_pos=%lf\n", k, k2, r_obj_range[k2]);

        file = fopen("results_control_eye.dat", "a");
        for (i2=0; i2<dur; i2++) {
          fprintf(file, "%lf\n", out_eye_best[i2]);
        }
        fprintf(file, "\n");
        fclose(file);

        file = fopen("results_control_head.dat", "a");
        for (i2=0; i2<dur; i2++) {
          fprintf(file, "%lf\n", out_head_best[i2]);
        }
        fprintf(file, "\n");
        fclose(file);
    
        file = fopen("results_pos_eye.dat", "a");
        for (i2=0; i2<dur; i2++) {
          fprintf(file, "%lf\n", r_eye_best[i2]);
        }
        fprintf(file, "\n");
        fclose(file);
    
        file = fopen("results_pos_head.dat", "a");
        for (i2=0; i2<dur; i2++) {
          fprintf(file, "%lf\n", r_head_best[i2]);
        }
        fprintf(file, "\n");
        fclose(file);
      }
    }
  }

  fprintf(stderr, "\ngnuplot");
  fprintf(stderr, "\nset style data lines");
  fprintf(stderr, "\nplot \"results_pos_head.dat\", \"results_pos_eye.dat\", \"results_control_head.dat\", \"results_control_eye.dat\"\n");
  
  return 0;
}

