# -*- coding: utf-8 -*-
import re
import enum
from copy import deepcopy
from itertools import izip
from yt.wrapper import ypath_split
from datacloud.dev_utils.yt.yt_config_table import LogTable
from datacloud.dev_utils.solomon.solomon_utils import str2ts, post_sensors_to_solomon
from datacloud.stability.stability import calculate_psi
from datacloud.dev_utils.time.patterns import RE_DAILY_LOG_FORMAT
from datacloud.stability.stability_record import (
    StabilityRecord, DistStabMngr, DistStabTable,
    METAFEATURE_MONITORING_LOG_TABLE,
    METAFEATURE_STABILITY_FOLDER,
    METAFEATURE_MONITORING_LOG_TABLE_SCHEME)

from datacloud.dev_utils.logging.logger import get_basic_logger
logger = get_basic_logger(__name__)

DISTRIB_TYPE_DAILY = 'daily'
DISTRIB_TYPE_SEGMENT = 'segment'


class DataTags(enum.Enum):
    REFERENCE = 'reference'
    INPUT = 'input_data'

    def __str__(self):
        return self.value


class MetafeatureDistributionHandler(object):

    def __init__(self, yt_client, task, is_testing_mode=False):
        self.is_testing_mode = is_testing_mode
        self.yt_client = yt_client
        self.date = task.data['date']
        self.type = task.data['type']
        self.table_path = task.data['table_path']
        self.stability_record = StabilityRecord(**task.data['stability_record_params'])
        self.score_id = self.stability_record.score_id
        self.partner_id = self.stability_record.partner_id
        self.score_application_date = yt_client.get_attribute(self.stability_record.features_table, 'key')
        self.config_table = DistStabTable()
        self.config = self.config_table.get_record_by_params({
            'partner_id': self.partner_id,
            'score_id': self.score_id
        })

    def __call__(self):
        self.evaluate_prod_or_test(self.is_testing_mode)
        logger.info('Started table processing {} with reference table {}'.format(self.table_path, self.reference_table_path))
        self.input_data = self.load_data(DataTags.INPUT)
        self.reference_data = self.load_data(DataTags.REFERENCE)
        self.process_histograms(self.input_data)
        self.process_histograms(self.reference_data)
        self.calculate_psis()
        self.send_psis_to_solomon()
        self.send_distributions_to_solomon()
        self.check_criterion()
        self.send_psis_to_log_table()
        if not (self.config['fixed_reference_table'] or self.is_testing_mode):
            self.change_reference_table_to_current()
        logger.info('Everything done!')

    def evaluate_prod_or_test(self, is_testing_mode):
        if is_testing_mode:
            self.evaluate_test_mode()
        else:
            self.evaluate_prod_mode()

    def evaluate_prod_mode(self):
        if self.type == DISTRIB_TYPE_DAILY:
            self.reference_table_path = self.config['reference_distribution_table']
        else:
            self.reference_table_path = self.config['reference_distribution_table_segments']

    def evaluate_test_mode(self):
        self.reference_table_path = self.get_closest_date_path()

    def get_closest_date_path(self):
        stability_record = self.stability_record
        stability_manager = DistStabMngr(stability_record)
        if self.type == DISTRIB_TYPE_DAILY:
            path_for_search = stability_manager.daily_stability_path
        elif self.type == DISTRIB_TYPE_SEGMENT:
            path_for_search = stability_manager.segment_stability_path
        else:
            raise ValueError('Unexpected type value: {}'.format(self.type))
        logger.info('Looking for previos day distribution for comparison in folder [{}]'.format(path_for_search))
        distribs_list = sorted(self.yt_client.list(path_for_search, absolute=True))[::-1]
        for distrib in distribs_list:
            if self.date > ypath_split(distrib)[-1]:
                return distrib
        raise AssertionError('No tables to control psis for this score')

    def load_data(self, which_data):
        if which_data == DataTags.INPUT:
            recs = self.yt_client.read_table(self.table_path)
            data = {}
        elif which_data == DataTags.REFERENCE:
            if self.reference_table_path is not None:
                recs = self.yt_client.read_table(self.reference_table_path)
                data = {}
            else:
                data = None
        else:
            raise NotImplementedError('Field which_data expected to be "input_data" or "reference" but got "{}"'.format(which_data))

        if data is None:
            return data

        for rec in recs:
            if self.type == DISTRIB_TYPE_DAILY:
                rec['segment'] = DISTRIB_TYPE_DAILY
            if data.get(rec['segment']) is None:
                data[rec['segment']] = {}
            data[rec['segment']][rec['feature']] = rec['hist']
        return data

    def process_histograms(self, data):
        if data is None:
            return None

        for segment, hists in data.iteritems():
            for mfeature, hist in hists.iteritems():
                if self.config['custom_bin_aggregation']:
                    data[segment][mfeature] = self.reduce_bins_ca(hist, self.config['custom_bin_aggregation'], mfeature)
                sum_ = sum(data[segment][mfeature])
                data[segment][mfeature] = map(lambda a: 1.0 * a / sum_, data[segment][mfeature])
        return data

    def calculate_psis(self):
        self.psis = {}
        for segment, hists in self.input_data.iteritems():
            for mfeature, hist in hists.iteritems():
                self.psis.setdefault(segment, {})

                if self.reference_data is not None and self.reference_data.get(segment):
                    if self.reference_data[segment].get(mfeature):
                        self.psis[segment][mfeature] = calculate_psi(
                            self.input_data[segment][mfeature],
                            self.reference_data[segment][mfeature]
                        )
                    else:
                        raise IndexError(
                            'Missmatch of reference (len: {}) and input histograms (len: {})'.format(
                                len(self.reference_data[segment]),
                                len(self.input_data[segment])
                            )
                        )
                else:
                    self.psis[segment][mfeature] = 0

    def check_criterion(self):
        self.stability_ok = deepcopy(self.psis)
        for segment, psis in self.psis.iteritems():
            for mfeature, psi in psis.iteritems():
                self.stability_ok[segment][mfeature] = bool(psi < self.config['criterion'])
                if self.stability_ok[segment][mfeature] is False:
                    logger.warn('PSI for metafeature #{} is higher than client requirements!'.format(mfeature))

    def send_distributions_to_solomon(self):
        sensors = []
        for segment, hists in self.input_data.iteritems():
            for mfeature, hist in hists.iteritems():
                for i in range(len(hist)):
                    sensors.append(
                        self.make_metafeature_hist_sensor(
                            segment, mfeature, i
                        )
                    )
        post_sensors_to_solomon('datacloud', 'feature', 'metafeature-distribution', sensors)
        logger.info('Histograms were sent to SOLOMON!')

    def send_psis_to_solomon(self):
        sensors = []
        for segment, psis in self.psis.iteritems():
            for mfeature, psi in psis.iteritems():
                sensors.append(
                    self.make_metafeature_psi_sensor(
                        segment, mfeature
                    )
                )
        post_sensors_to_solomon('datacloud', 'feature', 'metafeature-psi', sensors)
        logger.info('PSIs were sent to SOLOMON!')

    def make_metafeature_psi_sensor(self, segment, mfeature):
        ts = str2ts(self.date)
        return {
            'labels': {
                'partner_id': self.partner_id,
                'score_id': self.score_id,
                'segment': segment,
                'feature': mfeature,
                'testing_mode': self.is_testing_mode,
                'score_application_date': self.score_application_date
            },
            'ts': ts,
            'value': str(self.psis[segment][mfeature])
        }

    def make_metafeature_hist_sensor(self, segment, mfeature, bin):
        ts = str2ts(self.date)
        return {
            'labels': {
                'partner_id': self.partner_id,
                'score_id': self.score_id,
                'segment': segment,
                'feature': mfeature,
                'testing_mode': self.is_testing_mode,
                'score_application_date': self.score_application_date,
                'bin': bin
            },
            'ts': ts,
            'value': str(self.input_data[segment][mfeature][bin])
        }

    def change_reference_table_to_current(self):
        if self.type == DISTRIB_TYPE_DAILY:
            self.config['reference_distribution_table'] = self.table_path
        else:
            self.config['reference_distribution_table_segments'] = self.table_path
        self.config_table.insert_records([self.config])
        logger.info('Changed reference distribution for features from {} to {}'.format(self.reference_table_path, self.table_path))

    def send_psis_to_log_table(self):
        log_table = LogTable(METAFEATURE_MONITORING_LOG_TABLE, METAFEATURE_MONITORING_LOG_TABLE_SCHEME)
        if not log_table.exists():
            log_table.create_table()
        records_out = [dict_keys_to_string({
            'partner_id': self.partner_id,
            'score_id': self.score_id,
            'date': self.date,
            'segment_type': self.type,
            'testing_mode': self.is_testing_mode,
            'metafeature_psi_list': self.psis,
            'criterion_check_list': self.stability_ok,
            'reference_table': self.reference_table_path,
            'input_table': self.table_path,
            'score_application_date': self.score_application_date
        })]
        log_table.insert_records(records_out)
        logger.info('Logs Successfully transfered to YT')

    def reduce_bins_ca(self, hist, ca, mfeature_number):
        new_hist = []
        ca_feat = [0] + ca[mfeature_number]
        for left, right in izip(ca_feat, ca_feat[1:]):
            new_hist.append(sum(hist[left:right]))
        return new_hist


def detect_ready_api_stream_logs(date_time, yt_client):
    for table_path in yt_client.search(METAFEATURE_STABILITY_FOLDER):
        if re.match(RE_DAILY_LOG_FORMAT, ypath_split(table_path)[-1]):
            yield table_path, True


def dict_keys_to_string(rec):
    if isinstance(rec, dict):
        new_rec = dict()
        for key, item in rec.iteritems():
            new_rec[str(key)] = dict_keys_to_string(item)
    else:
        new_rec = rec
    return new_rec


if __name__ == '__main__':
    pass
