from PyQt4 import QtCore, QtGui
import pyqtgraph as pg
import numpy as np
import pylab


class FeaturedImageDialog(QtGui.QDialog):

    def __init__(self, parent, title, image):
        QtGui.QDialog.__init__(self, parent)

        self.setWindowTitle(title)
        self.resize(800, 600)

        self.horizontalLayout = QtGui.QHBoxLayout(self)
        self.horizontalLayout.setObjectName("horizontalLayout")

        self.verticalLayout = QtGui.QVBoxLayout()
        self.verticalLayout.setObjectName("verticalLayout")

        self.image_layout_widget = pg.GraphicsLayoutWidget(self)
        self.image_layout_widget.setObjectName("image_layout_widget")

        self.verticalLayout.addWidget(self.image_layout_widget)
        self.horizontalLayout.addLayout(self.verticalLayout)

        image_layout = self.image_layout_widget.addLayout()

        viewbox = image_layout.addViewBox()
        viewbox.setAspectLocked(True)

        # @UndefinedVariable
        viewbox.addItem(pg.ImageItem(np.rot90(image, -1)))


class FeaturedPlotDialog(QtGui.QDialog):

    def __init__(self, parent, title, plot_item, labels):
        QtGui.QDialog.__init__(self, parent)

        self.setWindowTitle(title)
        self.resize(800, 600)

        self.horizontalLayout = QtGui.QHBoxLayout(self)
        self.horizontalLayout.setObjectName("horizontalLayout")

        self.verticalLayout = QtGui.QVBoxLayout()
        self.verticalLayout.setObjectName("verticalLayout")

        self.plot_widget = pg.PlotWidget(self)
        self.plot_widget.setObjectName("plot_widget")
        self.plot_widget.setBackground((255, 255, 255))

        self.verticalLayout.addWidget(self.plot_widget)
        self.horizontalLayout.addLayout(self.verticalLayout)

        self.plot_widget.addItem(plot_item)
        self.plot_widget.setLabels(**labels)


