# -*- coding: utf-8 -*-
from collections import defaultdict
import numpy as np
import yt.wrapper as yt_wrapper
from datacloud.dev_utils.logging.logger import get_basic_logger
from datacloud.dev_utils.data import data_utils as du
from datacloud.dev_utils.yt import yt_utils, yt_files
from datacloud.dev_utils.yt.converter import TableConverter

logger = get_basic_logger(__name__)

__all__ = [
    'FeaturesToMetaFeaturesConverter',
    'FeaturesToStabilityBinsConverter'
]

TAG = 'STABILITY-BINS'

DEFAULT_NBINS = 100
DEFAULT_LEFT_BORDER = -1
DEFAULT_RIGHT_BORDER = 1


@yt_wrapper.aggregator
class BuildBinsMapper:
    """
    Makes histogram for score and feature tables
    """
    def __init__(self, n_bins=DEFAULT_NBINS,
                 left_border=DEFAULT_LEFT_BORDER,
                 right_border=DEFAULT_RIGHT_BORDER,
                 table_type='features',
                 column=None):
        self._n_bins = n_bins
        self._left_border = left_border
        self._right_border = right_border
        self._table_type = table_type
        if column is not None:
            self._column = column
        elif table_type == 'features':
            self._column = 'features'
        elif table_type == 'scores':
            self._column = 'score'

        self._range = abs(self._right_border - self._left_border)
        self._step = self._range / float(self._n_bins)
        self._shift = self._range / 2.0
        self._eps = 1e-3

    def __call__(self, recs):
        dist = defaultdict(lambda: 0)
        for rec in recs:

            if self._table_type == 'scores':
                input_array = np.array([rec[self._column]])
            elif self._table_type == 'features':
                input_array = du.array_fromstring(rec[self._column])

            numbers = (input_array.clip(self._left_border, self._right_border - self._eps) -
                       self._left_border) / self._step

            for number_idx, bin_idx in enumerate(numbers):
                dist[(number_idx, int(bin_idx))] += 1

        for key, value in dist.items():
            yield {
                'feature': key[0],
                'bin': key[1],
                'count': value,
                'segment': rec['segment'] if 'segment' in rec else 'daily'
            }


def combine_bins(key, recs):
    val = 0
    for rec in recs:
        val += rec['count']
    yield {
        'feature': key['feature'],
        'bin': key['bin'],
        'count': val,
        'segment': rec['segment']
    }


class BuildHistogramReducer():
    def __init__(self, n_bins=DEFAULT_NBINS):
        self.n_bins = n_bins

    def __call__(self, key, recs):
        hist = {idx: 0 for idx in range(self.n_bins)}
        for rec in recs:
            hist[rec['bin']] = rec['count']
        result = [val[1] for val in sorted(hist.items())]
        s = float(sum(result))
        if s > 0:
            norm_hist = [val / s for val in result]
        else:
            norm_hist = [0 for val in result]
        yield {
            'feature': key['feature'],
            'hist': result,
            'norm_hist': norm_hist,
            'segment': rec['segment']
        }


def calculate_psi(target, reference, eps=1e-9):
    target = np.array(target)
    reference = np.array(reference)
    return np.sum((target - reference) * np.log((target + eps) / (reference + eps)))


