#include "FeaturePoints.h"
#include "SkeletonReconstructor.h"
#include "coord3.h"
#include "Volume.h"
#include "utils.h"
#include "surface/Graph.h"

#include <Eigen/Core>
#include <Eigen/SVD>
#include <Eigen/Dense>



typedef coord3s(*PrevToLocal)(coord3s, int, int);

template <coord3s(*T)(coord3s, int, int)>
FORCE_INLINE coord3s findShortestPrevs(SimpleVolume<coord3s>& sdt_x, Volume<byte>& v, coord3s cur, const Volume<float>& old_edt)
{
  coord3s maxCoord = v(cur) > 0 ? cur: coord3s(-1, -1, -1);
  float maxDistance = v(cur) > 0 ? old_edt(cur) : 0.0f;

  for (int u = -1; u <= 1; u++)
    for (int v = -1; v <= 1; v++)
    {
      coord3s prevIndex = T(cur, u, v);
      coord3s prev = sdt_x(prevIndex);
      if (prev.x == -1 || prev == maxCoord)
        continue;

      float distance = old_edt(prev) - sqrtf((cur - prev).lengthSquared());

      if (distance >= maxDistance)
      {
        maxCoord = prev;
        maxDistance = distance;
      }
    }

  return maxCoord;
}

inline coord3s ForwardPrevZ(coord3s cur, int u, int v) { return coord3s(cur.x + v, cur.y + u, cur.z - 1); }
inline coord3s BackwardPrevZ(coord3s cur, int u, int v) { return coord3s(cur.x + v, cur.y + u, cur.z + 1); }

inline coord3s ForwardPrevX(coord3s cur, int u, int v) { return coord3s(cur.x - 1, cur.y + v, cur.z + u); }
inline coord3s BackwardPrevX(coord3s cur, int u, int v) { return coord3s(cur.x + 1, cur.y + v, cur.z + u); }

inline coord3s ForwardPrevY(coord3s cur, int u, int v) { return coord3s(cur.x + v, cur.y - 1, cur.z + u); }
inline coord3s BackwardPrevY(coord3s cur, int u, int v) { return coord3s(cur.x + v, cur.y + 1, cur.z + u); }


coord3s GetClosestCoord(const Volume<float>& old_edt, coord3s origin, coord3s a, coord3s b)
{
  if (a.x == -1) return b;
  if (b.x == -1) return a;

  float dista = old_edt(a) - sqrtf((a - origin).lengthSquared());
  float distb = old_edt(b) - sqrtf((b - origin).lengthSquared());

  return dista > distb ? a : b;
}

void GetMinEdt(SimpleVolume<coord3s>& edt_a, SimpleVolume<coord3s>& edt_b, SimpleVolume<coord3s>& edt_out,
  const Volume<float>& old_edt, bool min_out = false)
{
#pragma omp parallel for
  for (int z = 1; z < edt_a.getDepth() - 1; z++)
    for (int y = 1; y < edt_a.getHeight() - 1; y++)
      for (int x = 1; x < edt_a.getWidth() - 1; x++)
      {
        coord3s cur(x, y, z);
        coord3s m = GetClosestCoord(old_edt, cur, edt_a(cur), edt_b(cur));

        if (min_out)
          m = GetClosestCoord(old_edt, cur, edt_out(cur), m);

        edt_out(cur) = m;
      }
}

void InvEdtInitPaddedAreas(SimpleVolume<coord3s>& sdt)
{
  // z pass
  for (int y = 0; y < sdt.getHeight(); y++)
    for (int x = 0; x < sdt.getWidth(); x++)
    {
      sdt(x, y, 0) = coord3s(-1, -1, -1);
      sdt(x, y, sdt.getDepth() - 1) = coord3s(-1, -1, -1);
    }

  // y pass
  for (int z = 0; z < sdt.getDepth(); z++)
    for (int x = 0; x < sdt.getWidth(); x++)
    {
      sdt(x, 0, z) = coord3s(-1, -1, -1);
      sdt(x, sdt.getHeight() - 1, z) = coord3s(-1, -1, -1);
    }

  // x pass
  for (int z = 0; z < sdt.getDepth(); z++)
    for (int y = 0; y < sdt.getHeight(); y++)
    {
      sdt(0, y, z) = coord3s(-1, -1, -1);
      sdt(sdt.getWidth() - 1, y, z) = coord3s(-1, -1, -1);
    }
}


