# -*- coding: utf-8 -*-
import math
import os
import re
import uuid

from six import string_types


def sigmoid(z):
    s = 1.0 / (1.0 + math.exp(-z))
    return s


def make_command(program, *args, **kwargs):
    command_parts = [program]
    for key, value in sorted(kwargs.items(), key=lambda t: t[0]):
        if len(key) == 1:
            option = '-{}'.format(key)
        else:
            option = '--{}'.format(key)
        if value is False:
            continue
        elif value is True:
            arg_list = [option]
        elif isinstance(value, string_types):
            arg_list = ['{} {}'.format(option, value)]
        elif hasattr(value, '__getitem__'):  # Listlike value
            arg_list = ['{} {}'.format(option, subvalue) for subvalue in value]
        else:
            arg_list = ['{} {}'.format(option, value)]
        command_parts.extend(arg_list)
    command_parts.extend(args)
    command_parts = ' '.join(command_parts)
    return command_parts


FEATURE_INDEX_PATTERN = re.compile(r'([a-z]+\d+).*?')


def parse_tensor(line):
    selected_tensor_features_str = line[line.index('{') + 1:line.index('}')]
    if not selected_tensor_features_str:
        return

    selected_tensor_features = selected_tensor_features_str.split(', ')
    return [parse_feature_index(feature) for feature in selected_tensor_features]


def parse_feature_index(feature_value):
    feature_index_match = re.match(FEATURE_INDEX_PATTERN, feature_value)
    if feature_index_match:
        return feature_index_match.group(1)
    else:
        raise ValueError('unknown feature %s' % feature_value)


def parse_tensornet_out(out_filepath):
    target_learn = []
    target_test = []
    tensors = []
    min_test = 1
    min_test_iteration = 0
    iteration = 0

    feature_types = {
        'b': 0,
        'sb': 0,
        'f': 0,
        'c': 0,
        's': 0,
    }
    feature_mapping = {}

    with open(out_filepath) as f:
        for line in f:
            line = line.strip()

            if line.startswith('feature'):
                line_parts = line.split()
                feature_index = line_parts[1]

                if line.endswith('sparse binary'):
                    feature_mapping['sb%s' % feature_types['sb']] = feature_index
                    feature_types['sb'] += 1
                elif line.endswith('binary'):
                    feature_mapping['b%s' % feature_types['b']] = feature_index
                    feature_types['b'] += 1
                elif 'borders' in line:
                    feature_mapping['f%s' % feature_types['f']] = feature_index
                    feature_types['f'] += 1

            elif line.startswith('sparse feature') and line.endswith('dhash start'):
                line = line.strip()
                line_parts = line.split()
                feature_index = int(line_parts[2].strip(': '))

                feature_mapping['s%s' % feature_types['s']] = feature_index
                feature_types['s'] += 1
            elif line.startswith('categ feature') and line.endswith('dhash start'):
                line = line.strip()
                line_parts = line.split()
                feature_index = int(line_parts[2].strip(': '))

                feature_mapping['c%s' % feature_types['c']] = feature_index
                feature_types['c'] += 1

            elif line.startswith('learn') and 'test' in line and 'min_test' in line:
                iteration += 1
                line_parts = line.split()
                try:
                    test_ll = float(line_parts[3])
                    target_learn.append(float(line_parts[1]))
                    target_test.append(test_ll)

                    if test_ll < min_test:
                        min_test = test_ll
                        min_test_iteration = iteration
                except ValueError:
                    pass

            if 'selected tensor' in line:
                tensors_feature_indexes = parse_tensor(line)
                tensors.append([feature_mapping[tensors_feature_index]
                                for tensors_feature_index in tensors_feature_indexes])

    return TensorNetOutput(
        target_learn=target_learn,
        target_test=target_test,
        min_test=min_test,
        min_test_iteration=min_test_iteration,
        tensors=tensors,
    )


