#pragma once

#include "stdafx.h"
#include "field.h"

const unsigned int INVALID_ROOT_NODE_INDEX = (std::numeric_limits<unsigned int>::max)();
const int INVALID_ROOT_GRAPH_COMPONENT = -1;
const TCoord3 INVALID_ROOT_NODE_COORDS = TCoord3(-1, -1, -1);

class RootEdge
{
	unsigned int beginIndex, endIndex;
	double length;

public:
	RootEdge(unsigned int beginIndex, unsigned int endIndex, double length);

	unsigned int GetBeginIndex() const;
	unsigned int GetEndIndex() const;
	double GetLength() const;
};

typedef std::vector<RootEdge> RootEdgeVector;

class RootNode
{
	unsigned int index;
	TCoord3 coords;
	RootEdgeVector edges;
	int component;

public:
	RootNode(); // Otherwise the compiler complains
	RootNode(unsigned int index, const TCoord3& coords);
	
	const TCoord3& GetCoords() const;

	void AddEdge(unsigned int endIndex, double length);
	void RemoveEdge(unsigned int endIndex);

	const RootEdgeVector& GetEdges() const;

	unsigned int GetIndex() const;

	int GetComponent() const;
	void SetComponent(int component);
};

typedef std::unordered_map<unsigned int, RootNode> RootNodeMap;
typedef std::unordered_set<unsigned int> RootNodeIndexSet;

class RootNodeVisitor
{
public:
	virtual float Visit(const RootNode& node) = 0;
};

class JunctionEndPointDetector : public RootNodeVisitor
{
	RootNodeIndexSet junctions, endPoints;

public:
	JunctionEndPointDetector();

	virtual float Visit(const RootNode& node);

	const RootNodeIndexSet& GetJunctions() const;
	const RootNodeIndexSet& GetEndPoints() const;
};

class RootGraph
{
	int volumeWidth, volumeHeight, volumeDepth;
	RootNodeMap nodes;
	
	RootGraph(int volumeWidth, int volumeHeigth, int volumeDepth);

	static void AddEdges(std::shared_ptr<const FIELD3<float>> field, RootNode& node);
	static unsigned int GetIndex(std::shared_ptr<const FIELD3<float>> field, const TCoord3& coords);

	std::vector<int> CountComponents();
	void RemoveComponents(int minComponentSize, const std::vector<int>& components);

public:
	RootNode& AddNode(unsigned int index, const TCoord3& coords);

	const RootNode& GetNode(unsigned int index) const;
	const RootNodeMap& GetNodes() const;

	static std::shared_ptr<RootGraph> FromField(std::shared_ptr<const FIELD3<float>> field);
	std::shared_ptr<TFloatField3> ToField(RootNodeVisitor* visitor = nullptr) const;

	void RemoveSmallComponents(int minComponentSize);
	void RemoveJunctionConnectedEndpoints(const JunctionEndPointDetector* detector);
};