from __future__ import (
    absolute_import,
    division,
    print_function,
)

from abc import ABCMeta
import math
import os
import re
import shutil
from subprocess import call
import tempfile

import numpy as np
import pandas as pd
import six
from sklearn.base import (
    BaseEstimator,
    ClassifierMixin,
)
from sklearn.utils import check_X_y


FEATURE_INDEX_PATTERN = re.compile(r'([a-z]+\d+).*?')


def read_prediction_results(path):
    columns = (
        'query_id',
        'target',
        'url',
        'weight',
        'result',
    )
    df = pd.read_csv(
        path,
        sep='\t',
        header=None,
        names=columns
    )
    result = df['result'].map(lambda z: 1.0 / (1.0 + math.exp(-z))).values
    return np.column_stack((1 - result, result))


def parse_tensor(line):
    selected_tensor_features_str = line[line.index('{') + 1:line.index('}')]
    if not selected_tensor_features_str:
        return

    selected_tensor_features = selected_tensor_features_str.split(', ')
    return [parse_feature_index(feature) for feature in selected_tensor_features]


def parse_feature_index(feature_value):
    feature_index_match = re.match(FEATURE_INDEX_PATTERN, feature_value)
    if feature_index_match:
        return feature_index_match.group(1)
    else:
        raise ValueError('unknown feature %s' % feature_value)


def parse_tensornet_out(out_filepath):
    target_learn = []
    target_test = []
    tensors = []
    min_test = 1
    min_test_iteration = 0
    iteration = 0

    feature_types = {
        'b': 0,
        'sb': 0,
        'f': 0,
        'c': 0,
        's': 0,
    }
    feature_mapping = {}

    with open(out_filepath) as f:
        for line in f:
            line = line.strip()

            if line.startswith('feature'):
                line_parts = line.split()
                feature_index = line_parts[1]

                if line.endswith('sparse binary'):
                    feature_mapping['sb%s' % feature_types['sb']] = feature_index
                    feature_types['sb'] += 1
                elif line.endswith('binary'):
                    feature_mapping['b%s' % feature_types['b']] = feature_index
                    feature_types['b'] += 1
                elif 'borders' in line:
                    feature_mapping['f%s' % feature_types['f']] = feature_index
                    feature_types['f'] += 1

            elif line.startswith('sparse feature') and line.endswith('dhash start'):
                line = line.strip()
                line_parts = line.split()
                feature_index = int(line_parts[2].strip(': '))

                feature_mapping['s%s' % feature_types['s']] = feature_index
                feature_types['s'] += 1
            elif line.startswith('categ feature') and line.endswith('dhash start'):
                line = line.strip()
                line_parts = line.split()
                feature_index = int(line_parts[2].strip(': '))

                feature_mapping['c%s' % feature_types['c']] = feature_index
                feature_types['c'] += 1

            elif line.startswith('learn') and 'test' in line and 'min_test' in line:
                iteration += 1
                line_parts = line.split()
                try:
                    test_ll = float(line_parts[3])
                    target_learn.append(float(line_parts[1]))
                    target_test.append(test_ll)

                    if test_ll < min_test:
                        min_test = test_ll
                        min_test_iteration = iteration
                except ValueError:
                    pass

            if 'selected tensor' in line:
                tensors_feature_indexes = parse_tensor(line)
                tensors.append([feature_mapping[tensors_feature_index]
                                for tensors_feature_index in tensors_feature_indexes])

    return dict(
        target_learn=target_learn,
        target_test=target_test,
        min_test=min_test,
        min_test_iteration=min_test_iteration,
        tensors=tensors,
    )


