# -*- coding: utf-8 -*-
from __future__ import annotations

from copy import deepcopy
from datetime import datetime, timedelta, timezone
from itertools import groupby
from typing import Any, Callable, Generator, Iterable, List, NamedTuple, Optional, Type, Union
from uuid import uuid4
import logging

from travel.hotels.content_manager.data_model.storage import EntityWithStatus
from travel.hotels.content_manager.data_model.types import EXCEPTIONAL_ENTITY_PRIORITY, ProcessType, uint
from travel.hotels.content_manager.data_model.options import (
    NirvanaWorkflowOptions, StageOptions, TolokaPoolOptions, TriggerOptions
)
from travel.hotels.content_manager.lib.common import dc_to_dict
from travel.hotels.content_manager.lib.delayed_executor import DelayedExecutor
from travel.hotels.content_manager.lib.path_info import PathInfo
from travel.hotels.content_manager.lib.path_mapping import PathMapping
from travel.hotels.content_manager.lib.storage import Storage
from travel.hotels.content_manager.lib.persistence_manager import Condition, PersistenceManager
from travel.hotels.content_manager.lib.processor_run_info import (
    ProcessorRunInfo, ProcessorRunProgress, ProcessorRunStatus
)
import travel.hotels.content_manager.data_model.storage as storage


LOG = logging.getLogger(__name__)

EXCEPTIONAL_POOL_PRIORITY = 98


FilterCallable = Callable[[storage.EntityWithStatus, Storage], bool]


class EntityGroupingKey(NamedTuple):
    priority: uint
    group_name: str


class ThreadFilter:
    name: str
    options_key: str
    local_storage: Storage

    def __init__(self, name: str, options_key: str, filter_callable: Optional[FilterCallable] = None):
        self.name = name
        self.options_key = options_key
        if filter_callable is not None:
            self.filter = filter_callable

    # noinspection PyUnusedLocal
    @staticmethod
    def filter(entity: storage.EntityWithStatus, local_storage: Storage) -> bool:
        return True

    def apply(self, entity: storage.EntityWithStatus) -> bool:
        return self.filter(entity, self.local_storage)


class Producer(object):

    def prepare_data(
            self,
            trigger: 'Trigger',
            thread_filter: ThreadFilter,
            local_storage: Storage,
            path: str,
            entities: Iterable[storage.EntityWithStatus],
            grouping_key: EntityGroupingKey,
            options: TriggerOptions,
    ):
        pass


