// see License.txt for copyright and terms of use

#include "prcheck.h"        // this module
#include "prcheck_cmd.h"    // PrcheckCmd
#include "prcheck_global.h"
#include "oink.gr.gen.h"        // CCParse_Oink
#include "strutil.h"            // quoted
#include "oink_util.h"
#include "expr_visitor.h"
#include "squash_util.h"
#include <set>
#include <map>
#include <list>
#include <sstream>
#include <cppundolog.h>
#include "patcher.h"

using namespace std;

// all of the prbool "variables" go into here
static set<Variable*> prboolSet;

// all prbool returning functions/methods
static set<Variable*> prboolFunctionSet;

// function -> prbool parameter index list
typedef map<Variable*, list<int> > arg_map; 
static arg_map prboolParams;

// the following two mirror the above..they are used to detect typedefed entities
// typedefs for function pointers returning prbool
static set<StringRef> prboolTypedefs;
// function -> prbool parameter index list
typedef map<StringRef, list<int> > typedef_arg_map; 
static typedef_arg_map prboolTypedefParams;

static set<StringRef> prboolAliases;

typedef pair<set<StringRef>::iterator, typedef_arg_map::iterator> typedef_it_pair;

enum BoolType {
  NOT_A_BOOL = 0,
  PRBOOL 
  //  LIKE_PRBOOL
}; 

/* Can check function parameters, [global|class] variables, 
 * function return values */
class PrVisitor : public ExpressionVisitor {
public:
  PrVisitor(bool verifyPhase, Patcher &patcher, bool debug): 
    returnPrbool(false),
    patcher(patcher),
    debug(debug),
    verifyPhase(verifyPhase)
  {
    prboolAliases.insert(globalStrTable("PRBool"));
    prboolAliases.insert(globalStrTable("JSBool"));
    prboolAliases.insert(globalStrTable("PRPackedBool")); 
    prboolAliases.insert(globalStrTable("JSPackedBool")); 
    // makes GTK-interfacing code happier
    prboolAliases.insert(globalStrTable("gboolean")); 
    prboolAliases.insert(globalStrTable("cairo_bool_t")); 
  }

  virtual bool visitFunction(Function *f);
  // not really needed because in order for a return statement to occur
  // a function must start first. Still nice this to have for symmetry
  virtual void postvisitFunction(Function *f);
  virtual bool visitDeclaration(Declaration *d);
  virtual bool visitDeclarator(Declarator *d);
  virtual bool visitE_funCall(E_funCall *f);
  virtual bool visitASTTypeId(ASTTypeId *a);
  virtual bool visitS_return(S_return *s);
  virtual bool visitE_assign(E_assign *e);
  virtual bool visitE_cast(E_cast *e);
  virtual bool visitE_keywordcast(E_cast *e);
  virtual bool visitTypeSpecifier(TypeSpecifier *t);
  virtual void postvisitTypeSpecifier(TypeSpecifier *t);
  // need to check class member inits too
private:
  // add an item to prboolSet && check initializer
  void add(Declarator *d, BoolType reason);

  void checkPRBool(Expression *e, bool outerBracket = false);
  void checkFunctionArguments(Declarator *d);
  void peelAndCheckPRBool(Expression *e, Type *t);
  void makeBool(PairLoc loc, bool bracket, bool outerBracket);
  void checkForPrBoolTypedefs(Declaration *d);
  
  // check if this type refers to typedef for function dealing with prbools
  typedef_it_pair isTypedefedFunc(TypeSpecifier *ts);
  
  bool returnPrbool;
  Patcher &patcher;
  bool debug;
  // used to help with typedefs within a class scope
  // these will blow away a typedef if it is defined in global scope
  // and then in a class one, but i'll fix that once i run into that issue
  vector<list<set<StringRef>::iterator> > prboolTypedefsStack;
  vector<list<typedef_arg_map::iterator> > prboolTypedefParamsStack;

public:
  // true -> prbool value checks are enabled
  // false -> prbool functions & vars are gathered
  bool verifyPhase;
};

