#include "stdafx.h"

#include "skeletonizerthread.h"

SkeletonizerThread::SkeletonizerThread(const ParallelSkeletonizer* skeletonizer, std::shared_ptr<SkeletonizerThreadContext> context,
	int threadID, int threads)
	: wxThread(wxTHREAD_JOINABLE), skeletonizer(skeletonizer), context(context),
	threadID(threadID), threads(threads), removedVoxels(0), acceptedVoxels(0)
{
	indexField = skeletonizer->GetIndexField();
	boundaryField = skeletonizer->GetBoundaryField();
	adjacencyLists = skeletonizer->GetAdjacencyLists();
	featureTransform = skeletonizer->GetFeatureTransform();

	extendedFT = context->GetExtendedFT();

	maxArea = (int)(0.5 * skeletonizer->GetTau() * boundaryField->getMaxIndex() + 0.5);
}

int SkeletonizerThread::GetThreadID() const
{
	return threadID;
}

int SkeletonizerThread::GetRemovedVoxels() const
{
	return removedVoxels;
}

int SkeletonizerThread::GetAcceptedVoxels() const
{
	return acceptedVoxels;
}

wxThread::ExitCode SkeletonizerThread::Entry()
{
	for (int index = threadID; index < indexField->getMaxIndex(); index += threads)
		ProcessVoxel(index);

	return 0;
}

void SkeletonizerThread::ProcessVoxel(unsigned int index)
{
	ComputeExtendedFT(index);

	if (extendedFT->at(index).size() == 0) return;

	TShortestPathSet pathSet = ComputePathSet(index);
	context->GetShortestPathSetField()->m_Values[index] = new TShortestPathSet(pathSet);

	std::shared_ptr<TIndexedOrigins_Vector> mergedGeodesics = pathSet.m_PathSet;

	TIndexedOrigins_Vector dilatedShortestPaths = Dilate(*mergedGeodesics);
	std::vector<unsigned int> components = CountComponents(dilatedShortestPaths);

	bool isCurveSkeleton = components.size() >= 2;
	if (isCurveSkeleton)
	{
		context->AddCurveSkeletonPoint(index);
		context->GetLengthField()->m_Values[index] = mergedGeodesics->size();

		if (skeletonizer->ShouldComputeAngle())
		{
			double cosine = ComputeGeodesicCosine(index, pathSet);
			context->GetAngleField()->m_Values[index] = (float)(1.0 - (cosine + 1.0) / 2.0);

			if (cosine > skeletonizer->GetMaxCosine())
			{
				context->GetImportanceMeasureField()->m_Values[index] = 0.0f;
				removedVoxels++;
				return;
			}

			if (cosine < skeletonizer->GetAcceptCosine())
			{
				context->GetImportanceMeasureField()->m_Values[index] = 1.0f;
				acceptedVoxels++;
				return;
			}
		}

		if (skeletonizer->ShouldComputeImportance())
		{
			std::vector<int> componentAreas = ComputeComponentAreas(dilatedShortestPaths, components);

			double importance = ComputeImportanceMeasure(componentAreas);
			context->GetImportanceMeasureField()->m_Values[index] = (float)importance;
		}
	}
}

void SkeletonizerThread::ComputeExtendedFT(unsigned int index)
{
	const TCoord3& point = indexField->vidx2coord(index);

	TIndexedOrigins_Set origins;

	// All neighbours in one octant, including self
	for (int z = 0; z <= 1; z++)
	{
		for (int y = 0; y <= 1; y++)
		{
			for (int x = 0; x <= 1; x++)
			{
				TCoord3 neighbourPoint(point.x + x, point.y + y, point.z + z);

				if (indexField->vinside(neighbourPoint))
				{
					unsigned int neighbourIndex = indexField->valuep(neighbourPoint);

					const TIndexedOrigins_Vector* neighbourOrigins = &featureTransform->at(neighbourIndex);
					origins.merge(neighbourOrigins);
				}
			}
		}
	}

	extendedFT->at(index) = origins;
}

