// see License.txt for copyright and terms of use

#include "thrower.h"        // this module
#include "thrower_cmd.h"    // ThrowerCmd
#include "thrower_global.h"
#include "oink.gr.gen.h"        // CCParse_Oink
#include "strutil.h"            // quoted
#include "oink_util.h"
#include "expr_visitor.h"
#include "squash_util.h"
#include "patcher.h"
#include "thrower_cmd.h"
#include <sstream>
#include <set>
#include <list>
#include "thrower_analyzer.h"

using namespace std;
/* Can check function parameters, [global|class] variables, 
 * function return values */
class Thrower : public ExpressionVisitor {
public:
  Thrower(ThrowerCmd const &cmd, Patcher &patcher): 
    cmd(cmd),
    patcher(patcher),
    str_nsresult(globalStrTable("nsresult"))
  {
  }

  virtual bool visitFunction(Function *f);
  virtual bool visitTopForm(TopForm *t);
  virtual bool visitMember(Member *m);
  virtual bool visitTypeSpecifier(TypeSpecifier *ts);
private:
  bool processFunctionDecl(Declarator *d, TypeSpecifier *retspec);
  PairLoc eliminateConditional(S_if *s, E_funCall *eCond);
  void eliminateFromE_binary(E_binary *e, Expression* eCond);

  bool is_nsresult(TypeSpecifier *ts) {
    TS_name *tsname = ts->ifTS_name();
    if (!tsname) return false;
    
    PQ_name *pqname = tsname->name->ifPQ_name();
    return (pqname && pqname->name == str_nsresult);
  }

  ThrowerCmd const &cmd;
  Patcher &patcher;
  // function-local nsresult variables
  const StringRef str_nsresult;
};

void ThrowerProcessor::process(ThrowerCmd const &cmd) {
  Patcher patcher(true);
  Thrower prGatherer(cmd, patcher);
  foreachSourceFile {
    File *file = files.data();
    patcher.setFile(file->name);
    
    maybeSetInputLangFromSuffix(file);
    TranslationUnit *unit = file2unit.get(file);
    unit->traverse(prGatherer.loweredVisitor);
  }
}

bool Thrower::processFunctionDecl(Declarator *decl, TypeSpecifier *retspec) {
  Variable *v = decl->var;
  // const check is there to avoid rewriting error codes
  if (v->type->isConst() 
      || !is_nsresult(retspec)) {
    return false;
  }

  PairLoc pairLoc(retspec->asTS_name()->name->loc, decl->decl->loc);
  string newType = "void ";
  MacroUndoEntry *m = pairLoc.getMacro();

  if (!pairLoc.hasExactPosition()) {
    xassert(m);
    static const StringRef str_NS_IMETHOD = globalStrTable("NS_IMETHOD"); 
    static const StringRef str_NS_IMETHODIMP = globalStrTable("NS_IMETHODIMP"); 
    static const StringRef str_NS_CALLBACK = globalStrTable("NS_CALLBACK");
    static const StringRef str_NS_METHOD = globalStrTable("NS_METHOD"); 
    static const StringRef str_XPCOM_API = globalStrTable("XPCOM_API"); 
    if (m->name == str_NS_IMETHOD) {
      newType = "NS_IMETHOD_(void)";
    } else if (m->name == str_NS_IMETHODIMP) {
      newType = "NS_IMETHODIMP_(void)";
    } else if (m->name == str_NS_CALLBACK) {
      MacroDefinition *md = m->params.first();
      PairLoc pairParam(md->fromLoc, md->toLoc);
      xassert(pairParam.hasExactPosition());
      newType = "NS_CALLBACK_(void, "
        + patcher.getRange(pairParam) + ")";
    } else if (m->name == str_NS_METHOD) {
      newType = "NS_METHOD_(void)";
    }  else if (m->name == str_XPCOM_API) {
      newType = "XPCOM_API(void)";
    } else {
      cerr << toString(m->preStartLoc) << ": Macro " << m->name;
      if (! (!strncmp(m->name, "NS_DECL_NSI", 11))) {
        cerr << " prevented rewriting of " << v->name << endl;
        return false;
      }
      cerr  << " needs to be modified by a hacked xpidlgen" << endl;
      return true;
    }
    pairLoc.first.overrideLoc(m->preStartLoc);
    pairLoc.second.overrideLoc(m->preEndLoc);
  } else if (m && m->isParam()) {
    // otherwise trailing part of macro is overridden
    pairLoc.second.overrideLoc(m->preEndLoc);
    newType.erase(newType.size() - 1);
  }
  patcher.printPatch(newType, pairLoc);
  
  return true;
}

