/***************************************************************************
 *   Copyright (C) 2009 by Hessel Hoogendorp                               *
 *   bugs.ccc@gmail.com                                                    *
 *                                                                         *
 *   This program is free software; you can redistribute it and/or modify  *
 *   it under the terms of the GNU General Public License as published by  *
 *   the Free Software Foundation; either version 2 of the License, or     *
 *   (at your option) any later version.                                   *
 *                                                                         *
 *   This program is distributed in the hope that it will be useful,       *
 *   but WITHOUT ANY WARRANTY; without even the implied warranty of        *
 *   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the         *
 *   GNU General Public License for more details.                          *
 *                                                                         *
 *   You should have received a copy of the GNU General Public License     *
 *   along with this program; if not, write to the                         *
 *   Free Software Foundation, Inc.,                                       *
 *   59 Temple Place - Suite 330, Boston, MA  02111-1307, USA.             *
 ***************************************************************************/


// ----------------------------------------------------------------------------
// Includes
// ----------------------------------------------------------------------------
#include "ccie_definition_filter.h"
#include "ccie_debug_printing.h"
#include <algorithm>


// ----------------------------------------------------------------------------
// Static accessor for convenience
// ----------------------------------------------------------------------------
void CCIE_DefinitionFilter::Filter(/*CCIE_TranslationUnitInfo & tuInfo*/)
{
	CCIE_DefinitionFilter definitionFilter/*(tuInfo)*/;
	definitionFilter._Filter();
}


// ----------------------------------------------------------------------------
// Construction & destruction
// ----------------------------------------------------------------------------
CCIE_DefinitionFilter::CCIE_DefinitionFilter(/*CCIE_TranslationUnitInfo & tuInfo*/) /*:
	m_tuInfo(tuInfo)*/
{
}


// ----------------------------------------------------------------------------
// Filtering
// ----------------------------------------------------------------------------
void CCIE_DefinitionFilter::_Filter()
{
	// Calculate the set of all function signatures and the set of signatures
	// of callable functions.
	
	DEBUG_PRINT("Calculating function set...");
	CalculateFunctionSet();
	DEBUG_PRINT("Calculating callable function set...");
	CalculateCallableSet();

	// Calculate the set of signatures of non-callable functions.
	DEBUG_PRINT("Calculating non-callable function set...");
	StringSet nonCallables;
	std::insert_iterator<StringSet> nonCallablesIns(nonCallables, nonCallables.begin());
	std::set_difference(m_functions.begin(), m_functions.end(), m_callables.begin(), m_callables.end(), nonCallablesIns);

	// Remove all functions that are non-callable.
	DEBUG_PRINT("Removing non callable functions...");
	StringSet::iterator itNonCallable = nonCallables.begin();
	while(itNonCallable != nonCallables.end())
	{
		CCIE_TranslationUnitInfo::m_callCandidates.RemoveFunction(*itNonCallable);
		DEBUG_PRINT("Removed non-callable function: " << (*itNonCallable));
		itNonCallable++;
	}

	DEBUG_PRINT("Removed " << nonCallables.size() << " of " << m_functions.size() << " functions, leaving " << m_callables.size() << " callable functions.");
}

void CCIE_DefinitionFilter::CalculateFunctionSet()
{
	// Start with an empty set.
	m_functions.clear();

	// Add all function signatures to the set.
	LocalFunctionMap* pLocalFunctionMap = CCIE_TranslationUnitInfo::m_localFunctionRepository.GetFunctions();
	LocalFunctionMap::iterator itLocalFunction = pLocalFunctionMap->begin();
	while(itLocalFunction != pLocalFunctionMap->end())
	{
		m_functions.insert(itLocalFunction->first);
		itLocalFunction++;
	}
}

void CCIE_DefinitionFilter::CalculateCallableSet()
{
	// Start with the empty set.
	m_callables.clear();

	// Calculate the initial callable set.
	CalculateInitialCallableSet();
	PrintCallableSet("", "Initial Callable Set");

	// Calculate the closure of the callable set.
	CalculateCallableSetClosure();
	PrintCallableSet("", "Callable Set Closure");
}

