#include "edt3d.h"
#include <math.h>

#define INFTY 100000001

/** 
 **************************************************
 * @b sum
 * @param a Long number with INFTY
 * @param b Long number with INFTY
 * @return The sum of a and b handling INFTY
 **************************************************/
static inline long sum(long a, long b) 
{
  if ((a==INFTY) || (b==INFTY))     
    return INFTY;    
  else 
    return a+b;
}

/** 
 **************************************************
 * @b prod
 * @param a Long number with INFTY
 * @param b Long number with INFTY
 * @return The product of a and b handling INFTY
 **************************************************/
static inline long prod(long a, long b) 
{
  if ((a==INFTY) || (b==INFTY)) 
    return INFTY;  
  else 
    return a*b;
}
/** 
 **************************************************
 * @b opp
 * @param a Long number with INFTY
 * @return The opposite of a  handling INFTY
 **************************************************/
static inline long opp (long a) {
  if (a == INFTY) {
    return INFTY;
  }
  else {
    return -a;
  }
}

/** 
 **************************************************
 * @b intdivint
 * @param divid Long number with INFTY
 * @param divis Long number with INFTY
 * @return The division (integer) of divid out of divis handling INFTY
 **************************************************/
static inline long intdivint (long divid, long divis) {
  if (divis == 0) 
    return  INFTY;
  if (divid == INFTY) 
    return  INFTY;
  else 
    return  divid / divis;
}

////////// Functions F and Sep for the SDT labelling
/** 
 **************************************************
 * @b F
 * @param x 
 * @param i 
 * @param gi2 
 * @return Definition of a parabola
 **************************************************/
static long F(int x, int i, long gi2)
{
  return sum((x-i)*(x-i), gi2);
}

/** 
 **************************************************
 * @b Sep
 * @param i 
 * @param u 
 * @param gi2 
 * @param gu2 
 * @return The abscissa of the intersection point between two parabolas
 **************************************************/
static long Sep(int i, int u, long gi2, long gu2) {
  return intdivint(sum( sum((long) (u*u - i*i),gu2), opp(gi2) ), 2*(u-i));
}
/////////


/** 
 **************************************************
 * @b phaseSaitoX
 * @param V Input volume
 * @param sdt_x SDT along the x-direction
 **************************************************/
//First step of  the saito  algorithm 
// (Warning   : we  store the  EDT instead of the SDT)
static void phaseSaitoX(Volume<byte> &V, Volume<int> &sdt_x, int thr)
{
  for (int z = 0; z < V.getDepth() ; z++) 	    
    for (int y = 0; y < V.getHeight() ; y++) 
      {
	if (V(0,y,z) > thr) 
	  sdt_x(0,y,z)=INFTY;
	else 	    
	  sdt_x(0,y,z)=0;  
	  
	// Forward scan
	for (int x = 1; x < V.getWidth() ; x++) 	    
	  if (V(x,y,z) > thr)
	    sdt_x(x,y,z)=sum(1, sdt_x(x-1,y,z));
	  else 
	    sdt_x(x,y,z)=0;

	//Backward scan
	for (int x = V.getWidth() -2; x >= 0; x--)      
	  if (sdt_x(x+1,y,z) < sdt_x(x,y,z)) 
	    sdt_x(x,y,z)=sum(1, sdt_x(x+1,y,z));
      }
}

/** 
 **************************************************
 * @b phaseSaitoY
 * @param sdt_x the SDT along the x-direction
 * @param sdt_xy the SDT in the xy-slices
 **************************************************/
