#include "PointSetRenderer.h"

#include <glm/gtc/matrix_transform.hpp>
#include <glm/gtc/matrix_inverse.hpp>
#include <glm/gtc/type_ptr.hpp>
#include <GlutWindow.h>

PointSetRenderer::PointSetRenderer(
  const std::vector<Surfel> &surfels,
  glm::ivec2 canvasSize)
  : canvasSize(canvasSize), lightPosition(500, 500, 1000)
{
  splattingProgram = CompileProgramFromFiles("point_splatting.vert", "point_splatting.frag");
  shadingProgram = CompileProgramFromFiles("point_shading.vert", "point_shading.frag");
  positionsProgram = CompileProgramFromFiles("point_shading.vert", "point_positions.frag");
  QueryUniforms();

  //LoadRoiWeights(roiWeights);
  LoadPointSet(surfels);
  LoadQuad();

  ReloadShaderData();
  ResetFramebuffer();
  ResetPositionsFramebuffer();
}


PointSetRenderer::~PointSetRenderer()
{
  glDeleteShader(splattingProgram.VertexShader);
  glDeleteShader(splattingProgram.FragmentShader);
  glDeleteProgram(splattingProgram.Program);

  glDeleteBuffers(1, &pointSet.Buffer);
  glDeleteVertexArrays(1, &pointSet.Array);
  glDeleteBuffers(1, &shaderDataBuffer);
  glDeleteBuffers(1, &roiWeightsBuffer);

  glDeleteFramebuffers(1, &framebuffer);
  glDeleteRenderbuffers(1, &depthBuffer);
  glDeleteTextures(1, &colorWeightTexture);
  glDeleteTextures(1, &normalLambdaTexture);

  glDeleteFramebuffers(1, &positionsFramebuffer);
  glDeleteTextures(1, &positionsTexture);
}


void PointSetRenderer::LoadPointSet(const std::vector<Surfel> &points)
{
  pointSet.Count = points.size();

  if (!pointSet.Buffer)
  {
    glGenBuffers(1, &pointSet.Buffer);
    assert(pointSet.Buffer);
  }
  if (points.size() != 0)
  {
    glBindBuffer(GL_ARRAY_BUFFER, pointSet.Buffer);
    glBufferData(GL_ARRAY_BUFFER, points.size() * sizeof(Surfel), points.data(), GL_STATIC_DRAW);
    glBindBuffer(GL_ARRAY_BUFFER, 0);
  }


  if (!pointSet.Array)
  {
    glGenVertexArrays(1, &pointSet.Array);
    assert(pointSet.Array);

    glBindVertexArray(pointSet.Array);

    glBindBuffer(GL_ARRAY_BUFFER, pointSet.Buffer);

    glEnableVertexAttribArray(0); // Position
    glVertexAttribPointer(0, 3, GL_FLOAT, false, sizeof(Surfel),
      reinterpret_cast<void*>(offsetof(Surfel, Position)));

    glEnableVertexAttribArray(1); // Color
    glVertexAttribPointer(1, 3, GL_FLOAT, false, sizeof(Surfel),
      reinterpret_cast<void*>(offsetof(Surfel, Color)));

    glEnableVertexAttribArray(2); // U
    glVertexAttribPointer(2, 3, GL_FLOAT, false, sizeof(Surfel),
      reinterpret_cast<void*>(offsetof(Surfel, U)));

    glEnableVertexAttribArray(3); // V
    glVertexAttribPointer(3, 3, GL_FLOAT, false, sizeof(Surfel),
      reinterpret_cast<void*>(offsetof(Surfel, V)));

    if (roiWeightsBuffer)
    {
      glBindBuffer(GL_ARRAY_BUFFER, roiWeightsBuffer);

      glEnableVertexAttribArray(4); // ROI weight
      glVertexAttribPointer(4, 1, GL_FLOAT, false, sizeof(float), 0);
    }

    glBindBuffer(GL_ARRAY_BUFFER, 0);

    glBindVertexArray(0);
  }
}


