#include "perceptron/pmc_network.hh"
#include "perceptron/pmc_datatypes.hh"
#include <iostream>
#include <ext/stdio_filebuf.h>
#include <unistd.h>
#include <cmath>
#include <time.h>
#include <algorithm>

static const double pi = 3.1415926535897931;
typedef std::pair<pmc::datas_type, pmc::datas_type> db_t;

/*
** A class where you can place all your dirty code...
*/
class Framework {

public:
  
  Framework(): nn(NULL) {
    int in[2]; /* Gnuplot in */
    int out[2]; /* Gnuplot out*/

    green.push_back(1);
    green.push_back(-1);
    red.push_back(-1);
    red.push_back(1);

    layers.push_back(20);
    layers.push_back(10);
    layers.push_back(15);

    rungp(in, out);
    gpcin = new std::istream(new __gnu_cxx::stdio_filebuf<char>
			     (in[0], std::ios_base::in, true, 1));
    gpcout = new std::ostream(new __gnu_cxx::stdio_filebuf<char>
			      (out[1], std::ios_base::out, true, 8192));
    
    (*gpcout) << " set mouse" << std::endl;
    (*gpcout) << " set print \"-\"" << std::endl;
    (*gpcout) << " set terminal x11" << std::endl;
    (*gpcout) << " plot [-1:1] [-1:1] -2 notitle" << std::endl;
    srand(time(NULL));
  }

  void rungp(int in[2], int out[2]) {
    int pid;
    
    /* first, create pipes. */
    if (pipe(in))
      perror("pipe 1");
    if (pipe(out))
      perror("pipe 2");
    
    pid = fork();
    if (pid > 0) { /* Parent */
      /* first, close unnecessary file descriptors */
      close(in[1]);  /* we don't need to write to this pipe.  */
      close(out[0]); /* we don't need to read from this pipe. */
      return ;
    } else if (pid == 0) { /* Child */
      close(out[1]); /* we don't need to write to this pipe.  */
      close(in[0]);  /* we don't need to read from this pipe. */
      dup2(in[1], STDOUT_FILENO);
      dup2(out[0], STDIN_FILENO);
      execlp("gnuplot", "gnuplot", 0);
      perror("exec");
      exit(1);
    } else { /* pid < 0 : Error*/
      perror("fork");
    }
  }

  std::pair<double,double> &tkmouse()
  {
    *gpcout << " pause mouse" << std::endl
	    << " print (MOUSE_X)"<< std::endl
	    << " print (MOUSE_Y)" << std::endl;
    *gpcin >> mouse.first >> mouse.second;
    //std::cout << mouse.first << " and " << mouse.second << std::endl;
    return mouse;
  }

  void init_normalize(std::vector<db_t> &db) {
    centerx = 0;
    centery = 0;
    for (unsigned i = 0; i < db.size(); ++i) {
      centerx += db[i].first[0];
      centery += db[i].first[1];
    }
    centerx /= db.size();
    centery /= db.size();
    scalex = 0;
    scaley = 0;
    for (unsigned i = 0; i < db.size(); ++i) {
      double x =  db[i].first[0] - centerx;
      double y =  db[i].first[1] - centery;
      if (fabs(x) > scalex)
	scalex = fabs(x);
      if (fabs(y) > scaley)
	scaley = fabs(x);
    }
  }

  std::vector<double> &normalize(std::vector<double> &pt) {
      pt[0] -= centerx;
      pt[1] -= centery;
      pt[0] /= scalex;
      pt[1] /= scaley;
      return pt;
  }

  std::vector<double> &unnormalize(std::vector<double> &pt) {
      pt[0] *= scalex;
      pt[1] *= scaley;
      pt[0] += centerx;
      pt[1] += centery;
      return pt;
  }

  void learn(bool animate) {
    unsigned  i;
    error.clear();
    results.clear();
    
    //  for (int k = 0; k < 1000; ++k) {
    double errorgen = 1000000;
    

    init_normalize(db);

    for (i = 0; i < db.size(); ++i) {
      normalized.push_back(db_t(db[i].first, db[i].second));
      normalize(normalized.back().first);
    }
    for (int k = 0; k < 1000; ++k) {
      //do {
      error.push_back(errorgen);
      std::random_shuffle(normalized.begin(), normalized.end());
      for (i = 0; i < normalized.size(); ++i) {
	nn->learn(normalized[i].first, normalized[i].second, true, 0.2);
	nn->weightsUpdate();
      }
      errorgen = 0;
      for (i = 0; i < normalized.size(); ++i) {
	pmc::datas_type out((*nn)(normalized[i].first));
	for (unsigned j = 0; j < out.size(); ++j)
	  errorgen +=  sqrt((out[j] - normalized[i].second[j]) * (out[j] - normalized[i].second[j]));
      }
      if (animate) 
	redisp(get_results());
      std::cout << "Erreur: " << errorgen/(double) db.size() << std::endl;
      //} while (errorgen < error->back() * 1.2);
    }
    error.erase(error.begin()); // Not a real error...
}

  void savdatas(std::vector<db_t> &db,
		const char *dataGname, const char *dataRname) {
    std::vector<db_t>::iterator i;
    
    std::ofstream dataR(dataRname, std::ios_base::out);
    std::ofstream dataG(dataGname, std::ios_base::out);
    for (i = db.begin(); i != db.end(); ++i) {
      if ((*i).second == red)
	dataR << (*i).first[0] << "\t" << (*i).first[1] <<  std::endl;
      else if ((*i).second == green)
	dataG << (*i).first[0] << "\t" << (*i).first[1] <<  std::endl;
      else
	assert(1);
    }
  }

