#include "SkeletonController.h"
#include "VolumeHelpers.h"
#include "ima/ImaSkel.h"
#include <chrono>
#include "memusage.h"
#include <iostream>


using namespace std;


void SetImportanceMinMax(Renderer& renderer, collapseSkel3d& skel)
{
  Volume<float> &imp = skel.getImportance();
  const int width = imp.getWidth(), height = imp.getHeight(),
    depth = imp.getDepth();
  // Todo: refactor this piece of uqly code
  renderer.imp_min = 1.0e+7;
  renderer.imp_max = -renderer.imp_min;
  for (int i = 1; i < depth - 1; i++)         //Find range of importance.
    for (int j = 1; j < height - 1; j++)        //Useful next for color mapping.
      for (int k = 1; k < width - 1; k++)
      {
        float v = imp[i][j][k];
        if (v < renderer.imp_min) renderer.imp_min = v;
        if (v > renderer.imp_max) renderer.imp_max = v;
      }

}

surface::Graph &SkeletonController::GetGraph()
{
  return graph;
}


collapseSkel3d &SkeletonController::GetSkel() const
{
  return skel;
}

Renderer &SkeletonController::GetRenderer() const
{
  return renderer;
}

std::shared_ptr<SkeletonModel> SkeletonController::GetSkeletonModel() const
{
  return skeletonModel;
}

std::shared_ptr<SkeletonModel> SkeletonController::GetInputModel() const
{
  return inputModel;
}

std::shared_ptr<SkeletonModel> SkeletonController::GetReconstructionModel() const
{
  return reconstructionModel;
}

FeaturePoints& SkeletonController::GetFeaturePoints()
{
  return *featurePoints;
}



std::shared_ptr<AddNoisePickHandler> SkeletonController::GetNoisePickHandler() const
{
  return noisePickHandler;
}

ImportanceVolume *SkeletonController::GetSaliencyVolume() const
{
  return saliencyVolume.get();
}

void SkeletonController::SetNoisePickHandler(const std::shared_ptr<AddNoisePickHandler> &value)
{
  noisePickHandler = value;
}

void SkeletonController::UpdateImportance(float impThreshold)
{
  if (viewScalarFieldType != ScalarFieldType::Importance)
    return;

  this->impThreshold = impThreshold;
  skeletonModel->Update(skel, impThreshold);

  renderer.pointRenderer->UpdateSkeletonPoints();
}

void SkeletonController::ComputeSaliency()
{
  Volume<float> saliencyMetric;
  switch (saliencyMethod)
  {
  case SaliencyMethod::Classical:
    saliencyMetric = saliencyMetricEngine.Compute(impThreshold);
    break;
  case SaliencyMethod::Derrivative:
    saliencyMetric = saliencyMetricEngine.ComputeDerivative(skel.getV(), graph, impThreshold);
    break;
  case SaliencyMethod::InvFeature:
    saliencyMetric = saliencyMetricEngine.ComputeReverseEDT(skeletonModel, impThreshold);
    break;
  case SaliencyMethod::Streamline:
    saliencyMetric = saliencyMetricEngine.ComputeStreamlines(skel.getV(), graph, impThreshold);
    break;
  case SaliencyMethod::GlobalImportance:
    saliencyMetric = saliencyMetricEngine.ComputeGlobalMonotonic(skel.getV(), graph, impThreshold);
    break;
  default: break;
  }

  saliencyMetricVolume.reset(new Volume<float>(saliencyMetric));
  saliencyVolume.reset(new ImportanceVolume(*saliencyMetricVolume.get(), 1000));
}

void SkeletonController::UpdateSaliency(float threshold, bool onlyKeepLargestComponent)
{
  if (saliencyMetricVolume)
  {
    skeletonModel->Update(skel, threshold, *saliencyMetricVolume.get(), onlyKeepLargestComponent);
    renderer.pointRenderer->UpdateSkeletonPoints();
  }

}

