import numpy

def save(f, X, y, image_names, feature_names):
    """Save point matrix X, class vector y, image and feature names to f."""
    f.write('DY\n{0}\n{1}\n'.format(X.shape[0], X.shape[1]))
    f.write(';'.join(feature_names) + '\n')

    for i in range(X.shape[0]):
        f.write(';'.join([image_names[i]] + [str(c)
                   for c in X[i]] + [str(int(y[i]))]) + '\n')

def load(f):
    """Read point matrix X, class vector y, image and feature names from f."""
    magic_code = f.readline().strip()

    n = int(f.readline())
    m = int(f.readline())

    feature_names = f.readline().strip()
    if feature_names == '':
        feature_names = ['f{0}'.format(i+1) for i in range(m)]
    else:
        feature_names = feature_names.split(';')
        
    if len(feature_names) != m:
        raise Exception('Invalid number of features')

    X = numpy.zeros((n, m))
    y = numpy.zeros(n)

    supervised = (magic_code[1] == 'Y')
    
    if magic_code[0] == 'D':
        X, y, image_names = load_dense(f, X, y, n, m, supervised)
    elif magic_code[0] == 'S' :
        X, y, image_names = load_sparse(f, X, y, n, m, supervised)
    else:
        raise Exception('Unknown point matrix format')
    
    return X, y, image_names, feature_names
    
    
def load_dense(f, X, y, n, m, supervised):
    image_names = []

    for i in range(n):
        line = f.readline().strip().split(';')
        image_names.append(line[0])

        if  ( (not supervised and len(line) != m + 1) or
              (supervised and len(line) != m + 2) ):
            raise Exception(
                'Invalid features for image {0}'.format(line[0]))

        if supervised:
            X[i] = [float(v) for v in line[1:-1]]
            y[i] = int(float(line[-1]))
        else:
            X[i] = [float(v) for v in line[1:]]

    return X, y, image_names


def load_sparse(f, X, y, n, m, supervised):
    image_names = []

    for i in range(n):
        line = f.readline().strip().split(';')
        image_names.append(line[0])
        
        if supervised:
            y[i] = int(float(line[-1]))
        
        for cv in line[1:-1]:
            c,v = cv.split(':')
            X[i, int(c)] = float(v)

    return X, y, image_names

def save_dmatrix(f, dmatrix, y, image_names):
    f.write('{0}\n'.format(dmatrix.shape[0]))
    f.write(';'.join(image_names) + '\n')
    f.write(';'.join([str(e) for e in y]) + '\n')
    
    for i in range(1,dmatrix.shape[1]):
        f.write(';'.join([str(dmatrix[i,j]) for j in range(i)]) + '\n')