class FeaturesToStabilityBinsConverter(TableConverter):
    def __init__(self, yt_client, n_bins, table_type='features'):
        super(FeaturesToStabilityBinsConverter, self).__init__(yt_client)
        self._n_bins = n_bins
        self._table_type = table_type
        self._output_table_schema = [
            {'name': 'feature', 'type': 'int32', 'required': True},
            {'name': 'hist', 'type': 'any'},
            {'name': 'segment', 'type': 'string'},
            {'name': 'norm_hist', 'type': 'any'},
        ]

    def _convert(self, input_table, output_table):
        output_table = yt_wrapper.TablePath(
            output_table,
            schema=self._output_table_schema
        )
        with self._yt_client.Transaction():
            with self._yt_client.TempTable('//tmp', 'xprod-hist') as tmp_table:
                self._yt_client.run_map_reduce(
                    BuildBinsMapper(n_bins=self._n_bins, table_type=self._table_type),
                    combine_bins,
                    input_table,
                    tmp_table,
                    reduce_by=['feature', 'bin'],
                    spec={
                        'title': '[{}] Compute stability bins'.format(TAG)
                    }
                )
                self._yt_client.run_map_reduce(
                    None,
                    BuildHistogramReducer(n_bins=self._n_bins),
                    tmp_table,
                    output_table,
                    reduce_by='feature',
                    spec={'title': '[{}] Build histogramms'.format(TAG)}
                )
                self._yt_client.run_merge(
                    output_table,
                    output_table,
                    spec={
                        'title': '[{}] Merge chunks'.format(TAG),
                        'combine_chunks': True
                    }
                )
                self._yt_client.run_sort(
                    output_table,
                    sort_by='feature',
                    spec={'title': '{} Sort histogramms'.format(TAG)}
                )


class GroupFeaturesForMonitoringMapper(object):
    def __init__(self, clf, list_groups):
        self._clf = clf
        self._groups = list_groups

    def __call__(self, rec):
        X = du.array_fromstring(rec['features'])
        meta_features = np.zeros(len(self._groups))
        coefs = self._clf.coef_[0]
        for idx in range(len(self._groups)):
            meta_features[idx] = sum(X[self._groups[idx]] * coefs[self._groups[idx]])
        rec['features'] = du.array_tostring(meta_features)
        yield rec


class FeaturesToMetaFeaturesConverter(TableConverter):
    def __init__(self, yt_client, clf_yt_path, list_groups_yt_path):
        super(FeaturesToMetaFeaturesConverter, self).__init__(yt_client)
        self._clf_yt_path = clf_yt_path
        self._list_groups_yt_path = list_groups_yt_path

    def _convert(self, input_table, output_table):
        clf = yt_files.joblib_load_from_yt(self._yt_client, self._clf_yt_path)
        list_groups = yt_files.joblib_load_from_yt(self._yt_client, self._list_groups_yt_path)
        self._yt_client.run_map(
            GroupFeaturesForMonitoringMapper(clf, list_groups),
            input_table,
            output_table
        )


if __name__ == '__main__':
    yt_client = yt_utils.get_yt_client()

    list_groups_yt = '//home/x-products/production/datacloud/bins/partner_models/tcs/additional/bins_for_987_20181212.pkl'
    model_yt = '//home/x-products/production/datacloud/bins/partner_models/tcs/tcs_xprod_987_20181212/tcs_xprod_987_20181212.pkl'

    meta_features_table = '//projects/scoring/tmp/stability-scores/interesting_meta_features'
    meta_features_stability_table = '//projects/scoring/tmp/stability-scores/interesting_meta_features_stability'

    features_to_metafeatures = FeaturesToMetaFeaturesConverter(
        yt_client,
        clf_yt_path='//home/x-products/production/datacloud/bins/partner_models/tcs/tcs_xprod_987_20181212/current',
        list_groups_yt_path='//home/x-products/production/datacloud/bins/partner_models/tcs/additional/sklearn181_bins_for_987_20181212.pkl'
    )

    features_to_stability_bins = FeaturesToStabilityBinsConverter(yt_client, n_bins=100)
    features_to_stability_bins << features_to_metafeatures

    features_to_stability_bins(
        '//tmp/r9-cid-to-features-table',  # '//projects/scoring/tmp/stability-scores/interesting_features',
        '//tmp/r9-tmp-stability'
    )

    # input_table = '//projects/scoring/tmp/stability-scores/interesting_features'
    # output_table = '//tmp/r9-tmp-metafeatures'

    # features_to_metafeatures_converter(input_table, output_table)

    # FeaturesToStabilityBinsConverter(yt_client, 100)('//tmp/r9-tmp-metafeatures', '//tmp/r9-tmp-stability')
