#!/usr/bin/env python
# -*- coding: utf-8 -*-

import logging

from crypta.lib.python.bigb_catboost_applier import (
    train_sample as train_sample_utils,
    features_mapping as features_mapping_utils,
)
from crypta.lib.python.nirvana.nirvana_helpers.nirvana_transaction import NirvanaTransaction
from crypta.lib.python.yt import yt_helpers
from crypta.prism.lib.config import (
    config,
    fields,
)
from crypta.prism.services.training.lib import utils


logger = logging.getLogger(__name__)


def prepare_train_val_samples(
    yt_client,
    features_mapping,
    raw_train_sample_table,
    catboost_train_table,
    catboost_val_table,
):
    dst_paths = [catboost_train_table, catboost_val_table]

    for path in dst_paths:
        train_sample_utils.create_kv_empty_table(yt_client=yt_client, path=path)

    yt_client.run_map(
        train_sample_utils.PrepareSamplesMapper(
            features_mapping=features_mapping,
            counters_to_features=utils.COUNTERS_TO_FEATURES,
            keywords_to_features=utils.KEYWORDS_TO_FEATURES,
            key_column=fields.CRYPTA_ID,
            target_column='target',
            weight_column=None,
            validation_sample_percentage=config.VALIDATION_SAMPLE_PERCENTAGE,
            validation_sample_rest=config.VALIDATION_SAMPLE_REST,
        ),
        raw_train_sample_table,
        dst_paths,
    )

    for path in dst_paths:
        yt_client.run_sort(path, sort_by=fields.KEY)

    assert yt_client.row_count(catboost_train_table) > 0, 'empty prism train sample'


def prepare(
    yt_client,
):
    with NirvanaTransaction(yt_client):
        features_mapping, float_features_description = features_mapping_utils.get_features_mapping(
            yt_client=yt_client,
            features_mapping_table_path=config.FEATURES_MAPPING_TABLE,
        )

        prepare_train_val_samples(
            yt_client=yt_client,
            features_mapping=features_mapping,
            raw_train_sample_table=config.RAW_TRAIN_SAMPLE_TABLE,
            catboost_train_table=config.CATBOOST_TRAIN_SAMPLE_TABLE,
            catboost_val_table=config.CATBOOST_VAL_SAMPLE_TABLE,
        )
        logger.info('Prepared prism Catboost train samples')

        for counter_name, path in (
            ('train_rows', config.CATBOOST_TRAIN_SAMPLE_TABLE),
            ('validation_rows', config.CATBOOST_VAL_SAMPLE_TABLE)
        ):
            yt_helpers.write_stats_to_yt(
                yt_client=yt_client,
                table_path=config.DATALENS_REALTIME_PRISM_COUNTS_TABLE,
                data_to_write={
                    'counter_name': counter_name,
                    'count': yt_client.row_count(path),
                },
                schema={
                    'counter_name': 'string',
                    'count': 'uint64',
                },
            )

        train_sample_utils.create_cd_file(
            yt_client=yt_client,
            path=config.CATBOOST_COLUMN_DESCRIPTION_TABLE,
            float_features_description=float_features_description,
            text_features_description=['binding_text'],
        )
        logger.info('Prepared pool.cd Catboost file')
