# -*- coding: utf-8 -*-
from datacloud.dev_utils.yt import yt_utils
from datacloud.dev_utils.solomon import solomon_utils
from datacloud.dev_utils.solomon.solomon_utils import is_date, str2ts
from datacloud.dev_utils.logging.logger import get_basic_logger
from datacloud.stability import stability
from datacloud.model_applyer.lib.features_v2 import DSSMFeature
from datacloud.model_applyer.lib.features_list import base_features
from datacloud.features.dssm.dssm_main import DSSMTables
from datacloud.features.dssm.join_scores import join_dssm_scores
from datacloud.stability.features_psi_table import FeaturesPSITable

logger = get_basic_logger(__name__)

TMP_FOLDER = '//tmp'
FEATURE_REFERENCES_FOLDER_PATH = '//home/x-products/production/datacloud/stability/scores_and_features_PSI/reference_histograms'


# ------------ SOLOMON FUNCTIONS START ------------
def make_sensor_feature_psi(feature_name, timestamp, feature_number, value):
    """ Makes sensor for one feature of feature set
    """
    return {
        'labels': {
            'feature_name': feature_name,
            'type': 'psi',
            'feature_number': feature_number
        },
        'ts': timestamp,
        'value': value
    }


def load_row_to_solomon(row):
    """
    Upload row with fields ['date', 'score_name', 'PSI', 'norm_hist']
    to solomon
    row['date'] should have 'YYYY-MM-DD' format
    """
    feature_name = row['feature_name']
    date = row['date']
    sensors = []
    if is_date(date):
        timestamp = str2ts(date)
        for i in range(len(row['PSI'])):
            sensors.append(make_sensor_feature_psi(feature_name, timestamp, i, row['PSI'][i]))
        solomon_utils.post_sensors_to_solomon('datacloud', 'feature', 'stability', sensors)
        logger.info('Loaded score {} calcuated on {}'.format(feature_name, date))
# ------------ SOLOMON FUNCTIONS END ------------


def calculate_feature_psi_table(yt_client, date_str, feature_name, input_tables):
    """
    Function takes table with calculated set of features and calculates
    PSI for each of it. Then data is added to PSI_FEATURE_TABLE_PATH table and
    sends to solomon.

    There is reference tables folder and if there is no reference table
    for current feature it add it as reference. Once feature was uploaded
    fature table was uploaded it added as new reference table
    """
    with yt_client.Transaction(), yt_client.TempTable(TMP_FOLDER) as tmp_table:
        if len(input_tables) > 1:
            path_config = DSSMTables(date_str, yt_client=yt_client, garbage_collect_on=False)
            input_table = tmp_table
            path_config.config.ready_table = input_table

            join_dssm_scores(path_config.processor)
        else:
            input_table = input_tables[0]
        table_size = yt_client.row_count(input_table)

        stability.FeaturesToStabilityBinsConverter(yt_client, n_bins=100)(input_table, tmp_table)
        logger.info('Processed table: {}'.format(input_table))

        # check is there referenсe table for input features
        is_reference = True
        reference_table = input_table
        reference_tables_list = list(yt_client.list(FEATURE_REFERENCES_FOLDER_PATH, absolute=True))
        for table in sorted(reference_tables_list, reverse=True):
            if table.split('/')[-1][:-len('_0000-00-00')] == feature_name:
                reference_table = table
                is_reference = False
                break

        if not is_reference:
            reference_histograms = dict()
            for row in yt_client.read_table(reference_table):
                reference_histograms[row['feature']] = row['norm_hist']
        else:
            logger.warning('Reference histogram for input features was not found! Input histogram will be used as reference')

        psi_list = []
        for feature in yt_client.read_table(tmp_table):
            if is_reference:
                psi = 0
            else:
                ref_hist = reference_histograms.get(feature['feature'])
                if ref_hist is not None:
                    psi = stability.calculate_psi(feature['norm_hist'], ref_hist)
                else:
                    psi = 0
                    logger.warning('New {} feature found at index of {}.'.format(feature_name, feature['feature']))

            psi_list.append(psi)

        row_out = {
            'date': date_str,
            'feature_name': feature_name,
            'PSI': psi_list,
            'reference_table': reference_table,
            'input_table': sorted(input_tables)[-1],
            'is_reference': is_reference,
            'table_size': table_size
        }
        if len(input_tables) > 1:
            row_out['additional'] = {'num_tables_used': len(input_tables)}
        FeaturesPSITable().add_record(**row_out)

        logger.info('PSI uploaded to YT for table: {}'.format(input_table))

        load_row_to_solomon(row_out)
        logger.info('PSI uploaded to Solomon for table: {}'.format(input_table))
        yt_client.copy(tmp_table, FEATURE_REFERENCES_FOLDER_PATH + '/' + feature_name + '_' + row_out['date'], force=True)
        if not is_reference:
            yt_client.remove(reference_table)


def run_calculate_feature_psi_table(yt_client, date_str, feature_name, first_table, last_table):
    feature = base_features[feature_name]
    all_tables = sorted(yt_client.list(feature.get_weekly_folder(), absolute=True))
    input_tables = filter(lambda t: first_table <= t <= last_table, all_tables)

    if len(input_tables) > 1:
        assert feature_name == DSSMFeature().feature_name, \
            'Multiple input tables allowed only for DSMM; {} found!'.format(feature_name)
        # if last_table == all_tables[-1]:
        #     input_tables = [feature.get_ready_table()]

    calculate_feature_psi_table(yt_client, date_str, feature_name, input_tables)


if __name__ == '__main__':
    yt_client = yt_utils.get_yt_client()
    table_list = []
    # table_list.append('//home/x-products/production/datacloud/aggregates/cluster/user2clust/2018-01-31')
    # table_list.append('//home/x-products/production/datacloud/aggregates/cluster/user2clust/2018-04-05')
    table_list.append('//home/x-products/production/datacloud/aggregates/cluster/user2clust/2018-05-18')
    # for table in yt_client.list('//home/x-products/junk/nryzhikov/stability/aggregates/test_features', absolute=True):
    for table in table_list:
        calculate_feature_psi_table(table, yt_client)
