#include "TestHelpers.h"
#include <SaliencyMetricEngine.h>

bool TestNoiseReductionDensitySaliency()
{
  Volume<byte> data;

  if (!data.readAVSField("horse.fld", 2))
  {
    std::cerr << "Loading horse.fld failed" << std::endl;
    return false;
  }
  Volume<byte> original = data;

  std::ofstream logfile("nremoval_sal-eft_horse_with_smoothing_and_noise.csv", std::ios::out);

  collapseSkel3d skel(false, true);
  int borderVoxels;

  std::cout << "Computing Skeleton...." << std::endl;
  skel.init(data, 0, borderVoxels, false);
  std::shared_ptr<SkeletonModel> inputModel(new SkeletonModel(skel, 0.0f));
  surface::Graph graph;
  graph.construct(inputModel->GetThinPoints());

  // add noise, then reset skeleton
  NoiseCreator noiseCreator(data, graph, inputModel);
  noiseCreator.GeneratePoints(0.020, NoisePickType::Convex, NoiseShape::Ball, 3.0f);

  skel.reset();
  skel.init(data, 0, borderVoxels, true);
  while (skel.collapse_iteration(1.0f));

  // reprocess input model and graph after adding noise and ocmputing the skeleton
  graph.reset();
  inputModel->Update(skel, 0.3e-5f);
  graph.construct(inputModel->GetThinPoints());

  FeaturePoints featurePoints(inputModel->GetThinPoints());
  featurePoints.construct(inputModel->GetSkelPoints(), skel.getV(), skel.getImportance(), skel.getEDT());
  SaliencyMetricEngine saliencyEngine = constructSaliencyEngine(skel, featurePoints);
  int totalPoints = getForegroundPoints(original);
  int totalSkeletonPoints = getForegroundPoints(skel.getImportance());

  std::cout << "Calculating saliency Metric" << std::endl;
  Volume<float> saliencyImportance = saliencyEngine.ComputeGeodesic(graph);
  Volume<float>& importanceVol = saliencyImportance;

  std::cout << "Creating importance index..." << std::endl;
  std::vector<float> sortedImportancePoints = CreateImportanceIndex(importanceVol, 30);


  SkeletonModel skeletonModel(skel, sortedImportancePoints[0]);
  for (float imp : sortedImportancePoints)
  {
    skeletonModel.Update(skel, imp, importanceVol, true);
    Volume<byte> reconstruction;
    
    reconstruction.makeVolume(data.getWidth(), data.getHeight(), data.getDepth());
    auto skelPoints = skeletonModel.GetSkelPoints();


    for (auto& p : skelPoints)
      reconstruction.addSphere(coord3s(p.x, p.y, p.z), 1, skel.getEDT()(p.x, p.y, p.z) - 0.01f);


    float error = static_cast<float>(volumeDiff(reconstruction, original)) / totalPoints;
    float volumeRatio = static_cast<float>(skelPoints.size()) / totalPoints;
    logfile << imp << ", " << error << ", " << volumeRatio << std::endl;
    std::cout << imp << ", " << error << ", " << volumeRatio << std::endl;
  } 

}

