#pragma once
#include <Volume.h>
#include <cstdint>
#include "UnionFind.h"

template <class T>
struct ConnectedComponent
{
  T label;
  T size;
};

template <class T>
class ComponentFinder
{
  Volume<T>& volume;
  UnionFind<int> unionFind;
  T threshold;

public:
  ComponentFinder(Volume<T>& volume, T threshold) :
    volume(volume),
    unionFind(volume.getUnpaddedSize(2)),
    threshold(threshold)
  {
  }

  void findAll()
  {
    for (int z = 1; z < volume.getDepth() - 1; z++)
      for (int y = 1; y < volume.getHeight() - 1; y++)
        for (int x = 1; x < volume.getWidth() - 1; x++)
          if (volume(x, y, z) > threshold)
            findNeighbours26(x, y, z);
      
  }

  int getLargestLabel()
  {
    return unionFind.getLargestLabel();
  }

  int getLabel(int x, int y, int z)
  {
    return unionFind.root(getKey(x, y, z));
  }

  UnionFind<int>& getUnionFind() const
  {
    return unionFind;
  }


private:
  void findNeighbours26(int x, int y, int z)
  {
    int keyCur = getKey(x, y, z);

    for (int k = z - 1; k <= z; k++)
      for (int j = y - 1; j <= y; j++)
        for (int i = x - 1; i <= x; i++)
        {
          if (volume(i, j, k) <= threshold)
            continue;

          int keyNeigh = getKey(i, j, k);
          if (!unionFind.find(keyCur, keyNeigh))
            unionFind.merge(keyCur, keyNeigh);
        }
  }

  int getKey(int x, int y, int z)
  {
    assert(x > 0 && x < volume.getWidth() - 1);
    assert(y > 0 && y < volume.getHeight() - 1);
    assert(z > 0 && z < volume.getDepth() - 1);

    return (x - 1) + (y - 1) * (volume.getWidth() - 2 ) + 
      (z - 1) * (volume.getHeight() - 2) * (volume.getWidth() - 2);
  }

};