// doesn't work particularly well
// need to do this at patcher level to get it closer to decent
static string indent(int count, string str) {
  string::size_type i = str.size();
  string strIndent(count, ' ');
  do {
    if (i == 0 || str[i-1] == '\n') str.insert(i, strIndent);
  } while (i--);
  return str;
}

// TODO: add endloc to typespecifiers
bool Thrower::visitFunction(Function *f) {
  bool in_nsresult = processFunctionDecl(f->nameAndParams, f->retspec);
  RvAnalyzer rva(in_nsresult);
  rva.analyze(f->body->asS_compound());
  // deal with return NS_OK
  for (list<RvAnalyzer::NS_OKinfo>::iterator it 
         = rva.returnNS_OKs.begin();
       it != rva.returnNS_OKs.end();
       ++it) {
    Expression *e = it->first;
    // when this is not null may nuke the return NS_OK
    Statement *snuke = it->second;
    SourceLoc loc;
    SourceLoc endloc;
    if (snuke) {
      loc = snuke->loc;
      endloc = snuke->endloc;
    } else {
      loc = e->loc;
      endloc = e->endloc;
    }
    PairLoc pairLoc(loc, endloc);
    if (!pairLoc.hasExactPosition()) continue;
    patcher.printPatch("", pairLoc);
  }
  // deal with error returns
  for (list<RvAnalyzer::ReturnWrapped>::iterator it = rva.returns.begin();
       it != rva.returns.end();
       ++it) {
    S_return *s = it->first;
    Expression *e = getCtorArg(s->ctorStatement);
    PairLoc eLoc(e->loc, e->endloc);
    PairLoc pairLoc(s->loc, e->endloc);
    if (! (eLoc.hasExactPosition() && pairLoc.hasExactPosition())) continue;
    string strWrap = it->second ? "wrap_nsexception("
      : "throw nsexception(";
    patcher.printPatch(strWrap + patcher.getRange(eLoc) + ")" , pairLoc);
  }
  // deal with functions where ret value is ignored
  for (list<E_funCall*>::iterator gt = rva.ignoredErrors.begin();
       gt != rva.ignoredErrors.end();
       ++gt) {
    E_funCall *f = *gt;
    PairLoc pairLoc(f->loc, f->endloc);
    if (!pairLoc.hasExactPosition()) continue;
    patcher.printPatch("IGNORE_NSEXCEPTION(" 
                       + patcher.getRange(pairLoc) + ")", pairLoc);
  }
  // optimize away stuff provided by exceptions
  for (list<S_if*>::iterator st = rva.automaticErrors.begin();
       st != rva.automaticErrors.end();
       ++st) {
    S_if *f = *st;
    PairLoc pairLoc(f->loc, f->endloc);
    if (!pairLoc.hasExactPosition()) continue;
    patcher.printPatch("", pairLoc);
  }
  // deal with all nsresult vars and their usage
  for (RvAnalyzer::rvmap::iterator it = rva.nsresultVars.begin();
      it != rva.nsresultVars.end();
      ++it) {
    for (list<Expression*>::iterator xt =
           it->second.valueUsed.begin();
         xt != it->second.valueUsed.end();
         ++xt) {
      Expression *e = *xt;
      PairLoc pairLoc(e->loc, e->endloc);
      if (!pairLoc.hasExactPosition()) continue;
      patcher.printPatch("exc.getCode()", pairLoc);
    }
    for (list<RvAnalyzer::FcallBlock>::iterator yt = 
          it->second.nsassignFuncs.begin();
        yt != it->second.nsassignFuncs.end();
        ++yt) {
      PairLoc pairLoc(yt->wrapper->loc, yt->wrapper->endloc);
      PairLoc pairFcall(yt->fcall->loc, yt->fcall->endloc);
      if (! (pairLoc.hasExactPosition() && pairFcall.hasExactPosition())) {
        continue;
      }
      if (yt->isTransitive()) {
        // no try/catch is needed, eliminate rv
        patcher.printPatch(patcher.getRange(pairFcall) + ";", pairLoc);
        //delete the rets
        continue;
      }
      
      // not transitive
      stringstream ss;
      UnboxedPairLoc unboxedPairLoc(pairLoc);
      int indentSize = 2 + (unboxedPairLoc.first.col - 1);
      ss << "try {\n";
      ss << indent(indentSize, patcher.getRange(pairFcall)) << ";\n";
      if (yt->nssucceeded.size()) {
        //remove(NS_SUCCEEDED from if)
        xassert(yt->nssucceeded.size() == 1);
        S_if *s = yt->nssucceeded.front().second;
        xassert(s->isS_if());
        PairLoc pairSucc = 
          eliminateConditional(s,
                               yt->nssucceeded.front().first);
        pairLoc.second = s->endloc;
        ss << indent(indentSize, patcher.getRange(pairSucc)) << '\n';
        xassert(yt->nsfailed.size());
      }
      ss << indent(indentSize - 2, "} catch (nsexception &exc) {") << '\n';
      if (yt->nsfailed.size()) {
        xassert(yt->nsfailed.size());
        S_if *s = yt->nsfailed.front().second;
        PairLoc pairFailed = 
          eliminateConditional(s,
                               yt->nsfailed.front().first);
        ss << indent(indentSize, patcher.getRange(pairFailed));
        pairLoc.second = s->endloc;
      } 
      ss << '\n' << indent(indentSize - 2, "}");
      patcher.printPatch(ss.str(), pairLoc, true);
    }
  }
  return false;
}

