import os
import numpy as np
from sklearn.base import BaseEstimator, ClassifierMixin
from sklearn.base import ClusterMixin

cimport numpy as np
np.import_array()

from includes.iftDataSet cimport iftDataSet
from includes.iftDataSet cimport TRAIN
from includes.iftDataSet cimport TEST
from includes.iftDataSet cimport iftCreateDataSet
from includes.iftDataSet cimport iftDestroyDataSet
from includes.iftDataSet cimport iftReadOPFDataSet
from includes.iftDataSet cimport iftSetStatus

from includes.iftClassification cimport iftCplGraph
from includes.iftClassification cimport iftCreateCplGraph
from includes.iftClassification cimport iftCreateCplGraph
from includes.iftClassification cimport iftDestroyCplGraph
from includes.iftClassification cimport iftSupTrain
from includes.iftClassification cimport iftClassify

from includes.iftClustering cimport iftKnnGraph
from includes.iftClustering cimport iftUnsupLearn
from includes.iftClustering cimport iftKnnGraphCutFun3
from includes.iftClustering cimport iftCluster
from includes.iftClustering cimport iftDestroyKnnGraph
from includes.iftClustering cimport iftGetKnnRootSamples
from includes.iftClustering cimport iftGetKnnBoundarySamples

from includes.iftSet cimport iftSet
from includes.iftSet cimport iftDestroySet
from includes.iftAdjSet cimport iftAdjSet
from includes.iftSet cimport iftSet_to_list

def load_dataset(filepath):
    """Load iftDataSet stored in opf format from filepath.
    
    Parameters
    ----------
    filepath: string
        Path to the file storing the dataset.
    
    Returns
    -------
    X: float array of shape [n_samples, n_features]
        A 2D array representing the observations. Each observation 
        corresponds to a line in this matrix.
    y: int array of shape [n_samples]
        A 1D array representing the class labels ("truelabel") for each observation. 
        The class labels can be arbitrary integers.
        
    """
    if not os.path.isfile(filepath):
        raise Exception('Invalid path')
        
    dataset = iftReadOPFDataSet(filepath)
        
    cdef np.ndarray[np.float64_t, ndim = 2] X
    cdef np.ndarray[np.int64_t, ndim = 1] y

    X = np.zeros((dataset.nsamples,dataset.nfeats), dtype = np.float64)
    y = np.zeros(dataset.nsamples, dtype = np.int64)

    for i in range(X.shape[0]):
        y[i] = dataset.sample[i].truelabel
        for j in range(X.shape[1]):
            X[i,j] = dataset.sample[i].feat[j]
    
    iftDestroyDataSet(&dataset)
    
    return X, y
    
cdef iftDataSet* matrix_to_dataset(np.ndarray[np.float64_t, ndim = 2] X):
    """Create a iftDataSet* from a 2D float array.
    
    Parameters
    ----------
    X: float array of shape [n_samples, n_features]
        A 2D array representing the observations. Each observation should 
        correspond to a line in this matrix.
        
    Returns
    -------
    iftDataset*
        iftDataSet* representing the two dimensional matrix. Sample statuses are
        set to TEST.
    
    """
    dataset = iftCreateDataSet(X.shape[0], X.shape[1])
    
    cdef int i,j  
    for i in range(X.shape[0]):
        dataset.sample[i].id = i
        dataset.sample[i].status = TEST
        
        for j in range(X.shape[1]):
            dataset.sample[i].feat[j] = X[i,j]
            
    return dataset