void PointSetRenderer::LoadRoiWeights(std::vector<float> const& roiWeights)
{
  if (!roiWeightsBuffer)
  {
    glGenBuffers(1, &roiWeightsBuffer);
    assert(roiWeightsBuffer);
  }

  glBindBuffer(GL_ARRAY_BUFFER, roiWeightsBuffer);
  glBufferData(GL_ARRAY_BUFFER, roiWeights.size() * sizeof(float), roiWeights.data(), GL_DYNAMIC_DRAW);
  glBindBuffer(GL_ARRAY_BUFFER, 0);

  auto minMax = std::minmax_element(roiWeights.cbegin(), roiWeights.cend());
  minRoiWeight = *minMax.first; maxRoiWeight = *minMax.second;
}


void PointSetRenderer::LoadQuad()
{
  float data[] =
  {
     0.f,  0.f,
    -1.f, -1.f,
     1.f,  0.f,
     1.f, -1.f,
     0.f,  1.f,
    -1.f,  1.f,
     1.f,  1.f,
     1.f,  1.f
  };

  quad.Count = 4;

  glGenBuffers(1, &quad.Buffer);
  assert(quad.Buffer);

  glBindBuffer(GL_ARRAY_BUFFER, quad.Buffer);
  glBufferData(GL_ARRAY_BUFFER, sizeof(data), data, GL_STATIC_DRAW);

  glGenVertexArrays(1, &quad.Array);
  assert(quad.Array);

  glBindVertexArray(quad.Array);

  glEnableVertexAttribArray(0); // Position
  glVertexAttribPointer(0, 2, GL_FLOAT, false, 4 * sizeof(float),
    reinterpret_cast<void*>(2 * sizeof(float)));

  glEnableVertexAttribArray(1); // Tex. coords
  glVertexAttribPointer(1, 2, GL_FLOAT, false, 4 * sizeof(float), 0);

  glBindVertexArray(0);
  glBindBuffer(GL_ARRAY_BUFFER, 0);
}


ShaderData PointSetRenderer::GetShaderData() const
{
  float
    aspectRatio = static_cast<float>(canvasSize.x) / canvasSize.y,
    fov = glm::radians(60.f),
    nearVal = 10.0f, farVal = 3000.0f;

  //float
  //  right = glm::tan(fov / 2.f) * nearVal,
  //  left = -right,
  //  top = right * aspectRatio,
  //  bottom = -top;

  float
    top = glm::tan(fov / 2.f) * nearVal,
    bottom = -top,
    right = top * aspectRatio,
    left = -right;


  ShaderData data;

  // Hacks to make stuff work with zpr:
  glm::mat4 proj, modelview;
  glGetFloatv(GL_PROJECTION_MATRIX, glm::value_ptr(proj));
  glGetFloatv(GL_MODELVIEW_MATRIX, glm::value_ptr(modelview));
  farVal = GlutWindow::GetInstance().getMaxDepth();

  data.ModelViewProj = proj * modelview;
  data.ModelView = modelview;

  data.InverseModelView = glm::affineInverse(data.ModelView);

  data.ProjectScale = 2.f * -nearVal * canvasSize.y / (2.f * top);
  data.Near = nearVal;
  data.UnprojectScale = glm::vec2(2.f * right / canvasSize.x, 2.f * top / canvasSize.y);
  data.UnprojectOffset = glm::vec2(right, top);
  data.DepthScale = farVal * nearVal / (farVal - nearVal);
  data.DepthOffset = farVal / (farVal - nearVal);

  data.LightPosition = lightPosition;
  data.Shininess = shininess;
  data.SpecularCoeff = specular;
  data.DiffuseCoeff = diffuse;
  data.AmbientCoeff = 1.f - diffuse;

  data.SplatScale = splatScale;
  data.HoleFillingFilterRadius = holeFillingFilterRadius;

  data.MinRoiWeight = glm::abs(maxRoiWeight - minRoiWeight) > 1.E-6f ? minRoiWeight : 0.f;
  data.MaxRoiWeight = maxRoiWeight;
  data.VisualizeRoiWeights = visualizeRoiWeights;
  data.BackfaceCulling = backfaceCulling;

  return data;
}

