#!/usr/bin/env python
# -*- coding: utf-8 -*-

from abc import abstractmethod
from functools import partial
from collections import (
    defaultdict,
    namedtuple,
)
import os

import luigi

from yt.wrapper import create_table_switch
import yt.yson as yson

from crypta.profile.utils.luigi_utils import BaseYtTask, YtDailyRewritableTarget, YtDateTarget
from crypta.profile.utils.config import config


fake_segment_id = 1
LaLParams = namedtuple('LaLParams', ['name', 'id', 'type', 'coverage', 'include_input'])


class CustomSegmentBuilder(BaseYtTask):
    date = luigi.Parameter()

    @abstractmethod
    def build_segment(self, inputs, outputs):
        pass

    @abstractmethod
    def output(self):
        pass

    def __init__(self, date):
        super(CustomSegmentBuilder, self).__init__(date)
        self.transaction = None
        self.yt.config['spec_defaults']['pool'] = config.SEGMENTS_POOL

    def run(self):
        # построение всех сегментов билдера
        with self.yt.Transaction() as transaction:
            self.transaction = transaction
            self.build_segment(
                inputs=self.input(),
                outputs=self.output(),
            )

    def prepare_samples_for_lal(self, input_table, id_field, lals_params, segment_priority=1):
        def mapper(row, segment_name_to_table_index_dict):
            if row.get('segment_name') in segment_name_to_table_index_dict:
                yield create_table_switch(segment_name_to_table_index_dict[row['segment_name']])
                yield {'id_value': str(row[id_field])}

        temporary_output_tables = list()
        segment_name_to_output_table_index = dict()
        output_table_to_lal_params = dict()

        for index, lal_params in enumerate(lals_params):
            assert isinstance(lal_params.id, int)
            segment_name_to_output_table_index[lal_params.name] = index

            temporary_output_table = os.path.join(config.PROFILES_TMP_YT_DIRECTORY, str(lal_params.id))
            self.yt.create_empty_table(temporary_output_table)
            temporary_output_tables.append(temporary_output_table)
            output_table_to_lal_params[temporary_output_table] = lal_params

        self.yt.run_map(
            partial(mapper, segment_name_to_table_index_dict=segment_name_to_output_table_index),
            input_table,
            temporary_output_tables,
            spec={'enable_legacy_live_preview': False},
        )

        for temporary_table in temporary_output_tables:
            lal_params = output_table_to_lal_params[temporary_table]
            self.yt.set_attribute(
                temporary_table,
                'crypta_maintain_device_distribution',
                False,
            )

            self.yt.set_attribute(
                temporary_table,
                'crypta_maintain_geo_distribution',
                False,
            )

            self.yt.set_attribute(temporary_table, '_max_coverage', int(lal_params.coverage))
            self.yt.set_attribute(temporary_table, '_include_input', lal_params.include_input)
            self.yt.set_attribute(temporary_table, 'crypta_related_goals', [])
            self.yt.set_attribute(temporary_table, 'crypta_status', 'new')
            self.yt.set_attribute(temporary_table, 'segment_priority', segment_priority)
            self.yt.set_attribute(temporary_table, 'segment_id', fake_segment_id)

            output_table = os.path.join(config.REGULAR_LAL_INPUT_DIRECTORY, lal_params.type, str(lal_params.id))
            self.logger.info('Moving {} to {}'.format(temporary_table, output_table))
            self.yt.move(
                temporary_table,
                output_table,
                preserve_account=True,
                force=True,
                recursive=True,
            )


class LalSampleBuilder(CustomSegmentBuilder):
    task_group = 'sample_for_lal'

    def output(self):
        return YtDailyRewritableTarget(
            table=os.path.join(config.LAL_SAMPLE_FOLDER, self.__class__.__name__),
            date=self.date,
            allow_great_or_equal_date=True,
        )

    @abstractmethod
    def build_segment(self, inputs, output_path):
        pass

    def run(self):
        with self.yt.Transaction() as transaction:
            self.transaction = transaction

            self.build_segment(
                inputs=self.input(),
                output_path=self.output().table,
            )


keyword_names = {
    216: 'heuristic_segments',
    217: 'probabilistic_segments',
    281: 'marketing_segments',

    544: 'lal_common',
    545: 'lal_private',
    546: 'lal_internal',

    547: 'heuristic_common',
    548: 'heuristic_private',
    549: 'heuristic_internal',

    557: 'audience_segments',
    601: 'longterm_interests',
}


def new_heuristic_segment_formatter(segment_set):
    if not segment_set:
        return None

    return [yson.YsonUint64(segment) for segment in segment_set]


def old_heuristic_segment_formatter(segment_set):
    if not segment_set:
        return None

    return {segment: yson.YsonUint64(1) for segment in segment_set}


def old_probabilistic_segment_formatter(segment_probabilities):
    if not segment_probabilities:
        return None

    return {
        segment_id: {'0': yson.YsonDouble(probability)}
        for segment_id, probability in segment_probabilities.iteritems()
    }