class OPFClassifier(BaseOPFClassifier,BaseEstimator, ClassifierMixin):
    """Optimum-path forest classifier.
    
    Examples
    --------
    Learning the exclusive-or function:
    >>>> X = np.array([[0, 0], [0, 1], [1, 0], [1, 1]])
    >>>> y = np.array([0, 1, 1, 0])
    >>>> opf = OPFClassifier()
    >>>> opf.fit(X, y)
    >>>> opf.predict(X)
    [0 1 1 0]
    
    See also
    --------
    BaseOPFClassifier
    BaseEstimator
    ClassifierMixin

    Notes
    -----
    For details on this classification technique, see:
    Supervised Pattern Classification based on Optimum-Path Forest, J.P. Papa
    et al. International Journal of Imaging Systems and Technology 2009.
    
    """
    def __init__(self):
        BaseOPFClassifier.__init__(self)

    def fit(self, X, y):
        """
        Fit an optimum-path forest classifier to data matrix `X` and class
        labels vector `y`.
        
        Parameters
        ----------
        X: float array of shape [n_samples, n_features]
            A 2D array representing the training observations. Each observation 
            should correspond to a line in this matrix.
        y: int array of shape [n_samples]
            A 1D array representing the class labels for each observation. The
            class labels can be arbitrary integers.
            
        Returns
        ------
        self
        
        """
        if X.shape[0] != y.shape[0]:
            raise ValueError('X and y should have the same number of observations')
        
        return BaseOPFClassifier.fit(self, np.asarray(X, dtype = np.float64), 
                                     np.asarray(y, dtype = np.int64))
        
    def predict(self, X):
        """
        Predict class labels for observations in the data matrix `X`.

        Parameters
        ----------       
        X: float array of shape [n_samples, n_features]
            A 2D array representing the observations. Each observation 
            corresponds to a line in this matrix.
            
        Returns
        -------
        A 1D array representing the predicted class labels for each observation.
        
        """
        return BaseOPFClassifier.predict(self, np.asarray(X, dtype = np.float64))

class OPFCluster(BaseOPFCluster, BaseEstimator, ClusterMixin):
    """Optimum-path clustering implementation.
    
    Examples
    --------
    Clustering three blobs in 2D space
    >>>> X, _ = sklearn.datasets.make_blobs(n_samples=n_samples, random_state=8)
    >>>> opf = OPFCluster()
    >>>> opf.fit(X)
    >>>> opf.labels_
    [2 2 0 ..., 2 2 0]    
    >>>> opf.cluster_centers_
    [[ 0.71304573  0.02729775]
     [ 0.71171119  1.1578621 ]
     [-1.41845867 -1.28228492]]
    
    Parameters
    ----------
    k_max_percentage: float in the range (0,1]
        Ratio (with respect to the number of samples in the training data) that 
        defines a maximum number of neighbors `kmax = k_max_percentage*n_samples`. 
        
        Several models are considered with a constant varying between 1 and kmax, 
        and the best is chosen heuristically. This parameter indirectly affects the 
        number of clusters. Higher values usually lead to a smaller number of 
        clusters, but also make the fitting slower.
        
    n_iterations: int
        Number of outlier removal steps.
    
    Attributes
    ----------
    labels_
    cluster_centers_
    boundary_observations_ind_
    cluster_centers_ind_
    
    See also
    --------
    BaseOPFCluster
    BaseEstimator
    ClusterMixin
    
    Notes
    -----
    For details on this clustering technique, see:
    Data Clustering as an Optimum-Path Forest Problem with Applications 
    in Image Analysis. L.M. Rocha et al. International Journal of Imaging 
    Systems and Technology 2009.
    
    """
    def __init__(self, k_max_percentage = 0.1, n_iterations = 10):
        BaseOPFCluster.__init__(self, k_max_percentage, n_iterations)
        
    def fit(self, X, y = None):
        """Compute clustering given data matrix.
        
        Parameters
        ----------
        X: float array of shape [n_samples, n_features]
            A 2D array representing the training observations. Each observation 
            corresponds to a line in this matrix.
        
        Returns
        -------
        self
        """
        self.X_ = np.array(X)
        return BaseOPFCluster.fit(self, np.asarray(X,dtype = np.float64))

    @property
    def labels_(self):
        """Cluster labels assigned to each observation in the data matrix used
        to fit this model.
        
        Returns
        -------
        int array of shape [n_samples]
            Cluster labels (in the range [0, n_clusters -1]) assigned to each
            observation in the data matrix.
        
        """
        return self.get_labels()
        
    @property
    def cluster_centers_(self):
        """Representative observations
        
        Returns
        -------
        float array of shape [n_clusters, n_features]
            Matrix of observations that represent each cluster. The i-th
            observation corresponds to cluster i.
        
        """
        return self.X_[self.get_cluster_centers_indices()]
        
    @property
    def cluster_centers_ind_(self):
        """Index of representative observations
        
        Returns
        -------
        int array of shape [n_clusters]
            Array of observation indices that represent each cluster. The i-th
            observation index in this array corresponds to cluster i.
        
        """
        return self.get_cluster_centers_indices()
        
    @property
    def boundary_observations_ind_(self):
        """Index of boundary observations. A observation is in the boundary if
        it has a k-nearest neighbor in another cluster. The parameter k depends
        on the model fitted to the data.
        
        Returns
        -------
        int array of shape [n_boundary_observations]
            Array of observation indices that are in the boundary
        
        """
        return self.get_boundary_observations_indices()
        
    