//Second      Step   of    the       saito   algorithm    using    the
//[Meijster/Roerdnik/Hesselink] optimization
static void phaseSaitoY(Volume<int> &sdt_x, Volume<int> &sdt_xy)
{
  
  int s[sdt_x.getHeight()]; //Center of the upper envelope parabolas
  int t[sdt_x.getHeight()]; //Separating index between 2 upper envelope parabolas 
  int q; 
  int w;

  for ( int z = 0; z<sdt_x.getDepth(); z++) 	    
    for ( int x = 0; x < sdt_x.getWidth(); x++) 
      {
	q=0;
	s[0] = 0;
	t[0] = 0;
	
	//Forward Scan
	for (int u=1; u < sdt_x.getHeight() ; u++) 
	  {
	    while ((q >= 0) &&
		   (F(t[q],s[q],prod(sdt_x(x,s[q],z),sdt_x(x,s[q],z))) > 
		    F(t[q],u,prod(sdt_x(x,u,z),sdt_x(x,u,z))))
		   ) 
	      q--;
	    
	    if (q<0) 
	      {
		q=0;
		s[0]=u;
	      }
	    else 
	      {
		w = 1 + Sep(s[q],
			    u,
			    prod(sdt_x(x,s[q],z),sdt_x(x,s[q],z)),
			    prod(sdt_x(x,u,z),sdt_x(x,u,z)));
	
		if (w < sdt_x.getHeight()) 
		  {
		    q++;
		    s[q]=u;
		    t[q]=w;
		  }
	      }
	  }

	//Backward Scan
	for (int u = sdt_x.getHeight()-1; u >= 0; --u) 
	  {
	    sdt_xy(x,u,z) = F(u,s[q],prod(sdt_x(x,s[q],z),sdt_x(x,s[q],z)));	      
	    if (u==t[q]) 
	      q--;
	  }
      }
}

/** 
 **************************************************
 * @b phaseSaitoZ
 * @param sdt_xy the SDT in the xy-slices
 * @param sdt_xyz the final SDT
 **************************************************/
//Third   Step      of     the    saito   algorithm     using      the
//[Meijster/Roerdnik/Hesselink] optimization
static void phaseSaitoZ(Volume<int> &sdt_xy, Volume<int> &sdt_xyz)
{
  
  int s[sdt_xy.getDepth()]; //Center of the upper envelope parabolas
  int t[sdt_xy.getDepth()]; //Separating index between 2 upper envelope parabolas 
  int q; 
  int w;

  for ( int y = 0; y<sdt_xy.getHeight(); y++) 	    
    for ( int x = 0; x < sdt_xy.getWidth(); x++) 
      {
	q=0;
	s[0] = 0;
	t[0] = 0;
	
	//Forward Scan
	for (int u=1; u < sdt_xy.getDepth() ; u++) 
	  {
	    while ((q >= 0) &&
		   (F(t[q],s[q], sdt_xy(x,y,s[q])) > 
		    F(t[q],u,sdt_xy(x,y,u)))
		   ) 
	      q--;
	    
	    if (q<0) 
	      {
		q=0;
		s[0]=u;
	      }
	    else 
	      {
		w = 1 + Sep(s[q],
			    u,
			    sdt_xy(x,y,s[q]),
			    sdt_xy(x,y,u));
	
		if (w < sdt_xy.getDepth()) 
		  {
		    q++;
		    s[q]=u;
		    t[q]=w;
		  }
	      }
	  }

	//Backward Scan
	for (int u = sdt_xy.getDepth()-1; u >= 0; --u) 
	  {
	    sdt_xyz(x,y,u) = F(u,s[q],sdt_xy(x,y,s[q]));	      
	    if (u==t[q]) 
	      q--;
	  }
      }
}

/* sdt_x becomes squared euclidian distance transform of v.
   sdt_x must already have been initialized to the same dimensions as v.
*/
void edt_3d(Volume<int> &sdt_x, Volume<byte> &v, int thr)
{
  //Volume<int> sdt_xy(sdt_x);
  Volume<int> sdt_xy(v.getWidth(), v.getHeight(), v.getDepth());

  phaseSaitoX(v,sdt_x,thr);
  phaseSaitoY(sdt_x,sdt_xy);
  phaseSaitoZ(sdt_xy,sdt_x); //We reuse sdt_x to store the final result!!
}

