# -*- coding: utf-8 -*-
from __future__ import annotations

from abc import abstractmethod, ABCMeta
from typing import Dict, Iterable, List, Optional, Set
from uuid import uuid4
import logging

from travel.hotels.content_manager.data_model import storage
from travel.hotels.content_manager.data_model.options import StageOptions
from travel.hotels.content_manager.data_model.stage import EntityInfo, EntityInfoClass
from travel.hotels.content_manager.data_model.storage import (
    DispatchableEntity, AdvanceableEntityCls, EntityClassWithStatus, EntityWithStatus
)
from travel.hotels.content_manager.lib.common import (
    dc_to_dict, get_dc_yt_schema, str_from_set, str_to_set, ts_to_str_msk_tz
)
from travel.hotels.content_manager.data_model.types import StageResult
from travel.hotels.content_manager.lib.attributes import Attributes
from travel.hotels.content_manager.lib.dispatcher import Dispatcher
from travel.hotels.content_manager.lib.storage import TableData
from travel.hotels.content_manager.lib.path_info import PathInfo
from travel.hotels.content_manager.lib.path_mapping import PathMapping
from travel.hotels.content_manager.lib.persistence_manager import PersistenceManager
from travel.hotels.content_manager.lib.yql_simple_client import YqlSimpleClient


LOG = logging.getLogger(__name__)