TShortestPathSet SkeletonizerThread::ComputePathSet(unsigned int index)
{
	const TCoord3& point = indexField->vidx2coord(index);

	std::shared_ptr<TIndexedOrigins_Vector> origins(new TIndexedOrigins_Vector(extendedFT->at(index)));
	TShortestPathSet pathSet(point, origins);

	for (TIndexedOrigins_Vector::const_iterator begin = origins->begin(); begin != origins->end(); begin++)
	{
		for (TIndexedOrigins_Vector::const_iterator end = begin + 1; end != origins->end(); end++)
		{
			unsigned int beginIndex = *begin, endIndex = *end;

			std::shared_ptr<TShortestPath> path;
			
			if (context->GetGeodesicCache() != nullptr)
				path = context->GetGeodesicCache()->GetShortestPath(beginIndex, endIndex, *this);
			else
				path = ComputeShortestPath(beginIndex, endIndex);

			if (path != nullptr)
			{
				//ComputeTangentPlane(point, *path); // Unused?
				pathSet.m_Paths.push_back(path);
			}
		}
	}

	pathSet.m_PathSet.reset(new TIndexedOrigins_Vector());
	pathSet.updatePathSet();

	return pathSet;
}

typedef std::priority_queue<std::pair<float, unsigned int>, std::vector<std::pair<float, unsigned int>>,
	std::greater<std::pair<float, unsigned int>>> PriorityQueue; // I love C++

static const float WEIGHTS[3] = { 0.9016f, 1.289f, 1.615f };

std::shared_ptr<TShortestPath> SkeletonizerThread::ComputeShortestPath(unsigned int beginIndex, unsigned int endIndex)
{
	assert(beginIndex != endIndex);

	const TCoord3& startPoint = boundaryField->vidx2coord(beginIndex),
		endPoint = boundaryField->vidx2coord(endIndex);

	pathNodes.clear();

	PriorityQueue queue;
	pathNodes[beginIndex].m_Distance = 0;
	queue.push(std::pair<float, unsigned int>(0.0f, beginIndex));

	while (!queue.empty())
	{
		unsigned int index = queue.top().second;
		queue.pop();

		TDeltaOmega::TAuxiliaryStruct& auxiliary = pathNodes[index];
		if (auxiliary.m_IsKnown) continue;
		auxiliary.m_IsKnown = true;

		if (index == endIndex) break;

		const TDeltaOmega::TAdjacencyStruct& adjacencyList = adjacencyLists->at(index);

		for (unsigned int neighbour = 0; neighbour < adjacencyList.m_EdgeCount; neighbour++)
		{
			unsigned int neighbourIndex = adjacencyList.m_Edges[neighbour].m_ToVertex;
			TDeltaOmega::TAuxiliaryStruct& neighbourAuxiliary = pathNodes[neighbourIndex];

			if (!neighbourAuxiliary.m_IsKnown)
			{
				float weight = WEIGHTS[adjacencyList.m_Edges[neighbour].m_Weight];
				float distance = auxiliary.m_Distance + weight;

				if (distance < neighbourAuxiliary.m_Distance)
				{
					neighbourAuxiliary.m_Distance = distance;
					neighbourAuxiliary.m_Previous = index;

					const TCoord3& neighbourPoint = boundaryField->vidx2coord(neighbourIndex);
					queue.push(std::pair<float, unsigned int>(distance + endPoint.distance(neighbourPoint), neighbourIndex));
				}
			}
		}
	}

	if (!pathNodes[endIndex].m_IsKnown)
	{
		wxLogWarning("Could not find path between (%d, %d, %d) and (%d, %d, %d)",
			startPoint.x, startPoint.y, startPoint.z, endPoint.x, endPoint.y, endPoint.z);

		return nullptr;
	}

	unsigned int middleIndex;
	std::shared_ptr<TShortestPath> path(new TShortestPath(beginIndex, endIndex, pathNodes[endIndex].m_Distance,
		ReconstructPath(endIndex, middleIndex)));
	path->m_Middle = middleIndex;

	return path;
}

std::shared_ptr<TIndexedOrigins_Vector> SkeletonizerThread::ReconstructPath(unsigned int endIndex, unsigned int& middleIndex)
{
	// pathNodes must have been initialized by ComputeShortestPath

	double pathLength = pathNodes[endIndex].m_Distance;
	TIndexedOrigins_Vector path;

	unsigned int index = endIndex;
	while (index != TIndexedOrigins::INVALIDINDEX)
	{
		path.add(index);

		if (pathNodes[index].m_Distance >= pathLength / 2.0)
			middleIndex = index;

		index = pathNodes[index].m_Previous;
	}

	return std::shared_ptr<TIndexedOrigins_Vector>(new TIndexedOrigins_Vector(path));
}

