from __future__ import (
    absolute_import,
    division,
    print_function,
)

from abc import ABCMeta
import math
import os
import shutil
from subprocess import call
import tempfile

import numpy as np
from passport.backend.utils.file import chdir
import six
from sklearn.base import (
    BaseEstimator,
    ClassifierMixin,
)
from sklearn.utils import check_X_y


matrix_net_train_method = {
    'quad': '',
    'qrmse': '-q',
    'binary': '-c',
    'multi': '-m',
    'pair': '-P',
    'pfound': '-F',
    'pfound_classic': '-L',
    'pfound_test': '-U',
    'pfound_flat': '-u',
    'pfound2': '-2',
}


class MatrixNetClassifier(six.with_metaclass(ABCMeta, BaseEstimator, ClassifierMixin)):
    def __init__(self,
                 iterations=10,
                 depth=3,
                 method='binary',
                 cv='1/5',
                 binary_name='matrixnet',
                 base_folder=None,
                 fmt='%.10f'):
        """
        Consturct MatrixNet wrapper
        :param iterations:
        :param depth:
        :param cv: Cross-validation
        :param binary_name: binary name on host
        :param base_folder: Folder to store model files
        :return:
        """
        self.iterations = iterations
        self.depth = depth
        self.method = method
        self.cv = cv
        self.binary_name = binary_name
        self.base_folder = base_folder
        self.fmt = fmt
        self.trained = False
        self.__real_base_folder = None

    def __enter__(self):
        pass

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.clear()

    def clear(self):
        self.trained = False
        if self.__real_base_folder is not None:
            if os.path.exists(self.__real_base_folder):
                shutil.rmtree(self.__real_base_folder)
            self.__real_base_folder = None

    def fit(self, X, y, sample_weight=None, columns=None):
        """
        :param X: nd.array or list of lists
        :param y: nd.array or list
        :param sample_weight: (optional)
        :param columns: names of columns (for get_feature_importances)
        :return:
        """

        check_X_y(X, y, multi_output=(self.method == 'multi'))
        X = np.asarray(X)
        y = np.asarray(y)
        self.__real_base_folder = self.base_folder
        if self.base_folder is None:
            self.__real_base_folder = tempfile.mkdtemp(prefix='matrixnet_model')
        if os.path.exists(self.__real_base_folder):
            shutil.rmtree(self.__real_base_folder)
        if not os.path.exists(self.__real_base_folder):
            os.makedirs(self.__real_base_folder)

        self.__columns = columns
        if self.__columns is None:
            self.__columns = ['feature_%d' % i for i in range(X.shape[1])]

        if sample_weight is None:
            sample_weight = np.array([1] * X.shape[0])

        self.__model_folder = os.path.join(self.__real_base_folder, 'matrixnet')
        self.__train_dataset_filepath = os.path.join(self.__real_base_folder, 'train.tsv')
        self.__test_dataset_name = 'test.tsv'
        self.__test_dataset_filepath = os.path.join(self.__real_base_folder, self.__test_dataset_name)
        self.__prediction_dataset_filepath = os.path.join(self.__real_base_folder, 'prediction.tsv')
        self.__features_filepath = os.path.join(self.__real_base_folder, 'features.txt')
        self.__trained_model_filepath = os.path.join(self.__real_base_folder, 'matrixnet.txt')
        self.__trained_model_fstr_filepath = self.__model_folder + '.fstr'

        data = np.c_[
            np.arange(1, X.shape[0] + 1)[:, np.newaxis],
            y[:, np.newaxis],
            np.arange(1, X.shape[0] + 1)[:, np.newaxis],
            sample_weight[:, np.newaxis],
            X
        ]

        np.savetxt(
            self.__train_dataset_filepath,
            data,
            delimiter='\t',
            fmt=['%d'] * 3 + [self.fmt] * (X.shape[1] + 1)
        )

        cmd = [
            self.binary_name,
            matrix_net_train_method[self.method],
            '-W',
            '-f', self.__train_dataset_filepath,
            '-o', self.__model_folder,
            '-n', str(self.depth),
            '-i', str(self.iterations),
        ]

        if self.cv:
            cmd.extend(['-X', self.cv])

        cmd = ' '.join(cmd)
        call('%s > %s' % (cmd, self.__trained_model_filepath), shell=True)
        self.__parsed_model = self.__parse_matrixnet_out()
        self.trained = True
        return self

    def __parse_matrixnet_fstr(self):
        features = {}
        if self.__trained_model_fstr_filepath:
            for line in open(self.__trained_model_fstr_filepath):
                line = line.strip().split()
                features[self.__columns[int(line[0])]] = float(line[3])
        return features

    def __parse_matrixnet_out(self):
        target_learn = []
        target_test = []
        min_test = 1
        min_test_iteration = 0
        iteration = 0

        for line in open(self.__trained_model_filepath):
            line = line.strip()
            if line.startswith('learn') and 'test' in line:
                iteration += 1
                line_parts = line.split()
                try:
                    test_ll = abs(float(line_parts[3].strip(',')))
                    target_learn.append(abs(float(line_parts[1].strip(','))))
                    target_test.append(test_ll)

                    if test_ll < min_test:
                        min_test = test_ll
                        min_test_iteration = iteration
                except ValueError:
                    pass

        return dict(
            target_learn=target_learn,
            target_test=target_test,
            min_test=min_test,
            min_test_iteration=min_test_iteration,
            features=self.__parse_matrixnet_fstr(),
        )

    def get_feature_importances(self):
        if not self.trained:
            raise RuntimeError('Model has not been trained yet.')
        return self.__parsed_model['features']

    def predict_proba(self, X):
        '''
        :param X: np.ndarray
        :return:
        '''

        if not self.trained:
            raise RuntimeError('Model has not been trained yet.')

        X = np.asarray(X)
        data = np.c_[
            np.arange(10 ** 9, 10 ** 9 + X.shape[0])[:, np.newaxis],
            [1] * X.shape[0],
            np.arange(10 ** 9, 10 ** 9 + X.shape[0])[:, np.newaxis],
            [1] * X.shape[0],
            X
        ]

        np.savetxt(
            self.__test_dataset_filepath,
            data,
            delimiter='\t',
            fmt=['%d'] * 3 + [self.fmt] * (X.shape[1] + 1)
        )

        cmd = [
            self.binary_name,
            '-A',
            '-f', self.__test_dataset_name,
        ]

        cmd = ' '.join(cmd)
        with chdir(self.__real_base_folder):
            call('%s > /dev/null' % cmd, shell=True)
        result = []
        for line in open(self.__test_dataset_filepath + '.matrixnet'):
            _, _, _, _, prediction = line.split('\t')
            if self.method == 'multi':
                result.append(int(round(float(prediction))))
            else:
                probability = 1.0 / (1.0 + math.exp(-float(prediction)))
                result.append([1.0 - probability, probability])
        return np.array(result)
