#!/usr/bin/env python
# -*- coding: utf-8 -*-

import os
from functools import partial

import luigi
import yt.yson as yson
from yt.wrapper import with_context

from crypta.profile.utils import utils
from crypta.profile.utils.config import config
from crypta.profile.utils.loggers import TimeTracker
from crypta.profile.utils.luigi_utils import (
    YtDailyRewritableTarget,
    ExternalInput,
    BaseYtTask,
)


@with_context
class SegmentMapper(object):
    def __init__(self, table_index_to_segment_id_dict):
        self.table_index_to_segment_id_dict = table_index_to_segment_id_dict

    def __call__(self, row, context):
        if utils.is_valid_yandexuid(row['yuid']):
            yield {
                'id': row['yuid'],
                'id_type': 'yandexuid',
                'score': row['score'],
                'segment_id': self.table_index_to_segment_id_dict[context.table_index or 0],
            }


def segment_reducer(key, rows, data_source):
    segment_dict = {}
    segment_set = set()

    formatted_segments = None

    if data_source in ('marketing_segments', 'lal_internal'):
        for row in rows:
            segment_dict[row['segment_id']] = row['score']
        formatted_segments = segment_dict
    elif data_source == 'probabilistic_segments':
        for row in rows:
            segment_dict[row['segment_id']] = {'0': row['score']}
        formatted_segments = segment_dict
    elif data_source == 'audience_segments':
        for row in rows:
            segment_set.add(yson.YsonUint64((row['segment_id'])))
        formatted_segments = sorted(list(segment_set))

    yield {
        'id': key['id'],
        'id_type': key['id_type'],
        data_source: formatted_segments,
    }


class GetRegularLalFromAudience(BaseYtTask):
    date = luigi.Parameter()
    data_source = luigi.Parameter()
    priority = 100
    task_group = 'export_profiles'

    def requires(self):
        input_directory = os.path.join(config.REGULAR_LAL_OUTPUT_DIRECTORY, self.data_source)
        required_tables = []

        for segment_id in self.yt.list(input_directory):
            if segment_id.isdigit() and int(segment_id) > 0 and \
                    self.yt.row_count(os.path.join(input_directory, segment_id)) > 0:
                required_tables.append(os.path.join(input_directory, segment_id))
            else:
                self.logger.warning('Invalid segment_id in {}: {}'.format(input_directory, segment_id))

        self.logger.info('Input tables for {}: {}'.format(self.data_source, required_tables))
        return [ExternalInput(table) for table in required_tables]

    def output(self):
        return YtDailyRewritableTarget(
            os.path.join(
                config.PROFILES_SEGMENT_PARTS_YT_DIRECTORY,
                'regular_{}'.format(self.data_source),
            ),
            self.date,
            allow_empty=True,
        )

    def run(self):
        with TimeTracker('{}_{}'.format(self.__class__.__name__, self.data_source)):

            with self.yt.Transaction():
                self.yt.create_empty_table(
                    self.output().table,
                    schema={
                        'id': 'string',
                        'id_type': 'string',
                        self.data_source: 'any',
                    },
                )

                table_index_to_segment_id_dict = {}
                input_tables = []
                for i, target in enumerate(self.input()):
                    table_index_to_segment_id_dict[i] = os.path.basename(target.table)
                    input_tables.append(target.table)

                if len(input_tables) > 0:
                    self.yt.run_map_reduce(
                        SegmentMapper(table_index_to_segment_id_dict),
                        partial(segment_reducer, data_source=self.data_source),
                        input_tables,
                        self.output().table,
                        reduce_by=['id', 'id_type'],
                    )

                self.yt.set_attribute(self.output().table, 'generate_date', self.date)


class GetRegularLal(luigi.WrapperTask):
    date = luigi.Parameter()
    priority = 100
    task_group = 'export_profiles'

    def requires(self):
        return {
            'lal_internal': GetRegularLalFromAudience(self.date, 'lal_internal'),
            'marketing_segments': GetRegularLalFromAudience(self.date, 'marketing_segments'),
            'audience_segments': GetRegularLalFromAudience(self.date, 'audience_segments')
        }
