from yabs.proto import user_profile_pb2
from yt.yson import get_bytes
from yt.wrapper import create_table_switch

from crypta.lib.python.yt import yt_helpers
from crypta.lib.python.bigb_catboost_applier import (
    fields,
    catboost_features_calculator,
)


def create_kv_empty_table(yt_client, path):
    yt_helpers.create_empty_table(
        yt_client=yt_client,
        path=path,
        schema={
            'key': 'string',
            'value': 'string',
        },
        additional_attributes={'optimize_for': 'scan'},
        force=True,
    )


def add_to_cd_file(to_cd_file, value):
    to_cd_file.append({
        'key': str(len(to_cd_file)),
        'value': value,
    })


def create_cd_file(yt_client, path, float_features_description, text_features_description=None, has_weight=False):
    create_kv_empty_table(
        yt_client=yt_client,
        path=path,
    )

    to_cd_file = list()
    add_to_cd_file(to_cd_file=to_cd_file, value='Auxiliary')
    add_to_cd_file(to_cd_file=to_cd_file, value='Label')
    if has_weight:
        add_to_cd_file(to_cd_file=to_cd_file, value='Weight')

    for description in float_features_description:
        add_to_cd_file(to_cd_file=to_cd_file, value='Num\t{}'.format(description))

    if text_features_description:
        for description in text_features_description:
            add_to_cd_file(to_cd_file=to_cd_file, value='Text\t{}'.format(description))

    yt_client.write_table(path, to_cd_file)


class PrepareSamplesMapper(object):
    def __init__(
        self,
        features_mapping,
        counters_to_features,
        keywords_to_features,
        key_column,
        target_column,
        weight_column=None,
        split=True,
        validation_sample_percentage=10,
        validation_sample_rest=0,
    ):
        self.features_mapping = features_mapping
        self.counters_to_features = counters_to_features
        self.keywords_to_features = keywords_to_features
        self.target_column = target_column
        self.weight_column = weight_column
        self.key_column = key_column
        self.split = split

        self.validation_sample_percentage = validation_sample_percentage
        self.validation_sample_rest = validation_sample_rest
        self.train_sample_index = 0
        self.val_sample_index = 1

    def start(self):
        self.profile = user_profile_pb2.Profile()
        self.features_calculator = catboost_features_calculator.TCatboostFeaturesCalculator(
            self.features_mapping,
            self.counters_to_features,
            self.keywords_to_features,
        )

    def __call__(self, row):
        if self.weight_column is not None:
            target = '{}\t{}'.format(row[self.target_column], row[self.weight_column])
        else:
            target = row[self.target_column]

        self.profile.ParseFromString(get_bytes(row[fields.PROFILE]))
        float_features = '\t'.join([str(x) for x in self.features_calculator.PrepareFloatFeatures(self.profile)])
        text_features = self.features_calculator.PrepareTextFeatures(self.profile)
        value = get_bytes('{}\t{}\t{}'.format(target, float_features, text_features)).decode('utf-8')

        if self.split:
            table_index = self.train_sample_index if row[fields.CRYPTA_ID] % self.validation_sample_percentage != \
                self.validation_sample_rest else self.val_sample_index
            yield create_table_switch(table_index)

        yield {
            fields.KEY: str(row[self.key_column]),
            fields.VALUE: value,
        }