void PointSetRenderer::ReloadShaderData()
{
  ShaderData data(GetShaderData());

  if (!shaderDataBuffer)
  {
    glGenBuffers(1, &shaderDataBuffer);
    assert(shaderDataBuffer);
  }

  glBindBuffer(GL_UNIFORM_BUFFER, shaderDataBuffer);
  glBufferData(GL_UNIFORM_BUFFER, sizeof(data), &data, GL_STATIC_DRAW);
  glBindBuffer(GL_UNIFORM_BUFFER, 0);
}


void PointSetRenderer::QueryUniforms()
{
  depthEpsilonLocation = glGetUniformLocation(splattingProgram.Program, "depthEpsilon");
  assert(depthEpsilonLocation != -1);
}


void PointSetRenderer::ResetFramebuffer()
{
  glDeleteFramebuffers(1, &framebuffer);
  glDeleteRenderbuffers(1, &depthBuffer);
  glDeleteTextures(1, &colorWeightTexture);
  glDeleteTextures(1, &normalLambdaTexture);

  glGenRenderbuffers(1, &depthBuffer);
  glBindRenderbuffer(GL_RENDERBUFFER, depthBuffer);
  glRenderbufferStorage(GL_RENDERBUFFER, GL_DEPTH_COMPONENT24, canvasSize.x, canvasSize.y);
  glBindRenderbuffer(GL_RENDERBUFFER, 0);

  glGenTextures(1, &colorWeightTexture);
  glBindTexture(GL_TEXTURE_2D, colorWeightTexture);
  glTexImage2D(GL_TEXTURE_2D, 0, GL_RGBA16F, canvasSize.x, canvasSize.y, 0, GL_RGBA, GL_FLOAT, nullptr);
  glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MIN_FILTER, GL_NEAREST);
  glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_S, GL_CLAMP_TO_EDGE);
  glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_T, GL_CLAMP_TO_EDGE);
  glBindTexture(GL_TEXTURE_2D, 0);

  glGenTextures(1, &normalLambdaTexture);
  glBindTexture(GL_TEXTURE_2D, normalLambdaTexture);
  glTexImage2D(GL_TEXTURE_2D, 0, GL_RGBA16F, canvasSize.x, canvasSize.y, 0, GL_RGBA, GL_FLOAT, nullptr);
  glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MIN_FILTER, GL_NEAREST);
  glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_S, GL_CLAMP_TO_EDGE);
  glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_T, GL_CLAMP_TO_EDGE);
  glBindTexture(GL_TEXTURE_2D, 0);

  glGenFramebuffers(1, &framebuffer);
  glBindFramebuffer(GL_FRAMEBUFFER, framebuffer);
  glFramebufferRenderbuffer(GL_FRAMEBUFFER, GL_DEPTH_ATTACHMENT, GL_RENDERBUFFER, depthBuffer);
  glFramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, GL_TEXTURE_2D, colorWeightTexture, 0);
  glFramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT1, GL_TEXTURE_2D, normalLambdaTexture, 0);

  unsigned int buffers[] = { GL_COLOR_ATTACHMENT0, GL_COLOR_ATTACHMENT1 };
  glDrawBuffers(2, buffers);

  assert(glCheckFramebufferStatus(GL_FRAMEBUFFER) == GL_FRAMEBUFFER_COMPLETE);
  glBindFramebuffer(GL_FRAMEBUFFER, 0);
}