class Updater(object):
    __metaclass__ = ABCMeta

    def __init__(
            self,
            stage_name: str,
            persistence_manager: PersistenceManager,
            yql_client: Optional[YqlSimpleClient],
            path_info: PathInfo,
            start_ts: int,
            save_history: bool,
            options: Dict[str, StageOptions],
    ) -> None:
        self.stage_name = stage_name

        self.persistence_manager = persistence_manager
        self.yql_client = yql_client

        self.path_info = path_info
        self.stage_input_path = persistence_manager.join(path_info.stages_path, stage_name, 'input')
        self.stage_output_path = persistence_manager.join(path_info.stages_path, stage_name, 'output')
        self.entity_to_path = PathMapping(path_info).entity_to_path
        self.entity_to_name = {v: k for k, v in storage.TABLE_NAMES.items()}

        self.start_ts = start_ts
        self.save_history = save_history
        self.options = options

        self.fields_to_update: Dict[storage.EntityClass, Set[str]] = dict()

        self.dispatcher = Dispatcher()

    def add_fields_to_update(self, entity_cls: storage.EntityClass, fields: Iterable[str]):
        fields_to_update = self.fields_to_update.setdefault(entity_cls, set())
        fields_to_update.update(fields)

    def update_entity_status(
            self,
            entity_cls: EntityClassWithStatus,
            entities: Iterable[EntityWithStatus],
            stage: str,
            status: storage.StageStatus,
            ts: int
    ) -> None:
        status_field = f'status_{stage}'
        ts_field = f'{status_field}_ts'
        for entity in entities:
            if not isinstance(entity, entity_cls):
                raise RuntimeError(f'Expected {entity_cls} but got {entity}')

            setattr(entity, status_field, status)
            setattr(entity, ts_field, ts)

        self.add_fields_to_update(entity_cls, [status_field, ts_field])

    def mark_as_processed(
            self,
            entity_cls: EntityClassWithStatus,
            entities: Iterable[EntityWithStatus],
    ) -> None:
        self.update_entity_status(
            entity_cls=entity_cls,
            entities=entities,
            stage=self.stage_name,
            status=storage.StageStatus.NOTHING_TO_DO,
            ts=self.start_ts,
        )

    def send_to_stage(
            self,
            entity_cls: EntityClassWithStatus,
            entities: Iterable[EntityWithStatus],
            stage: str,
            delay: Optional[int] = None,
    ) -> None:
        if not stage:
            raise Exception(f'Expected stage name but got "{stage}"')
        if delay is None:
            delay = 0
            stage_options = self.options.get(stage)
            if stage_options is not None:
                delay = stage_options.delay
        self.update_entity_status(
            entity_cls=entity_cls,
            entities=entities,
            stage=stage,
            status=storage.StageStatus.TO_BE_PROCESSED,
            ts=self.start_ts + delay,
        )
        for entity in entities:
            if hasattr(entity, 'route'):
                entity.route = ','.join(line for line in (entity.route, stage) if line)

    def update_storage_entity(
            self,
            entity_data: TableData,
            entity_cls: storage.EntityClass,
    ) -> None:
        dst_path = self.entity_to_path[entity_cls]
        fields_to_update = list(self.fields_to_update.get(entity_cls, list()))

        LOG.info(f'Updating storage for {entity_cls.__name__}')
        self.persistence_manager.request_upsert(
            src_data=entity_data,
            dst_path=dst_path,
            dc=entity_cls,
            fields_to_update=fields_to_update,
        )

    @abstractmethod
    def run(self, output_path: str, temp_dir: str) -> None:
        pass

    def do_save_history(self, input_path: str, output_path: str, result_id: str):
        history_path = self.persistence_manager.join(self.path_info.history_path, self.stage_name, result_id)
        history_input_path = self.persistence_manager.join(history_path, 'input')
        history_output_path = self.persistence_manager.join(history_path, 'output')

        LOG.info('Moving data to history')

        if not self.persistence_manager.exists(history_path):
            self.persistence_manager.create_dir(history_path)
        LOG.info(f'Copying from {input_path} to {history_input_path}')
        self.persistence_manager.copy(input_path, history_input_path)
        LOG.info(f'Copying from {output_path} to {history_output_path}')
        self.persistence_manager.copy(output_path, history_output_path)

    def process_result(self, result_id: str, input_path: str, output_path: str) -> None:
        temp_dir = self.persistence_manager.join(self.path_info.temp_path, str(uuid4()))
        self.run(output_path, temp_dir)

        if self.save_history:
            self.do_save_history(input_path, output_path, result_id)

        LOG.info(f'Deleting {input_path}')
        self.persistence_manager.delete(input_path)

        LOG.info(f'Deleting {output_path}')
        self.persistence_manager.delete(output_path)

        if self.persistence_manager.exists(temp_dir):
            LOG.info(f'Deleting {temp_dir}')
            self.persistence_manager.delete(temp_dir)

    def process(self) -> None:
        result_ids = self.list_path(self.stage_output_path, list())
        for result_id in result_ids:
            input_path = self.persistence_manager.join(self.stage_input_path, result_id)
            output_path = self.persistence_manager.join(self.stage_output_path, result_id)
            LOG.debug(f'New result: {output_path}')
            self.process_result(result_id, input_path, output_path)

    def list_path(self, path: str, default: Optional[List[str]] = None) -> List[str]:
        if self.persistence_manager.exists(path):
            return [n.name for n in self.persistence_manager.list(path)]
        if default is not None:
            return default
        raise IOError(f'No such path: {path}')

    def update_task_logs(self, logs_data: Iterable[EntityInfo], logs_cls: EntityInfoClass, table_prefix: str) -> None:
        table_name = f'{table_prefix} {ts_to_str_msk_tz(self.start_ts)}'
        logs_table = self.persistence_manager.join(self.path_info.logs_path, 'task_info', table_name)
        logs_data = (dc_to_dict(ld) for ld in logs_data)
        self.persistence_manager.write(logs_table, logs_data, get_dc_yt_schema(logs_cls))

    def get_table_data(
        self,
        path: str, entities: Iterable[storage.EntityClass],
    ) -> Dict[storage.EntityClass, TableData]:
        table_data = dict()
        for entity_cls in entities:
            table_path = self.persistence_manager.join(path, self.entity_to_name[entity_cls])
            table_data[entity_cls] = list(self.persistence_manager.read(table_path))
        return table_data

    def update_finished_stages(self, entities: Iterable[DispatchableEntity]) -> None:
        for entity in entities:
            finished_stages = str_to_set(entity.finished_stages)
            finished_stages.add(self.stage_name)
            entity.finished_stages = str_from_set(finished_stages)

    def dispatch_entities(
        self,
        entity_cls: AdvanceableEntityCls,
        entities: Iterable[DispatchableEntity],
        final_stage: str,
        delay: Optional[int] = 0,
    ) -> None:
        # TODO: choose better place for attributes check
        for entity in entities:
            if self.stage_name == 'actualization':
                continue
            finished_stages = str_to_set(entity.finished_stages)
            if not ('actualization' in finished_stages and entity.actualization_result == StageResult.SUCCESS):
                continue
            if not Attributes.get_attributes_to_check(entity):
                continue
            finished_stages -= {'actualization'}
            entity.finished_stages = str_from_set(finished_stages)

        entity_by_stage = self.dispatcher.dispatch_entities(entities, final_stage)

        for stage, stage_entities in entity_by_stage.items():
            LOG.info(f'Stage: {stage}, entities: {stage_entities}')
            self.send_to_stage(entity_cls, stage_entities, stage, delay)

    def dispatch_entities_delayed(
        self,
        entity_cls: AdvanceableEntityCls,
        entities: Iterable[DispatchableEntity],
        final_stage: str,
    ) -> None:
        self.dispatch_entities(entity_cls, entities, final_stage, None)


class StubUpdater(Updater):
    def run(self, output_path: str, temp_dir: str) -> None:
        # Do nothing in run stage
        pass