static void doubleScan (int *ft, int lim, int *ftline, int *dtline, int *ss, int *tt) 
{
  int q = 0, j, w ;
  ss[0] = tt[0] = 0 ;
  for (j = 1 ; j < lim ; j ++) {
    while (q >= 0 && 
	   (j-ss[q])*(j+ss[q]-2*tt[q]) < dtline[ss[q]]-dtline[j]){
      q-- ;
    }
    if (q < 0) {
      q = 0 ;
      ss[0] = j ;
    } else {
      w = 1 + 
	((j+ss[q])*(j-ss[q])+dtline[j]-dtline[ss[q]]) / (2*(j-ss[q])) ;
      if (w < lim) {
	q ++ ;
	ss[q] = j ;
	tt[q] = w ; 
      }
    }
  }
  for (j = lim - 1 ; j >= 0 ; j --) {
    ft[j] = ftline[ss[q]] * lim + ss[q] ; // encoding    
    if (j == tt[q]) q -- ;
  }
}

void ft_3d(Volume<int> &ft, Volume<byte> &v, int thr) 
{
  // interpretation (x,y,z) in boundary IFF boundary[x][y][z] == 0
  // first phase: construction of feature transform in the x-direction
  // Vectors (x, y, z) are encoded as integers by
 //  encode(x, y, z) = z + zdim * (y + ydim * x)
  int x, y, z, xy, right, left ; 
  int xdim = v.getWidth(), ydim = v.getHeight(), zdim = v.getDepth();
  //int _INFTY = 1 + (int)sqrt((float)(xdim*xdim)+(ydim*ydim)+(zdim*zdim));
  int LIM = std::max(xdim, std::max(ydim, zdim));

  int *ftline = new int[LIM] ;
  int *ftline1 = new int[LIM] ;
  int *dtline = new int[LIM] ;
  int *ss = new int[LIM] ; // scratch variable for doublescan
  int *tt = new int[LIM] ; // scratch variable for doublescan
  
  /* The pure algorithm require a nonempty boundary; the encoding 
   * requires all coordinates nonnegative. We therefore extend the 
   * boundary with points with the plane x = xdim - 1 + INFTY.
   * Conditions: (xdim-1)^2 + ... + (zdim-1)^2 < INFTY^2
   * and (xdim-1+INFTY) * ydim * zdim <= Integer.MAX_VALUE */
  for (y = 0 ; y < ydim ; y ++) {

    for (z = 0 ; z < zdim ; z ++) {
      dtline [xdim - 1] = right = (v[z][y][xdim-1] <= thr ? 0 : INFTY) ;
      /* INFTY simulates a boundary outside the image */
      for (x = xdim - 2 ; x >= 0 ; x --) {
	dtline[x] = right = (v[z][y][x] <= thr ? 0 : right + 1) ;
      }
      ft[z][y][0] = left = dtline[0];
      for (x = 1 ; x < xdim ; x ++) {
	right = dtline[x] ;
	ft[z][y][x] = left = (x - left <= right ? left : x + right) ;
      }
    }
  }

  /* second phase: construction of feature transform in the xy-direction 
   * based on feature transform in the x-direction */
  for (z = 0 ; z < zdim ; z ++) {

    for (x = 0 ; x < xdim ; x ++) {
      for (y = 0 ; y < ydim ; y ++) {
	ftline[y] = xy = ft[z][y][x] ;
	dtline[y] = (xy - x)*(xy - x) ;
      }
      doubleScan (ftline1, ydim, ftline, dtline, ss, tt) ;
      for (y = 0 ; y < ydim ; y ++) 
	ft[z][y][x] = ftline1[y];
    }
  }

  /* third phase: construction of feature transform in the xyz-direction 
   * based on feature transform in the xy-direction */
  for (x = 0 ; x < xdim ; x ++) {

    for (y = 0 ; y < ydim ; y ++) {
      for (z = 0 ; z < zdim ; z ++) {
	ftline[z] = xy = ft[z][y][x] ;
	dtline[z] = (xy / ydim - x)*(xy / ydim - x) + (xy % ydim - y)*(xy % ydim - y);
      }
      doubleScan (ftline1, zdim, ftline, dtline, ss, tt) ;
      for (z = 0 ; z < zdim ; z ++) 
	ft[z][y][x] = ftline1[z];
    }
  }

  delete []ftline;
  delete []ftline1;
  delete []dtline;
  delete []ss;
  delete []tt;
}