void InvEdtZ(SimpleVolume<coord3s>& sdt_zf, SimpleVolume<coord3s>& sdt_zb, SimpleVolume<coord3s>& sdt_out,
  Volume<byte>& v, const Volume<float>& old_edt)
{

  // forward pass
  for (int z = 1; z < v.getDepth() - 1; z++)
#pragma omp parallel for
    for (int y = 1; y < v.getHeight() - 1; y++)
      for (int x = 1; x < v.getWidth() - 1; x++)
        sdt_zf(x, y, z) = findShortestPrevs<ForwardPrevZ>(sdt_zf, v, coord3s(x, y, z), old_edt);


  // backward pass
  for (int z = v.getDepth() - 2; z > 0; z--)
#pragma omp parallel for
    for (int y = 1; y < v.getHeight() - 1; y++)
      for (int x = 1; x < v.getWidth() - 1; x++)
        sdt_zb(x, y, z) = findShortestPrevs<BackwardPrevZ>(sdt_zb, v, coord3s(x, y, z), old_edt);

  GetMinEdt(sdt_zf, sdt_zb, sdt_out, old_edt, true);

}

void InvEdtY(SimpleVolume<coord3s>& sdt_yf, SimpleVolume<coord3s>& sdt_yb, SimpleVolume<coord3s>& sdt_out,
  Volume<byte>& v, const Volume<float>& old_edt)
{

  // forward pass
  for (int y = 1; y < v.getHeight() - 1; y++)
#pragma omp parallel for
    for (int z = 1; z < v.getDepth() - 1; z++)
      for (int x = 1; x < v.getWidth() - 1; x++)
        sdt_yf(x, y, z) = findShortestPrevs<ForwardPrevY>(sdt_yf, v, coord3s(x, y, z), old_edt);


  // backward pass
  for (int y = v.getHeight() - 2; y > 0; y--)
#pragma omp parallel for
    for (int z = 1; z < v.getDepth() - 1; z++)
      for (int x = 1; x < v.getWidth() - 1; x++)
        sdt_yb(x, y, z) = findShortestPrevs<BackwardPrevY>(sdt_yb, v, coord3s(x, y, z), old_edt);


  GetMinEdt(sdt_yf, sdt_yb, sdt_out, old_edt, true);

}

void InvEdtX(SimpleVolume<coord3s>& sdt_xf, SimpleVolume<coord3s>& sdt_xb, SimpleVolume<coord3s>& sdt_out,
  Volume<byte>& v, const Volume<float>& old_edt)
{

  // forward pass
  for (int x = 1; x < v.getWidth() - 1; x++)
#pragma omp parallel for
    for (int z = 1; z < v.getDepth() - 1; z++)
      for (int y = 1; y < v.getHeight() - 1; y++)
        sdt_xf(x, y, z) = findShortestPrevs<ForwardPrevX>(sdt_xf, v, coord3s(x, y, z), old_edt);


  // backward pass
  for (int x = v.getWidth() - 2; x > 0; x--)
#pragma omp parallel for
    for (int z = 1; z < v.getDepth() - 1; z++)
      for (int y = 1; y < v.getHeight() - 1; y++)
        sdt_xb(x, y, z) = findShortestPrevs<BackwardPrevX>(sdt_xb, v, coord3s(x, y, z), old_edt);

  GetMinEdt(sdt_xf, sdt_xb, sdt_out, old_edt);

}

void InvEdtCombine(Volume<int>& sdt_x, SimpleVolume<coord3s>& sdt_xyz, const Volume<float>& old_edt)
{
#pragma omp parallel for
  for (int z = 1; z < sdt_x.getDepth() - 1; z++)
    for (int y = 1; y < sdt_x.getHeight() - 1; y++)
      for (int x = 1; x < sdt_x.getWidth() - 1; x++)
      {
        coord3s cur(x, y, z);
        coord3s origin = sdt_xyz(cur);

        if (origin.x == -1)
        {
          sdt_x(cur) = 0;
          continue;
        }

        int distOld = int(round(old_edt(origin) * old_edt(origin)));
        int distNew = distOld - (cur - origin).lengthSquared();

        sdt_x(cur) = distNew < 0 ? 0 : distNew;
      }
}



