#include "SkeletonModel.h"
#include "utils.h"
#include "SkeletonReconstructor.h"

#include <array>
#include <vector>
#include <Eigen/Core>
#include <Eigen/SVD>

#include "VolumeHelpers.h"

SkeletonModel::SkeletonModel(collapseSkel3d& skelEngine, float impThreshold, bool build) :
importance(&skelEngine.getImportance()),
edt(skelEngine.getEDT()),
thinImage(skelEngine.getThinImage()),
width(edt.getWidth()),
height(edt.getHeight()),
depth(edt.getDepth()),
maxDistance(skelEngine.getMaxDst()),
impThreshold(impThreshold)
{
  if (build)
    Update(skelEngine, impThreshold);

}

SkeletonModel::~SkeletonModel()
{
  
}

void SkeletonModel::Update(collapseSkel3d& skelEngine, float impThreshold,
                           const Volume<float>& importance, bool keepLargest)
{
  width = skelEngine.getEDT().getWidth();
  height = skelEngine.getEDT().getHeight();
  depth = skelEngine.getEDT().getDepth();

  //clear pointsearch
  skelPointSearch.reset();

  maxDistance = skelEngine.getMaxDst();
  this->impThreshold = impThreshold;
  this->importance = (Volume<float>* )&importance;

  thinPoints = std::move(createThinSet(impThreshold, thinImage));
  createNormals(thinImage, thinPoints, importance, thinNormals, 0);


  Volume<float> impLargest;
  if(keepLargest)
  {
    impLargest = importance;
    FilterLargestComponent(impLargest, impThreshold);
  }

  const Volume<float>& imp = keepLargest ? impLargest : importance;

  if (skelEngine.IsReady())
  {
    skelPoints = std::move(createSkeletonPoints(imp, impThreshold));
    createNormals(thinImage, skelPoints, importance, skelNormals, impThreshold);
  }
}

void SkeletonModel::Update(collapseSkel3d& skelEngine, float impThreshold)
{
  Update(skelEngine, impThreshold, *importance, false);
}

void SkeletonModel::Update(collapseSkel3d& skelEngine)
{
  Update(skelEngine, impThreshold);
}

std::shared_ptr<SkeletonModel> SkeletonModel::Reconstruct(collapseSkel3d& skelEngine,
  ReconstructionSmoothingType smoothing, int radius, Volume<byte>& outRecon)
{
  std::shared_ptr<SkeletonModel> model(new SkeletonModel(skelEngine, impThreshold, false));
  
  /*Volume<byte> reconstruction;
  reconstruction.makeVolume(width, height, depth);

  for (auto& p : skelPoints)
        reconstruction.addSphere(coord3s(p.x, p.y, p.z), 1, edt(p.x, p.y, p.z) - 0.01f);*/

  SkeletonReconstructor reconstructor(skelPoints, edt);
  outRecon = reconstructor.Reconstruct(smoothing, radius,
    importance, impThreshold);

  // not pretty, should be refactored
  model->thinPoints = std::move(createThinSet(impThreshold, outRecon));
  model->createNormals(outRecon, model->thinPoints, *importance, model->thinNormals, 0);

  return model;
  
}

std::shared_ptr<SkeletonModel> SkeletonModel::Reconstruct(collapseSkel3d& skelEngine,
  ReconstructionSmoothingType smoothing, int radius,
  surface::Graph& graph, FeaturePoints& featurePoints, Volume<byte>& outRecon)
{
    std::shared_ptr<SkeletonModel> model(new SkeletonModel(skelEngine, impThreshold, false));

    SkeletonReconstructor reconstructor(skelPoints, edt);
    outRecon = reconstructor.Reconstruct(smoothing, radius,
      importance, impThreshold, &graph, &featurePoints);

    // not pretty, should be refactored
    model->thinPoints = std::move(createThinSet(impThreshold, outRecon));
    model->createNormals(outRecon, model->thinPoints, *importance, model->thinNormals, 0);

    return model;
}

std::vector<Point3d> SkeletonModel::CreateThickSet()
{
  Volume<float> thickVolume = *importance;
  thickVolume.dilate(1, thickVolume, 0);


  auto pointsDilated =  createSkeletonPoints(thickVolume, impThreshold);
  auto points = createSkeletonPoints(*importance, impThreshold);

  std::cout << "points: " << points.size() << ", dilated: " << pointsDilated.size();
  return pointsDilated;
}

