# -*- coding: utf-8 -*-
from __future__ import annotations

from typing import Any, Dict, Iterable, List
import logging

from travel.hotels.content_manager.data_model.stage import (
    DescriptionTaskInfo, SCUpdateDescriptionsData, SCUpdateDescriptionsDataInput, SCUpdateDescriptionsDataOutput
)
from travel.hotels.content_manager.data_model.storage import EntityWithStatus, StorageSCDescription
from travel.hotels.content_manager.data_model.options import TriggerOptions
from travel.hotels.content_manager.lib.common import dc_to_dict, get_dc_yt_schema
from travel.hotels.content_manager.lib.persistence_manager import PersistenceManager
from travel.hotels.content_manager.lib.storage import Storage
from travel.hotels.content_manager.lib.trigger import EntityGroupingKey, Producer, ThreadFilter, Trigger


LOG = logging.getLogger(__name__)


DescriptionKey = (str, str, str)


class ProducerSCUpdateDescriptions(Producer):

    @staticmethod
    def get_description_key(description: StorageSCDescription) -> DescriptionKey:
        return description.carrier_code, description.car_type_code, description.sc_code

    @staticmethod
    def get_dict_data(
        persistence_manager: PersistenceManager,
        table_path: str,
        key_field: str,
    ) -> Dict[str, Dict[str, str]]:
        data = dict()
        for row in persistence_manager.read(table_path):
            key = row[key_field]
            data[key] = row
        return data

    def get_trigger_data(
        self,
        trigger: Trigger,
        entities: Iterable[StorageSCDescription],
    ) -> List[Dict[str, Any]]:
        LOG.info('Preparing trigger data')

        dicts_path = '//home/travel/prod/train/service_classes'
        carriers_table_path = trigger.persistence_manager.join(dicts_path, 'carriers')
        carriers_dict = self.get_dict_data(trigger.persistence_manager, carriers_table_path, 'carrier_code')

        car_types_table_path = trigger.persistence_manager.join(dicts_path, 'car_types')
        car_types_dict = self.get_dict_data(trigger.persistence_manager, car_types_table_path, 'car_type_code')

        data = list()
        for description in entities:
            LOG.debug(description)

            carrier_info = carriers_dict.get(description.carrier_code, dict())
            car_type_info = car_types_dict.get(description.car_type_code, dict())

            assert carrier_info['enabled']
            assert car_type_info['enabled']

            d_input = SCUpdateDescriptionsDataInput(
                carrier_code=description.carrier_code,
                car_type_code=description.car_type_code,
                sc_code=description.sc_code,
                country=carrier_info['country'],
                url=carrier_info['url'],
                carrier_name=carrier_info['carrier_name'],
                car_type_name=car_type_info['car_type_name'],
                sc_name='',
                sc_description=description.sc_description,
            )

            d = SCUpdateDescriptionsData(
                input=d_input,
                output=SCUpdateDescriptionsDataOutput(),
                info=DescriptionTaskInfo(),
            )
            data.append(dc_to_dict(d))
        return data

    def prepare_data(
            self,
            trigger: Trigger,
            thread_filter: ThreadFilter,
            local_storage: Storage,
            path: str,
            entities: Iterable[EntityWithStatus],
            grouping_key: EntityGroupingKey,
            options: TriggerOptions,
    ):
        task_data = self.get_trigger_data(trigger, entities)

        descriptions_table = trigger.persistence_manager.join(path, 'descriptions')
        LOG.info(f'Writing result to {descriptions_table}')
        trigger.persistence_manager.write(descriptions_table, task_data, get_dc_yt_schema(SCUpdateDescriptionsData))

        trigger.write_options(options, descriptions_table, grouping_key)