  void redisp(std::vector<db_t> &db) {
    std::vector<db_t>::iterator i;
    
    *gpcout << "plot  [-1:1] [-1:1] '-' notitle, '-' notitle" << std::endl;
    *gpcout << "-2\t-2" << std::endl;
    for (i = db.begin(); i != db.end(); ++i) {
      if ((*i).second == red)
	*gpcout << (*i).first[0] << "\t" << (*i).first[1] <<  std::endl;
    }
    *gpcout << "e" << std::endl;
    *gpcout << "-2\t-2" << std::endl;
    for (i = db.begin(); i != db.end(); ++i) {
      if ((*i).second == green)
	*gpcout << (*i).first[0] << "\t" << (*i).first[1] <<  std::endl;
    }
    *gpcout << "e" << std::endl;
  }

  pmc::datas_type &genpt(double origx, double origy, double sx,
			 double sy, double theta)
  {
    // Arg "new"!! Lost forever
    pmc::datas_type *ret = new pmc::datas_type();
    
    double s = sqrt(-2. * log(rand() / (double) RAND_MAX));
    double c = 2 * pi * (rand() / (double) RAND_MAX);
    double X = sx * s * cos(c);
    double Y = sy * s * sin(c);
    
    c = cos(theta * pi);
    s = sin(theta * pi);
    ret->push_back(origx + c * X - s * Y);
    ret->push_back(origy + s * X + c * Y);
    
    return *ret;
  }
  
  void mkgauss(pmc::datas_type &color)
  {
    double sx, sy, theta;
    int nbpts;
    
    std::cin >> sx >> sy >> theta >> nbpts;
    for (int i = 0; i < nbpts; ++i)
      db.push_back(db_t(genpt(mouse.first, mouse.second, sx, sy, theta), 
			color));
  }

  void mkone(pmc::datas_type &color)
  {
    // ARG!! "new" ! --> lost forever!
    pmc::datas_type *ret = new pmc::datas_type();
    
    ret->push_back(mouse.first);
    ret->push_back(mouse.second);
    db.push_back(db_t(*ret, color));
  }

  // FIXME : make a lazy founction for this...
  std::vector<db_t> &get_results() 
  {
    results.clear();
    for (std::vector<db_t>::iterator i = db.begin(); i != db.end(); ++i) {
      pmc::datas_type *in = new pmc::datas_type() ;
      pmc::datas_type out;
      in->push_back((*i).first[0]);
      in->push_back((*i).first[1]);
      normalize(*in);
      out = (*nn)(*in);
      unnormalize(*in);
      if (out[0] > out[1])
	results.push_back(db_t(*in , green));
      else
	results.push_back(db_t(*in , red));
    }
    for (double y = -1; y < 1; y += 0.05) {
      for (double x = -1; x < 1; x += 0.05) {
	pmc::datas_type *in = new pmc::datas_type() ;
	pmc::datas_type out;
	in->push_back(x);
	in->push_back(y);
	normalize(*in);
	out = (*nn)(*in);
	unnormalize(*in);
	if (out[0] > out[1])
	  results.push_back(db_t(*in , green));
	else
	  results.push_back(db_t(*in , red));
      }
    }
    return results;
  }


  void run() 
  {
    while (42) {
      std::string button;
      std::cin >> button;
      std::cout << "received " << button << std::endl;
      if (button == "CR") {
	tkmouse();
	mkone(red);
	redisp(db);
      } else if (button == "CG") {
	tkmouse();
	mkone(green);
	redisp(db);
      } else if (button == "CRG") {
	tkmouse();
	mkgauss(red);
	redisp(db);
      } else if (button == "CGG") {
	tkmouse();
	mkgauss(green);
	redisp(db);
      } else if (button == "CL") {
	db.clear();
	redisp(db);
      } else if (button == "Q") {
	exit(0);
      } else if (button == "SDB") { /* Save  Datas */
	savdatas(db, "in_greens.dat", "in_reds.dat");
	if (! results.empty()) {
	  savdatas(results, "out_greens.dat", "out_reds.dat");
	}
	if (! error.empty()) {
	  std::ofstream errorstr("error.dat", std::ios_base::out);
	  for (std::vector<double>::iterator j = error.begin();
	       j != error.end(); ++j)
	    errorstr << (*j) << std::endl;
	}
      } else if (button == "RNN") { /* Reset NN*/
	if (nn != NULL) {
	  delete nn;
	  nn = NULL;
	}
      } else if (button == "L") { /* Learn */
	if (nn == NULL)
	  nn = new  pmc::Network(2, 2, layers, 0.05, new pmc::AFLogNormalized());
	learn(false);

	// results << db;
	std::cout << *nn << std::endl;
	redisp(get_results());
      } else if (button == "AL") { /* Animated Learn */
	if (nn == NULL)
	  nn = new  pmc::Network(2, 2, layers, 0.05, new pmc::AFLogNormalized());
	learn(true);

	// results << db;
	std::cout << *nn << std::endl;
	redisp(get_results());

      } else {
	std::cout << "Error in message: " << button << std::endl;
      }
    }
  }
  
  
private:
  std::istream *gpcin;
  std::ostream *gpcout;

  pmc::datas_type green;
  pmc::datas_type red;

  // Input
  std::vector<db_t> db;
  std::vector<db_t> normalized;
  // Output
  std::vector<db_t> results;
  // Error evolve
  std::vector<double> error;
  double centerx;
  double centery;
  double scalex;
  double scaley;


  // Init PMC
  std::vector<unsigned int> layers;
  pmc::Network *nn; //(2, 2, layers, 0.5, new pmc::AFLog());
  //std::cout << pmc << std::endl;

  std::pair<double,double> mouse;

};


int main(int argc, char **argv) 
{
  Framework *f = new Framework();

  f->run();
}

