import numpy as np

from sklearn.feature_selection import chi2
from sklearn.feature_selection import f_classif
from sklearn.feature_selection import SelectKBest
from sklearn.feature_selection import RFE
from sklearn.linear_model import RandomizedLogisticRegression
from sklearn.ensemble import ExtraTreesClassifier
from sklearn.svm import SVC
from sklearn.metrics import silhouette_score
from externals.irelief import Irelief


class FeatureScorer:

    def __init__(self, Xt, images1, images2, feats):
        if len(images1) < 1 or len(images2) < 1:
            raise Exception('Neither set of images may be empty.')
        if len(feats) < 1:
            raise Exception('At least one feature must be selected.')

        feats = sorted(feats)

        images1 = sorted(images1)
        images2 = sorted(images2)

        self.Xt = Xt

        self.X1 = self.Xt[:, feats]
        self.X1 = self.X1[images1, :]

        self.X2 = self.Xt[:, feats]
        self.X2 = self.X2[images2, :]

        Xs = [self.X1, self.X2]

        self.X = np.concatenate(Xs)
        self.y = np.concatenate([[i] * Xs[i].shape[0] for i in range(len(Xs))])


class ChiSquaredScorer(FeatureScorer):

    def __init__(self, Xt, images1, images2, feats):
        FeatureScorer.__init__(self, Xt, images1, images2, feats)

    def score(self):
        selector = SelectKBest(chi2, k='all')
        selector.fit(self.X, self.y)
        scores = np.nan_to_num(selector.scores_)

        scores = scores / sum(scores)

        if np.isfinite(scores).sum() < len(scores):
            raise Exception('Impossible to compute feature importances')

        return scores


class OneWayAnovaScorer(FeatureScorer):

    def __init__(self, Xt, images1, images2, feats):
        FeatureScorer.__init__(self, Xt, images1, images2, feats)

    def score(self):
        selector = SelectKBest(f_classif, k='all')
        selector.fit(self.X, self.y)
        scores = np.nan_to_num(selector.scores_)

        scores = scores / sum(scores)

        if np.isfinite(scores).sum() < len(scores):
            raise Exception('Impossible to compute feature importances')

        return scores


class ReliefScorer(FeatureScorer):

    def __init__(self, Xt, images1, images2, feats):
        FeatureScorer.__init__(self, Xt, images1, images2, feats)

    def score(self):
        selector = Irelief(T=100)
        scores = selector.weights(self.X, self.y)

        scores = scores / sum(scores)

        if np.isfinite(scores).sum() < len(scores):
            raise Exception('Impossible to compute feature importances')

        return scores


class RecursiveFeatureEliminationScorer(FeatureScorer):

    def __init__(self, Xt, images1, images2, feats):
        FeatureScorer.__init__(self, Xt, images1, images2, feats)

    def score(self):
        svc = SVC(kernel='linear',random_state=0)
        rfe = RFE(svc, 1, step=1)
        rfe.fit(self.X, self.y)

        return np.array(self.X.shape[1] - rfe.ranking_, dtype=np.float64)


class SilhouetteCoefficientScorer(FeatureScorer):

    def __init__(self, Xt, images1, images2, feats):
        FeatureScorer.__init__(self, Xt, images1, images2, feats)

    def score(self):
        c = np.zeros(self.X1.shape[1])

        X = np.vstack([self.X1, self.X2])
        y = np.hstack([np.zeros(self.X1.shape[0]),
                       np.ones(self.X2.shape[0])])

        for i in range(len(c)):
            c[i] = silhouette_score(X[:, i].reshape(-1, 1), y, random_state=0)

        c = (c - c.min()) / (c.max() - c.min())

        return c


class VarianceBasedCoherenceScorer(FeatureScorer):

    def __init__(self, Xt, images1, images2, feats):
        FeatureScorer.__init__(self, Xt, images1, images2, feats)

    def score(self):
        lvar = self.X1.var(axis=0)
        gvar = np.vstack([self.X1, self.X2]).var(axis=0)

        c = np.where(1 - np.isclose(gvar, 0), lvar / gvar, float('-inf'))
        c = np.where(c == float('-inf'), c.max(), c)

        c = 1 - (c - c.min()) / (c.max() - c.min())

        return c


class RandomizedDecisionTreesScorer(FeatureScorer):

    def __init__(self, Xt, images1, images2, feats):
        FeatureScorer.__init__(self, Xt, images1, images2, feats)

    def score(self):
        forest = ExtraTreesClassifier(n_estimators=1000, random_state=0)
        forest.fit(self.X, self.y)
        return forest.feature_importances_


class RandomizedLogisticRegressionScorer(FeatureScorer):

    def __init__(self, Xt, images1, images2, feats):
        FeatureScorer.__init__(self, Xt, images1, images2, feats)

    def score(self):
        randomized_logistic = RandomizedLogisticRegression(random_state=0, n_resampling=1000)
        randomized_logistic.fit(self.X, self.y)
        return randomized_logistic.scores_

available_methods = {
    'chi_squared':  ChiSquaredScorer,
    'one_way_anova': OneWayAnovaScorer,
    'relief': ReliefScorer,
    'rfe': RecursiveFeatureEliminationScorer,
    'silhouette': SilhouetteCoefficientScorer,
    'variance': VarianceBasedCoherenceScorer,
    'randomized_trees': RandomizedDecisionTreesScorer,
    'randomized_logistic': RandomizedLogisticRegressionScorer
}