void PointSetRenderer::ResetPositionsFramebuffer()
{
  glDeleteFramebuffers(1, &positionsFramebuffer);
  glDeleteTextures(1, &positionsTexture);

  glGenTextures(1, &positionsTexture);
  glBindTexture(GL_TEXTURE_2D, positionsTexture);
  glTexImage2D(GL_TEXTURE_2D, 0, GL_RGBA16F, canvasSize.x, canvasSize.y, 0, GL_RGBA, GL_FLOAT, nullptr);
  glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MIN_FILTER, GL_NEAREST);
  glBindTexture(GL_TEXTURE_2D, 0);

  glGenFramebuffers(1, &positionsFramebuffer);
  glBindFramebuffer(GL_FRAMEBUFFER, positionsFramebuffer);
  glFramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, GL_TEXTURE_2D, positionsTexture, 0);

  unsigned int buffers[] = { GL_COLOR_ATTACHMENT0 };
  glDrawBuffers(1, buffers);

  assert(glCheckFramebufferStatus(GL_FRAMEBUFFER) == GL_FRAMEBUFFER_COMPLETE);
  glBindFramebuffer(GL_FRAMEBUFFER, 0);
}


void PointSetRenderer::ApplyRotation(glm::quat rotationToApply)
{
  rotation = rotationToApply * rotation;
  ReloadShaderData();
  modelPositions.reset();
}


void PointSetRenderer::MoveCamera(float cameraDistanceDelta)
{
  cameraDistance = glm::clamp(cameraDistance + cameraDistanceDelta, 1.f, 25.f);
  ReloadShaderData();
  modelPositions.reset();
}




void PointSetRenderer::ResizeCanvas(glm::ivec2 canvasSize)
{
  this->canvasSize = canvasSize;
  ReloadShaderData();
  ResetFramebuffer();
  ResetPositionsFramebuffer();
  modelPositions.reset();
}


void PointSetRenderer::Draw()
{
  ReloadShaderData();
  glBindBufferBase(GL_UNIFORM_BUFFER, 0, shaderDataBuffer);

  SplattingPass();
  ShadingPass();

  glBindBufferBase(GL_UNIFORM_BUFFER, 0, 0);
}

std::vector<Surfel> PointSetRenderer::GetPoints() const
{
  std::vector<Surfel> surfels;
  surfels.resize(pointSet.Count);

  glBindBuffer(GL_ARRAY_BUFFER, pointSet.Buffer);
  Surfel *pts = static_cast<Surfel*>(glMapBuffer(GL_ARRAY_BUFFER, GL_READ_ONLY));
  std::memcpy(surfels.data(), pts, pointSet.Count * sizeof(Surfel));
  glUnmapBuffer(GL_ARRAY_BUFFER);
  glBindBuffer(GL_ARRAY_BUFFER, 0);

  return surfels;
}

void PointSetRenderer::SetPoints(std::vector<Surfel> const& surfels)
{
  LoadPointSet(surfels);
  modelPositions.reset();
}


void PointSetRenderer::SetRoiWeigths(const std::vector<float> &roiWeights)
{
  LoadRoiWeights(roiWeights);
  ReloadShaderData();
}


void PointSetRenderer::UpdateRoiWeights(const std::vector<std::pair<size_t, float>> &updates, float roiWeight)
{
  glBindBuffer(GL_ARRAY_BUFFER, roiWeightsBuffer);

  float *roiWeights = static_cast<float*>(glMapBuffer(GL_ARRAY_BUFFER, GL_WRITE_ONLY));
  for (auto &update : updates) roiWeights[update.first] = roiWeight;
  glUnmapBuffer(GL_ARRAY_BUFFER);

  glBindBuffer(GL_ARRAY_BUFFER, 0);

  // The range will only get wider this way.
  maxRoiWeight = std::max(maxRoiWeight, roiWeight);
  minRoiWeight = std::min(minRoiWeight, roiWeight);
  ReloadShaderData();
}