float SkeletonController::GetEqualizedImportance(float normThreshold)
{
  return importanceVolume->GetEqualizedImportance(normThreshold);
}


inline Volume<byte> toByteVolume(Volume<float>& imp, float t = 0)
{
  Volume<byte> ret(imp.getWidth(), imp.getHeight(), imp.getDepth());

  FilterLargestComponent(imp, t);

  for (int x = 0; x < imp.getWidth(); x++)
    for (int y = 0; y < imp.getHeight(); y++)
      for (int z = 0; z < imp.getDepth(); z++)
      {
        float importance = imp(x, y, z);
        if (importance > t)
          ret(x, y, z) = 128;
      }

  return ret;
}

bool SkeletonController::ImportImportanceVolume(Volume<float>& newImportance)
{
  if (!this->IsDone())
    return false;

  skel.setImportance(newImportance);
  skeletonModel->Update(skel, renderer.imp_thr);
  SetImportanceMinMax(renderer, skel);
  importanceVolume.reset(new ImportanceVolume(skel.getImportance(), 1000));
  //float diffThreshold = importanceVolume->GetLargestDifferenceThreshold();
  renderer.vis_imp_max = renderer.imp_max;

  return true;
}


std::shared_ptr<ShowInfoHandler> SkeletonController::getShowInfoHandler() const
{
  return showInfoHandler;
}

void SkeletonController::setShowInfoHandler(const std::shared_ptr<ShowInfoHandler> &value)
{
  showInfoHandler = value;
}

std::shared_ptr<ShowCorridorGraphHandler> SkeletonController::getShowCorridorGraphHandler() const
{
  return showCorridorGraphHandler;
}

void SkeletonController::setShowCorridorGraphHandler(const std::shared_ptr<ShowCorridorGraphHandler> &value)
{
  showCorridorGraphHandler = value;
}

cuts::CutController &SkeletonController::getCutController()
{
  return *cutController;
}

void SkeletonController::setCutController(std::unique_ptr<cuts::CutController> value)
{
  cutController = std::move(value);
}

SegmentationHandler &SkeletonController::getSegmentationHandler()
{
  return *segmentationHandler;
}

void SkeletonController::setSegmentationHandler(std::unique_ptr<SegmentationHandler> value)
{
  segmentationHandler = std::move(value);
}

SaliencyMetricEngine& SkeletonController::GetSaliencyMetricEngine()
{
  return saliencyMetricEngine;
}

float SkeletonController::GetImpThreshold()
{
  return impThreshold;
}

SkeletonController::SkeletonController(collapseSkel3d& skel, float impThreshold, Renderer& renderer) :
  skel(skel),
  impThreshold(impThreshold),
  skeletonModel(new SkeletonModel(skel, impThreshold)),
  //inputModel(new SkeletonModel(*skeletonModel.get())), //do we need a copy here?
  inputModel(new SkeletonModel(skel, impThreshold)),
  featurePoints(new FeaturePoints(inputModel->GetThinPoints())),
  noisePickHandler(new AddNoisePickHandler(*featurePoints, inputModel, skeletonModel, graph, skel, renderer)),
  showInfoHandler(new ShowInfoHandler(*this)),
  showCorridorGraphHandler(new ShowCorridorGraphHandler(*this)),
  renderer(renderer),
  saliencyMetricEngine(*featurePoints.get(), skel.getEDT(), skel.getImportance()),
  cutController(new cuts::CutController),
  segmentationHandler(new SegmentationHandler(*this))
{
  skel.init_iteration(0.99990f);
  graph.construct(inputModel->GetThinPoints());

}

void SkeletonController::StartCollapse(bool visualize)
{
  visualizeSkeleton = visualize;

  startTime = std::chrono::duration_cast<std::chrono::duration<double, std::chrono::seconds::period>>
      (std::chrono::steady_clock::now().time_since_epoch()).count();
  iterations = 0;
  bordervoxels = skel.getIPoints().size();
  collapseActive = true;

}