void CCIE_DefinitionFilter::CalculateInitialCallableSet()
{
	// Add all functions that have been specified as an entry point to the callable set.
	if(CCIE_Cmd::FilterDefinitionsKeepSignatures()->size() > 0)
	{
//		// No entry points were specified explicitly, so add the default 'main'
//		// entry point.
//		m_callables.insert("void (main)()");
//		m_callables.insert("void (main)(int, char**)");
//		m_callables.insert("int (main)()");
//		m_callables.insert("int (main)(int, char**)");
//	}
//	else
//	{
		// Entry points were specified explicitly, so add those to the entry
		// point map.
		StringVector* pKeepSignatures = CCIE_Cmd::FilterDefinitionsKeepSignatures();
		m_callables.insert(pKeepSignatures->begin(), pKeepSignatures->end());
	}

	// Also, all functions that do not have static linkage and do have a
	// definition are entry-points.
	LocalFunctionMap* pLocalFunctions = CCIE_TranslationUnitInfo::m_localFunctionRepository.GetFunctions();
	LocalFunctionMap::iterator itLocalFunction = pLocalFunctions->begin();
	while(itLocalFunction != pLocalFunctions->end())
	{
		CCI_LocalFunction* pLocalFunction = itLocalFunction->second;

		if(pLocalFunction->HasFunctionDefinition())
		{
			// If this function does not have static linkage, then it is always
			// an entry point.
			if(!pLocalFunction->IsFlagStaticLinkage())
			{
				m_callables.insert(pLocalFunction->GetSignature());
			}

			// If this function does have static linkage, then it is an entry
			// point iff:
			// * We are in conservative mode, and
			// (
			//     * The function is virtual, or
			//     * The function is a c-style function
			// )

			// If we are in conservative mode, then we have even more entry points.
			else if(CCIE_Cmd::FilterDefinitionsConservative())
			{
				// If this function is a virtual method, or a c-style function,
				// then it is also an entry point.
				if((pLocalFunction->IsFlagMethod() && pLocalFunction->IsFlagVirtual()) ||
				   !pLocalFunction->IsFlagMethod())
				{
					m_callables.insert(pLocalFunction->GetSignature());
				}
			}
			else
			{
				/* WE ARE MERELY COMPOSING THE INITIAL CALLABLE SET HERE. IT MIGHT
				   VERY WELL BE THAT ALL THESE FUNCTIONS ARE ADDED TO THE CALLABLE
				   SET DURING CLOSURE, SO WE CANNOT YET ISSUE WARNINGS HERE.
				// If this function is a virtual method, or a c-style function,
				// then we need to issue a warning: This function is going to
				// be filtered out, but it might be that this function is a
				// call candidate in another translation unit.
				if((pLocalFunction->IsFlagMethod() && pLocalFunction->IsFlagVirtual()) ||
				   !pLocalFunction->IsFlagMethod())
				{
					// Only issue a warning here if we are supposed to.
					if(CCIE_Cmd::PrintConservativenessWarnings())
					{
						std::cout << "WARNING: Function that is a potential call-candidate in another translation unit is being filtered:" << std::endl;
						std::cout << "{" << std::endl;
						std::cout << "\tLocation         : " << pLocalFunction->GetFunctionDefinition()->GetSourceLoc()->ToString("") << std::endl;
						std::cout << "\tSignature        : " << pLocalFunction->GetSignature() << std::endl;
						std::cout << "\tPointer signature: " << pLocalFunction->GetPointerSignature() << std::endl;
						std::cout << "}" << std::endl;
					}
				}
				*/
			}
		}

		itLocalFunction++;
	}

	// Lastly, all functions that are called by initializers and finalizers are
	// also entry points.
	FunctionCallVector::iterator itInitializingCall = CCIE_TranslationUnitInfo::m_initializingCalls.begin();
	while(itInitializingCall != CCIE_TranslationUnitInfo::m_initializingCalls.end())
	{
		// Retrieve the call candidates of this function call.
		CCI_FunctionCall* pFunctionCall = *itInitializingCall;
		StringVector callCandidateSignatures;
		pFunctionCall->GetCallCandidates(CCIE_TranslationUnitInfo::m_callCandidates, callCandidateSignatures);
		// Add the call candidates to the set of entry points.
		m_callables.insert(callCandidateSignatures.begin(), callCandidateSignatures.end());
		itInitializingCall++;
	}

	FunctionCallVector::iterator itFinalizingCall = CCIE_TranslationUnitInfo::m_finalizingCalls.begin();
	while(itFinalizingCall != CCIE_TranslationUnitInfo::m_finalizingCalls.end())
	{
		// Retrieve the call candidates of this function call.
		CCI_FunctionCall* pFunctionCall = *itFinalizingCall;
		StringVector callCandidateSignatures;
		pFunctionCall->GetCallCandidates(CCIE_TranslationUnitInfo::m_callCandidates, callCandidateSignatures);
		// Add the call candidates to the set of entry points.
		m_callables.insert(callCandidateSignatures.begin(), callCandidateSignatures.end());
		itFinalizingCall++;
	}
}