void InvEdt(Volume<int>& sdt_final, SimpleVolume<coord3s>& sdt_out, Volume<byte>& v, const Volume<float>& old_edt)
{
  SimpleVolume<coord3s> sdt_xf(v.getWidth(), v.getHeight(), v.getDepth());
  SimpleVolume<coord3s> sdt_xb(v.getWidth(), v.getHeight(), v.getDepth());

  InvEdtInitPaddedAreas(sdt_xf);
  InvEdtInitPaddedAreas(sdt_xb);

  InvEdtX(sdt_xf, sdt_xb, sdt_out, v, old_edt);
  InvEdtY(sdt_xf, sdt_xb, sdt_out, v, old_edt);
  InvEdtZ(sdt_xf, sdt_xb, sdt_out, v, old_edt);
  InvEdtCombine(sdt_final, sdt_out, old_edt);

}

SkeletonReconstructor::SkeletonReconstructor(const std::vector<Point3d>& skelPoints, const Volume<float>& edt)
  : skelPoints(skelPoints), edt(edt)
{
  skeletonMask.makeVolume(edt.getWidth(), edt.getHeight(), edt.getDepth());
  skeletonMask.clearVolume(0);
}

Volume<byte> SkeletonReconstructor::Reconstruct(ReconstructionSmoothingType smoothingType, int radius,
  const Volume<float>* imp, float threshold, surface::Graph* graph, FeaturePoints* featurePoints, bool constrainSearch)
{
  Volume<float>* edtUsed = (Volume<float>*) &edt;
  Volume<float> edtCopy;

  const Volume<float>* thresholdVol = constrainSearch && imp ? imp :  &edt;
  const float thresholdValue = constrainSearch && imp ? threshold : 0.001f;

  switch(smoothingType)
  {
  case ReconstructionSmoothingType::Linear:
    edtCopy = edt.meanFilter(radius, *thresholdVol, thresholdValue);
    edtCopy = edtCopy.minFilter(edt);
    edtUsed = &edtCopy;
    break;
  case ReconstructionSmoothingType::Median:
    edtCopy = edt.medianFilter(radius, *thresholdVol, thresholdValue);
    edtCopy = edtCopy.minFilter(edt);
    edtUsed = &edtCopy;
    break;
  case ReconstructionSmoothingType::Min:
    edtCopy = edt.erode(radius, *thresholdVol, thresholdValue, true);
    edtUsed = &edtCopy;
    break;
  case ReconstructionSmoothingType::Opening:
    edtCopy = edt.erode(radius, *thresholdVol, thresholdValue, false);
    edtCopy = edtCopy.dilate(radius, *thresholdVol, thresholdValue);
    edtUsed = &edtCopy;
    break;
  case ReconstructionSmoothingType::HalfConstrainedOpening:
    edtCopy = edt.erode(radius, edt, 0.001, false);
    edtCopy = edtCopy.dilate(radius, *imp, threshold);
    edtUsed = &edtCopy;
    break;
  case ReconstructionSmoothingType::LeastSquares:
    edtCopy = LeastSquaresProjection(*imp, threshold, radius);
    edtUsed = &edtCopy;
    break;
  case ReconstructionSmoothingType::FlatProjectionEdt:
    edtCopy = FlatProjectionEdt(*imp, *featurePoints, *graph, threshold, radius);
    edtUsed = &edtCopy;
    break;
  case ReconstructionSmoothingType::FlatProjectionPosition:
    FlatProjectionPosition(*imp, *featurePoints, *graph, threshold, radius);
  default:
    break;
  }

  // Skeleton points can be changed by the smoothing, thus we initialize the mask here.
  InitSkeletonMask();

  SimpleVolume<coord3s> ift_out(edt.getWidth(), edt.getHeight(), edt.getDepth());
  return Reconstruct(ift_out, (const Volume<float>&) *edtUsed);
}

Volume<byte> SkeletonReconstructor::Reconstruct(SimpleVolume<coord3s>& ift_out)
{
  return Reconstruct(ift_out, edt);
}