cdef class BaseOPFClassifier:
    """Base class for OPFClassifier

    See Also
    --------
    OPFClassifier    
    
    """
    cdef iftDataSet *dataset_
    cdef iftCplGraph *graph_
    cdef object _is_fitted
    
    def __init__(self):
        self._is_fitted = False
    
    def fit(self, np.ndarray[np.float64_t, ndim = 2] X, 
            np.ndarray[np.int64_t, ndim = 1] y):
        self._clean()
                
        self.classes_, y = np.unique(y, return_inverse=True)
        y = y + 1        
        
        self.dataset_ = matrix_to_dataset(X)
        self.dataset_.nclasses = len(self.classes_)
        self.dataset_.nlabels = len(self.classes_)
        self.dataset_.ntrainsamples = X.shape[0]
        
        cdef int i
        for i in range(X.shape[0]):
            self.dataset_.sample[i].truelabel = y[i]
            self.dataset_.sample[i].label = y[i]
            self.dataset_.sample[i].status = TRAIN
            
        self.graph_ = iftCreateCplGraph(self.dataset_)
        iftSupTrain(self.graph_)
        
        self._is_fitted = True
        
        return self
        
    def predict(self, X):
        if not self._is_fitted:
            raise Exception('Model is not fitted')
        if X.shape[1] != self.dataset_.nfeats:
            raise ValueError('Invalid number of features')
        
        cdef iftDataSet *dataset_test = matrix_to_dataset(X)
        iftClassify(self.graph_, dataset_test)
        
        cdef np.ndarray[np.int64_t, ndim = 1] y 
        y = np.zeros(X.shape[0], dtype = np.int64) 
        
        cdef int i
        for i in range(X.shape[0]):
            y[i] = self.classes_[dataset_test.sample[i].label - 1]
            
        iftDestroyDataSet(&dataset_test)            
        
        return y
        
    def _clean(self):
        if self._is_fitted:        
            iftDestroyCplGraph(&self.graph_)
            iftDestroyDataSet(&self.dataset_)
            
            self._is_fitted = False
        
    def __dealloc__(self):
        self._clean()
        
cdef class BaseOPFCluster:
    """Base class for OPFCluster.

    See Also
    --------
    OPFCluster    
    
    """
    def __init__(self, k_max_percentage, n_iterations):
        self.k_max_percentage = k_max_percentage
        self.n_iterations = n_iterations
        
        self._is_fitted = False
        
    def fit(self, np.ndarray[np.float64_t, ndim = 2] X):
        self._clean()
        
        self.dataset_ = matrix_to_dataset(X)
        iftSetStatus(self.dataset_, TRAIN)
                
        self.knn_graph_ = iftUnsupLearn(self.dataset_, 
            self.k_max_percentage, iftKnnGraphCutFun3, self.n_iterations)
        iftCluster(self.knn_graph_, self.dataset_)
        
        self._labels = np.zeros(X.shape[0], dtype = np.int64)
        cdef int i
        for i in range(X.shape[0]):
            self._labels[i] = self.dataset_.sample[i].label
            
        _, self._labels = np.unique(self._labels, return_inverse=True)
        
        self._is_fitted = True
        
        return self
        
    def get_cluster_centers_indices(self):
        if not self._is_fitted:
            raise Exception('Model is not fitted')
            
        cdef iftSet *rootSet = iftGetKnnRootSamples(self.knn_graph_)
        indices = iftSet_to_list(rootSet)
        iftDestroySet(&rootSet)
        
        cdef np.ndarray[np.int64_t, ndim = 1] representatives
        representatives = np.zeros(len(indices), dtype = np.int64)
        cdef int ind
        for ind in indices:
            representatives[self._labels[ind]] = ind
            
        return representatives
    
    def get_boundary_observations_indices(self):
        if not self._is_fitted:
            raise Exception('Model is not fitted')
            
        cdef iftSet  *boundarySet = iftGetKnnBoundarySamples(self.knn_graph_)
        indices = iftSet_to_list(boundarySet)
        iftDestroySet(&boundarySet)
        
        return indices
        
    def _clean(self):
        if self._is_fitted:
            iftDestroyDataSet(&self.dataset_)
            iftDestroyKnnGraph(&self.knn_graph_)
            
            self._is_fitted = False
            
    def get_labels(self):
        return self._labels
    
    def __dealloc__(self):
        self._clean()