class TensorNetOutput(object):
    def __init__(self, target_learn, target_test, min_test=None, min_test_iteration=None, tensors=None):
        self.target_learn = target_learn
        self.target_test = target_test
        self.min_test = min_test
        self.min_test_iteration = min_test_iteration
        self.tensors = tensors


class TensorNetError(Exception):
    pass


class TensorNet(object):
    def __init__(self, bin, model=None, fdfile=None):
        self.bin = bin
        self.model = model
        self.fdfile = fdfile

    def predict(self, features, model=None, pdfile=None, **kwargs):
        if not features:
            raise ValueError('no features')

        model = model or self.model
        pdfile = pdfile or self.fdfile

        rnd = uuid.uuid4().hex

        dataset_filepath = '/tmp/tensornet-%s.tsv' % rnd
        results_filepath = '/tmp/tensornet-eval-%s.tsv' % rnd

        features = [0, 0, 0, 0] + features  # add TensorNet columns: query_id, target, url, host

        with open(dataset_filepath, 'w+') as dataset_file:
            dataset_file.write('\t'.join(map(str, features)))

        try:
            self.test(dataset_filepath, results_filepath, model, pdfile, **kwargs)
        finally:
            os.remove(dataset_filepath)

        try:
            with open(results_filepath) as results_file:
                try:
                    _, _, _, _, prediction = results_file.readline().split('\t')
                    prediction = float(prediction)
                    return sigmoid(prediction)
                except ValueError:
                    raise TensorNetError('no prediction')
                finally:
                    os.remove(results_filepath)
        except OSError:
            raise TensorNetError('results file was not found')

    def train(self, outfile, dataset, depth, iterations,
              cross_validation=None, weights=False, model=None, fdfile=None, **kwargs):
        model = model or self.model
        fdfile = fdfile or self.fdfile

        params = dict(
            program=self.bin,
            L=True,  # локальный запуск
            g=True,  # режим классификации с gradient walker'ом
            f=dataset,  # файл с выборкой
            F=model,  # имя папки, куда будет сохранена модель
            d=fdfile,  # .fd-файл
            n=depth,  # глубина деревьев
            i=iterations,  # число итераций
        )
        params.update(kwargs)

        if cross_validation:
            params['X'] = cross_validation

        if weights:
            params['W'] = True  # учитывать веса

        tsnet_cmd = make_command(**params)
        os.system('%s > %s' % (tsnet_cmd, outfile))

        return parse_tensornet_out(outfile)

    def test(self, dataset, resultsfile, model=None, fdfile=None, **kwargs):
        model = model or self.model
        fdfile = fdfile or self.fdfile

        params = dict(
            program=self.bin,
            L=True,
            A=True,
            F=model,  # имя папки, куда будет сохранена модель
            f=dataset,  # файл с выборкой
            d=fdfile,  # .fd-файл
            e=resultsfile,
        )
        params.update(kwargs)

        tsnet_cmd = make_command(**params)
        os.system('%s > /dev/null' % tsnet_cmd)

    def grid_search(self, train_dataset, model, pdfile, output, n_list, i_list, x_list, **kwargs):
        results = []
        for n in n_list:
            for i in i_list:
                for x in x_list:
                    params = dict(
                        program=self.bin,
                        F=model,  # имя папки, куда будет сохранена модель
                        f=train_dataset,  # файл с выборкой
                        d=pdfile,  # .fd-файл
                        L=True,  # локальный запуск
                        g=True,  # режим классификации с gradient walker'ом
                        n=n,  # глубина дерева
                        i=i,  # число итераций
                        X=x,  # cross-fold validation
                    )
                    params.update(kwargs)

                    tsnet_cmd = make_command(**params)
                    os.system('%s > %s' % (tsnet_cmd, output))
                    tsnet_output = parse_tensornet_out(output)
                    results.append(((n, i, x), tsnet_output))

        return results