void Prcheck::check(bool debug) {
  Patcher patcher;
  PrVisitor prGatherer(false, patcher, debug);
  PrVisitor prVerifier(true, patcher, debug);
  foreachSourceFile {
    File *file = files.data();
    patcher.setFile(file->name);
    
    maybeSetInputLangFromSuffix(file);
    TranslationUnit *unit = file2unit.get(file);
    unit->traverse(prGatherer.loweredVisitor);
    unit->traverse(prVerifier.loweredVisitor);
  }
}

class PRBoolDetector : public ExpressionVisitor {
  virtual bool visitE_variable(E_variable *v) {
    if (prboolSet.find(v->var) != prboolSet.end()) {
      throw v->var;
    }
    return false;
  }

public: 


  static BoolType isPRBool(TypeSpecifier *ts) {    
    TS_name *tsname = ts->ifTS_name();
    if (!tsname) return NOT_A_BOOL;
    
    PQ_name *pqname = tsname->name->ifPQ_name();
    if (!pqname) return NOT_A_BOOL;
    
    return prboolAliases.find(pqname->name) != prboolAliases.end() 
      ? PRBOOL 
      : NOT_A_BOOL;
  }
  
  static bool checkTarget(Expression *e) {
    try {
      PRBoolDetector d;
      // got no recursion, but exceptions still work =D
      e->traverse(d.loweredVisitor);
    } catch (Variable *v) {
      Type *t = v->type;
      // hack to deal with *prbool = val;
      while(t->isPointerType()) {
	t = t->asPointerType()->atType;
	if (!e->isE_deref()) return false;
	e = e->asE_deref()->ptr;
      }
      return true;
    }
    return false;
  } 
};

typedef_it_pair PrVisitor::isTypedefedFunc(TypeSpecifier *ts) {
  StringRef name = NULL;
  if (TS_name *n = ts->ifTS_name()) {
    if (PQ_name *pqname = n->name->ifPQ_name()) {
      name = pqname->name;
    } else if (PQ_qualifier *pqq = n->name->ifPQ_qualifier()) {
      name = globalStrTable(pqq->toString());
    }
  }
  return typedef_it_pair(prboolTypedefs.find(name), 
			      prboolTypedefParams.find(name));
}

bool PrVisitor::visitDeclaration(Declaration *d) {
  checkForPrBoolTypedefs(d);
  // check for typedefed func pointers
  typedef_it_pair pair = isTypedefedFunc(d->spec);
  for (FakeList <Declarator >* lsd = d->decllist;
       !verifyPhase && !lsd->isEmpty(); 
       lsd = lsd->butFirst()) {
    Declarator *decl = lsd->first();

    if (pair.first != prboolTypedefs.end()) {
      prboolFunctionSet.insert(decl->var);
    }
   
    if (pair.second != prboolTypedefParams.end()) {
      prboolParams[decl->var] = pair.second->second;
    }
  }
  BoolType boolType;
  if (!(boolType = PRBoolDetector::isPRBool(d->spec))) return true;

  for (FakeList <Declarator >* lsd = d->decllist;
      !lsd->isEmpty(); lsd = lsd->butFirst()) {
    add(lsd->first(), boolType);
  }
  return true;
}

bool PrVisitor::visitFunction(Function *f) {
  BoolType boolType = PRBoolDetector::isPRBool(f->retspec);
  returnPrbool = boolType == PRBOOL;
  if (boolType) {
    add(f->nameAndParams, boolType);
  }
  return true;
}

void PrVisitor::postvisitFunction(Function *f) {
  returnPrbool = false;
}

bool PrVisitor::visitASTTypeId(ASTTypeId *a) {
  if (!verifyPhase) {
    typedef_it_pair pair = isTypedefedFunc(a->spec);
    
    if (pair.first != prboolTypedefs.end()) {
      prboolFunctionSet.insert(a->decl->var);
    }
    
    if (pair.second != prboolTypedefParams.end()) {
      prboolParams[a->decl->var] = pair.second->second;
    }
  }
  BoolType boolType = PRBoolDetector::isPRBool(a->spec);

  if (boolType) add(a->decl, boolType);

  return true;
}