void SkeletonController::PostSkeletonProcessing()
{
  skeletonModel->Update(skel, renderer.imp_thr);
  featurePoints->construct(skeletonModel->GetSkelPoints(), skel.getV(), skel.getImportance(), skel.getEDT());


  Volume<float> &imp = skel.getImportance();
  const int width = imp.getWidth(), height = imp.getHeight(),
    depth = imp.getDepth();

  //replace importance with pairwise geodesics
  if (importanceType != SkeletonImportanceType::TPAMI)
  {
    Volume<float> newImp;

    std::cout << "Computing geodesics..." << std::endl;
    bool useEstimate = importanceType == SkeletonImportanceType::GeodesicEstimate;
    if (importanceType == SkeletonImportanceType::Geodesic || useEstimate)
      newImp = saliencyMetricEngine.ComputeGeodesicRaw(graph, useEstimate);
    else
      newImp = saliencyMetricEngine.ComputeEstimatedEuclidean(2 * sqrt(2));

    skel.setImportance(newImp);
    skeletonModel->Update(skel, renderer.imp_thr);
  }

  importanceVolume.reset(new ImportanceVolume(imp, 1000));
  float diffThreshold = importanceVolume->GetLargestDifferenceThreshold();
  renderer.vis_imp_max = importanceType == SkeletonImportanceType::TPAMI ? diffThreshold : importanceVolume->GetMaxImportance();

  SetImportanceMinMax(renderer, skel);
  renderer.setScalar(&imp, renderer.imp_min, renderer.vis_imp_max);
  skeletonComputed = true;

}

//End an earlier-started collapse process, display results/timings.
void SkeletonController::StopCollapse()
{
  iterations = 0;
  collapseActive = false;

  PostSkeletonProcessing();

  Volume<float> &imp = skel.getImportance();
  const int width = imp.getWidth(), height = imp.getHeight(),
    depth = imp.getDepth();

  //Report peak memory usage
  if (importanceType == SkeletonImportanceType::TPAMI)
  {
    int foreground = skel.getLastForegroundCount();
    size_t memMB = getPeakRSS() / size_t(1024 * 1024);
    cout << "============================================" << endl;
    cout << "Total time:            " << execTime << " secs" << endl;
    cout << "Peak memory usage:     " << memMB << " MB." << endl;
    cout << "Object:                border: " << bordervoxels << " foreground: " <<
      foreground << " volume: " << depth*height*width << endl;
    cout << "Object Kvoxels/second: " << float(foreground) / execTime / 1000 <<
      endl;
    cout << "Importance range:      " << renderer.imp_min << "," << renderer.imp_max <<
      endl;
  }
}

//Do one collapse step
void SkeletonController::DoCollapseStep()
{
  int interfaceSize = skel.collapse_iteration(1.0f);
  ++iterations;

  execTime = std::chrono::duration_cast<std::chrono::duration<double, std::chrono::seconds::period>>
      (std::chrono::steady_clock::now().time_since_epoch()).count() - startTime;

  if (visualizeSkeleton)
  {
    skeletonModel->Update(skel, impThreshold);
    SetImportanceMinMax(renderer, skel);
    renderer.setScalar(&skel.getImportance(), renderer.imp_min, std::min(renderer.vis_imp_max, renderer.imp_max));
  }

  //We can stop the collapse if the interface is one voxel..
  if (!interfaceSize)
  {

    StopCollapse();
  }
}

void SkeletonController::ComputeSkeleton()
{
  if (extractionMethod == SkeletonExtractionMethod::TPAMI ||
    importanceType == SkeletonImportanceType::TPAMI)
  {
    StartCollapse();
    while (collapseActive)
      DoCollapseStep();
  }
  else
  {
    Volume<float> newImportance;
    if(extractionMethod == SkeletonExtractionMethod::ImplicitEuclidean)
     newImportance = saliencyMetricEngine.ComputeEstimatedEuclidean(2.0 * sqrt(2));
    else /* if ima*/
     newImportance = ComputeImaSkeletonFloat(skel.getThinImage());

    Volume<Vector> newV = saliencyMetricEngine.BuildGradientEdtVelocity(graph);
    skel.setImportance(newImportance);
    skel.setV(newV);
    skel.set_ready(true);
    skel.getThinImage().clearVolume();

    PostSkeletonProcessing();
  }
}

