#include "surface/Graph.h"
#include <iostream>
#include <algorithm>

using namespace std;

namespace surface
{

static inline unsigned int hashkey(int k, int j, int i)
{
  return 1000 * 1000 * k + 1000 * j + i;
}




Graph::Graph(): boundary(0)
{

}


Graph::~Graph()
{
  reset();
}


void Graph::reset()
{
  for (Nodes::iterator it = nodes.begin(); it != nodes.end(); ++it)     //Delete all Nodes created and stored by this
    delete it->second;

  nodes.clear();
  nid2nodes.clear();
  boundary = nullptr;
}

void Graph::construct(const vector<Point3d> &boundary_)
{
  boundary = &boundary_;

  //1. Make a Node for each boundary-point
  for (int i = 0, B = boundary->size(); i < B; ++i)
  {
    //   Add nodes to nodes[] using a coord-based hash-key
    const Point3d &bp = (*boundary)[i];
    //   This allows quickly finding the boundary-neighbors of a node
    unsigned int key = hashkey(bp.x, bp.y, bp.z);

    //make node for boundary-point bp
    Node* n = new Node(i, key);                   
    //store link to Point3d-index in boundary in node
    //store hash-key of node in node itself
    //add new node to nodes[]

    nodes.insert(make_pair(key, n));
    nid2nodes.insert(make_pair(i, n));
  }

  //2. Create neighbor-relations between all nodes:
  for (Nodes::iterator it = nodes.begin(); it != nodes.end(); ++it)
  {
    Node* n = it->second;

    const Point3d &np = (*boundary)[n->id];             //This is the boundary-point for node 'n'

    int nbs = 0;
    int x = np.x, y = np.y, z = np.z;                   //Find all boundary-neighbors of 'n'
    for (int xx = x - 1; xx <= x + 1; ++xx)
      for (int yy = y - 1; yy <= y + 1; ++yy)
        for (int zz = z - 1; zz <= z + 1; ++zz)
        {
          if (xx == x && yy == y &&zz == z) continue;   //'n' is not a neighbor of itself

          unsigned int nkey = hashkey(xx, yy, zz);      //Search boundary-node; if we found it,
          Nodes::const_iterator nit = nodes.find(nkey); //then we've got a neighbor-node of 'n'
          if (nit == nodes.end()) continue;

          Node* n2 = nit->second;                       //Connect neighbor with n bidirectionally

          const Point3d &np2 = (*boundary)[n2->id];

          float w0 = np.dist(np2);
#ifdef USE_MAGIC_GRAPH_WEIGHTS
          float w = (fabs(w0 - 1) < 1.0e-6) ? 0.9016 :
                    (fabs(w0 - sqrt(2)) < 1.0e-6) ? 1.289 : 1.615;
#else
          float w = w0;
#endif

          n->addNeighbor(n2, w);
          n2->addNeighbor(n, w);
          ++nbs;
        }
  }
}

///Given two graph nodes, computes shortest-path in graph between them.
void Graph::shortestPath(int pid1, int pid2, Path &path) const
{
  //Returns the length of this path.
  const float       INFTY = 1.0e+8;
  int           NN = nodes.size();
  vector<float>     dist(NN,
                         INFTY);        //distances of all vertices to pid1, initially set to infinity
  vector<Node*>     previous(NN,
                             (Node*)0);   //previous nodes of all vertices along shortest-paths, initially none
  multimap<float, Node*> PQ;              //priority queue for visiting vertices
  vector<multimap<float, Node*>::iterator> ptrs(
    NN);  //locations in PQ of all vertices; used for modifying priorities

  dist[pid1] = 0;                   //pid1 is at 0 distance from itself

  for (auto it = nodes.begin(), ie = nodes.end(); it != ie; ++it)
  {
    Node* v = it->second;             //add locations in priority queue of all nodes, so we can easily
    ptrs[v->id] = PQ.insert(make_pair(dist[v->id], v)); //update their priorities
  }

  const Point3d &target =
    (*boundary)[pid2];      //cache target voxel, for fast distance computations

  Node* last = 0;
  float length = 0;
  while (!PQ.empty())                 //do the search of pid2 from pid1 onwards:
  {
    Node* u = PQ.begin()->second;         //get min-distance vertex from the queue
    if (u->id == pid2)
    {
      last = u;
      length = dist[u->id];
      break;
    }
    //test if target reached
    PQ.erase(PQ.begin());             //remove min-distance vertex from queue
    ptrs[u->id] = PQ.end();             //mark this vertex as visited

    float hu = (*boundary)[u->id].dist(target);   //heuristic cost for vertex u

    for (Node::Neighbors::iterator ei = u->neighbors.begin(), ee = u->neighbors.end(); ei != ee; ++ei)
    {
      //search neighbors of u
      Node* v = ei->first;
      if (ptrs[v->id] == PQ.end()) continue;    //neighbor v visited, nothing to do

      float w = ei->second;
      float hv = (*boundary)[v->id].dist(target);
      float alt = dist[u->id] + w + hv -
                  hu;    //cost to go u->, using graph + A* heuristic

      if (alt < dist[v->id])            //found a shorter path from pid1 to v
      {
        dist[v->id] = alt;            //update distance of v
        previous[v->id] = u;          //mark that we can better reach v from u
        PQ.erase(ptrs[v->id]);          //update priority of v in queue
        ptrs[v->id] = PQ.insert(make_pair(alt, v));
      }
    }
  }

  Path revpath;

  for (;;)
  {
    revpath.push_back(last);
    if (last->id == pid1) break;
    last = previous[last->id];
  }

  path.clear();
  path.resize(revpath.size());
  for (int i = revpath.size() - 1; i >= 0; --i)
    path[revpath.size() - 1 - i] = revpath[i];

  length += (*boundary)[pid1].dist(target);

  path.length = length;
}


Graph::Node* Graph::Path::midPoint()
{
  float crt_d = 0;
  float mid_d = length / 2;
  float prev_err = mid_d;               //Distance from start-point to midpoint

  Node* prev = (*this)[0];
  for (int i = 1, N = size(); i < N; ++i)
  {
    Node* crt = (*this)[i];
    Node::Neighbors::iterator nit = prev->neighbors.find(crt);
    if (nit == prev->neighbors.end())
    {
      cout << "Error: Path::midPoint: missing connection: point " << prev->id <<
           " hasn't " << crt->id << " as neighbor" << endl;
      return 0;
    }

    crt_d += nit->second;             //Distance from crt-point to startpoint
    float err = fabs(crt_d - mid_d);          //Distance from crt-point to midpoint
    if (err > prev_err)               //Current point further than previous one from midpoint;
      return prev;                //So midpoint is the previous point

    prev = crt;
    prev_err = err;
  }

  cout << "Error: Path::midPoint: should not arrive here" << endl;

  return 0;
}


void Graph::shortestPath(const FeaturePoints &fp, int skelpoint, Path &path) const
{
  //Computes best approximation of the shortest-path between the 'true' FPs of given
  //skeleton point. Returns the length of the path.
  const FeaturePoints::FeatureTriple &f = fp.featurePoints[skelpoint];

  shortestPath(f.first, f.second, path);
}



void Graph::geoFront(int pid1, int pid2, Path &path, vector<float> &dist)
{
  const float       INFTY = 1.0e+8;
  int           NN = nodes.size();
  vector<Node*>     previous(NN,
                             (Node*)0);   //previous nodes of all vertices along shortest-paths, initially none
  multimap<float, Node*> PQ;              //priority queue for visiting vertices
  vector<multimap<float, Node*>::iterator> ptrs(
    NN);  //locations in PQ of all vertices; used for modifying priorities

  dist[pid1] = 0;                   //pid1 is at 0 distance from itself

  for (Nodes::iterator it = nodes.begin(), ie = nodes.end(); it != ie; ++it)
  {
    Node* v = it->second;             //add locations in priority queue of all nodes, so we can easily
    ptrs[v->id] = PQ.insert(make_pair(dist[v->id], v)); //update their priorities
  }

  const Point3d &target =
    (*boundary)[pid2];      //cache target voxel, for fast distance computations

  Node* last = 0;
  float length = 0;
  while (!PQ.empty())                 //do the search of pid2 from pid1 onwards:
  {
    Node* u = PQ.begin()->second;         //get min-distance vertex from the queue

    PQ.erase(PQ.begin());             //remove min-distance vertex from queue
    ptrs[u->id] = PQ.end();             //mark this vertex as visited

    if (u->id == pid2)
    {
      last = u;  //test if target reached
      length = dist[u->id];
      break;
    }

    for (Node::Neighbors::iterator ei = u->neighbors.begin(),
         ee = u->neighbors.end(); ei != ee; ++ei)
    {
      //search neighbors of u
      Node* v = ei->first;
      if (ptrs[v->id] == PQ.end()) continue;    //neighbor v visited, nothing to do

      float w = ei->second;
      float alt = dist[u->id] + w;        //cost to go u->, using graph + A* heuristic

      if (alt < dist[v->id])            //found a shorter path from pid1 to v
      {
        dist[v->id] = alt;            //update distance of v
        previous[v->id] = u;          //mark that we can better reach v from u
        PQ.erase(ptrs[v->id]);          //update priority of v in queue
        ptrs[v->id] = PQ.insert(make_pair(alt, v));
      }
    }
  }

  //the distanceof PQ points is not final; set it to some known vaalue
  while (!PQ.empty())
  {
    Node* u = PQ.begin()->second;  //get min-distance vertex from the queue
    dist[u->id] = INFTY;
    PQ.erase(PQ.begin());   //remove point
  }

  Path revpath;
  for (;;)
  {
    revpath.push_back(last);
    if (last->id == pid1) break;
    last = previous[last->id];
  }

  path.clear();
  path.resize(revpath.size());
  for (int i = revpath.size() - 1; i >= 0; --i)
    path[revpath.size() - 1 - i] = revpath[i];

  length += (*boundary)[pid1].dist(target);

  path.length = length;
}



void Graph::geoLoci(const FeaturePoints &fp, int skelpoint, float eps,
                    vector<float> &dist)
{
  const FeaturePoints::FeatureTriple &f = fp.featurePoints[skelpoint];
  const float       INFTY = 1.0e+8;

  int NN = nodes.size();
  vector<float>  dist1(NN, INFTY);
  vector<float>  dist2(NN, INFTY);

  Path path1;
  geoFront(f.first, f.second, path1,
           dist1);          //DT of f.first on input surface
  Path path2;
  geoFront(f.second, f.first, path2,
           dist2);          //DT of f.second on input surface
  dist.resize(NN);

  float mind = INFTY;

  for (Nodes::iterator it = nodes.begin(), ie = nodes.end(); it != ie; ++it)
  {
    Node* v = it->second;
    int vid = v->id;

    dist[vid] = 0.0f;
    if (dist1[vid] < 0.5 * INFTY && dist2[vid] < 0.5 * INFTY)
    {
      float val = dist1[vid] + dist2[vid];
      mind = std::min(mind, val); //compute inf
    }
  }

  for (Nodes::iterator it = nodes.begin(), ie = nodes.end(); it != ie; ++it)
  {
    Node* v = it->second;
    int vid = v->id;

    if (dist1[vid] < 0.5 * INFTY && dist2[vid] < 0.5 * INFTY)
    {
      float val = dist1[vid] + dist2[vid];

      //second term increases 0..1 from f.second to f.first;
      if (val <= mind + eps)
      {
        dist[vid] = pow(val / (mind + eps), 20.0f);
        //dist[vid] = val/200;
      }

      val = fabsf(dist1[vid] - dist2[vid]);

      //second term increases 0..1 from f.second to f.first;
      if (val < 1.0f)
      {
        dist[vid] = 1.0f / (val + 0.1f);
      }


    }
  }

  //detectSaddles(dist);

}

void Graph::detectSaddles(vector<float> &dist)
{
  for (Nodes::iterator it = nodes.begin(), ie = nodes.end(); it != ie; ++it)
  {
    Node* v = it->second;
    const int vid = v->id;

    if (dist[vid] < 1e-7f) continue;

    const Point3d &np = (*boundary)[vid];

    int nh = 0, nl = 0;

    int x = np.x, y = np.y, z = np.z;  //Find all boundary-neighbors of 'n'
    for (int xx = x - 1; xx <= x + 1; ++xx)
      for (int yy = y - 1; yy <= y + 1; ++yy)
        for (int zz = z - 1; zz <= z + 1; ++zz)
        {
          if (xx == x && yy == y && zz == z) continue;

          unsigned int nkey = hashkey(xx, yy, zz);
          Nodes::const_iterator nit = nodes.find(nkey);

          if (nit == nodes.end()) continue;

          Node* n2 = nit->second;

          //const Point3d& np2 = (*boundary)[n2->id];

          if (dist[n2->id] < 1e-7f) continue;

          if (dist[vid] > dist[n2->id]) nl++;
          else if (dist[vid] < dist[n2->id]) nh++;

        }

    if (nh == 2 && nl == 2) printf("bla\n");
  }

}

//Update the 'length' field for this path.
void Graph::updateLength(Path &p) const
{
  //Use the same distance-metric as for the Graph underlying the path.
  p.length = 0;

  for (int i = 0, sz = p.size(); i < sz; ++i)
  {
    Node* crt = p[i];
    Node* nxt = p[(i + 1) % sz];

    Node::Neighbors::iterator it = crt->neighbors.find(nxt);
    if (it == crt->neighbors.end())
    {
      const Point3d &p1 = nodePosition(crt->id);
      const Point3d &p2 = nodePosition(nxt->id);
      p.length += p1.dist(p2);
    }
    else
      p.length += it->second;
  }
}

}