void PointSetRenderer::SplattingPass() const
{
  glBindFramebuffer(GL_FRAMEBUFFER, framebuffer);

  const float initialColorWeight[] = { 0.f, 0.f, 0.f, 1.E-6f };
  const float initialNormalLambda[] = { 0.f, 0.f, 0.f, 0.f };
  glClearBufferfv(GL_COLOR, 0, initialColorWeight);
  glClearBufferfv(GL_COLOR, 1, initialNormalLambda);
  glClear(GL_DEPTH_BUFFER_BIT);

  glUseProgram(splattingProgram.Program);
  glBindVertexArray(pointSet.Array);

  glUniform1f(depthEpsilonLocation, depthEpsilon);
  glColorMask(false, false, false, false);
  glDrawArrays(GL_POINTS, 0, pointSet.Count);
  glColorMask(true, true, true, true);

  glEnable(GL_BLEND);
  glBlendFunc(GL_ONE, GL_ONE);
  glDepthFunc(GL_LEQUAL);
  glUniform1f(depthEpsilonLocation, 0.f);
  glDepthMask(false);
  glDrawArrays(GL_POINTS, 0, pointSet.Count);
  glDepthMask(true);
  glDisable(GL_BLEND);

  glBindVertexArray(0);
  glUseProgram(0);

  glBindFramebuffer(GL_FRAMEBUFFER, 0);
}


void PointSetRenderer::ShadingPass() const
{
  glUseProgram(shadingProgram.Program);
  glBindVertexArray(quad.Array);

  glActiveTexture(GL_TEXTURE0);
  glBindTexture(GL_TEXTURE_2D, colorWeightTexture);
  glActiveTexture(GL_TEXTURE1);
  glBindTexture(GL_TEXTURE_2D, normalLambdaTexture);
  glActiveTexture(GL_TEXTURE2);
  glBindTexture(GL_TEXTURE_2D, depthBuffer);

  glDrawArrays(GL_TRIANGLE_STRIP, 0, quad.Count);

  glActiveTexture(GL_TEXTURE2);
  glBindTexture(GL_TEXTURE_2D, 0);
  glActiveTexture(GL_TEXTURE1);
  glBindTexture(GL_TEXTURE_2D, 0);
  glActiveTexture(GL_TEXTURE0);
  glBindTexture(GL_TEXTURE_2D, 0);

  glBindVertexArray(0);
  glUseProgram(0);
}



void PointSetRenderer::RetrieveModelPositions()
{
  if (!modelPositions)
  {
    glBindBufferBase(GL_UNIFORM_BUFFER, 0, shaderDataBuffer);
    glBindFramebuffer(GL_FRAMEBUFFER, positionsFramebuffer);

    const float initialPositions[] = { 0.f, 0.f, -1000.f, 0.f };
    glClearBufferfv(GL_COLOR, 0, initialPositions);

    glUseProgram(positionsProgram.Program);
    glBindVertexArray(quad.Array);

    glBindTexture(GL_TEXTURE_2D, colorWeightTexture);
    glActiveTexture(GL_TEXTURE1);
    glBindTexture(GL_TEXTURE_2D, normalLambdaTexture);

    glDrawArrays(GL_TRIANGLE_STRIP, 0, quad.Count);

    glBindTexture(GL_TEXTURE_2D, 0);
    glActiveTexture(GL_TEXTURE0);
    glBindTexture(GL_TEXTURE_2D, 0);

    glBindVertexArray(0);
    glUseProgram(0);

    glBindFramebuffer(GL_FRAMEBUFFER, 0);
    glBindBufferBase(GL_UNIFORM_BUFFER, 0, 0);

    modelPositions.reset(new glm::vec3[canvasSize.x * canvasSize.y]);
    glBindTexture(GL_TEXTURE_2D, positionsTexture);
    glGetTexImage(GL_TEXTURE_2D, 0, GL_RGB, GL_FLOAT, modelPositions.get());
    glBindTexture(GL_TEXTURE_2D, 0);
  }
}

glm::vec3 PointSetRenderer::GetModelPosition(int x, int y)
{
  RetrieveModelPositions();
  assert(x >= 0 && y >= 0 && x < canvasSize.x && y < canvasSize.y);
  return modelPositions[(canvasSize.y - y - 1) * canvasSize.x + x];
}
