import numpy as np
cimport numpy as np
np.import_array()

from sklearn.base import ClassifierMixin

from externals.libift.opf cimport BaseOPFCluster
from externals.libift.opf cimport matrix_to_dataset
#from opf cimport BaseOPFCluster
#from opf cimport matrix_to_dataset

from includes.iftSet cimport list_to_iftSet
from includes.iftSet cimport iftSet_to_list
from includes.iftSet cimport iftSet
from includes.iftSet cimport iftDestroySet

from includes.iftMST cimport iftMST
from includes.iftMST cimport DECREASING
from includes.iftMST cimport iftCreateMSTFromSet
from includes.iftMST cimport iftDestroyMST
from includes.iftMST cimport iftGetMSTBoundarySamples
from includes.iftMST cimport iftSortNodesByWeightMST

from includes.iftDataSet cimport iftDataSet
from includes.iftDataSet cimport iftCreateDataSet
from includes.iftDataSet cimport iftCopyDataSet
from includes.iftDataSet cimport TRAIN
from includes.iftDataSet cimport TEST
from includes.iftDataSet cimport iftDestroyDataSet
from includes.iftDataSet cimport iftSetStatus

from includes.iftClassification cimport iftCplGraph
from includes.iftClassification cimport iftClassify
from includes.iftClassification cimport iftCreateCplGraph
from includes.iftClassification cimport iftSupTrain
from includes.iftClassification cimport iftDestroyCplGraph

class OPFActiveLearning(BaseOPFActiveLearning, ClassifierMixin):
    """Implementation of active learning based on Optimum-path forest clustering
    and classification.
    
    After fitting the data matrix, an object of this class is capable of suggesting
    observations for labeling which, according to internal heuristics, would
    benefit a classifier the most.
    
    Examples
    --------
    Generating sample data:    
    >>>> X, y = datasets.make_moons(n_samples=2000, noise=.05)
    >>>> X = StandardScaler().fit_transform(X)     
    Fitting model to data:
    >>>> opf_al = OPFActiveLearning()
    >>>> opf_al.fit(X)
    Obtaining observations suggested for labeling:
    >>>> opf_al.suggestions_          
    Labeling these observations:
    >>>> opf_al.assign(opf_al.suggestions_, y[opf_al.suggestions_])
    Obtaining new suggestions:    
    >>>> opf_al.suggestions_
    Obtaining predictions for data
    >>>> y_pred = opf_al.predict(X)
    
    Parameters
    ----------
    suggestions_per_iteration: int
        maximum number of suggestions per suggestion step
    k_max_percentage: float in range (0,1]
        Clustering parameter. See `OPFCluster` parameter `k_max_percentage`
    clustering_iterations: int
        Clustering parameter. See `OPFCluster` parameter `n_iterations`
    n_classes: int (optional, 0 represents unknown number of classes)
        The total number of classes in the problem
    
    Attributes
    ----------
    suggestions_: int array of shape less or equal to `suggestions_per_iteration`
        Indices of the samples in the data matrix that are suggested for labeling
        in the next step.        
        
    See also
    --------
    BaseOPFActiveLearning
    ClassifierMixin
    
    Notes
    -----
    For more information on this technique, see:
    Active semi-supervised learning using optimum-path forest. P. Saito et al.
    In Proc. International Conference on Pattern Recognition.
    
    """
    def __init__(self, suggestions_per_iteration=10, k_max_percentage=0.1, 
                 clustering_iterations=10, n_classes=0):
        BaseOPFActiveLearning.__init__(self, suggestions_per_iteration, 
                                       k_max_percentage, clustering_iterations, n_classes)
                                       
    def fit(self, X, y = None):
        """Creates a model from a data matrix
        
        Parameters
        ----------
        X: float array of shape [n_samples, n_features]
            A 2D array representing the observations. Each observation 
            corresponds to a line in this matrix.
            
        Returns
        -------
        self
        
        """
        return BaseOPFActiveLearning.fit(self, np.asarray(X, dtype = np.float64))
        
    def assign(self, indices, y):
        """Assigns class labels in `y` to samples in `indices`
        
        Parameters
        ----------
        indices: int array
            Array of observation indices, each in range [0, n_samples - 1].
        y: int array
            Array of class labels which will be associated to the observations
            in `indices`. Observation `indices[i]` will be associated to class
            label `y[i]`.
            
        Returns
        -------
        self
        
        """
        
        return BaseOPFActiveLearning.assign(self, np.asarray(indices, dtype = np.int64),
                                            np.asarray(y, dtype = np.int64))
        
    def predict(self, X):
        return BaseOPFActiveLearning.predict(self, np.asarray(X, dtype = np.float64))
        
    @property
    def n_assignments(self):
        return BaseOPFActiveLearning.get_n_assignments(self)