def new_probabilistic_segment_formatter(segment_probabilities):
    if not segment_probabilities:
        return None

    return {
        segment_id: yson.YsonDouble(probability)
        for segment_id, probability in segment_probabilities.iteritems()
    }


probabilistic_keywords = {217, 281, 544, 545, 546}

keyword_formatters = {
    216: old_heuristic_segment_formatter,
    217: old_probabilistic_segment_formatter,
    281: new_probabilistic_segment_formatter,

    544: new_probabilistic_segment_formatter,
    545: new_probabilistic_segment_formatter,
    546: new_probabilistic_segment_formatter,

    547: new_heuristic_segment_formatter,
    548: new_heuristic_segment_formatter,
    549: new_heuristic_segment_formatter,

    557: new_heuristic_segment_formatter,

    601: new_heuristic_segment_formatter,
}


class FormatSegmentReducer(object):
    def __init__(self, name_segment_dict):
        self.name_segment_dict = name_segment_dict

    def process_probabilistic_keyword(self, rows):
        segments_with_probabilities = defaultdict(int)

        for row in rows:
            segment_id = self.name_segment_dict.get(row['segment_name'])[1]
            probability = row.get('probability', 1.0)

            if segment_id is not None and 0 <= probability <= 1:
                segment_id = str(segment_id)
                segments_with_probabilities[segment_id] = max(segments_with_probabilities[segment_id], probability)

        return dict(segments_with_probabilities)

    def process_heuristic_keyword(self, rows):
        segments = set()

        for row in rows:
            segment_id = self.name_segment_dict.get(row['segment_name'])[1]

            if segment_id is not None:
                segments.add(str(segment_id))

        return segments

    def __call__(self, key, rows):
        result_record = dict(key)

        rows_to_process = defaultdict(list)
        no_data = True

        for row in rows:
            if row['segment_name'] not in self.name_segment_dict:
                continue

            keyword_id = self.name_segment_dict[row['segment_name']][0]
            rows_to_process[keyword_id].append(row)

        rows_to_process = dict(rows_to_process)

        for keyword_id, rows in rows_to_process.iteritems():

            segment_field_name = keyword_names[keyword_id]

            if keyword_id in probabilistic_keywords:
                segment_probabilities = self.process_probabilistic_keyword(rows)
                result_record[segment_field_name] = keyword_formatters[keyword_id](segment_probabilities)

            else:
                segment_set = self.process_heuristic_keyword(rows)
                result_record[segment_field_name] = keyword_formatters[keyword_id](segment_set)

            if result_record[segment_field_name]:
                no_data = False

        if not no_data:
            yield result_record


class RegularSegmentBuilder(CustomSegmentBuilder):
    task_group = 'coded_segments'

    indevice = False

    @property
    def name_segment_dict(self):
        raise NotImplementedError('name_segment_dict')

    @property
    def keyword(self):
        for key, item in self.name_segment_dict.iteritems():
            if not isinstance(item, tuple):
                raise NotImplementedError('keyword')
        return None

    @abstractmethod
    def build_segment(self, inputs, output_path):
        pass

    def __init__(self, date):
        super(RegularSegmentBuilder, self).__init__(date)

        if self.keyword:
            for name, segment_id in self.name_segment_dict.iteritems():
                if not isinstance(segment_id, tuple):
                    self.name_segment_dict[name] = (self.keyword, segment_id)

        self.output_schema = {
            'id': 'string',
            'id_type': 'string'
        }
        for key, item in self.name_segment_dict.iteritems():
            self.output_schema[keyword_names[item[0]]] = 'any'

        self._segment_output_table = os.path.join(
            config.PROFILES_SEGMENT_PARTS_YT_DIRECTORY,
            config.REGULAR_SEGMENTS,
            self.__class__.__name__,
        )

        self._segment_raw_output_path = os.path.join(
            config.SEGMENT_RAW_OUTPUT_FOLDER,
            self.__class__.__name__,
        )

    def output(self):
        return YtDateTarget(
            self._segment_output_table,
            self.date,
        )

    def _format_segment(self):
        self.logger.info('Schema for output table: {}'.format(self.output_schema))
        self.yt.create_empty_table(
            path=self._segment_output_table,
            schema=self.output_schema,
            additional_attributes={'generate_date': self.date},
        )

        self.logger.info('Using name segment dict: {}'.format(self.name_segment_dict))
        self.yt.run_map_reduce(
            None,
            FormatSegmentReducer(self.name_segment_dict),
            self._segment_raw_output_path,
            self._segment_output_table,
            reduce_by=['id', 'id_type'],
            spec={'title': '{name} make final output'.format(name=self.__class__.__name__)},
        )

        if self.indevice:
            self.yt.set_attribute(
                self._segment_output_table,
                'indevice',
                True,
            )

    def run(self):
        with self.yt.Transaction() as transaction:
            self.transaction = transaction

            if not self.yt.exists(config.SEGMENT_RAW_OUTPUT_FOLDER):
                self.yt.create("map_node", config.SEGMENT_RAW_OUTPUT_FOLDER, recursive=True)

            self.build_segment(
                inputs=self.input(),
                output_path=self._segment_raw_output_path,
            )

            self._format_segment()
