#!/usr/bin/env python
# -*- coding: utf-8 -*-

from abc import (
    ABCMeta,
    abstractproperty,
)
from collections import namedtuple
from functools import partial
import os

from crypta.lib.python.custom_ml.classification import model_application
from crypta.lib.python.custom_ml.tools import application_utils
from crypta.profile.runners.segments.lib.coded_segments.ml_tools.prepare_catboost_features import PrepareCatboostFeatures
from crypta.profile.utils.api import get_api
from crypta.profile.utils.config import config
from crypta.profile.utils.luigi_utils import ExternalInput
from crypta.profile.utils.segment_utils.builders import RegularSegmentBuilder
from crypta.profile.utils.utils import report_ml_metrics_to_solomon


class AbstractModelApplication(RegularSegmentBuilder):
    __metaclass__ = ABCMeta

    juggler_host = config.CRYPTA_ML_JUGGLER_HOST
    keyword = None
    audience_segments = None

    ModelDescription = namedtuple('ModelDescription', [
        'industry',
        'objective',
        'positive_conversions',
        'negative_conversions',
    ])

    def __init__(self, date):
        super(AbstractModelApplication, self).__init__(date)

        self.yt.config['spec_defaults']['pool'] = config.TRAINABLE_SEGMENTS_POOL
        self._segment_output_table = os.path.join(
            config.PROFILES_SEGMENT_PARTS_YT_DIRECTORY,
            config.TRAINABLE_SEGMENTS,
            self.__class__.__name__,
        )

    @abstractproperty
    def train_helper(self):
        """
        :return: CustomClassificationModelTrainHelper's heir class
        """
        pass

    @abstractproperty
    def percentiles(self):
        """
        :return: OrderedDict percentiles for making segments
        """
        pass

    @abstractproperty
    def slice_to_segment_name_dict(self):
        pass

    @abstractproperty
    def probability_classes(self):
        """
        :return: list of names for probabilities
        """
        pass

    @property
    def industry_name(self):
        return None

    @property
    def model_name(self):
        return self.__class__.__name__.replace('ModelApplication', '')

    @property
    def positive_name(self):
        return 'positive'

    @property
    def model_description(self):
        """
        :return: ModelDescription object or None
        """
        return None

    def get_tags(self, export_id):
        try:
            tags = get_api().lab.getTags(id=export_id).result().tags
        except Exception as e:
            self.logger.info(e)
            tags = []

        self.logger.info('export_id: {}, tags: {}'.format(export_id, ', '.join(tags)))
        return tags

    def add_tags_to_exports(self):
        segments = set(map(lambda segment: segment[1], self.name_segment_dict.values()))

        tags_to_add = ['ml_segments', 'crypta_id', self.model_name]
        if self.industry_name is not None:
            tags_to_add.append('ml_segments|{}'.format(self.industry_name))

        for row in self.yt.read_table(config.LAB_SEGMENTS_INFO_TABLE):
            export_keyword_id = int(row['exportKeywordId']) if row['exportKeywordId'] else None
            export_segment_id = int(row['exportSegmentId']) if row['exportSegmentId'] else None
            export_id = row['exportId']

            if self.audience_segments is not None and export_keyword_id == 557 and \
                    export_segment_id in self.audience_segments:
                tags = self.get_tags(export_id)

                for tag in tags_to_add:
                    if tag not in tags:
                        self.logger.info(get_api().lab.addTag(id=export_id, tag=tag).result())

            elif export_keyword_id == 546 and export_segment_id in segments:
                tags = self.get_tags(export_id)
                if 'trainable' not in tags:
                    self.logger.info(get_api().lab.addTag(id=export_id, tag='trainable').result())
                self.logger.info(get_api().lab.disableExportToBigB(id=export_id).result())

    def check_model_in_api(self):
        models = [model.modelName for model in get_api().lab.getTrainingSamplesIndustries().result()]

        if self.model_description is None or self.model_name in models:
            return

        self.logger.info(get_api().lab.addNewIndustry(
            name=self.model_description.industry,
            modelName=self.model_name,
            objective=self.model_description.objective,
            positiveConversions=self.model_description.positive_conversions,
            negativeConversions=self.model_description.negative_conversions,
        ).result())

    def requires(self):
        return {
            'yandexuid_cryptaid': ExternalInput(config.YANDEXUID_CRYPTAID_TABLE),
            'catboost_prepared_features': PrepareCatboostFeatures(self.date),
        }

    def map_predictions_to_segments(
        self,
        model_predictions,
        output_path
    ):
        if self.percentiles is not None:
            with self.yt.TempTable() as for_percentile_segmentation:
                if len(self.probability_classes) == 2:
                    assert self.positive_name is not None, 'positive_name must be specified'
                    self.yt.run_map(
                        partial(
                            model_application.extract_probability_for_segmentation,
                            positive_name=self.positive_name
                        ),
                        model_predictions,
                        for_percentile_segmentation,
                    )
                else:
                    assert self.probability_classes is not None, 'Probability classes must be specified'
                    self.yt.run_map(
                        partial(
                            model_application.extract_integral_score_for_segmentation,
                            class_order=self.probability_classes,
                        ),
                        model_predictions,
                        for_percentile_segmentation,
                    )

                self.yt.run_sort(for_percentile_segmentation, sort_by='probability')

                assert len(self.slice_to_segment_name_dict) == len(self.percentiles), \
                    'For each percentile segment_name must exist'

                self.yt.run_map(
                    model_application.GetSegmentWithPercentile(
                        slice_to_segment_name_dict=self.slice_to_segment_name_dict,
                    ),
                    application_utils.get_slices_for_percentile_segmentation(
                        yt=self.yt,
                        table=for_percentile_segmentation,
                        percentiles=self.percentiles,
                    ),
                    output_path,
                )
        else:
            assert self.train_helper.model_params.ordered_thresholds is not None, 'Ordered thresholds must be specified'
            self.yt.run_map(
                partial(
                    model_application.get_most_appropriate_segment_with_probability,
                    thresholds=self.train_helper.model_params.ordered_thresholds,
                    segments=self.name_segment_dict.keys(),
                ),
                model_predictions,
                output_path,
            )

    def build_segment(self, inputs, output_path):
        self.add_tags_to_exports()
        if config.environment != 'local_testing':
            self.check_model_in_api()

        if self.train_helper.model_params.ordered_thresholds:
            for segment in self.name_segment_dict:
                assert segment in self.train_helper.model_params.ordered_thresholds.keys()

        with self.yt.TempTable() as catboost_applied, \
                self.yt.TempTable() as catboost_applied_with_cryptaid, \
                self.yt.TempTable() as catboost_applied_without_cryptaid, \
                self.yt.TempTable() as catboost_voted_by_cryptaid, \
                self.yt.TempTable() as predictions_stats:
            self.yql.query(
                query_string=''.join((
                    application_utils.apply_catboost_common_query.format(
                        number_of_classes=len(self.probability_classes),
                        catboost_model=self.train_helper.model_params.model_tag,
                        input_table=inputs['catboost_prepared_features'].table,
                    ),
                    model_application.apply_catboost_classification_query.format(
                        schema=', '.join(map(lambda x: "'{}'".format(x), self.probability_classes)),
                        output_table=catboost_applied,
                    ),
                )),
                transaction=self.transaction,
            )

            self.yql.query(
                application_utils.split_by_cryptaid_presence_query.format(
                    catboost_applied_table=catboost_applied,
                    yandexuid_cryptaid_table=inputs['yandexuid_cryptaid'].table,
                    output_with_cryptaid_table=catboost_applied_with_cryptaid,
                    output_without_cryptaid_table=catboost_applied_without_cryptaid,
                ),
                transaction=self.transaction,
            )

            self.yt.run_reduce(
                model_application.voting_by_cryptaid,
                catboost_applied_with_cryptaid,
                catboost_voted_by_cryptaid,
                reduce_by='crypta_id',
            )

            self.yt.create_empty_table(
                output_path,
                schema={
                    'id': 'string',
                    'id_type': 'string',
                    'segment_name': 'string',
                    'probability': 'double',
                },
            )

            self.map_predictions_to_segments(
                model_predictions=[
                    catboost_applied_without_cryptaid,
                    catboost_voted_by_cryptaid,
                ],
                output_path=output_path
            )

            if len(self.name_segment_dict) < 2:
                return

            table_size = self.yt.row_count(output_path)
            self.yt.unique_count(output_path, predictions_stats, unique_by=['segment_name'], result_field='counts')

            metrics_to_send = []
            for row in self.yt.read_table(predictions_stats):
                metrics_to_send.append({
                    'labels': {
                        'metric': 'application_distribution',
                        'model': self.train_helper.model_params.metrics_group_name,
                        'segment': row['segment_name'],
                    },
                    'value': float(row['counts']) / table_size,
                })

            metrics_to_send.append({
                'labels': {
                    'metric': 'application_sure_ratio',
                    'model': self.train_helper.model_params.metrics_group_name,
                },
                'value': float(table_size) / self.yt.row_count(catboost_applied),
            })

            report_ml_metrics_to_solomon(
                service=config.SOLOMON_TRAINABLE_SEGMENTS_SERVICE,
                metrics_to_send=metrics_to_send,
            )