PointSearch &SkeletonModel::getSkelPointSearch()
{
  if (!skelPointSearch) //only compute when requested
  {
    skelPointSearch.reset(new PointSearch(GetSkelPoints(), 1000));
  }

  return *skelPointSearch;
}

std::vector<Point3d> SkeletonModel::createThinSet(float impThresh, const Volume<byte>& thinImage)
{
  std::vector<Point3d> points;

  for (int i = 1; i < depth - 1; i++)  
    for (int j = 1; j < height - 1; j++)
      for (int k = 1; k < width - 1; k++)
      {
        coord3s c(k, j, i);   
        byte tval =  thinImage(c);    //select all nonzero thin-volume voxels on the thin-volume surface
        if (tval)                     //(that is, foreground ones not surrounded by ONLY foreground voxels)
        {
          //copy them with their values from the thin-volume, normalized to [0..1]
          bool draw = false;
          for (int m = -1; m <= 1; m++)
            for (int n = -1; n <= 1; n++)
              for (int p = -1; p <= 1; p++)
                if (!thinImage(k + p, j + n, i + m) != 0)
                {
                  draw = true;
                  goto BREAK;
                }

BREAK:
          if (draw)
          {
            points.emplace_back(k, j, i, tval / 255.0f);
          }
        }
      }

  return points;
}

std::vector<Point3d> SkeletonModel::createSkeletonPoints(const Volume<float>& importance, float impThresh)
{
  std::vector<Point3d> points;

  for (int i = 1; i < depth - 1; i++)
    for (int j = 1; j < height - 1; j++)
      for (int k = 1; k < width - 1; k++)
      {
        coord3s c(k, j, i);                 //select all skel voxels above threshold
        float ival = importance(c);        //copy them with their EDT values
        if (ival > impThresh)
          points.emplace_back(k, j, i, edt(c));
      }

  return points;
}

namespace
{
void GetNeightbours(const Point3d& p, const int NSZ, const Volume<byte>& thinImage, const Volume<float>& importance, 
  float impThr, std::vector<std::array<float, 3>>& nbs)
{
  nbs.clear();
  int i = p.z, j = p.y, k = p.x;

  for (int m = i - NSZ; m <= i + NSZ; m++)
    for (int n = j - NSZ; n <= j + NSZ; n++)
      for (int p = k - NSZ; p <= k + NSZ; p++)
      {
        bool inside = impThr == 0 ? thinImage(p, n, m) != 0
          : importance(p, n, m) > impThr;

        if (inside)
          nbs.push_back({
                static_cast<float>(p),
                static_cast<float>(n),
                static_cast<float>(m)});
      }
}
}

void SkeletonModel::createNormals(const Volume<byte>& thinImage, std::vector<Point3d>& points, const Volume<float>& importance,
  std::vector<Point3d>& normals, float impThr, const int NSZ)
{
  int NP = points.size();
  normals.resize(NP);
  std::vector<std::array<float, 3>> nbs;
  std::array<float, 3> norm;
  float *nrm;

  for (int pid = 0; pid < NP; ++pid)
  {
    const Point3d& p = points[pid];

    //Estimate skeleton-normal by PCA on skeleton neighborhood
    //The normal is the smallest eigenvector
    GetNeightbours(p, NSZ, thinImage, importance, impThr, nbs);
    //Estimate plane only if there are enough neighbors...
    Point3d normal;
    if (nbs.size() >= 3)
    {
      nrm = norm.data();
      Eigen::Map<Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
        pointsMat(&nbs[0][0], nbs.size(), 3);
      Eigen::MatrixXf centered = pointsMat.rowwise() - pointsMat.colwise().mean();
      //use SVD to get principal directions
      Eigen::JacobiSVD<Eigen::MatrixXf> svd(centered, Eigen::ComputeThinV);
      Eigen::Vector3f eigenNormal(svd.matrixV().col(2));
      norm = {eigenNormal(0), eigenNormal(1), eigenNormal(2)};
      normal = Point3d(nrm);
      // Correct the orientation of the normal if pointing inwards
      size_t reverseCount = 0;
      for (auto& neigharr : nbs)
      {
        Point3d neigh(neigharr.data());
        Point3d dir = p - neigh;
        if (dir.dot(normal) < 0)
          reverseCount++;
      }
      if (reverseCount > nbs.size() - reverseCount)
        normal *= -1;
      normal.normalize();
    }
    else nrm = nullptr;

    normals[pid] = (nrm) ? normal : Point3d();
  }
}