void SkeletonizerThread::ComputeTangentPlane(const TCoord3& point, TShortestPath& path) const
{
	const TCoord3& beginPoint = boundaryField->vidx2coord(path.m_Begin),
		endPoint = boundaryField->vidx2coord(path.m_End);

	// Plane equation: <N, (p - O)> = 0 
	TVector3 N; 
	TVector3 O;

	TVector3 a(beginPoint.x, beginPoint.y, beginPoint.z);
	a.x -= point.x; 
	a.y -= point.y; 
	a.z -= point.z;

	TVector3 b(endPoint.x, endPoint.y, endPoint.z);
	b.x -= point.x; 
	b.y -= point.y; 
	b.z -= point.z;

	// N = (a - b) / || a - b ||
	N.add(a);
	N.subtract(b);
	N.normalize();

	// O = (a + b) / 2
	O.add(a); 
	O.add(b); 
	O.scale(0.5f);

	O.x += point.x;
	O.y += point.y;
	O.z += point.z;

	path.m_LocalSheet.reset(new TPlane(O, N));
}

double SkeletonizerThread::ComputeGeodesicCosine(unsigned int index, const TShortestPathSet& shortestPathSet) const
{
	TVector3 point = indexField->vidx2coord(index).toVector();
	double cosine = 1.0;

	for (TShortestPathSet::TPaths::const_iterator i = shortestPathSet.m_Paths.begin();
		i != shortestPathSet.m_Paths.end(); i++)
	{
		for (TShortestPathSet::TPaths::const_iterator j = i + 1;
			j != shortestPathSet.m_Paths.end(); j++)
		{
			const TShortestPath& a = **i, b = **j;

			TVector3 vectorA = boundaryField->vidx2coord(a.getMiddle()).toVector();
			TVector3 vectorB = boundaryField->vidx2coord(b.getMiddle()).toVector();

			vectorA.subtract(point); vectorA.normalize();
			vectorB.subtract(point); vectorB.normalize();

			cosine = (std::min)(cosine, (double)vectorA.dot(vectorB));

			// Early termination
			if (cosine < skeletonizer->GetAcceptCosine())
				return -1.0;
		}
	}

	return cosine > skeletonizer->GetMaxCosine() ? 1.0 : cosine;
}

TIndexedOrigins_Vector SkeletonizerThread::Dilate(const TIndexedOrigins_Vector& mergedGeodesics)
{
	TIndexedOrigins_Vector output;

	pathNodes.clear();
	PriorityQueue queue;

	for (TIndexedOrigins_Vector::const_iterator i = mergedGeodesics.begin(); i != mergedGeodesics.end(); i++)
	{
		unsigned int index = *i;
		pathNodes[index].m_Distance = 0.0f;
		queue.push(std::pair<float, unsigned int>(0.0f, index));
	}

	while (!queue.empty())
	{
		unsigned int index = queue.top().second;
		queue.pop();

		TDeltaOmega::TAuxiliaryStruct& auxiliary = pathNodes[index];
		if (auxiliary.m_IsKnown) continue;
		auxiliary.m_IsKnown = true;
		output.push_back(index);

		const TDeltaOmega::TAdjacencyStruct& adjacencyList = adjacencyLists->at(index);

		for (unsigned int neighbour = 0; neighbour < adjacencyList.m_EdgeCount; neighbour++)
		{
			unsigned int neighbourIndex = adjacencyList.m_Edges[neighbour].m_ToVertex;
			TDeltaOmega::TAuxiliaryStruct& neighbourAuxiliary = pathNodes[neighbourIndex];

			if (!neighbourAuxiliary.m_IsKnown)
			{
				float weight = WEIGHTS[adjacencyList.m_Edges[neighbour].m_Weight];
				float distance = auxiliary.m_Distance + weight;

				if (distance < neighbourAuxiliary.m_Distance && distance < skeletonizer->GetDilationDistance())
				{
					neighbourAuxiliary.m_Distance = distance;
					//neighbourAuxiliary.m_Previous = index; // Not necessary

					queue.push(std::pair<float, unsigned int>(distance, neighbourIndex));
				}
			}
		}
	}

	return output;
}