Volume<byte> SkeletonReconstructor::Reconstruct(SimpleVolume<coord3s>& ift_out, const Volume<float>& edt)
{
  Volume<int> edt_volume(edt.getWidth(), edt.getHeight(), edt.getDepth());
  Volume<byte> out_volume(edt.getWidth(), edt.getHeight(), edt.getDepth());

  if(ift_out.getSize() == 0)
    ift_out.makeVolume(edt.getWidth(), edt.getHeight(), edt.getDepth());

  InvEdt(edt_volume, ift_out, skeletonMask, edt);

  for (int z = 0; z < edt.getDepth(); z++)
    for (int y = 0; y < edt.getHeight(); y++)
      for (int x = 0; x < edt.getWidth(); x++)
        out_volume(x, y, z) = byte(edt_volume(x, y, z) > 0);

  return out_volume;
}

void SkeletonReconstructor::InitSkeletonMask()
{
  for (auto& p : skelPoints)
    skeletonMask(coord3s(p.x, p.y, p.z)) = 255;
}

namespace 
{

  struct NeighbourPoint
  {
    coord3s Coord;
    coord3s FP1;
    coord3s FP2;
  };

  void GetNeightbours(coord3s p, const int NSZ, const Volume<float>& edt, const Volume<float>& importance,
    float impThr, std::vector<Eigen::Vector4f>& nbs)
  {
    nbs.clear();
    int i = p.z, j = p.y, k = p.x;

    for (int m = i - NSZ; m <= i + NSZ; m++)
      for (int n = j - NSZ; n <= j + NSZ; n++)
        for (int p = k - NSZ; p <= k + NSZ; p++)
        {

          if (importance(p, n, m) > impThr)
            nbs.push_back({
              static_cast<float>(p),
              static_cast<float>(n),
              static_cast<float>(m),
              edt(p, n, m)
            });
        }
  }

  void GetNonSalientNeightbours(coord3s p, const int NSZ, const Volume<float>& edt, const Volume<float>& saliency,
    float impThr, std::vector<NeighbourPoint>& nbs, FeaturePoints::Coord2SKP& coord2SkelIdIndex, FeaturePoints& featurePoints,
    surface::Graph& graph)
  {
    nbs.clear();
    int i = p.z, j = p.y, k = p.x;

    for (int m = i - NSZ; m <= i + NSZ; m++)
      for (int n = j - NSZ; n <= j + NSZ; n++)
        for (int p = k - NSZ; p <= k + NSZ; p++)
        {
          coord3s coord(p, n, m);
          if (saliency(p, n, m) < impThr && saliency(p, n, m) != 0.0f)
          {

            int fpIndex = coord2SkelIdIndex[generateKey(coord)];

            auto& fp = featurePoints.featurePoints[fpIndex];

            Point3d a = graph.nodePosition(fp.first);
            Point3d b = graph.nodePosition(fp.second);

            nbs.push_back({
              coord,
              coord3s(a.x, a.y, a.z),
              coord3s(b.x, b.y, b.z)
            });


          }

        }
  }


}


Volume<float> SkeletonReconstructor::LeastSquaresProjection(const Volume<float>& imp, float threshold, int r)
{
  std::vector<Eigen::Vector4f> nbs;
  Volume<float> filteredEdt(edt.getWidth(), edt.getHeight(), edt.getDepth());

  for (int_fast64_t i = 0; i < skelPoints.size(); i++)
  {
    auto& p = skelPoints[i];
    coord3s cur(p.x, p.y, p.z);
    float dist = edt(cur);

    if (dist == 0.0f)
    {
      filteredEdt(cur) = dist;
      continue;
    }

    Eigen::Vector4f point(p.x, p.y, p.z, dist);
    GetNeightbours(cur, r, edt, imp, threshold, nbs);

    if (nbs.size() < 3)
    {
      filteredEdt(cur) = dist;
      continue;
    }


    Eigen::Map<Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
      X(&nbs[0][0], nbs.size(), 4);
    Eigen::Vector4f mean = X.colwise().mean();
    Eigen::MatrixXf centered = X.rowwise() - mean.transpose();
    // use SVD to get principal directions
    Eigen::JacobiSVD<Eigen::MatrixXf> svd(centered, Eigen::ComputeThinV);
    // Reduce input vector to a 2-dimensional plane 
    Eigen::Matrix<float, 4, 2> W = svd.matrixV().leftCols(2);
    Eigen::Vector4f reduced = W * (W.transpose() * (point - mean)) + mean;
    filteredEdt(cur) = round(reduced[3]);
  }

  return filteredEdt;
}

