#include "stdafx.h"

#include "rootgraph.h"

RootEdge::RootEdge(unsigned int beginIndex, unsigned int endIndex, double length)
	: beginIndex(beginIndex), endIndex(endIndex), length(length)
{

}

unsigned int RootEdge::GetBeginIndex() const
{
	return beginIndex;
}

unsigned int RootEdge::GetEndIndex() const
{
	return endIndex;
}

double RootEdge::GetLength() const
{
	return length;
}

RootNode::RootNode()
	: index(INVALID_ROOT_NODE_INDEX),
	coords(INVALID_ROOT_NODE_COORDS),
	component(INVALID_ROOT_GRAPH_COMPONENT)
{

}

RootNode::RootNode(unsigned int index, const TCoord3& coords)
	: index(index), coords(coords), component(INVALID_ROOT_GRAPH_COMPONENT)
{

}

const TCoord3& RootNode::GetCoords() const
{
	return coords;
}

void RootNode::AddEdge(unsigned int endIndex, double length)
{
	RootEdge edge(index, endIndex, length);
	edges.push_back(edge);
}

void RootNode::RemoveEdge(unsigned int endIndex)
{
	for (RootEdgeVector::const_iterator i = edges.begin();
		i != edges.end(); i++)
	{
		if (i->GetEndIndex() == endIndex)
		{
			edges.erase(i);
			break;
		}
	}
}

const RootEdgeVector& RootNode::GetEdges() const
{
	return edges;
}

unsigned int RootNode::GetIndex() const
{
	return index;
}

int RootNode::GetComponent() const
{
	return component;
}

void RootNode::SetComponent(int component)
{
	this->component = component;
}

RootGraph::RootGraph(int volumeWidth, int volumeHeight, int volumeDepth)
	: volumeWidth(volumeWidth), volumeHeight(volumeHeight), volumeDepth(volumeDepth)
{

}

RootNode& RootGraph::AddNode(unsigned int index, const TCoord3& coords)
{
	RootNode node(index, coords);
	nodes[index] = node;
	return nodes.at(index);
}

const RootNode& RootGraph::GetNode(unsigned int index) const
{
	return nodes.at(index);
}

const RootNodeMap& RootGraph::GetNodes() const
{
	return nodes;
}

unsigned int RootGraph::GetIndex(std::shared_ptr<const FIELD3<float>> field, const TCoord3& coords)
{
	unsigned int sliceSize = field->dimY() * field->dimX();
	unsigned int lineSize = field->dimX();

	return sliceSize * coords.z + lineSize * coords.y + coords.x;
}

void RootGraph::AddEdges(std::shared_ptr<const FIELD3<float>> field, RootNode& node)
{
	const TCoord3& coords = node.GetCoords();

	for (unsigned int z = coords.z - 1; z <= coords.z + 1; z++)
	{
		for (unsigned int y = coords.y - 1; y <= coords.y + 1; y++)
		{
			for (unsigned int x = coords.x - 1; x <= coords.x + 1; x++)
			{
				int differentCoords = (x != coords.x) + (y != coords.y) + (z != coords.z);

				if (differentCoords != 1) // Not a face neighbour
					continue;

				float neighbourValue = field->value(x, y, z);

				if (neighbourValue > 0.0f)
				{
					TCoord3 neighbourCoords(x, y, z);

					unsigned int neighbourIndex = GetIndex(field, neighbourCoords);
					double length = neighbourCoords.toPoint().distance(coords.toPoint());

					node.AddEdge(neighbourIndex, length);
				}
			}
		}
	}
}

std::shared_ptr<RootGraph> RootGraph::FromField(std::shared_ptr<const FIELD3<float>> field)
{
	std::shared_ptr<RootGraph> graph(new RootGraph(field->dimX(), field->dimY(), field->dimZ()));

	// The reading routine adds an empty border around a field
	for (unsigned int z = 1; z < field->dimZ() - 1; z++)
	{
		for (unsigned int y = 1; y < field->dimY() - 1; y++)
		{
			for (unsigned int x = 1; x < field->dimX() - 1; x++)
			{
				float value = field->value(x, y, z);

				if (value > 0.0f)
				{
					TCoord3 coords(x, y, z);

					unsigned int index = GetIndex(field, coords);
					RootNode& node = graph->AddNode(index, coords);

					AddEdges(field, node); 
				}
			}
		}
	}

	return graph;
}