inline Expression *skipE_unary(Expression* e) {
  while(e->isE_unary()) e = e->asE_unary()->expr;
  return e;
}

void Thrower::eliminateFromE_binary(E_binary *e, Expression* eCond) {
  Expression *e1 = skipE_unary(e->e1);
  Expression *e2 = skipE_unary(e->e2);
  if (e1 == eCond) {
    patcher.printPatch("", PairLoc(eCond->loc, e->e2->loc));
    return;
  } else if (e2 == eCond) {
    patcher.printPatch("", PairLoc(e->e1->endloc, eCond->endloc));
    return;
  }
  if (e1->isE_binary()) eliminateFromE_binary(e1->asE_binary(), eCond);
  if (e2->isE_binary()) eliminateFromE_binary(e2->asE_binary(), eCond);
  else {
    // ensure that either e1 or e2 are an E_binary so the recursion continues
    xassert(e1->isE_binary());
  }
}

// returns part of code to reuse
PairLoc Thrower::eliminateConditional(S_if *s, E_funCall *eCond) {
  try {
    Expression *e = s->cond->asCN_expr()->expr->expr;
    if (e->isE_funCall()) {
      xassert(eCond == e->asE_funCall());
      S_compound *sc = s->thenBranch->asS_compound();
      xassert(!sc->stmts.isEmpty());
      return PairLoc(sc->stmts.first()->loc, sc->stmts.last()->endloc);
    }
    xassert(e->isE_binary());
    eliminateFromE_binary(e->asE_binary(), eCond);
  } catch (x_match&) {
  }
  return PairLoc(s->loc, s->endloc);
}

bool Thrower::visitTopForm(TopForm *t) {
  try {
    TF_decl *tf = t->asTF_decl();
    return !processFunctionDecl(SOME(tf->decl->decllist->first()), tf->decl->spec);
  } catch (x_match &) {
  }
  try {
    // guard so classes declared as templates are skipped
    t->asTF_template()->td->asTD_decl()->d->spec->asTS_classSpec();
    return false;
  } catch (x_match &) {
  }
  try {
    // guard so functions declared as templates are skipped
    t->asTF_template()->td->asTD_func();
    return false;
  } catch (x_match &) {
  }
  return true;
}

bool Thrower::visitMember(Member *t) {
  try {
    MR_decl *mr = t->asMR_decl();
    return !processFunctionDecl(SOME(mr->d->decllist->first()), mr->d->spec);
  } catch (x_match &) {
  }
  return true;
}

bool Thrower::visitTypeSpecifier(TypeSpecifier *ts) {
  try {
    // don't get confused by the nsresult wrapper
    return ts->asTS_classSpec()->name->asPQ_name()->name != str_nsresult;
  } catch(x_match &) {
  }
  return true;
}
