import os
from PyQt4 import QtCore, QtGui
import skimage.io
skimage.io.use_plugin('freeimage')

import numpy as np
import pyqtgraph as pg

from model import feature_scoring
from model.parsing import point_matrix
from widgets import dialogs


class FeaturedHandler(object):

    def __init__(self):
        pass

    def handle_open_folder(self):
        directory = str(QtGui.QFileDialog.
                        getExistingDirectory(self, 'Open image directory'))

        if not os.path.exists(directory):
            return

        try:
            self.open_path(directory)
        except Exception as e:
            QtGui.QMessageBox.warning(self, "Warning", str(e))

    def handle_open_file(self):
        directory = str(QtGui.QFileDialog.
                        getOpenFileName(self, 'Open features file'))

        if not os.path.exists(directory):
            return

        try:
            self.open_path(directory)
        except Exception as e:
            QtGui.QMessageBox.warning(self, "Warning", str(e))

    def handle_run_projection_clicked(self):
        selected = self._get_selected_features_tree()
        self.set_selected_features(selected)

    def handle_points_selected(self, modifiers, points):
        selected_points = self.model.selected_images

        if modifiers & QtCore.Qt.ControlModifier:
            if modifiers & QtCore.Qt.AltModifier:
                selected_points = selected_points - points
            else:
                selected_points = selected_points | points
        else:
            selected_points = set(points)

        self.set_selected_points(selected_points)

        if modifiers & QtCore.Qt.ShiftModifier:
            self.create_unnamed_group()

    def handle_point_clicked(self, plot, points):
        selected_points = self.model.selected_images

        modifiers = QtGui.QApplication.keyboardModifiers()
        for p in points:
            if modifiers & QtCore.Qt.ControlModifier:
                if modifiers & QtCore.Qt.AltModifier:
                    if p.data() in selected_points:
                        selected_points.remove(p.data())
                else:
                    selected_points.add(p.data())
            else:
                selected_points = set([p.data()])

        self.set_selected_points(selected_points)

        if modifiers & QtCore.Qt.ShiftModifier:
            self.create_unnamed_group()

    def handle_point_dragged(self, i, destination):
        self.select_neighborhood(i, destination)

    def handle_point_released(self, i, destination):
        self.move_point_to(i, destination)

    def handle_feature_points_selected(self, modifiers, points):
        selected_feature_points = self.model.selected_features

        if modifiers & QtCore.Qt.ControlModifier:
            if modifiers & QtCore.Qt.AltModifier:
                selected_feature_points = selected_feature_points - points
            else:
                selected_feature_points = selected_feature_points | points
        else:
            selected_feature_points = set(points)

        self.set_selected_features(selected_feature_points)

    def handle_feature_points_clicked(self, plot, points):
        selected_feature_points = self.model.selected_features

        modifiers = QtGui.QApplication.keyboardModifiers()
        for p in points:
            if modifiers & QtCore.Qt.ControlModifier:
                if modifiers & QtCore.Qt.AltModifier:
                    if p.data() in selected_feature_points:
                        selected_feature_points.remove(p.data())
                else:
                    selected_feature_points.add(p.data())
            else:
                selected_feature_points = set([p.data()])

        self.set_selected_features(selected_feature_points)

    def handle_image_selected(self):
        if not self.ignore_events:
            selected_items = self.images_list.selectedItems()
            selected_images = set([i.data(QtCore.Qt.UserRole).toInt()[0]
                                   for i in selected_items])

            self.set_selected_points(selected_images)

    def handle_group_selected(self):
        if not self.ignore_events:
            selected_groups = [str(i.text())
                               for i in self.groups_list.selectedItems()]

            self.set_selected_groups(selected_groups)

    def handle_feature_clicked(self, item):
        if not self.ignore_events:
            self.ignore_events = True

            select = item.isSelected()
            self._select_children_features(item, select)

            if self.action_auto_update_projection.isChecked():
                self.set_selected_features(
                    self._get_selected_features_tree())

            self.ignore_events = False

    def handle_view_clicked(self):
        image_paths = self.get_selected_image_paths()
        for image in image_paths:
            img = skimage.io.imread(image)
            dialog = dialogs.FeaturedImageDialog(self, 'Featured - Image {0}'.
                                                 format(os.path.split(image)[1]), img)
            dialog.show()

    def handle_add_group_clicked(self):
        name, ok = QtGui.QInputDialog.getText(self, 'Group Identification',
                                              'Enter the name for the group:')
        if ok:
            self.create_group(str(name))

    def handle_remove_group_clicked(self):
        gs = []
        for item in self.groups_list.selectedItems():
            gs.append(str(item.text()))

        self.remove_groups(gs)

    def handle_color_by_changed(self):
        if not self.ignore_events:
            self.feature_combobox.setEnabled(
                self.color_by_combobox.currentText() == 'Feature')

            self.primary_obs_points = self._get_colored_observation_points(self.model.Xp,
                self.model.Xt, str(self.color_by_combobox.currentText()))
            self.update_point_coloring(self.primary_obs_points)

    def handle_distance_histogram(self):
        x, y = self.model.compute_distance_histogram()

        hist_plot = pg.PlotCurveItem(x, y, stepMode=True, fillLevel=0,
                                     brush=(0, 0, 255))
        labels = {'left': 'frequency', 'bottom': 'distance'}
        dialog = dialogs.FeaturedPlotDialog(self, 'Distance histogram',
                                            hist_plot, labels)
        dialog.show()

    def handle_similarity_matrix(self):
        d = self.model.compute_similarity_matrix()

        dialog = dialogs.FeaturedImageDialog(self, 'Similarity matrix', d)
        dialog.show()

    def handle_save_groups(self):
        filepath = str(QtGui.QFileDialog.
                       getSaveFileName(self, 'Save groups'))

        try:
            self.model.save_groups(filepath)
        except Exception as e:
            QtGui.QMessageBox.warning(self, "Warning", str(e))

    def handle_save_features(self):
        filepath = str(QtGui.QFileDialog.
                       getSaveFileName(self, 'Save features'))

        try:
            self.model.save_features(filepath)
        except Exception as e:
            QtGui.QMessageBox.warning(self, "Warning", str(e))

    def handle_auto_update_projection(self):
        if self.action_auto_update_projection.isChecked():
            self.run_projection()

    def handle_projection_options(self):
        p = self.settings.parameter_tree()

        ok, state, values = dialogs.FeaturedParametersDialog(self,
            'Settings', p).parameters()

        if ok:
            self.set_parameter_state(state, values)

    def handle_expand_features(self):
        self.features_tree.expandAll()

    def handle_collapse_features(self):
        self.features_tree.collapseAll()

    def handle_order_by_changed(self):
        if not self.ignore_events:
            self.ignore_events = True

            self.update_image_list_order()

            selected_points = self.model.selected_images
            self.update_point_selection(selected_points)

            self.ignore_events = False

    def handle_correlation_axis(self, axis):
        try:
            cx, cy = self.model.compute_projection_axis_correlation()

            if axis == 'x':
                contributions = cx
            else:
                contributions = cy

            contributions = np.abs(contributions)

            features = np.array(sorted(self.model.selected_features))
            feature_names = np.array(self.model.get_feature_names(features))

            Xt = self.model.Xt

            images = sorted(self.model.selected_images)
            if len(images) < 1:
                images = range(Xt.shape[0])

            dialog = dialogs.FeaturedStatsDialog(self, 'Correlation with ' + axis,
                contributions, features, feature_names, images, Xt)

            ok, selected_features = dialog.get_selection()
            if ok:
                self.set_selected_features(selected_features)

        except Exception as e:
            QtGui.QMessageBox.warning(self, "Warning", str(e))

    def handle_save_feature_state_clicked(self):
        self.last_feature_state = self.model.selected_features

    def handle_load_feature_state_clicked(self):
        if self.last_feature_state:
            self.set_selected_features(self.last_feature_state)

    def handle_invert_feature_state_clicked(self):
        feats = self.model.get_deselected_features()
        self.set_selected_features(feats)

    def handle_hide_selected(self):
        try:
            self.hide_selected_images()
        except Exception as e:
            QtGui.QMessageBox.warning(self, "Warning", str(e))

    def handle_chi_squared_stats(self):
        self.handle_feature_relevance_stats(feature_scoring.ChiSquaredScorer)

    def handle_anova_stats(self):
        self.handle_feature_relevance_stats(feature_scoring.OneWayAnovaScorer)

    def handle_relief_stats(self):
        self.handle_feature_relevance_stats(feature_scoring.ReliefScorer)

    def handle_recursive_elimination_stats(self):
        self.handle_feature_relevance_stats(
            feature_scoring.RecursiveFeatureEliminationScorer)

    def handle_decision_trees_stats(self):
        self.handle_feature_relevance_stats(
            feature_scoring.RandomizedDecisionTreesScorer)

    def handle_randomized_logistic_score_stats(self):
        self.handle_feature_relevance_stats(
            feature_scoring.RandomizedLogisticRegressionScorer)

    def handle_variance_based_coherence(self):
        self.handle_feature_relevance_stats(
            feature_scoring.VarianceBasedCoherenceScorer)

    def handle_silhouette_coefficient(self):
        self.handle_feature_relevance_stats(
            feature_scoring.SilhouetteCoefficientScorer)

    def handle_feature_relevance_stats(self, FeatureScorer):
        try:
            Xt = self.model.Xt
            features = np.array(sorted(self.model.selected_features))
            feature_names = np.array(self.model.get_feature_names(features))

            images1, images2, title = self._split_selection_for_scoring()

            feature_scorer = FeatureScorer(Xt, images1, images2, features)
            contributions = feature_scorer.score()

            dialog = dialogs.FeaturedStatsDialog(self, title,
                contributions, features, feature_names, images1 + images2, Xt)

            ok, selected_features = dialog.get_selection()
            if ok:
                self.set_selected_features(selected_features)
        except Exception as e:
            QtGui.QMessageBox.warning(self, "Warning", str(e))

    def handle_show_feature_plot_clicked(self):
        if self.action_show_feature_plot.isChecked():
            self.widget_feature_plot.show()
        else:
            self.widget_feature_plot.hide()

    def handle_show_observation_plot_clicked(self):
        if self.action_show_observation_plot.isChecked():
            self.widget_observation_plot.show()
        else:
            self.widget_observation_plot.hide()

    def handle_show_image_view_clicked(self):
        if self.action_show_image_view.isChecked():
            self.widget_image_view.show()
        else:
            self.widget_image_view.hide()

    def handle_last_feature_state_clicked(self):
        self.pop_feature_selection()

    def handle_feature_color_by_changed(self):
        if not self.ignore_events:
            self.scoring_combobox.setEnabled(
                self.feature_plot_color_by_combobox.currentText() == 'Score')

            self.primary_feat_points = self._get_colored_feature_points(self.model.fXp,
                str(self.feature_plot_color_by_combobox.currentText()))

            self.update_feature_point_coloring(self.primary_feat_points)

    def handle_load_feature_selection(self):
        filepath = str(QtGui.QFileDialog.
                       getOpenFileName(self, 'Open features file'))

        if not os.path.exists(filepath):
            return

        try:
            f = open(filepath)
            _, _, _, feature_names = point_matrix.load(f)
            f.close()

            selected_features = [self.model.feature_names.index(featname)
                                 for featname in feature_names]

            self.set_selected_features(selected_features)
        except Exception as e:
            QtGui.QMessageBox.warning(self, "Warning", str(e))

    def handle_point_hovered(self, i):
        if self.button_lensing_observations.isChecked():
            #('Lensing mode activated')
            if i is not None:
                #('A point is being hovered')
                if i != self.last_hovered_obs:
                    #('This point was not previously hovered, updating coloring')
                    self.plot_widget_vb.show_circle(True)
                    circle_radius = self.plot_widget_vb.circle_radius

                    lensed = set(self.model.nearest_points_r(i, circle_radius))

                    primary = self.primary_obs_points
                    secondary = self.secondary_obs_points
                    points = [secondary[j] if j in lensed else primary[j]
                              for j in range(self.model.n_observations)]

                    self.update_point_coloring(points)

                    self.last_hovered_obs = i
            else:
                #('A point is not being hovered')
                if self.last_hovered_obs is not None:
                    #('A point was previously colored, updating coloring')
                    self.update_point_coloring(self.primary_obs_points)
                    self.plot_widget_vb.show_circle(False)
                    self.last_hovered_obs = None
        else:
            self.plot_widget_vb.show_tooltip(i is not None)

    def handle_button_lensing_observations_clicked(self):
        if self.button_lensing_observations.isChecked():
            self.secondary_obs_points = self._get_colored_observation_points(
                self.model.Xp, self.model.Xt, self.settings.parameter('observation_lensing_color'))
        else:
            self.secondary_obs_points = None

    def handle_feature_point_hovered(self, i):
        if self.button_lensing_features.isChecked():
            #('Lensing mode activated')
            if i is not None:
                #('A point is being hovered')
                if i != self.last_hovered_feat:
                    #('This point was not previously hovered, updating coloring')
                    self.features_plot_widget_vb.show_circle(True)
                    circle_radius = self.features_plot_widget_vb.circle_radius

                    lensed = set(
                        self.model.nearest_feature_points_r(i, circle_radius))

                    primary = self.primary_feat_points
                    secondary = self.secondary_feat_points
                    points = [secondary[j] if j in lensed else primary[j]
                              for j in range(self.model.n_features)]

                    self.update_feature_point_coloring(points)

                    self.last_hovered_feat = i
            else:
                #('A point is not being hovered')
                if self.last_hovered_feat is not None:
                    #('A point was previously colored, updating coloring')
                    self.update_feature_point_coloring(
                        self.primary_feat_points)
                    self.features_plot_widget_vb.show_circle(False)
                    self.last_hovered_feat = None
        else:
            self.features_plot_widget_vb.show_tooltip(i is not None)

    def handle_button_lensing_features_clicked(self):
        if self.button_lensing_features.isChecked():
            self.secondary_feat_points = self._get_colored_feature_points(
                self.model.fXp, self.settings.parameter('feature_lensing_color'))
        else:
            self.secondary_feat_points = None
            
    def handle_compute_selection_accuracy_loo(self):
        selected_images = self.model.selected_images
        if 0 < len(selected_images) < self.model.n_observations:
            y = np.zeros(self.model.n_observations, dtype=int)        
            indices = sorted(self.model.selected_images)
            y[indices] = 1
        else:
            y = self.model.y
            
        ypred = self.model.compute_kfold_classification(y)
        
        self.update_message('{0:.3f}% accuracy.'.
                            format((100. * sum(ypred == y)) / len(ypred)))
           
    def handle_compute_neighborhood_hit(self):
        selected_images = self.model.selected_images
        if 0 < len(selected_images) < self.model.n_observations:
            y = np.zeros(self.model.n_observations, dtype=int)        
            indices = sorted(self.model.selected_images)
            y[indices] = 1
        else:
            y = self.model.y
            
        nhit = self.model.compute_neighborhood_hit(y)
        
        self.update_message('{0:.3f}% neighborhood hit.'.format(nhit*100))

    def handle_select_suggestions_clicked(self):
        suggestions = self.model.active_learning_suggestions()
        self.set_selected_points(suggestions)
        self.highlight_points(suggestions, size=1.5, symbol='t')

    def handle_auto_label_selection(self):
        indices = sorted(self.model.selected_images)
        classes = self.model.y[indices]

        self.commit_labelling(dict(zip(indices, classes)))

    def handle_reset_active_learning_clicked(self):
        self.model.active_learning_reset()
        self.run_projection()

    def handle_label_suggestions_clicked(self):
        label, ok = QtGui.QInputDialog.getInteger(self, 'Label observations',
                                                  'Class label (integer):', 0, 0)

        self.update_labelling({i: label for i in self.model.selected_images})
        
    def handle_discard_labelling_clicked(self):
        self.discard_labelling()

    def handle_commit_labelling_clicked(self):
        self.commit_labelling()