void set_weight_345(int *weights)
{
  weights[0] = 3; weights[1] = 4; weights[2] = 5; 
}

void set_weight_city_block(int *weights)
{
  weights[0] = 1; weights[1] = INFTY; weights[2] = INFTY; 
}

void set_weight_chessboard(int *weights)
{
  weights[0] = 1; weights[1] = 1; weights[2] = 1; 
}

void wedt_3d(Volume<int> &dt, Volume<byte> &v, int thr, int *weights)
{
  int xdim = v.getWidth(), ydim = v.getHeight(), zdim = v.getDepth();

  for(int i=0;i<zdim;i++)
    for(int j=0;j<ydim;j++) 
      for(int k=0;k<xdim;k++) 
	dt[i][j][k] =  (v[i][j][k] <= thr) ? 0 : INFTY;

  for(int i=1;i<zdim;i++)
    for(int j=1;j<ydim-1;j++) 
      for(int k=1;k<xdim-1;k++) {
	int mval = dt[i][j][k];
	mval = std::min(mval, dt[i-1][j-1][k-1]+weights[2]);
	mval = std::min(mval, dt[i-1][j-1][k]+weights[1]);
	mval = std::min(mval, dt[i-1][j-1][k+1]+weights[2]);
	mval = std::min(mval, dt[i-1][j][k-1]+weights[1]);
	mval = std::min(mval, dt[i-1][j][k]+weights[0]);
	mval = std::min(mval, dt[i-1][j][k+1]+weights[1]);
	mval = std::min(mval, dt[i-1][j+1][k-1]+weights[2]);
	mval = std::min(mval, dt[i-1][j+1][k]+weights[1]);
	mval = std::min(mval, dt[i-1][j+1][k+1]+weights[2]);
	mval = std::min(mval, dt[i][j-1][k-1]+weights[1]);
	mval = std::min(mval, dt[i][j-1][k]+weights[0]);
	mval = std::min(mval, dt[i][j-1][k+1]+weights[1]);
	mval = std::min(mval, dt[i][j][k-1]+weights[0]);
	dt[i][j][k] = mval;
      }

  for(int i=zdim-2;i>=0;i--)
    for(int j=ydim-2;j>=1;j--) 
      for(int k=xdim-2;k>=1;k--) {
	int mval = dt[i][j][k];
	mval = std::min(mval, dt[i+1][j+1][k+1]+weights[2]);
	mval = std::min(mval, dt[i+1][j+1][k]+weights[1]);
	mval = std::min(mval, dt[i+1][j+1][k-1]+weights[2]);
	mval = std::min(mval, dt[i+1][j][k+1]+weights[1]);
	mval = std::min(mval, dt[i+1][j][k]+weights[0]);
	mval = std::min(mval, dt[i+1][j][k-1]+weights[1]);
	mval = std::min(mval, dt[i+1][j-1][k+1]+weights[2]);
	mval = std::min(mval, dt[i+1][j-1][k]+weights[1]);
	mval = std::min(mval, dt[i+1][j-1][k-1]+weights[2]);
	mval = std::min(mval, dt[i][j+1][k+1]+weights[1]);
	mval = std::min(mval, dt[i][j+1][k]+weights[0]);
	mval = std::min(mval, dt[i][j+1][k-1]+weights[1]);
	mval = std::min(mval, dt[i][j][k+1]+weights[0]);
	dt[i][j][k] = mval;
      }
}

