# -*- coding: utf-8 -*-
from __future__ import annotations

from copy import deepcopy
from itertools import chain, product
from typing import Iterable, Tuple
import logging

from travel.hotels.content_manager.data_model.stage import ClusterizationData
from travel.hotels.content_manager.data_model.stage import ClusterizationDataInput, ClusterizationDataOutput
from travel.hotels.content_manager.data_model.stage import PermalinkTaskInfo
from travel.hotels.content_manager.data_model.storage import StoragePermalinkWL
from travel.hotels.content_manager.data_model.options import TriggerOptions
from travel.hotels.content_manager.data_model.types import (
    AssigneeSkill, ClusterizationStartReason
)
from travel.hotels.content_manager.lib.common import dc_to_dict, get_dc_yt_schema, str_to_set
from travel.hotels.content_manager.lib.storage import Storage
from travel.hotels.content_manager.lib.trigger import EntityGroupingKey, FilterCallable, Producer, ThreadFilter, Trigger


LOG = logging.getLogger(__name__)


class ClusterizationFilter(ThreadFilter):

    def __init__(self, name: str, options_key: str, filter_callable: FilterCallable, reward_multiplier: float):
        self.reward_multiplier = reward_multiplier
        super().__init__(name, options_key, filter_callable)


class ProducerClusterization(Producer):

    def prepare_data(
            self,
            trigger: Trigger,
            thread_filter: ClusterizationFilter,
            local_storage: Storage,
            path: str,
            entities: Iterable[StoragePermalinkWL],
            grouping_key: EntityGroupingKey,
            options: TriggerOptions,
    ):
        LOG.info('Preparing trigger data')
        task_data = list()

        for permalink in entities:
            LOG.debug(permalink)

            if permalink.comments is None:
                permalink.comments = ''
            assignee_skill = permalink.assignee_skill
            if permalink.clusterization_iteration == 1 and assignee_skill == AssigneeSkill.ADVANCED:
                assignee_skill = AssigneeSkill.BASIC
            required_stages = str_to_set(permalink.required_stages)

            cd_input = ClusterizationDataInput(
                permalink=str(permalink.permalink),
                altay_url=f'https://altay.yandex-team.ru/cards/perm/{permalink.permalink}',
                requirements=permalink.requirements.split('\n'),
                prev_comments=permalink.comments.split('\n'),
                hotel_name=permalink.hotel_name,
                assignee_skill=assignee_skill,
                stage_actualization_required='actualization' in required_stages,
                stage_call_center_required='call_center' in required_stages,
            )
            cd = ClusterizationData(
                input=cd_input,
                output=ClusterizationDataOutput(),
                info=PermalinkTaskInfo(),
            )
            task_data.append(dc_to_dict(cd))

        hotels_table = trigger.persistence_manager.join(path, 'hotels')
        LOG.info(f'Writing result to {hotels_table}')
        trigger.persistence_manager.write(hotels_table, task_data, get_dc_yt_schema(ClusterizationData))

        options = deepcopy(options)
        options.pool_options.reward_per_assignment = round(
            options.pool_options.reward_per_assignment * thread_filter.reward_multiplier, 0
        )
        trigger.write_options(options, hotels_table, grouping_key)


def get_filter(start_reason: ClusterizationStartReason, assignee_skill: AssigneeSkill) -> Tuple[ThreadFilter, ...]:
    options_id = f'{start_reason.value}_{assignee_skill.value}'

    # noinspection PyUnusedLocal
    def func_first(entity: StoragePermalinkWL, storage: Storage) -> bool:
        return (
            entity.clusterization_start_reason == start_reason and
            entity.assignee_skill == assignee_skill and
            entity.clusterization_iteration == 1
        )

    # noinspection PyUnusedLocal
    def func_next(entity: StoragePermalinkWL, storage: Storage) -> bool:
        return (
            entity.clusterization_start_reason == start_reason and
            entity.assignee_skill == assignee_skill and
            entity.clusterization_iteration > 1
        )

    return (
        ClusterizationFilter(f'{options_id}_first', options_id, func_first, 1.0),
        ClusterizationFilter(f'{options_id}_next', options_id, func_next, 0.8),
    )


FILTERS_CLUSTERIZATION = list(chain.from_iterable(
    get_filter(p, r) for p, r in product(ClusterizationStartReason, AssigneeSkill)
))