class FeaturedStatsDialog(QtGui.QDialog):

    def __init__(self, parent, title, scores, feats,
                 feature_names, images, Xt):
        QtGui.QDialog.__init__(self, parent)
        self.setupUi(title)

        argsortd = np.argsort(scores, kind='mergesort')
        self.sorted_scores = scores[argsortd]
        self.sorted_feats = feats[argsortd]
        self.sorted_fnames = np.array(feature_names)[argsortd]

        X = Xt[:, feats]
        rMaxSorted = X.max(axis=0)[argsortd]
        rMinSorted = X.min(axis=0)[argsortd]

        X = X[images, :]
        self.sorted_means = X.mean(axis=0)[argsortd]
        hist_plot = self.create_histogram(
            X[:, argsortd], rMaxSorted, rMinSorted)

        self.linear_region = pg.LinearRegionItem(values=[0, len(
            self.sorted_scores)], brush=pg.mkBrush(0, 128, 0, 10))
        self.linear_region.setBounds([0, len(self.sorted_scores)])

        self.feat_tooltip = pg.TextItem(text='', anchor=(1, 1),
                                        border=pg.mkPen('#000000', width=1),
                                        fill=pg.mkBrush(255, 255, 255, 200))
        self.feat_tooltip.hide()

        self.plot_widget.addItem(hist_plot)
        self.plot_widget.addItem(self.linear_region)
        self.plot_widget.addItem(self.feat_tooltip)
        self.plot_widget.setLabels(left='score',
                                   bottom='feature')

        self.plot_widget.getViewBox().setMouseEnabled(x=True, y=False)
        self.plot_widget.getViewBox().setRange(
            yRange=(0, self.sorted_scores.max()))
        self.plot_widget.getViewBox().setRange(
            xRange=(0, len(self.sorted_scores)))

        self.linear_region.sigRegionChanged.connect(self.handle_region_change)
        hist_plot.scene().sigMouseMoved.connect(self.handle_mouse_moved)

        self.exec_()

    def setupUi(self, title):
        self.setWindowTitle(title)
        self.resize(800, 600)

        self.horizontalLayout = QtGui.QHBoxLayout(self)
        self.horizontalLayout.setObjectName("horizontalLayout")

        self.verticalLayout = QtGui.QVBoxLayout()
        self.verticalLayout.setObjectName("verticalLayout")

        self.plot_widget = pg.PlotWidget(self)
        self.plot_widget.setObjectName("plot_widget")
        self.plot_widget.setBackground((255, 255, 255))

        self.verticalLayout.addWidget(self.plot_widget)
        self.horizontalLayout.addLayout(self.verticalLayout)

        self.statusbar = QtGui.QStatusBar(self)
        self.statusbar.setObjectName("statusbar")
        self.verticalLayout.addWidget(self.statusbar)

        self.buttonBox = QtGui.QDialogButtonBox(self)
        self.buttonBox.setOrientation(QtCore.Qt.Horizontal)
        self.buttonBox.setStandardButtons(
            QtGui.QDialogButtonBox.Cancel | QtGui.QDialogButtonBox.Ok)
        self.buttonBox.setObjectName("buttonBox")
        self.verticalLayout.addWidget(self.buttonBox)

        QtCore.QObject.connect(
            self.buttonBox, QtCore.SIGNAL("accepted()"), self.accept)
        QtCore.QObject.connect(
            self.buttonBox, QtCore.SIGNAL("rejected()"), self.reject)
        QtCore.QMetaObject.connectSlotsByName(self)

    def create_histogram(self, Xsorted, rMaxSorted, rMinSorted):
        hist_plot = pg.BarGraphItem(x=0.5 + np.arange(len(self.sorted_scores)),
                                    height=self.sorted_scores, width=1.0)

        nbins = int(np.ceil(1 + np.log2(Xsorted.shape[0])))

        def cm(a):
            color = pylab.cm.summer(a)  # @UndefinedVariable
            return int(color[0] * 255), int(color[1] * 255), int(color[2] * 255)

        brushes = []
        pens = []

        for i in range(len(self.sorted_scores)):
            if not np.allclose(rMaxSorted[i], rMinSorted[i]):
                histogram, bin_edges = np.histogram(Xsorted[:, i], nbins,
                    range=(rMinSorted[i], rMaxSorted[i]), density=True)

                histogram = histogram / histogram.max()

                gradient = QtGui.QLinearGradient(QtCore.QPointF(0.0, 0.0),
                                                 QtCore.QPointF(0.0, 1.0))
                gradient.setCoordinateMode(QtGui.QGradient.ObjectBoundingMode)

                bin_ratio = (
                    bin_edges[1] - bin_edges[0]) / (rMaxSorted[i] - rMinSorted[i])

                gradient.setColorAt(0, QtGui.QColor(*cm(histogram[0])))
                b = bin_ratio / 2
                for bi in range(nbins):
                    gradient.setColorAt(b, QtGui.QColor(*cm(histogram[bi])))
                    b += bin_ratio
                gradient.setColorAt(1, QtGui.QColor(*cm(histogram[-1])))
            else:
                gradient = QtGui.QLinearGradient(QtCore.QPointF(0.0, 0.0),
                                                 QtCore.QPointF(0.0, 1.0))
                gradient.setCoordinateMode(QtGui.QGradient.ObjectBoundingMode)
                gradient.setColorAt(0, QtGui.QColor(*cm(0)))
                gradient.setColorAt(0, QtGui.QColor(*cm(0)))

            brushes.append(QtGui.QBrush(gradient))
            pens.append(QtGui.QPen(gradient))

        hist_plot.setOpts(brushes=brushes, pens=pens)

        return hist_plot

    def _get_selection(self):
        first, last = self.linear_region.getRegion()
        first = int(round(first))
        last = int(round(last))
        return set(self.sorted_feats[first:last])

    def handle_region_change(self):
        first, last = self.linear_region.getRegion()
        first = int(round(first))
        last = int(round(last))

        self.statusbar.showMessage(
            '{0} features selected (area under curve: {1:.3f}).'.
            format(last - first, self.sorted_scores[first:last].sum()))

    def get_selection(self):
        return self.ok, self._get_selection()

    def handle_mouse_moved(self, pos):
        i = int(self.plot_widget.mapToView(pos).x())
        if 0 <= i < len(self.sorted_scores):
            text = 'Feature: {0}.\n'.format(self.sorted_fnames[i])
            text += 'Average: {0:.3f}.\n'.format(self.sorted_means[i])
            text += 'Feature score: {0:.3f}.'.format(self.sorted_scores[i])

            self.feat_tooltip.setText(text, color=pg.mkColor(0, 0, 0))
            self.feat_tooltip.setPos(self.plot_widget.mapToView(pos))
            self.feat_tooltip.show()
        else:
            self.feat_tooltip.hide()

    def accept(self):
        self.ok = True
        QtGui.QDialog.accept(self)

    def reject(self):
        self.ok = False
        QtGui.QDialog.reject(self)


class FeaturedParametersDialog(QtGui.QDialog):

    def __init__(self, parent, title, parameter):
        QtGui.QDialog.__init__(self, parent)

        self.setWindowTitle(title)

        self.resize(400, 600)

        self.horizontalLayout = QtGui.QHBoxLayout(self)
        self.horizontalLayout.setObjectName("horizontalLayout")
        self.verticalLayout = QtGui.QVBoxLayout()
        self.verticalLayout.setObjectName("verticalLayout")

        self.parameter_tree = pg.parametertree.ParameterTree(
            self)  # @UndefinedVariable
        self.parameter = parameter
        self.parameter_tree.setParameters(self.parameter)
        self.parameter_tree.header().hide()

        self.verticalLayout.addWidget(self.parameter_tree)

        self.buttonBox = QtGui.QDialogButtonBox(self)
        self.buttonBox.setOrientation(QtCore.Qt.Horizontal)
        self.buttonBox.setStandardButtons(
            QtGui.QDialogButtonBox.Cancel | QtGui.QDialogButtonBox.Ok)
        self.buttonBox.setObjectName("buttonBox")
        self.verticalLayout.addWidget(self.buttonBox)
        self.horizontalLayout.addLayout(self.verticalLayout)

        QtCore.QObject.connect(
            self.buttonBox, QtCore.SIGNAL("accepted()"), self.accept)
        QtCore.QObject.connect(
            self.buttonBox, QtCore.SIGNAL("rejected()"), self.reject)
        QtCore.QMetaObject.connectSlotsByName(self)

        self.exec_()

    def accept(self):
        self.ok = True
        QtGui.QDialog.accept(self)

    def reject(self):
        self.ok = False
        QtGui.QDialog.reject(self)

    def parameters(self):
        return self.ok, self.parameter.saveState(), self.parameter.getValues()