void CCIE_DefinitionFilter::CalculateCallableSetClosure()
{
	StringSet::iterator itCallable = m_callables.begin();
	while(itCallable != m_callables.end())
	{
		AddCallablesFromFunction(*itCallable);
		itCallable++;
	}
}

void CCIE_DefinitionFilter::AddCallablesFromFunction(std::string szSignature)
{
	// Pre-condition: szSignature is present in the callable set.

	// Find the CCI_LocalFunction belonging to the supplied signature.
	CCI_LocalFunction* pLocalFunction = CCIE_TranslationUnitInfo::m_localFunctionRepository.GetFunction(szSignature);
	if(pLocalFunction == NULL)
	{
		// The specified signature does not have a CCI_LocalFunction associated
		// with it. This can be caused by:
		//
		// 1. The user specified an invalid entry-point signature.
		// 2. The function was filtered out by declaration filtering.
		//
		// In both cases, there is nothing we can do.
		return;
	}

	// If this function does not have a definition, then it obviously cannot
	// call other functions, so we are done with this function.
	if(!pLocalFunction->HasFunctionDefinition())
		return;

	// Retrieve all the function calls that this function makes.
	CCI_FunctionDefinition* pFunctionDefinition = pLocalFunction->GetFunctionDefinition();
	FunctionCallVector* pFunctionCalls = pFunctionDefinition->GetFunctionCalls();
	FunctionCallVector::iterator itFunctionCall = pFunctionCalls->begin();
	while(itFunctionCall != pFunctionCalls->end())
	{
		// Retrieve the set of call candidates of this function call.
		StringVector callCandidates;
		(*itFunctionCall)->GetCallCandidates(CCIE_TranslationUnitInfo::m_callCandidates, callCandidates);
		
		// Iterate through all call candidates and check whether they already
		// exist in the callable set. If not, add them and recurse into them.
		StringVector::iterator itCallCandidate = callCandidates.begin();
		while(itCallCandidate != callCandidates.end())
		{
			std::string szCallCandidate = *itCallCandidate;
			if(m_callables.find(szCallCandidate) == m_callables.end())
			{
				// The function does not yet exist in the callable set, so it
				// needs to be added to the set and recursed.
				m_callables.insert(szCallCandidate);
				AddCallablesFromFunction(szCallCandidate);
			}

			// Move on to the next call candidate.
			itCallCandidate++;
		}

		// Move on to the next function call.
		itFunctionCall++;
	}
}

void CCIE_DefinitionFilter::WarnIfNecessary(std::string szSignature)
{
	// Retrieve the local function belonging to this signature.
	CCI_LocalFunction* pLocalFunction = CCIE_TranslationUnitInfo::m_localFunctionRepository.GetFunction(szSignature);
	if(pLocalFunction == NULL)
		return;

	// If this function is a virtual method, or a c-style function,
	// then we need to issue a warning: This function is going to
	// be filtered out, but it might be that this function is a
	// call candidate in another translation unit.
	if((pLocalFunction->IsFlagMethod() && pLocalFunction->IsFlagVirtual()) ||
	   !pLocalFunction->IsFlagMethod())
	{
		// Only issue a warning here if we are supposed to.
		if(CCIE_Cmd::PrintConservativenessWarnings())
		{
			std::cout << "WARNING: Function that is a potential call-candidate in another translation unit is being filtered:" << std::endl;
			std::cout << "{" << std::endl;
			std::cout << "\tLocation         : " << pLocalFunction->GetFunctionDefinition()->GetSourceLoc()->ToString("") << std::endl;
			std::cout << "\tSignature        : " << pLocalFunction->GetSignature() << std::endl;
			std::cout << "\tPointer signature: " << pLocalFunction->GetPointerSignature() << std::endl;
			std::cout << "}" << std::endl;
		}
	}
}

void CCIE_DefinitionFilter::PrintCallableSet(std::string szPrefix, std::string szMessage)
{
	// If debug printing is enabled, print the callable set.
	if(IS_DEBUG_PRINTING)
	{
		std::cout << szPrefix << szMessage << " (" << m_callables.size() << "):" << std::endl;
		std::cout << szPrefix << "{" << std::endl;
		StringSet::iterator itCallable = m_callables.begin();
		while(itCallable != m_callables.end())
		{
			std::cout << szPrefix << "\t" << (*itCallable) << std::endl;
			itCallable++;
		}
		std::cout << szPrefix << "}" << std::endl;
	}
}