// pull variables out of an ast
class VariableExtractor : public ExpressionVisitor {
  // handy for finding stuff that we should be matching on
  virtual bool visitExpression(Expression *e) {
    ExpressionVisitor::visitExpression(e);
    // give up quetly when to crazy ast nodes
    if (e->isE_binary() || e->isE_cond()) throw (Variable*)NULL;
    if (!e->isE_arrow() 
	&& !e->isE_deref()
	&& !e->isE_cast()) {
      cerr << toString(e->loc) << ": ";
      print(cerr, e);
      E_binary *b = e->ifE_binary();
      if(b) {
	cerr << ":"<<toString(b->op) <<":"<< b->op<< " <:" << endl;
      }
      cerr << "<--VariableExtractor got confused with this expr:" << e->kindName() << endl;
    }
    return true;
  }
  
  virtual bool visitE_variable(E_variable *v) {
    throw v->var;
  }

  virtual bool visitE_fieldAcc(E_fieldAcc *f) {
    throw f->field;
  }
  
  virtual bool visitE_binary(E_binary *b) {
    if (b->op == BIN_ARROW_STAR) {
      throw VariableExtractor::getVariable(b->e2);
    } else if (b->op == BIN_PLUS) {
      throw VariableExtractor::getVariable(b->e1);
    }
    return true;
  }
  
public:

  static Variable *getVariable(Expression *e) {
    VariableExtractor cc;
    try {
      e->traverse(cc.loweredVisitor);
    } catch (Variable *v) {
      return v;
    }
    return NULL;
  }
};

static bool isBoolFunc(Expression *e) {
  Variable *v = VariableExtractor::getVariable(e);
  return v ? prboolFunctionSet.find(v) != prboolFunctionSet.end() : false; 
}

static bool isPRBoolOrCompatible(Variable *v) {
  if (prboolSet.find(v) != prboolSet.end()) return true;
  AtomicType *t = NULL;
  if (v->type->isCVAtomicType()) {
    t = v->type->asCVAtomicType()->atomic;
  }
  if (v->isBitfield() 
      && v->getBitfieldSize() == 1 
      && t) {
    if (t->isSimpleType()) {
      switch (t->asSimpleType()->type) {
      case ST_UNSIGNED_CHAR:
      case ST_BOOL:
      case ST_UNSIGNED_INT:
      case ST_LONG_INT:
      case ST_UNSIGNED_LONG_INT:
      case ST_UNSIGNED_LONG_LONG: 
      case ST_UNSIGNED_SHORT_INT:
	return true;
      default:
	break;
      }
    }
  }
  // binary enums are prbool compatible
  if (t && t->isEnumType()) {
    EnumType *et = t->asEnumType();
    for (StringObjDict<EnumType::Value>::Iter it(et->valueIndex);
         !it.isDone(); it = it.next()) {
      const int value = it.value()->value ;
      if (value != 1 && value != 0)
        return false;
    }
    return true;
  }
  return false;
}

// for checking types in declarations.
// ensures that it's not a pointer
static bool isPRBoolTypeId(ASTTypeId *type) {
  return PRBoolDetector::isPRBool(type->spec)
    && !type->decl->var->type->isPtrOrRef();
}

