#!/usr/bin/env python
# -*- coding: utf-8 -*-

from abc import (
    ABCMeta,
    abstractproperty,
)
import os

from crypta.lib.python.custom_ml.regression import model_application
from crypta.lib.python.custom_ml.tools import application_utils
from crypta.profile.runners.segments.lib.coded_segments.ml_tools.prepare_catboost_features import PrepareCatboostFeatures
from crypta.profile.utils.config import config
from crypta.profile.utils.luigi_utils import ExternalInput
from crypta.profile.utils.segment_utils.builders import RegularSegmentBuilder


class AbstractModelApplication(RegularSegmentBuilder):
    __metaclass__ = ABCMeta

    juggler_host = config.CRYPTA_ML_JUGGLER_HOST

    def __init__(self, date):
        super(AbstractModelApplication, self).__init__(date)

        self.yt.config['spec_defaults']['pool'] = config.TRAINABLE_SEGMENTS_POOL
        self._segment_output_table = os.path.join(
            config.PROFILES_SEGMENT_PARTS_YT_DIRECTORY,
            config.TRAINABLE_SEGMENTS,
            self.__class__.__name__,
        )

    @abstractproperty
    def train_helper(self):
        """
        :return: CustomRegressionModelTrainHelper's heir class
        """
        pass

    @abstractproperty
    def percentiles(self):
        """
        :return: OrderedDict percentiles for making segments
        """
        pass

    @abstractproperty
    def slice_to_segment_name_dict(self):
        pass

    def requires(self):
        return {
            'yandexuid_cryptaid': ExternalInput(config.YANDEXUID_CRYPTAID_TABLE),
            'catboost_prepared_features': PrepareCatboostFeatures(self.date),
        }

    def build_segment(self, inputs, output_path):
        with self.yt.TempTable() as catboost_applied, \
                self.yt.TempTable() as catboost_applied_with_cryptaid, \
                self.yt.TempTable() as catboost_applied_without_cryptaid, \
                self.yt.TempTable() as catboost_voted_by_cryptaid, \
                self.yt.TempTable() as for_percentile_segmentation:
            self.yql.query(
                query_string=''.join((
                    application_utils.apply_catboost_common_query.format(
                        number_of_classes=1,
                        catboost_model=self.train_helper.model_params.model_tag,
                        input_table=inputs['catboost_prepared_features'].table,
                    ),
                    model_application.apply_catboost_regression_query.format(
                        output_table=catboost_applied,
                    ),
                )),
                transaction=self.transaction,
            )

            self.yql.query(
                application_utils.split_by_cryptaid_presence_query.format(
                    catboost_applied_table=catboost_applied,
                    yandexuid_cryptaid_table=inputs['yandexuid_cryptaid'].table,
                    output_with_cryptaid_table=catboost_applied_with_cryptaid,
                    output_without_cryptaid_table=catboost_applied_without_cryptaid,
                ),
                transaction=self.transaction,
            )

            self.yt.run_reduce(
                model_application.voting_by_cryptaid,
                catboost_applied_with_cryptaid,
                catboost_voted_by_cryptaid,
                reduce_by='crypta_id',
            )

            self.yt.create_empty_table(
                path=for_percentile_segmentation,
                schema={
                    'id': 'uint64',
                    'id_type': 'string',
                    'model_predictions': 'double',
                },
            )

            self.yt.run_sort(
                [
                    catboost_applied_without_cryptaid,
                    catboost_voted_by_cryptaid,
                ],
                for_percentile_segmentation,
                sort_by='model_predictions',
            )

            self.yt.create_empty_table(
                output_path,
                schema={
                    'id': 'string',
                    'id_type': 'string',
                    'segment_name': 'string',
                },
            )
            self.yt.run_map(
                model_application.GetSegmentWithPercentile(slice_to_segment_name_dict=self.slice_to_segment_name_dict),
                application_utils.get_slices_for_percentile_segmentation(
                    yt=self.yt,
                    table=for_percentile_segmentation,
                    percentiles=self.percentiles,
                ),
                output_path,
            )