Volume<float> SkeletonReconstructor::FlatProjectionEdt(const Volume<float>& saliency, FeaturePoints& featurePoints, surface::Graph& graph, 
  float threshold, int r)
{
  std::vector<NeighbourPoint> nbs;
  Volume<float> filteredEdt(edt.getWidth(), edt.getHeight(), edt.getDepth());

  auto& coord2SkelPointIndex = featurePoints.GetOrConstructCoord2SkelPointIndex();

  for (int_fast64_t i = 0; i < skelPoints.size(); i++)
  {
    auto& p = skelPoints[i];
    coord3s cur(p.x, p.y, p.z);
    float dist = edt(cur);

    GetNonSalientNeightbours(cur, r, edt, saliency, threshold, nbs, coord2SkelPointIndex, 
      featurePoints, graph);

    float minDist = dist;
    Eigen::Vector3f c(cur.x, cur.y, cur.z);

    for (NeighbourPoint& neigh : nbs)
    {
      Eigen::Vector3f f1(neigh.FP1.x, neigh.FP1.y, neigh.FP1.z);
      Eigen::Vector3f f2(neigh.FP2.x, neigh.FP2.y, neigh.FP2.z);

      Eigen::Vector3f cMinF1 = c - f1;
      Eigen::Vector3f cMinF2 = c - f2;
      float crossNorm = (cMinF1.cross(cMinF2)).norm();

      float newDist = crossNorm / (f2 - f1).norm();
      minDist = std::min(minDist, newDist);
    }

    filteredEdt(cur) = minDist;
  }

  return filteredEdt;
}

void SkeletonReconstructor::FlatProjectionPosition(const Volume<float>& saliency, FeaturePoints& featurePoints, surface::Graph& graph,
  float threshold, int r)
{
  std::vector<NeighbourPoint> nbs;

  auto& coord2SkelPointIndex = featurePoints.GetOrConstructCoord2SkelPointIndex();
  std::vector<Point3d> newSkelPoints;

  for (int_fast64_t i = 0; i < skelPoints.size(); i++)
  {
    auto& p = skelPoints[i];
    coord3s cur(p.x, p.y, p.z);
    float dist = edt(cur);

    GetNonSalientNeightbours(cur, r, edt, saliency, threshold, nbs, coord2SkelPointIndex,
      featurePoints, graph);

    float newDist = dist;
    float minDistToProjection = std::numeric_limits<float>::max();
    Eigen::Vector3f c(cur.x, cur.y, cur.z);

    const int steps = 50;
    for (int j = 0; j < steps; j++)
      for (NeighbourPoint& neigh : nbs)
      {
        Eigen::Vector3f f1(neigh.FP1.x, neigh.FP1.y, neigh.FP1.z);
        Eigen::Vector3f f2(neigh.FP2.x, neigh.FP2.y, neigh.FP2.z);

        Eigen::Vector3f f2MinF1 = f2 - f1;
        Eigen::Vector3f projectedVector = f1 + (c - f1).dot(f2MinF1) / f2MinF1.squaredNorm() * f2MinF1;
        Eigen::Vector3f dir = (c - projectedVector);
        float dirLength = dir.norm();
        if (dirLength > 0.0001f)
          dir /= dirLength;

        float distToProjection = (projectedVector - c).norm() + 0.01f;
        float change = newDist - distToProjection;
        c += dir * change / 5;
        newDist = edt.trilinear(coord3f(c[0], c[1], c[2]));

        if (j == (steps - 1) &&  distToProjection < minDistToProjection)
          minDistToProjection = distToProjection;
      }



    newSkelPoints.push_back(Point3d(round(c[0]), round(c[1]), round(c[2])));
  }

  std::vector<Point3d>* pts = (std::vector<Point3d>*) &skelPoints;
  *pts = newSkelPoints;
}