void SkeletonController::ResetSkelEngine(const SkeletonCreationOptions& options)
{
  featurePoints->useFeatureReflection = static_cast<int>(options.flags & SkeletonCreationFlags::EnableFpReflection) > 0;
  importanceType = options.importanceType;
  extractionMethod = options.extractionMethod;
  bool smoothing = static_cast<int>(options.flags & SkeletonCreationFlags::SmoothEdt) > 0;
  bool importanceBoosting = static_cast<int>(options.flags & SkeletonCreationFlags::ImportanceBoosting) > 0;
 
  Volume<byte> copy(skel.getThinImage());
  ResetSkelEngine(copy, smoothing, importanceBoosting);
  lastInputVolume = copy;
}

void SkeletonController::ResetSkelEngine(Volume<byte>& volume, bool smoothing, bool importanceBoosting)
{
  int dummy;
  skel.reset();
  skel.init(volume, 0, dummy, smoothing, importanceBoosting);
  featurePoints->refreshPointSearch();
  graph.reset();
  graph.construct(inputModel->GetThinPoints());
  inputModel->Update(skel);
  skeletonModel->Update(skel);
  skel.setThinImage(volume);

  renderer.pointRenderer->UpdateInputModelPoints();
  renderer.pointRenderer->UpdateSkeletonPoints();
}


void SkeletonController::Reconstruct(ReconstructionSmoothingType smoothingType, int radius)
{
  reconstructionModel = skeletonModel->Reconstruct(skel, smoothingType, radius, graph, *featurePoints, lastReconVolume);
  renderer.reconstructionModel = reconstructionModel;
  renderer.pointRenderer->SetReconstructionModel(reconstructionModel);
}

bool SkeletonController::Update()
{
  if(collapseActive)
    DoCollapseStep();

  return collapseActive;
}

void SkeletonController::SetSaliencyMethod(SaliencyMethod value)
{
  saliencyMethod = value;
}

bool SkeletonController::IsDone()
{
  return skeletonComputed;
}

bool SkeletonController::SetScalarField(ScalarFieldType field)
{
  viewScalarFieldType = field;

  switch (field)
  {
  case ScalarFieldType::EDT:
    if (!edtVolume)
      edtVolume.reset(new ImportanceVolume(skel.getEDT(), 1000));
    renderer.setScalar(*edtVolume.get());
    break;
  case ScalarFieldType::Importance:
    if (!importanceVolume) return false;
    renderer.setScalar(*importanceVolume.get());
    break;
  case ScalarFieldType::Saliency:
    if (!saliencyVolume) return false;
    renderer.setScalar(*saliencyVolume.get());
    break;
  }

  return true;
}

Volume<byte>& SkeletonController::GetVisibleVolume()
{
  if (!skel.IsReady())
    return skel.getThinImage();

  if (renderer.drawReconstruction && lastReconVolume.getDepth())
    return lastReconVolume;
  else
    return lastInputVolume; 

}

Volume<float>* SkeletonController::GetVisibleImportance()
{
  if (!skel.IsReady())
    return nullptr;

  switch (viewScalarFieldType)
  {
  case ScalarFieldType::EDT:
    if(edtVolume)
      return &edtVolume->GetInnerVolume();
  case ScalarFieldType::Importance:
    if(importanceVolume)
      return &importanceVolume->GetInnerVolume();
  case ScalarFieldType::Saliency:
    if(saliencyVolume)
      return &saliencyVolume->GetInnerVolume();
  default:
    return nullptr;
  }
}