std::unordered_set<unsigned int> SkeletonizerThread::FindEdgeVoxels(const TIndexedOrigins_Vector& dilated)
{
	// pathNodes must have been initialized by Dilate

	std::unordered_set<unsigned int> edgeVoxels;

	for (TIndexedOrigins_Vector::const_iterator i = dilated.begin(); i != dilated.end(); i++)
	{
		unsigned int index = *i;
		const TDeltaOmega::TAdjacencyStruct& adjacencyList = adjacencyLists->at(index);

		for (unsigned int neighbour = 0; neighbour < adjacencyList.m_EdgeCount; neighbour++)
		{
			unsigned int neighbourIndex = adjacencyList.m_Edges[neighbour].m_ToVertex;

			if (boost::math::isinf(pathNodes[neighbourIndex].m_Distance))
			{
				edgeVoxels.insert(index);
				break;
			}
		}
	}

	return edgeVoxels;
}

std::vector<unsigned int> SkeletonizerThread::CountComponents(const TIndexedOrigins_Vector& dilated)
{
	std::unordered_set<unsigned int> edgeVoxels = FindEdgeVoxels(dilated), visited;

	std::vector<unsigned int> components;

	for (std::unordered_set<unsigned int>::const_iterator i = edgeVoxels.begin(); i != edgeVoxels.end(); i++)
	{
		unsigned int edgeVoxelIndex = *i;
		if (!visited.insert(edgeVoxelIndex).second) continue;

		// Apparently there's another component
		components.push_back(edgeVoxelIndex);

		std::queue<unsigned int> queue;
		queue.push(edgeVoxelIndex);

		while (!queue.empty())
		{
			unsigned int index = queue.front();
			queue.pop();

			const TDeltaOmega::TAdjacencyStruct& adjacencyList = adjacencyLists->at(index);

			for (unsigned int neighbour = 0; neighbour < adjacencyList.m_EdgeCount; neighbour++)
			{
				unsigned int neighbourIndex = adjacencyList.m_Edges[neighbour].m_ToVertex;
					
				if (edgeVoxels.find(neighbourIndex) != edgeVoxels.end() &&
					visited.insert(neighbourIndex).second)
				{
					queue.push(neighbourIndex);
				}
			}
		}
	}

	return components;
}

std::vector<int> SkeletonizerThread::ComputeComponentAreas(const TIndexedOrigins_Vector& dilated,
	const std::vector<unsigned int>& components)
{
	std::vector<int> componentAreas;
	componentAreas.resize(components.size());

	for (int currentComponent = 0; currentComponent < components.size(); currentComponent++)
	{
		unsigned int seedIndex = components[currentComponent];

		std::unordered_set<unsigned int> visitedIndices;

		// Exclude voxels in dilated geodesic from search
		visitedIndices.insert(dilated.begin(), dilated.end());

		visitedIndices.insert(seedIndex);
		componentAreas[currentComponent]++;

		std::queue<unsigned int> queue;
		queue.push(seedIndex);

		while (!queue.empty())
		{
			if (componentAreas[currentComponent] >= maxArea) break;

			unsigned int index = queue.front();
			queue.pop();

			const TDeltaOmega::TAdjacencyStruct& adjacencyList = adjacencyLists->at(index);

			for (int neighbour = 0; neighbour < adjacencyList.m_EdgeCount; neighbour++)
			{
				int neighbourIndex = adjacencyList.m_Edges[neighbour].m_ToVertex;
				
				if (visitedIndices.insert(neighbourIndex).second)
				{
					componentAreas[currentComponent]++;
					queue.push(neighbourIndex);
				}
			}
		}
	}

	return componentAreas;
}

double SkeletonizerThread::ComputeImportanceMeasure(std::vector<int>& componentAreas)
{
	// Sort components by size
	std::sort(componentAreas.begin(), componentAreas.end());

	int area = 0;

	for (std::vector<int>::const_iterator i = componentAreas.begin();
		i != componentAreas.end() - 1 /* Skip last */; i++)
	{
		area += *i;
	}

	return area >= maxArea ? 1.0 : (std::min)(1.0, 2.0 * area / boundaryField->getMaxIndex());
}
