import os
import numpy as np
import scipy as sp
import scipy.stats

from sklearn import preprocessing
from sklearn.neighbors import NearestNeighbors
from sklearn.cross_validation import LeaveOneOut
from sklearn.cross_validation import StratifiedKFold

import feature_scoring

from projection import ObservationProjection
from projection import DistanceMatrixProjection

from parsing import point_matrix
from parsing import feature_tree
from sklearn.neighbors import KNeighborsClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble.forest import RandomForestClassifier
from sklearn.svm import SVC
from sklearn.metrics import normalized_mutual_info_score

from externals.libift.opf import OPFClassifier
from externals.libift.active_learning import OPFActiveLearning


class FeaturedModel(object):

    def __init__(self, path, settings):
        self.FEATURE_FILE = 'features.data'
        self.FEATURE_HIERARCHY_FILE = 'features.tree'
        self.hierarchy_root = 'all'

        self.settings = settings

        self.load(path)

    def load(self, path):
        if os.path.isdir(path):
            self.directory = path
            self.filename = self.FEATURE_FILE
            hierarchy_filename = self.FEATURE_HIERARCHY_FILE
        else:
            self.directory, self.filename = os.path.split(path)
            hierarchy_filename = os.path.splitext(self.filename)[0] + '.tree'

        if not os.path.exists(os.path.join(self.directory, self.filename)):
            raise Exception('Features not found.')

        feature_file = open(os.path.join(self.directory, self.filename))

        self.X, self.y, self.image_names, self.feature_names =  \
            point_matrix.load(feature_file)

        feature_file.close()

        if not os.path.isfile(os.path.join(self.directory,
                                           hierarchy_filename)):
            self.feat_hierarchy = {self.hierarchy_root:
                                   set(self.feature_names)}
        else:
            feature_hierarchy_file = open(
                os.path.join(self.directory, hierarchy_filename))
            self.feat_hierarchy = feature_tree.load(feature_hierarchy_file)
            feature_hierarchy_file.close()

        self.selected_images = set()
        self._selected_features = set(range(self.X.shape[1]))

        self.update_Xt()
        self.update_Xp()

        self.update_fXp()

        self.group_map = {}
        for i, c in enumerate(self.y):
            c = str(int(c))
            if c in self.group_map:
                self.group_map[c].add(i)
            else:
                self.group_map[c] = {i}

        self.nfeats = self.X.shape[1]
        self.nimages = self.X.shape[0]

    def update_Xt(self):
        normalization = self.settings.parameter('normalization')
        if normalization == 'Standardization':
            self.Xt = preprocessing.StandardScaler().fit_transform(self.X)
        elif normalization == 'Scaling':
            self.Xt = preprocessing.MinMaxScaler().fit_transform(self.X)
        elif normalization == 'None':
            self.Xt = np.array(self.X)
        else:
            raise Exception('Unknown normalization setting')

        self.active_learning_reset()

    def update_Xp(self):
        distance = self.settings.parameter('distance')
        technique = self.settings.parameter('technique')
        technique_opts = self.settings.parameter('technique_options')

        feats = sorted(self.selected_features)

        proj = ObservationProjection(technique, technique_opts, distance)

        self.Xp = proj.project(self.Xt[:, feats], self.y,
            self.image_names, [self.feature_names[i] for i in feats], distance)

        self.neighbors_Xp = NearestNeighbors(
            n_neighbors=self.settings.parameter('neighbors') + 1).fit(self.Xp)

    def update_fXp(self):
        technique = self.settings.parameter('feature_projection_technique')
        technique_opts = self.settings.parameter(
            'feature_projection_technique_opts')
        
        if technique != 'None':
            dist = self.compute_feature_dissimilarity(self.Xt)
    
            proj = DistanceMatrixProjection(technique, technique_opts)
            self.fXp = proj.project(dist)
        else:
            self.fXp = np.zeros((self.Xt.shape[1], 2))

        self.neighbors_fXp = NearestNeighbors(
            n_neighbors=self.settings.parameter('features_n_neighbors') + 1).fit(self.fXp)

    def compute_feature_dissimilarity(self, Xt):
        distance = self.settings.parameter('feature_dissimilarity')

        if distance == 'Pearson\'s r':
            corrcoef = np.corrcoef(Xt.T)
            corrcoef = np.nan_to_num(corrcoef)
            corrcoef = np.abs(corrcoef)
            dist = 1 - corrcoef
        elif distance == 'Spearman\'s r':
            corrcoef = sp.stats.spearmanr(Xt)[0]
            corrcoef = np.nan_to_num(corrcoef)
            corrcoef = np.abs(corrcoef)
            dist = 1 - corrcoef
        elif distance == 'Transposed data matrix distance':
            dist = sp.spatial.distance.pdist(Xt.T)
            dist = sp.spatial.distance.squareform(dist)
        elif distance == 'Distance correlation':
            import pyximport
            pyximport.install()
            import externals.dcor as dcor
            dist = sp.spatial.distance.pdist(Xt.T, metric=dcor.dcov_all)
            dist = sp.spatial.distance.squareform(dist)
            dist = 1 - dist
            for i in range(dist.shape[0]):
                dist[i, i] = 0
        elif distance == 'Mutual information':
            dist = 1 - normalized_mutual_information_matrix(Xt, n_bins=8)

        return dist

    def update_settings(self):
        self.update_Xt()
        self.update_Xp()
        self.update_fXp()

    def set_selected_groups(self, selected_groups):
        self.selected_images = set()
        for g in selected_groups:
            self.selected_images.update(self.group_map[g])

    @property
    def selected_features(self):
        return self._selected_features

    @selected_features.setter
    def selected_features(self, selected_features):
        selected_features = set(selected_features)

        if selected_features != self._selected_features:
            self._selected_features = selected_features
            self.update_Xp()
            self.active_learning_reset()

    @property
    def n_features(self):
        return self.Xt.shape[1]

    @property
    def n_observations(self):
        return self.Xt.shape[0]

    def get_deselected_features(self):
        return {f for f in range(self.nfeats) if f not in self.selected_features}

    def get_image_names(self, images):
        return [self.image_names[i] for i in images]

    def get_feature_names(self, features):
        return [self.feature_names[i] for i in features]

    def get_selected_feature_names(self):
        return {self.feature_names[i] for i in sorted(self.selected_features)}

    def get_transformed_feature(self, i):
        return self.Xt[:, i]

    def create_group(self, name):
        self.group_map[name] = self.selected_images

    def remove_groups(self, groups):
        for g in groups:
            del self.group_map[g]

    def save_groups(self, filepath):
        image_to_group = {i: set() for i in range(self.X.shape[0])}
        for k, v in self.group_map.items():
            for img in v:
                image_to_group[img].add(k)

        groups = sorted(self.group_map.keys())

        y = np.zeros(self.X.shape[0])
        for k in image_to_group.keys():
            g = image_to_group[k]
            if len(g) != 1:
                raise Exception(
                    "The groups do not partition the set of images.")
            else:
                y[k] = groups.index(g.pop())

        f = open(filepath, 'w')
        point_matrix.save(f, self.X, y,
                          self.image_names, self.feature_names)
        f.close()

    def save_features(self, filepath):
        feats = sorted(self.selected_features)
        if len(feats) == 0:
            feats = range(self.X.shape[1])

        feat_names = np.array(self.feature_names)[feats]

        f = open(filepath, 'w')
        point_matrix.save(
            f, self.X[:, feats], self.y, self.image_names, feat_names)
        f.close()

    def compute_distance_histogram(self):
        if len(self.selected_images) > 0:
            images = sorted(self.selected_images)
        else:
            images = list(range(self.X.shape[0]))

        if len(self.selected_features) > 0:
            feats = sorted(self.selected_features)
        else:
            feats = list(range(self.X.shape[1]))

        X = self.Xt[:, feats]
        X = X[images, :]

        distance_opts = self.settings.parameter('distance')
        if distance_opts == 'Cosine':
            dist = sp.spatial.distance.pdist(X, 'cosine')
        else:
            dist = sp.spatial.distance.pdist(X, 'euclidean')

        y, x = np.histogram(dist)
        y = y / float(sum(y))

        return x, y

    def compute_similarity_matrix(self):
        if len(self.selected_images) > 0:
            images = sorted(self.selected_images)
        else:
            images = list(range(self.X.shape[0]))

        if len(self.selected_features) > 0:
            feats = sorted(self.selected_features)
        else:
            feats = list(range(self.X.shape[1]))

        X = self.Xt[:, feats]
        X = X[images, :]

        Xsorted = X[np.argsort(self.y[images], kind='mergesort')]

        dist_opts = self.settings.parameter('distance')
        if dist_opts == 'Cosine':
            dist = sp.spatial.distance.pdist(Xsorted, 'cosine')
        else:
            dist = sp.spatial.distance.pdist(Xsorted, 'euclidean')

        dist = sp.spatial.distance.squareform(dist)

        if np.allclose(0, dist.max() - dist.min()):
            return np.zeros((1, 1))

        d = 1 - (dist - dist.max()) / (dist.max() - dist.min())

        return d

    def compute_projection_error(self):
        feats = sorted(self.selected_features)
        if len(feats) == 0:
            feats = range(self.X.shape[1])

        dXt = sp.spatial.distance.pdist(self.Xt[:, feats])
        dXp = sp.spatial.distance.pdist(self.Xp)

        e = dXp / dXp.max() - dXt / dXt.max()

        return sp.spatial.distance.squareform(e)

    def compute_mean_aggregated_projection_error(self):
        return np.abs(self.compute_projection_error()).sum(axis=1) / self.Xt.shape[0]

    def get_nearest_neighbors(self, i):
        """k-nearest neighbors of destination, in the transformed space, excluding point i 

        Note: these neighbors are computed using the Euclidean distance, 
        regardless of the default distance function
        """
        feats = sorted(self.selected_features)
        if len(feats) == 0:
            feats = range(self.X.shape[1])

        neighbors_Xt = NearestNeighbors(
            n_neighbors=self.settings.parameter('neighbors') + 1).fit(self.Xt[:, feats])

        ns = list(neighbors_Xt.kneighbors(np.array(self.Xt[i, feats]))[1][0])

        if i in ns:
            ns.remove(i)

        while len(ns) > self.settings.parameter('neighbors'):
            ns.pop()

        return ns

    def _nearest(self, neighborhood, destination, i, n_neighbors):
        ns = list(neighborhood.kneighbors(np.array(destination))[1][0])

        if i in ns:
            ns.remove(i)

        while len(ns) > n_neighbors:
            ns.pop()

        return ns

    def nearest_points(self, destination, i=None):
        """k-nearest neighbors of `destination`, in the projection space, excluding point `i`"""
        return self._nearest(self.neighbors_Xp, destination, i, self.settings.parameter('neighbors'))

    def nearest_feature_points(self, destination, i=None):
        """k-nearest neighbors of `destination`, in the feature projection space, excluding point `i`"""
        return self._nearest(self.neighbors_fXp, destination, i, self.settings.parameter('features_n_neighbors'))

    def _nearest_r(self, neighborhood, destination, r):
        return neighborhood.radius_neighbors(destination, radius=r, return_distance=False)[0]

    def nearest_points_r(self, i, r):
        """neighbors within radius `r` of `i`, in the projection space, including point `i`"""
        return self._nearest_r(self.neighbors_Xp, self.Xp[i], r)

    def nearest_feature_points_r(self, i, r):
        """neighbors within radius `r` of `i`, in the feature projection space, including point `i`"""
        return self._nearest_r(self.neighbors_fXp, self.fXp[i], r)

    def compute_feature_relation_with_neighbors(self):
        if len(self.selected_images) == 0:
            images = range(self.n_observations)
        else:
            images = sorted(self.selected_images)

        dist = self.compute_feature_dissimilarity(self.Xt[images])

        relations = np.zeros(self.n_features)
        for i in range(self.n_features):
            ns = self.nearest_feature_points(self.fXp[i], i)
            relations[i] = sum([dist[i, j] for j in ns])

        return relations.max() - relations

    def move_point_to(self, i, destination):
        method = self.settings.parameter('feature_scoring')

        ns = self.nearest_points(destination, i)
        images1 = [i] + ns

        images2 = [i for i in range(self.Xt.shape[0])
                   if i not in images1]

        feats = sorted(self.selected_features)
        scorer = feature_scoring.available_methods[
            method](self.Xt, images1, images2, feats)
        c = scorer.score()

        order = np.argsort(-c)

        c = c[order]
        feats = np.array(feats)[order]

        preserve = self.settings.parameter('preserve')
        npreserved = int(len(feats) * preserve)
        self.selected_features = feats[0:npreserved + 1]

        self.selected_images = set(images1)

    def compute_projection_axis_correlation(self):
        if len(self.selected_images) > 0:
            images = sorted(self.selected_images)
        else:
            images = list(range(self.X.shape[0]))

        feats = sorted(self.selected_features)
        if len(feats) == 0:
            raise Exception('At least one feature must be selected.')

        cx = np.zeros(len(feats))
        cy = np.zeros(len(feats))

        Xt = self.Xt[images, :]
        Xp = self.Xp[images, :]

        corr = self.settings.parameter('correlation')
        if corr == 'Pearson\'s r':
            corr = scipy.stats.pearsonr
        else:
            corr = scipy.stats.spearmanr

        for i, f in enumerate(feats):
            cx[i] = corr(Xt[:, f], Xp[:, 0])[0]
            cy[i] = corr(Xt[:, f], Xp[:, 1])[0]

        cx = np.nan_to_num(cx)
        cy = np.nan_to_num(cy)

        return cx, cy

    def hide_selected_images(self):
        if len(self.selected_images) == self.nimages:
            raise Exception('At least one sample must be unselected')

        cimages = [i for i in range(self.nimages)
                   if i not in self.selected_images]

        self.X = self.X[cimages, :]
        self.y = self.y[cimages]
        self.image_names = np.array(self.image_names)[cimages]
        self.nimages = self.X.shape[0]

        self.selected_images = set()

        self.update_Xt()
        self.update_Xp()
        self.update_fXp()

        self.group_map = {}
        for i, c in enumerate(self.y):
            c = str(int(c))
            if c in self.group_map:
                self.group_map[c].add(i)
            else:
                self.group_map[c] = {i}

    def _classification_model(self):
        method = self.settings.parameter('classification')        
        method_opts = self.settings.parameter('classification_opts')
        if method == 'K-Nearest neighbors':
            model = KNeighborsClassifier(n_neighbors=int(method_opts['Number of neighbors']))
        elif method == 'Logistic Regression':
            model = LogisticRegression(random_state=0)
        elif method == 'Linear Support vector machine':
            model = SVC(kernel='linear',C=float(method_opts['C']), random_state=0)
        elif method == 'RBF Support vector machine':
            model = SVC(kernel='rbf', max_iter=1000, gamma=float(method_opts['Gamma']), 
                        C=float(method_opts['C']), random_state=0)
        elif method == 'Random Forest Classifier':
            max_depth = int(method_opts['Maximum depth'])
            if max_depth < 1:
                max_depth = None
            model = RandomForestClassifier(n_estimators=int(method_opts['Number of estimators']),
                                           max_depth=max_depth, random_state=0)
        elif method == 'Optimum Path Forest':
            model = OPFClassifier()
        else:
            raise Exception('Invalid classification method')
            
        return model

    def compute_kfold_classification(self, target=None):
        if len(self.selected_features) > 0:
            feats = sorted(self.selected_features)
        else:
            feats = list(range(self.X.shape[1]))
            
        X = self.Xt[:, feats]
        y = self.y if target is None else target
        
        ypred = np.zeros(len(y))
        
        model = self._classification_model()
        
        n_folds = int(self.settings.parameter('classification_nfolds'))
        if n_folds < 2:
            skf = LeaveOneOut(n=len(ypred))
        else:
            skf = StratifiedKFold(y, n_folds=n_folds, shuffle=True, random_state=0)
            
        for itrain, itest in skf:
            Xtrain = X[itrain]
            ytrain = y[itrain]
            
            model.fit(Xtrain, ytrain)
            ypred[itest] = model.predict(X[itest])

        return ypred
        
    def compute_neighborhood_hit(self, target=None):
        X = self.Xp
        if target is None:
            y = self.model.y
        else:
            y = target
        
        neighbors_X = NearestNeighbors(
            n_neighbors=self.settings.parameter('neighbors') + 1).fit(X)

        nhit = 0.0
        for i in range(len(y)):
            ns = list(neighbors_X.kneighbors(X[i])[1][0])
            if i in ns:
                ns.remove(i)

            nhit += (y[ns] == y[i]).sum()/float(len(ns))
        
        return nhit/len(y)

    def get_images_from_group(self, group_name):
        return self.group_map[group_name]

    def compute_leave_one_out_relation_support(self):
        if len(self.selected_features) > 0:
            feats = sorted(self.selected_features)
        else:
            feats = list(range(self.X.shape[1]))

        Xt = self.Xt[:, feats]
        dist_sum = self.compute_feature_dissimilarity(Xt).sum()

        relation_support = np.zeros(self.n_observations)
        loo = LeaveOneOut(n=self.n_observations)

        for itrain, itest in loo:
            relation_support[itest[0]] = dist_sum - \
                self.compute_feature_dissimilarity(Xt[itrain]).sum()

        return relation_support

    def active_learning_reset(self):
        self.opf_al = None

    def active_learning_suggestions(self):
        if len(self.selected_features) > 0:
            feats = sorted(self.selected_features)
        else:
            feats = list(range(self.X.shape[1]))

        if self.opf_al is None:
            spi = self.settings.parameter(
                'active_learning_suggestions_per_step')
            kmp = self.settings.parameter('active_learning_kmax_percentage')
            ci = self.settings.parameter(
                'active_learning_clustering_iterations')

            self.opf_al = OPFActiveLearning(suggestions_per_iteration=spi,
                                            k_max_percentage=kmp,
                                            clustering_iterations=ci)

            X = self.Xt[:, feats]
            self.opf_al.fit(X)

            return self.opf_al.suggestions_
        else:
            return self.opf_al.suggestions_

    def active_learning_y_pred(self):
        if self.opf_al is None:
            self.active_learning_suggestions()

        if len(self.selected_features) > 0:
            feats = sorted(self.selected_features)
        else:
            feats = list(range(self.X.shape[1]))

        return self.opf_al.predict(self.Xt[:, feats])

    def active_learning_assign(self, indices, classes):
        if self.opf_al is None:
            self.active_learning_suggestions()

        self.opf_al.assign(indices, classes)
        
    def load_prediction(self, learning_algorithm_name):
        prediction_path = os.path.join(self.directory, 'prediction', 
                                       os.path.splitext(self.filename)[0] + '_' 
                                       + learning_algorithm_name + '_prediction.data')
                                       
        if not os.path.isfile(prediction_path):
            raise Exception('Prediction not found.')
            
        f = open(prediction_path)
        X, y, image_names, feature_names = point_matrix.load(f)
        f.close()
        
        print('Loaded prediction from {0}.'.format(prediction_path))
        
        if np.equal(self.X, X).all() and image_names == self.image_names and\
            feature_names == self.feature_names:
            return y
        else:
            raise Exception('Incompatible data matrices')
        

def normalized_mutual_information_matrix(X, n_bins):
    X = np.array(X)

    n_cols = X.shape[1]
    for i in range(n_cols):
        X[:, i] = np.digitize(
            X[:, i], np.linspace(X[:, i].min(), X[:, i].max(), n_bins)) - 1

    d = np.zeros((n_cols**2 - n_cols) / 2)
    c = 0
    for i in range(n_cols):
        for j in range(i + 1, n_cols):
            d[c] = normalized_mutual_info_score(X[:, i], X[:, j])
            c += 1

    return sp.spatial.distance.squareform(d)