// throws a true if something isn't a prbool
static bool isBool(Expression *e) {
  if (E_unary *u = e->ifE_unary()) {
    if (u->op == UNY_NOT) return true;

  } else if (E_variable *v = e->ifE_variable()) {
    if (v->var->isEnumerator()) {
      int value = v->var->getEnumeratorValue();
      if (value == 0 || value == 1) return true;
    }
    if (isPRBoolOrCompatible(v->var)) return true;
  } else if (E_fieldAcc *f = e->ifE_fieldAcc()) {
    if(isPRBoolOrCompatible(f->field)) return true;
    
    //    if (f->loc == SL_UNKNOWN) f->loc = f-fieldName->loc;
    //throw "Class member is not a PRBool";
  } else if (E_intLit *i = e->ifE_intLit()) {
    if (i->i == 1 || i->i == 0) return true;

  } else if (e->isE_boolLit()) {
    return true;

  } else if (E_keywordCast *c = e->ifE_keywordCast()) {
    // enabling 2nd parameter to be used could allow stuff like *static_cast<PRBool*>(foo)
    // to pass, but I'm not convinced that would be a good idea.
    if (isPRBoolTypeId(c->ctype/*, c->type->asRval()->isPointerType()*/)) return true;
    return isBool(c->expr);

  } else if (E_cast *c = e->ifE_cast()) {
    if (isPRBoolTypeId(c->ctype/*, c->type->asRval()->isPointerType()*/)) return true;
    return isBool(c->expr);

  } else if (E_binary *b = e->ifE_binary()) {
    switch(b->op) {
    case BIN_EQUAL:
    case BIN_NOTEQUAL:
    case BIN_LESS:
    case BIN_GREATER:
    case BIN_LESSEQ:
    case BIN_GREATEREQ:
    case BIN_AND:
    case BIN_OR:
      return true;
      // these are valid prbools, just ugly
    case BIN_BITAND:
      {
	//try to coerce either one into a prbool
	int c = 0;
	try {
	  isBool(b->e1);
	  c++;
	} catch(void const *) {
	}
	try {
	  isBool(b->e2);
	  c++;
	} catch(void const *) {
	}
	if (c) return true;
	break;
      }
    case BIN_BITXOR:
    case BIN_BITOR:
      if (isBool(b->e1) && isBool(b->e2))
	return true;
      // detect JSVAL_TO_BOOLEAN macro
    case BIN_RSHIFT:
      {
	CPPSourceLoc csl(b->loc);
	static const StringRef str_JSVAL_TO_BOOLEAN = globalStrTable("JSVAL_TO_BOOLEAN");
	if (csl.macroExpansion
	    && csl.macroExpansion->name == str_JSVAL_TO_BOOLEAN) return true;
      }
      break;
    default:
      break;
    }

  } else if (E_funCall *f = e->ifE_funCall()) {
    if (isBoolFunc(f->func)) return true;

  } else if (E_cond *c = e->ifE_cond()) {
    if (isBool(c->th) && isBool(c->el)) return true;

  } else if (E_deref *d = e->ifE_deref()) {
    return isBool(d->ptr);
    
  } else if (E_assign *a = e->ifE_assign()) {
    // not too sure about this
    // the more correct thing would be to check the target
    // and make sure that the target is also a valid bool elsewhere
    return isBool(a->src);

  }
  throw e;
}

void PrVisitor::makeBool(PairLoc loc, bool bracket, bool outerBracket) {
  stringstream ss;

  if (outerBracket) ss << '(';
  ss << "!!";
  if(bracket) ss << '(';
 
  ss << patcher.getRange(loc);
  
  if (bracket) ss << ')';
  if (outerBracket) ss << ')';
  
  patcher.printPatch(ss.str(), loc);
}

void PrVisitor::checkPRBool(Expression *e, bool outerBracket) {
  if (!verifyPhase) return;
  try {
    isBool(e);  
    return;
  } catch (char const *str) {
    if (debug) {
      cerr << toString(e->loc) << ": " << str;
      cerr << endl;
    }
  } catch(Expression *e) {
    if (debug) {
      cerr << toString(e->loc) << ": " << e->kindName() << ": ";
      if (E_binary *b = e->ifE_binary()) {
	cerr << toString(b->op)  << ": ";
      }
      print(cerr, e);
      cerr << endl;
    }
  }
  if (e->loc == SL_UNKNOWN || e->endloc == SL_UNKNOWN) return;

  CPPSourceLoc csl(e->loc);
  CPPSourceLoc endcsl(e->endloc);  

  if (!csl.hasExactPosition() || !endcsl.hasExactPosition()) {
      char const * file(NULL); 
      int line(0); 
      int col(0);  
      sourceLocManager->decodeLineCol(csl.macroExpansion->preStartLoc, file, line, col);

      cerr << "Expression is within a "
	   << (csl.macroExpansion ? csl.macroExpansion : endcsl.macroExpansion )->name
	   << " macro at " << ::resolveAbsolutePath("", file) << ":" << line << ":" << col << endl;
      return;
  }
  bool no_need_tobracket = e->isE_variable() || e->isE_keywordCast()
    || e->isE_funCall() || e->isE_intLit();
  makeBool(PairLoc(csl, endcsl), !no_need_tobracket, outerBracket);
}