std::shared_ptr<TFloatField3> RootGraph::ToField(RootNodeVisitor* visitor) const
{
	std::shared_ptr<TFloatField3> field(new TFloatField3(volumeWidth, volumeHeight, volumeDepth));

	for (unsigned int z = 1; z < field->dimZ() - 1; z++)
	{
		for (unsigned int y = 1; y < field->dimY() - 1; y++)
		{
			for (unsigned int x = 1; x < field->dimX() - 1; x++)
			{
				unsigned int index = GetIndex(field, TCoord3(x, y, z));
				
				RootNodeMap::const_iterator i = nodes.find(index);
				if (i != nodes.end())
				{
					if (visitor != nullptr)
						field->value(x, y, z) = visitor->Visit(i->second);
					else
						field->value(x, y, z) = 1.0f;
				}
				else field->value(x, y, z) = 0.0f;
			}
		}
	}

	return field;
}

std::vector<int> RootGraph::CountComponents()
{
	std::unordered_set<unsigned int> visited;
	std::vector<int> components;

	for (RootNodeMap::iterator i = nodes.begin(); i != nodes.end(); i++)
	{
		RootNode& startNode = i->second;

		if (visited.insert(startNode.GetIndex()).second)
		{
			int currentComponent = components.size(), componentSize = 0;

			std::queue<unsigned int> queue;
			queue.push(startNode.GetIndex());

			while (!queue.empty())
			{
				RootNode& node = nodes[queue.front()];
				queue.pop();

				node.SetComponent(currentComponent);
				componentSize++;

				for (RootEdgeVector::const_iterator i = node.GetEdges().begin(); i != node.GetEdges().end(); i++)
				{
					const RootEdge& edge = *i;
					if (visited.insert(edge.GetEndIndex()).second)
						queue.push(edge.GetEndIndex());
				}
			}

			// Apparently there was another component
			components.push_back(componentSize);
		}
	}

	return components;
}

void RootGraph::RemoveComponents(int minComponentSize, const std::vector<int>& components)
{
	int currentComponent = 0;
	for (std::vector<int>::const_iterator i = components.begin(); i != components.end(); i++)
	{
		int componentSize = *i;

		if (componentSize < minComponentSize)
		{
			RootNodeMap::const_iterator j = nodes.begin();

			while (j != nodes.end())
			{
				const RootNode& node = j->second;

				if (node.GetComponent() == currentComponent)
					j = nodes.erase(j);
				else j++;
			}
		}

		currentComponent++;
	}
}

void RootGraph::RemoveSmallComponents(int minComponentSize)
{
	std::vector<int> components = CountComponents();
	RemoveComponents(minComponentSize, components);
}

void RootGraph::RemoveJunctionConnectedEndpoints(const JunctionEndPointDetector* detector)
{
	for (RootNodeIndexSet::const_iterator i = detector->GetEndPoints().begin();
		i != detector->GetEndPoints().end(); i++)
	{
		unsigned int index = *i;
		const RootNode& node = nodes[index];

		unsigned int neigbourIndex = node.GetEdges().front().GetEndIndex();

		if (detector->GetJunctions().find(neigbourIndex) !=
			detector->GetJunctions().end())
		{
			nodes.erase(index);
			RootNode& junctionNode = nodes[neigbourIndex];
			junctionNode.RemoveEdge(index);
		}
	}
}

JunctionEndPointDetector::JunctionEndPointDetector()
	: junctions(0), endPoints(0)
{

}

float JunctionEndPointDetector::Visit(const RootNode& node)
{
	int n = node.GetEdges().size();

	if (n < 2)
	{
		endPoints.insert(node.GetIndex());
		return 0.55f;
	}

	if (n > 2)
	{
		junctions.insert(node.GetIndex());
		return 1.0f;
	}

	return 0.1f;
}

const RootNodeIndexSet& JunctionEndPointDetector::GetJunctions() const
{
	return junctions;
}

const RootNodeIndexSet& JunctionEndPointDetector::GetEndPoints() const
{
	return endPoints;
}