class Trigger(object):

    __considerable_job_running_time__ = 10 * 60

    def __init__(
            self,
            process_type: ProcessType,
            stage_name: str,
            producer_cls: Type[Producer],
            thread_filters: Optional[List[ThreadFilter]],
            persistence_manager: PersistenceManager,
            delayed_executor: DelayedExecutor,
            path_info: PathInfo,
            entity_cls: storage.EntityClassWithStatus,
            other_entities: List[storage.EntityClass],
            start_ts: int,
            options: StageOptions,
            jobs_max: Optional[int] = None,
            job_size: Optional[int] = None,
            manual_start: bool = False,
            run_on_storage_change: bool = False,
            job_retry_count: Optional[int] = None,
            job_max_run_time: Optional[int] = None,
    ) -> None:
        self.process_type = process_type
        self.stage_name = stage_name

        self.producer_cls = producer_cls

        if thread_filters is None:
            thread_filters = [ThreadFilter('default', 'default')]
        self.thread_filters = thread_filters

        self.persistence_manager = persistence_manager
        self.delayed_executor = delayed_executor

        self.start_ts = start_ts
        msk_delta = timedelta(hours=3)
        self.start_time_msk = (datetime.fromtimestamp(start_ts) + msk_delta).replace(tzinfo=timezone(msk_delta))

        self.path_info = path_info
        self.stage_input_path = persistence_manager.join(path_info.stages_path, stage_name, 'input')
        self.stage_request_path = persistence_manager.join(path_info.requests_path, stage_name)
        self.entity_to_path = PathMapping(path_info).entity_to_path

        self.entity_cls = entity_cls
        self.other_entities = other_entities

        self.jobs_max = jobs_max
        self.job_size = job_size
        self.manual_start = manual_start
        self.run_on_storage_change = run_on_storage_change
        self.job_retry_count = job_retry_count
        self.job_max_running_time = job_max_run_time

        if options is None:
            LOG.info(f'No options for {stage_name}')
            options = StageOptions(
                triggers={'default': TriggerOptions()}
            )
        self.options = options

        if not self.persistence_manager.exists(self.stage_input_path):
            LOG.info(f'Creating {self.stage_input_path}')
            self.persistence_manager.create_dir(self.stage_input_path)

        self.jobs_count = 0

    def add_storage_link_delayed(self, path: str) -> None:
        self.delayed_executor.schedule(self.add_storage_link, path)

    def add_storage_link(self, path: str) -> None:
        LOG.info('Adding storage link')

        if not self.persistence_manager.exists(path):
            self.persistence_manager.create_dir(path)

        trigger_storage_path = self.persistence_manager.join(path, 'storage')
        storage_real_path = self.persistence_manager.realpath(self.path_info.storage_path)
        LOG.info(f'Linking from {trigger_storage_path} to {storage_real_path}')
        self.persistence_manager.link(trigger_storage_path, storage_real_path)

    def clean_path(self, path):
        if self.persistence_manager.exists(path):
            LOG.info(f'Deleting {path}')
            self.persistence_manager.delete(path)

    def get_new_trigger_path(self):
        return self.persistence_manager.join(self.stage_input_path, str(uuid4()))

    def update_entity_status(self, entities: Iterable[storage.EntityWithStatus]) -> None:
        status_field = f'status_{self.stage_name}'
        ts_field = f'{status_field}_ts'
        for entity in entities:
            setattr(entity, status_field, storage.StageStatus.IN_PROCESS)
            setattr(entity, ts_field, self.start_ts)

    def new_job(
            self,
            producer: Producer,
            thread_filter: ThreadFilter,
            local_storage: Storage,
            entities: Iterable[storage.EntityWithStatus],
            grouping_key: EntityGroupingKey,
            options: TriggerOptions,
    ) -> None:
        trigger_path = self.get_new_trigger_path()

        if not self.persistence_manager.exists(trigger_path):
            LOG.info('Creating trigger directory')
            self.persistence_manager.create_dir(trigger_path)

        producer.prepare_data(self, thread_filter, local_storage, trigger_path, entities, grouping_key, options)
        self.update_entity_status(entities)

        if self.persistence_manager.exists(trigger_path) and not self.persistence_manager.list(trigger_path):
            LOG.info('No job data. Removing directory')
            self.persistence_manager.delete(trigger_path)

    def get_local_storage(
            self,
            process_type: ProcessType,
            entity_cls: storage.EntityClassWithStatus,
            other_entities: List[storage.EntityClass],
    ) -> Optional[Storage]:
        if process_type == ProcessType.CATROOM:
            return self.get_local_storage_catroom(entity_cls, other_entities)
        elif process_type == ProcessType.WHITELIST:
            return self.get_local_storage_whitelist(entity_cls, other_entities)
        elif process_type == ProcessType.SERVICE_CLASS:
            if other_entities:
                raise RuntimeError(f'No other_entities expected for {process_type}')
            return self.get_local_storage_sc(entity_cls)
        else:
            raise RuntimeError(f'No suitable updater for {process_type}')

    def get_local_storage_catroom(
            self,
            entity_cls: storage.CatroomEntityClassWithStatus,
            other_entities: List[storage.CatroomEntityClass],
    ) -> Optional[Storage]:
        for other_entity in other_entities:
            if entity_cls is other_entity:
                raise RuntimeError('"other_entities" arg contains "entity" value')

        table_data = dict()

        src_path = self.entity_to_path[entity_cls]
        status_field = f'status_{self.stage_name}'

        LOG.info(f'Getting {entity_cls.__name__} from {src_path}')
        entities = self.persistence_manager.request_select(
            src_path=src_path,
            dc=entity_cls,
            match_conditions=[
                Condition(status_field, '==', storage.StageStatus.TO_BE_PROCESSED.value),
            ],
        )
        table_data[entity_cls] = entities

        permalink_ids = list()

        if entity_cls is storage.StoragePermalink:
            permalinks = [storage.StoragePermalink(**p) for p in entities]
            LOG.info(f'Got {len(permalinks)} permalinks to process')

            if not permalinks:
                return None

            permalink_ids = list({p.id for p in permalinks})

        if entity_cls is storage.StorageMapping:
            mappings = [storage.StorageMapping(**m) for m in entities]
            LOG.info(f'Got {len(mappings)} mappings to process')

            if not mappings:
                return None

            permalink_ids = list({m.permalink for m in mappings})

        for other_entity in other_entities:
            permalink_field = 'id' if other_entity is storage.StoragePermalink else 'permalink'
            src_path = self.entity_to_path[other_entity]

            LOG.info(f'Getting {other_entity.__name__} from {src_path}')
            entities = self.persistence_manager.request_select(
                src_path=src_path,
                dc=other_entity,
                match_conditions=[Condition(permalink_field, 'in', permalink_ids)],
            )
            table_data[other_entity] = entities

        local_storage = Storage()
        local_storage.apply_data(table_data)
        return local_storage

    def get_local_storage_whitelist(
            self,
            entity_cls: storage.WhitelistEntityClassWithStatus,
            other_entities: List[storage.WhitelistEntityClass],
    ) -> Optional[Storage]:
        for other_entity in other_entities:
            if entity_cls is other_entity:
                raise RuntimeError('"other_entities" arg contains "entity" value')

        table_data = dict()

        src_path = self.entity_to_path[entity_cls]
        status_field = f'status_{self.stage_name}'

        LOG.info(f'Getting {entity_cls.__name__} from {src_path}')
        entities = self.persistence_manager.request_select(
            src_path=src_path,
            dc=entity_cls,
            match_conditions=[
                Condition(status_field, '==', storage.StageStatus.TO_BE_PROCESSED.value),
            ],
        )
        table_data[entity_cls] = entities

        permalink_ids = list()

        if entity_cls is storage.StorageHotelWL:
            hotels = [storage.StorageHotelWL(**h) for h in entities]
            LOG.info(f'Got {len(hotels)} hotels to process')

            if not hotels:
                return None

            permalink_ids = list({h.permalink for h in hotels})

        if entity_cls is storage.StoragePermalinkWL:
            permalinks = [storage.StoragePermalinkWL(**p) for p in entities]
            LOG.info(f'Got {len(permalinks)} permalinks to process')

            if not permalinks:
                return None

            permalink_ids = list({p.permalink for p in permalinks})

        for other_entity in other_entities:
            src_path = self.entity_to_path[other_entity]

            LOG.info(f'Getting {other_entity.__name__} from {src_path}')
            entities = self.persistence_manager.request_select(
                src_path=src_path,
                dc=other_entity,
                match_conditions=[Condition('permalink', 'in', permalink_ids)],
            )
            table_data[other_entity] = entities

        local_storage = Storage()
        local_storage.apply_data(table_data)
        return local_storage

    def get_local_storage_sc(
            self,
            entity_cls: storage.SCDescriptionEntityClassWithStatus,
    ) -> Optional[Storage]:
        path_info = self.path_info

        src_path = path_info.storage_sc_descriptions_table
        status_field = f'status_{self.stage_name}'

        LOG.info(f'Getting {entity_cls.__name__} from {src_path}')
        entities = self.persistence_manager.request_select(
            src_path=src_path,
            dc=entity_cls,
            match_conditions=[
                Condition(status_field, '==', storage.StageStatus.TO_BE_PROCESSED.value),
            ],
        )

        if not entities:
            return None

        table_data = {entity_cls: entities}
        local_storage = Storage()
        local_storage.apply_data(table_data)
        return local_storage

    def update_storage(
            self,
            local_storage: Storage,
            entity_cls: storage.EntityClassWithStatus,
    ) -> None:
        dst_path = self.entity_to_path[entity_cls]
        status_field = f'status_{self.stage_name}'
        ts_field = f'{status_field}_ts'

        LOG.info(f'Updating storage for {entity_cls.__name__}')
        self.persistence_manager.request_upsert(
            src_data=local_storage.get_entity_data(entity_cls),
            dst_path=dst_path,
            dc=entity_cls,
            fields_to_update=[status_field, ts_field],
        )

    def get_batches(
        self,
        entities: Iterable[Any],
        size: int,
        ts_field: str,
        job_batching_delay_max: int
    ) -> Generator[List[Any]]:
        batch = list()
        sorted_entities = sorted(entities, key=lambda x: getattr(x, ts_field))
        while sorted_entities:
            batch = sorted_entities[:size]
            sorted_entities = sorted_entities[size:]
            if len(batch) == size:
                yield batch
                continue
            break
        if len(batch) == 0:
            return

        LOG.info('Got small batch')
        batch_min_ts = getattr(batch[0], ts_field)
        LOG.info(f'start_ts = {self.start_ts} batch_min_ts = {batch_min_ts}, batching_delay = {job_batching_delay_max}')
        batch_min_ts = batch_min_ts or self.start_ts
        job_delay = self.start_ts - batch_min_ts
        if job_delay < job_batching_delay_max:
            return

        yield batch

    def write_attribute(self, path, attribute, value):
        options_path = self.persistence_manager.join(path, '@' + attribute)
        LOG.info(f'Writing {value} to {options_path}')
        self.persistence_manager.set(options_path, value)

    def write_options_attribute(
            self,
            options: Union[NirvanaWorkflowOptions, TolokaPoolOptions],
            path: str,
            attribute: str,
            **kwargs,
    ):
        if options is None:
            options_dict = dict()
        else:
            options_dict = dc_to_dict(options)
        options_dict.update(kwargs)

        self.write_attribute(path, attribute, options_dict)

    def write_options(self, options: TriggerOptions, path: str, grouping_key: EntityGroupingKey):
        workflow_options = deepcopy(options.workflow_options)
        if grouping_key.priority == EXCEPTIONAL_ENTITY_PRIORITY:
            workflow_options.priority = EXCEPTIONAL_POOL_PRIORITY
        self.write_options_attribute(workflow_options, path, '_workflow_options')

        job_path = self.persistence_manager.split(path)[0]
        job_id = self.persistence_manager.split(job_path)[1]
        pool_name = f'{self.start_time_msk} [{self.stage_name}] {grouping_key.group_name} {job_id}'
        self.write_options_attribute(options.pool_options, path, '_pool_options', private_name=pool_name)
        self.write_attribute(path, '_priority', grouping_key.priority)

    @staticmethod
    def job_grouper(entity: EntityWithStatus) -> EntityGroupingKey:
        return EntityGroupingKey(entity.priority, entity.grouping_key)

    def run_group_jobs(
            self,
            grouping_key: EntityGroupingKey,
            group: List[EntityWithStatus],
            local_storage: Storage,
            producer: Producer,
            ts_field: str,
    ) -> bool:
        has_data = False
        for thread_filter in self.thread_filters:
            options = self.options.triggers.get(thread_filter.options_key)
            if options is None:
                raise RuntimeError(f'No suitable options with key "{thread_filter.options_key}"')
            thread_filter.local_storage = local_storage

            filtered_entities = (e for e in group if getattr(e, ts_field) <= self.start_ts)
            filtered_entities = filter(thread_filter.apply, filtered_entities)
            filtered_entities = list(filtered_entities)

            if len(filtered_entities) == 0:
                continue

            check_jobs_count = grouping_key.priority < EXCEPTIONAL_ENTITY_PRIORITY
            job_batching_delay_max = self.options.job_batching_delay_max if check_jobs_count else 0
            for batch in self.get_batches(filtered_entities, self.job_size, ts_field, job_batching_delay_max):
                if check_jobs_count and self.jobs_max and self.jobs_count >= self.jobs_max:
                    LOG.info(f'Actual jobs count ({self.jobs_count}) >= max jobs count ({self.jobs_max})')
                    break

                LOG.info(f'Running job for {thread_filter.name}')
                self.new_job(producer, thread_filter, local_storage, batch, grouping_key, options)
                if check_jobs_count:
                    self.jobs_count += 1

                has_data = True

        return has_data

    def retry_job(self, input_path: str, run_info: ProcessorRunInfo) -> bool:
        if self.job_retry_count is None:
            LOG.info('Job not supposed to be retried')
            return False
        retry_count = run_info.retry_count or 0
        retry_count += 1
        if retry_count > self.job_retry_count:
            return False
        run_info.retry_count = retry_count
        run_info.status = ProcessorRunStatus.UNKNOWN
        run_info.progress = ProcessorRunProgress.ENQUEUED
        run_info.processing_start_ts = None
        retry_path = f'{input_path}-{retry_count}'
        LOG.info(f'Retrying job. Moving from {input_path} to {retry_path}')
        self.persistence_manager.move(input_path, retry_path)
        return True

    def check_and_retry(self) -> None:
        LOG.info('Checking running jobs')
        if not self.persistence_manager.exists(self.stage_input_path):
            LOG.info('No jobs to check')
            return
        jobs = self.persistence_manager.list(self.stage_input_path)
        if not jobs:
            LOG.info('No jobs to check')
            return
        for job in jobs:
            LOG.info(f'Checking {job.name}')
            run_info = ProcessorRunInfo(self.persistence_manager, job.path)

            if run_info.status == ProcessorRunStatus.SUCCESS:
                LOG.info(f'Job {job.name} successfully finished')
                continue

            if run_info.status == ProcessorRunStatus.FAILED:
                LOG.info(f'Job {job.name} failed')
                self.retry_job(job.path, run_info)
                continue

            with self.persistence_manager.transaction():
                locked_by_processor = not self.persistence_manager.lock(job.path)
            if locked_by_processor:
                LOG.info('Job is locked by processor')
                continue

            job_is_running = run_info.progress == ProcessorRunProgress.RUNNING
            if job_is_running and run_info.processing_start_ts:
                running_time = self.start_ts - run_info.processing_start_ts
                if running_time < self.__considerable_job_running_time__:
                    LOG.info('Job running time is not considerable')
                    continue

                if running_time > self.job_max_running_time:
                    LOG.info(f'Job {job.name} running too long')
                    self.retry_job(job.path, run_info)
                    continue

    def process(self) -> None:
        LOG.info(f'Running {self.stage_name} trigger')

        self.check_and_retry()

        if self.persistence_manager.exists(self.stage_input_path):
            self.jobs_count = len(self.persistence_manager.list(self.stage_input_path))
        LOG.info(f'Stage jobs count: {self.jobs_count}')

        if self.manual_start:
            if not self.persistence_manager.exists(self.stage_request_path):
                LOG.info('No request for this stage')
                return

            jobs = self.persistence_manager.list(self.stage_request_path)
            jobs = sorted(jobs, key=lambda x: x.created_at)

            for job_id in jobs:
                if self.jobs_max and self.jobs_count >= self.jobs_max:
                    LOG.info(f'Actual jobs count ({self.jobs_count}) >= max jobs count ({self.jobs_max})')
                    break

                job_id = job_id.name
                src_path = self.persistence_manager.join(self.stage_request_path, job_id)
                dst_path = self.persistence_manager.join(self.stage_input_path, job_id)
                trigger_path = dst_path

                if not self.persistence_manager.is_dir(src_path):
                    trigger_path = self.get_new_trigger_path()
                    if not self.persistence_manager.exists(trigger_path):
                        LOG.info(f'Creating {trigger_path}')
                        self.persistence_manager.create_dir(trigger_path)
                    dst_path = self.persistence_manager.join(trigger_path, job_id)

                LOG.info(f'Copying {src_path} to {dst_path}')
                self.persistence_manager.copy(src_path, dst_path)

                self.add_storage_link_delayed(trigger_path)

                LOG.info(f'Clearing request at {self.stage_request_path}')
                self.persistence_manager.delete(src_path)

                self.jobs_count += 1

            return

        if self.run_on_storage_change:
            trigger_path = self.get_new_trigger_path()
            LOG.info(f'Storage changed. Adding storage link to {trigger_path}')
            self.add_storage_link_delayed(trigger_path)
            return

        LOG.info('Getting storage info')
        local_storage = self.get_local_storage(self.process_type, self.entity_cls, self.other_entities)

        if local_storage is None:
            LOG.info('Nothing to process')
            return

        if self.entity_cls is storage.StorageHotelWL:
            entities = [h for h in local_storage.hotels_wl.values()]
        elif self.entity_cls is storage.StoragePermalinkWL:
            entities = [p for p in local_storage.permalinks_wl.values()]
        elif self.entity_cls is storage.StoragePermalink:
            entities = [p for p in local_storage.permalinks.values()]
        elif self.entity_cls is storage.StorageMapping:
            entities = [m for m in local_storage.mappings.values()]
        elif self.entity_cls is storage.StorageSCDescription:
            entities = [d for d in local_storage.sc_descriptions.values()]
        else:
            raise Exception(f'Entity not supported by trigger: {self.entity_cls}')

        entities = sorted(entities, key=self.job_grouper, reverse=True)

        producer = self.producer_cls()

        status_field = f'status_{self.stage_name}'
        ts_field = f'{status_field}_ts'

        groups = dict()
        for key, group in groupby(entities, self.job_grouper):
            groups[key] = list(group)

        has_data = False
        for key in sorted(groups.keys(), reverse=True):
            if self.run_group_jobs(key, groups[key], local_storage, producer, ts_field):
                has_data = True

        if not has_data:
            LOG.info('Nothing to process')
            return

        self.update_storage(local_storage, self.entity_cls)

        LOG.info('Process finished')