class TensorNetClassifier(six.with_metaclass(ABCMeta, BaseEstimator, ClassifierMixin)):
    def __init__(self,
                 iterations=10,
                 depth=3,
                 cv='1/5',
                 binary_name='tsnet2',
                 base_folder=None):
        self.iterations = iterations
        self.depth = depth
        self.cv = cv
        self.binary_name = binary_name
        self.base_folder = base_folder

    def __del__(self):
        if hasattr(self, '__real_base_folder') and self.__real_base_folder is not None:
            if os.path.exists(self.__real_base_folder):
                shutil.rmtree(self.__real_base_folder)

    def get_feature_importances(self):
        factors_frequency = {}
        for tensor in self.__parsed_model['tensors']:
            for factor in tensor:
                factors_frequency[self.__columns[int(factor)]] = factors_frequency.setdefault(self.__columns[int(factor)], 0) + 1
        not_tensornet_factors = list(set(self.__columns.tolist()) - set(factors_frequency.keys()))
        return sorted(factors_frequency.items(), key=lambda x: -x[1]) + map(lambda x: (x, 0), not_tensornet_factors)

    def fit(self, X, y, sample_weight=None):
        '''
        :param X: pd.DataFrame
        :param y: pd.Series or nd.array
        :param sample_weight: (optional)
        :return:
        '''

        check_X_y(X, y)
        self.__real_base_folder = self.base_folder
        if self.base_folder is None:
            self.__real_base_folder = tempfile.mkdtemp(prefix='tensornet_model')
        if os.path.exists(self.__real_base_folder):
            shutil.rmtree(self.__real_base_folder)
        if not os.path.exists(self.__real_base_folder):
            os.makedirs(self.__real_base_folder)

        self.__columns = X.columns

        self.__model_folder = os.path.join(self.__real_base_folder, 'model')
        self.__train_dataset_filepath = os.path.join(self.__real_base_folder, 'train.tsv')
        self.__test_dataset_filepath = os.path.join(self.__real_base_folder, 'test.tsv')
        self.__prediction_dataset_filepath = os.path.join(self.__real_base_folder, 'prediction.tsv')
        self.__fd_filepath = os.path.join(self.__real_base_folder, 'semiautoform.fd')
        self.__features_filepath = os.path.join(self.__real_base_folder, 'features.txt')
        self.__trained_model_filepath = os.path.join(self.__real_base_folder, 'tensornet.txt')

        data = pd.concat(
            [
                pd.DataFrame(
                    {
                        'query_id': range(1, 1 + X.shape[0]),
                        'target': y,
                        'url': range(1, 1 + X.shape[0]),
                        'host': [1] * X.shape[0] if sample_weight is None else sample_weight,
                    },
                    columns=['query_id', 'target', 'url', 'host'],
                ).reset_index(drop=True),
                X.reset_index(drop=True),
            ],
            axis=1,
        )
        data.to_csv(
            self.__train_dataset_filepath,
            sep='\t',
            float_format='%.10f',
            header=False,
            index=False,
        )

        with open(self.__features_filepath, 'w+') as features_file:
            for column in self.__columns:
                features_file.write('%s\n' % column)

        with open(self.__fd_filepath, 'w+') as fd_file:
            for feature_index in np.where(self.__columns.isin(X.select_dtypes(include=[object])))[0]:
                fd_file.write('%s\t%s\n' % (feature_index, 'categ'))

        cmd = [
            self.binary_name,
            '-L',
            '-g',
            '-W',
            '-f', self.__train_dataset_filepath,
            '-F', self.__model_folder,
            '-d', self.__fd_filepath,
            '-n', str(self.depth),
            '-i', str(self.iterations),
        ]

        if self.cv:
            cmd.extend(['-X', self.cv])

        cmd = ' '.join(cmd)
        call('%s > %s' % (cmd, self.__trained_model_filepath), shell=True)
        self.__parsed_model = parse_tensornet_out(self.__trained_model_filepath)
        return self

    def predict_proba(self, X):
        '''
        :param X: pd.DataFrame or np.ndarray
        :return:
        '''
        data = pd.concat(
            [
                pd.DataFrame(
                    {
                        'query_id': range(10 ** 9, 10 ** 9 + X.shape[0]),
                        'target': [1] * (X.shape[0] - 1) + [0],
                        'url': range(10 ** 9, 10 ** 9 + X.shape[0]),
                        'host': [1] * X.shape[0]
                    },
                    columns=['query_id', 'target', 'url', 'host']
                ),
                X.reset_index(drop=True)
            ],
            axis=1
        )

        data.to_csv(
            self.__test_dataset_filepath,
            sep='\t',
            float_format='%.10f',
            header=False,
            index=False
        )

        cmd = [
            self.binary_name,
            '-L',
            '-A',
            '-f', self.__test_dataset_filepath,
            '-F', self.__model_folder,
            '-d', self.__fd_filepath,
            '-e', self.__prediction_dataset_filepath,
        ]

        cmd = ' '.join(cmd)
        call('%s > /dev/null' % cmd, shell=True)
        return read_prediction_results(self.__prediction_dataset_filepath, )
