#!/usr/bin/env python
# -*- coding: utf-8 -*-

import logging

from crypta.lib.python.bigb_catboost_applier import (
    train_sample as train_sample_utils,
    features_mapping as features_mapping_utils,
)
from crypta.lib.python.nirvana.nirvana_helpers.nirvana_transaction import NirvanaTransaction
from crypta.lib.python.yt import yt_helpers
from crypta.rt_socdem.lib.python.model.config import config
from crypta.rt_socdem.lib.python.model import (
    fields,
    utils,
)


logger = logging.getLogger(__name__)


def prepare_train_val_samples(
    yt_client,
    features_mapping,
    segment_name,
    catboost_train_table=None,
    catboost_val_table=None,
    raw_train_table=None,
    validation_sample_percentage=config.VALIDATION_SAMPLE_PERCENTAGE,
    validation_sample_rest=config.VALIDATION_SAMPLE_REST,
):
    catboost_train_table = catboost_train_table or config.CATBOOST_TRAIN_TEMPLATE_TABLE.format(segment_name=segment_name)
    catboost_val_table = catboost_val_table or config.CATBOOST_VAL_TEMPLATE_TABLE.format(segment_name=segment_name)
    raw_train_table = raw_train_table or config.RAW_TRAIN_TEMPLATE_TABLE.format(segment_name=segment_name)

    dst_paths = [catboost_train_table, catboost_val_table]

    for path in dst_paths:
        train_sample_utils.create_kv_empty_table(yt_client=yt_client, path=path)

    yt_client.run_map(
        train_sample_utils.PrepareSamplesMapper(
            features_mapping=features_mapping,
            counters_to_features=utils.COUNTERS_TO_FEATURES,
            keywords_to_features=utils.KEYWORDS_TO_FEATURES,
            key_column=fields.YANDEXUID,
            target_column=segment_name,
            weight_column='{}_weight'.format(segment_name),
            validation_sample_percentage=validation_sample_percentage,
            validation_sample_rest=validation_sample_rest,
        ),
        raw_train_table,
        dst_paths,
    )

    for path in dst_paths:
        yt_client.run_sort(path, sort_by=fields.KEY)

    assert yt_client.row_count(catboost_train_table) > 0, 'empty {} sample'.format(segment_name)
    return catboost_train_table, catboost_val_table


def prepare(
    yt_client,
    features_mapping_table_path=config.FEATURES_MAPPING_TABLE,
    catboost_pool_file=config.CATBOOST_POOL_FILE,
    segments=config.SOCDEM_TYPES,
    train_val_kwargs=None,
):
    train_val_kwargs = train_val_kwargs or {}

    features_mapping, float_features_description = features_mapping_utils.get_features_mapping(
        yt_client=yt_client,
        features_mapping_table_path=features_mapping_table_path,
    )

    with NirvanaTransaction(yt_client):
        for segment_name in segments:
            catboost_train_table, catboost_val_table = prepare_train_val_samples(
                yt_client=yt_client,
                features_mapping=features_mapping,
                segment_name=segment_name,
                **train_val_kwargs
            )
            for counter_name, count in [
                ('{}_train_rows'.format(segment_name), yt_client.row_count(catboost_train_table)),
                ('{}_validation_rows'.format(segment_name), yt_client.row_count(catboost_val_table)),
            ]:
                yt_helpers.write_stats_to_yt(
                    yt_client=yt_client,
                    table_path=config.DATALENS_REALTIME_SOCDEM_COUNTS_TABLE,
                    data_to_write={
                        'counter_name': counter_name,
                        'count': count,
                    },
                    schema={
                        'counter_name': 'string',
                        'count': 'uint64',
                    },
                )
        logger.info('Prepared {} Catboost train and val samples'.format(', '.join(segments)))

        train_sample_utils.create_cd_file(
            yt_client=yt_client,
            path=catboost_pool_file,
            float_features_description=float_features_description,
            text_features_description=['binding_text'],
            has_weight=True,
        )
        logger.info('Prepared pool.cd Catboost file')