cdef class BaseOPFActiveLearning:
    """Base class for OPFActiveLearning

    See Also
    --------
    OPFActiveLearning    
    
    """
    cdef iftMST *mst_
    cdef iftDataSet *dataset_
    cdef iftCplGraph *graph_
    
    cdef object _is_fitted
    cdef object _has_assg
    
    def __init__(self, suggestions_per_iteration, k_max_percentage, 
                 clustering_iterations, n_classes):
        self.n_classes = n_classes
        self.suggestions_per_iteration = suggestions_per_iteration
        self.k_max_percentage = k_max_percentage
        self.clustering_iterations = clustering_iterations
        
        self.suggestions_ = None
        
        self._is_fitted = False
        self._has_assg = False
        
        self.assignments_ = {}
        self.mapped_class = {}
        self.orig_class = {}
        
    def fit(self, np.ndarray[np.float64_t, ndim = 2] X):
        self._clean()
        
        opf_cluster = BaseOPFCluster(self.k_max_percentage, self.clustering_iterations)
        opf_cluster.fit(X)
        self.dataset_ = iftCopyDataSet(opf_cluster.dataset_)
        
        boundary_observations = opf_cluster.get_boundary_observations_indices()
        
        cdef iftSet* bo_set = list_to_iftSet(boundary_observations)
        self.mst_ = iftCreateMSTFromSet(self.dataset_, bo_set)
        iftDestroySet(&bo_set)
        
        iftSortNodesByWeightMST(self.mst_, DECREASING)
        
        self.suggestions_ = opf_cluster.get_cluster_centers_indices()
        
        self.assignments_ = {}
        self.mapped_class = {}
        self.orig_class = {}
        
        self._is_fitted = True
        
        return self
        
    def assign(self, np.ndarray[np.int64_t, ndim = 1] indices, 
               np.ndarray[np.int64_t, ndim = 1] y):
        if not self._is_fitted:
            raise Exception('Model is not fitted')
        if indices.size != y.size:
            raise Exception('The number of indices and labels does not match')
        if len(indices) == 0:
            return
        if indices.max() >= self.dataset_.nsamples or indices.min() < 0:
            raise Exception('Invalid index found in assignment')
            
        self._has_assg = True
            
        cdef int i
        for i in range(y.shape[0]):
            if y[i] not in self.mapped_class:
                self.mapped_class[y[i]] = len(self.mapped_class) + 1
                
        y_mapped = [self.mapped_class[yi] for yi in y]            
            
        self.orig_class = { v: k for (k, v) in self.mapped_class.items()}
        
        self.assignments_.update(zip(indices, y_mapped))
        self.suggestions_ = [s for s in self.suggestions_ if s not in 
            self.assignments_]
            
        iftSetStatus(self.dataset_, TEST)
        for k, v in self.assignments_.items():
            self.dataset_.sample[k].truelabel = v
            self.dataset_.sample[k].label = v
            self.dataset_.sample[k].status = TRAIN
            
        self.dataset_.nclasses = max(self.n_classes, len(self.mapped_class))
        self.dataset_.nlabels = max(self.n_classes, len(self.mapped_class))
        self.dataset_.ntrainsamples = len(self.assignments_)
            
        if self._has_assg:
            iftDestroyCplGraph(&self.graph_)
            
        self.graph_ = iftCreateCplGraph(self.dataset_)
        iftSupTrain(self.graph_)
        
        cdef iftSet* new_obs = iftGetMSTBoundarySamples(self.graph_, self.mst_,
                self.suggestions_per_iteration - len(self.suggestions_))
        self.suggestions_ += iftSet_to_list(new_obs)
        iftDestroySet(&new_obs)
        
    def predict(self, np.ndarray[np.float64_t, ndim = 2] X):
        if not self._is_fitted:
            raise Exception('Model is not fitted')
        if not self._has_assg:
            raise Exception('Model has no labels')
        if X.shape[1] != self.dataset_.nfeats:
            raise ValueError('Invalid number of features')
            
        cdef iftDataSet *dataset_test = matrix_to_dataset(X)
        iftClassify(self.graph_, dataset_test)
        
        cdef np.ndarray[np.int64_t, ndim = 1] y_pred 
        y_pred = np.zeros(X.shape[0], dtype = np.int) 
        
        cdef int i
        if len(self.mapped_class) == 1:
            single_class = self.mapped_class.keys()[0]
            y_pred.fill(single_class)
        else:
            for i in range(X.shape[0]):
                y_pred[i] = self.orig_class[dataset_test.sample[i].label]
            
        iftDestroyDataSet(&dataset_test)            
        
        return y_pred
        
    def _clean(self):
        if self._is_fitted:      
            self.suggestions_ = None
            
            self.assignments_ = {}
            self.mapped_class = {}
            self.orig_class = {}
        
            iftDestroyDataSet(&self.dataset_)
            iftDestroyMST(&self.mst_)

            if self._has_assg:            
                iftDestroyCplGraph(&self.graph_)
                
            self._is_fitted = False
            self._has_assg = False       

    def get_n_assignments(self):
        return len(self.assignments_)

    def __dealloc__(self):
        self._clean()