void PrVisitor::add(Declarator *d, BoolType reason) {
  xassert(reason != NOT_A_BOOL);
  if (!verifyPhase) {
    //cerr << ( d->decl->isD_func() ? "found func:" : "Found var:") << d->var->name << endl;
    
    if (d->decl->isD_func()) {
      prboolFunctionSet.insert(d->var);
      return;
    }
    
    prboolSet.insert(d->var);
  }
  // check variable initializers
  Initializer *init = d->init;
  if (!init || d->var->type->isPtrOrRef()) return;
  // for other IN_* AST nodes could check array initialization, etc
  IN_expr *e  = init->ifIN_expr();
  if (e && e->e) checkPRBool(e->e);
}

bool PrVisitor::visitS_return(S_return *s) {
  if (returnPrbool) {
    checkPRBool(s->expr->expr);
  }
  return true;
}

bool PrVisitor::visitE_assign(E_assign *e) {
  if (PRBoolDetector::checkTarget(e->target)) checkPRBool(e->src);
  return true;
}

bool PrVisitor::visitE_cast(E_cast *c) {
  if (isPRBoolTypeId(c->ctype)) checkPRBool(c->expr, true);
  return true;
}

bool PrVisitor::visitE_keywordcast(E_cast *c) {
  if (isPRBoolTypeId(c->ctype)) checkPRBool(c->expr);
  return true;
}

void PrVisitor::checkFunctionArguments(Declarator *d) {
  if (verifyPhase) return;

  // don't process the same func twice
  arg_map::iterator it = prboolParams.find(d->var);
  if (it != prboolParams.end()) return;
  D_func *f = d->decl->asD_func();
  int i = 0;
  for (FakeList <ASTTypeId >* params = f->params; !params->isEmpty();
       params = params->butFirst(), i++) {
    
    if (!isPRBoolTypeId(params->first())) continue;
        
    if (it == prboolParams.end()) {
      list<int> ls;
      it = prboolParams.insert(pair<Variable*, list<int> >(d->var, ls)).first;
    }

    it->second.push_back(i);
    //cerr << "added prbool param " << i << " for " << d->var->name << endl;
  }
}

// capture functions with PRBool arguments
bool PrVisitor::visitDeclarator(Declarator *d) {
  if (d->decl->isD_func()) checkFunctionArguments(d);
  return true;
}

bool PrVisitor::visitE_funCall(E_funCall *f) {
  if (!verifyPhase) return true;

  Variable *v = VariableExtractor::getVariable(f->func);
  arg_map::iterator it_fun = prboolParams.find(v);
  if (it_fun == prboolParams.end()) return true;
  list<int>::iterator it = it_fun->second.begin();

  Type *type = v->type;
  while(type->isPointerType()) {
    type = type->asPointerType()->atType;
  }

  FunctionType *ft = type->asFunctionType();

  SObjListIter<Variable> pit(ft->params);
  // skip "this"
  if (ft->isMethod()) pit.adv();
    
  int i = 0;
  // no need t check pit.isDone because it is the same size as args
  for (FakeList <ArgExpression >* args = f->args; 
       !args->isEmpty() && it != it_fun->second.end();
       args = args->butFirst(), i++, pit.adv()) {
    if (*it != i) continue;
    
    checkPRBool(args->first()->expr);
    it++;
  }
  return true;
}

// capture indirectation created by typedefs
// ie typedefs that contain prbools
void PrVisitor::checkForPrBoolTypedefs(Declaration *d) {
  if (verifyPhase) return;

  if (!(d->dflags & DF_TYPEDEF)) return;
  
  for (FakeList <Declarator >* lsd = d->decllist;
       !lsd->isEmpty(); lsd = lsd->butFirst()) {
    Declarator *decl = lsd->first();
    // the following would be 1 line of ML
    D_func *f = decl->decl->ifD_func();

    // handle typedef PRBool alias; cases
    if (!f) {
      if (D_name *n = decl->decl->ifD_name()) {
	if (PRBoolDetector::isPRBool(d->spec)
	    && n->name->isPQ_name()) { 
	  prboolAliases.insert(n->name->asPQ_name()->name);
	}
      }
      continue;
    }

    D_grouping *g = f->base->ifD_grouping();
    if (!g) continue;

    D_pointer *p = g->base->ifD_pointer();
    if (!p) continue;

    D_name *n = p->base->ifD_name();
    if (!n) continue;

    PQ_name *name = n->name->ifPQ_name();
    if (!name) continue;

    // does a func of this type return a prbool?
    if (PRBoolDetector::isPRBool(d->spec)) {
      set<StringRef>::iterator it = 
	prboolTypedefs.insert(name->name).first;
      if (prboolTypedefsStack.size()) {
	prboolTypedefsStack.back().push_back(it);
      }
    }

    // this is an adaptation of checkFunctionArguments
    int i = 0;
    typedef_arg_map::iterator it = prboolTypedefParams.end();
    for (FakeList <ASTTypeId >* params = f->params; !params->isEmpty();
	 params = params->butFirst(), i++) {
      
      if (!isPRBoolTypeId(params->first())) continue;
      
      if (it == prboolTypedefParams.end()) {
	list<int> ls;
	it = prboolTypedefParams.insert(pair<StringRef, list<int> >(name->name, ls)).first;
        if (prboolTypedefParamsStack.size()) {
	  prboolTypedefParamsStack.back().push_back(it);
	}
      }
      
      it->second.push_back(i);
    }
  }
}

bool PrVisitor::visitTypeSpecifier(TypeSpecifier *t) {
  TS_classSpec *c = t->ifTS_classSpec();
  if (c && (c->keyword == TI_CLASS || c->keyword == TI_STRUCT)) {
    prboolTypedefsStack.resize(prboolTypedefsStack.size() + 1);
    prboolTypedefParamsStack.resize(prboolTypedefParamsStack.size() + 1);
  }
  return true;
}

void PrVisitor::postvisitTypeSpecifier(TypeSpecifier *t) {
  TS_classSpec *c = t->ifTS_classSpec();
  if (!c || !(c->keyword == TI_CLASS || c->keyword == TI_STRUCT)) return;

  PQ_name *pq = c->name->ifPQ_name();
  string prefix = pq ? pq->name : "";

  for (list<set<StringRef>::iterator>::iterator it = 
	 prboolTypedefsStack.back().begin();
       it != prboolTypedefsStack.back().end();
       ++it) {
    string name = prefix + "::" + *(*(it));
    prboolTypedefs.erase(*it);
    set<StringRef>::iterator newIt
      = prboolTypedefs.insert(globalStrTable(name.c_str())).first;
    if (prboolTypedefsStack.size() > 1) {
      prboolTypedefsStack[prboolTypedefsStack.size() - 2].push_back(newIt);
    }
  }
  prboolTypedefsStack.pop_back();

  for (list<typedef_arg_map::iterator>::iterator it = 
	 prboolTypedefParamsStack.back().begin();
       it != prboolTypedefParamsStack.back().end();
       ++it) {
    string name = prefix + "::" + (*(it))->first;
    typedef_arg_map::iterator newIt = 
      prboolTypedefParams.insert(pair<StringRef, list<int> >(globalStrTable(name.c_str()),
									 (*it)->second)).first;
    prboolTypedefParams.erase(*it);
    if (prboolTypedefParamsStack.size() > 1) {
      prboolTypedefParamsStack[prboolTypedefParamsStack.size() - 2].push_back(newIt);
    }
  }

  prboolTypedefParamsStack.pop_back();
}
