# coding: utf-8
from __future__ import unicode_literals, absolute_import, division, print_function

import logging
import threading
from concurrent.futures import ThreadPoolExecutor
from datetime import timedelta

from django.conf import settings
from django.core.cache import cache
from django.db import connection

from common.settings.utils import define_setting
from travel.rasp.library.python.common23.date import environment
from common.utils.caching import global_cache_sync_set, global_cache_sync_add, global_cache_sync_delete
from travel.rasp.train_api.train_partners.base import ApiError, PartnerError
from travel.rasp.train_api.wizard_api.client import train_wizard_api_client

log = logging.getLogger(__name__)

define_setting('TRAIN_QUERY_POOL_WORKERS', default=10, converter=int)


class WorkerApiError(ApiError):
    def get_user_message(self):
        return None


class WorkerEmptyResultError(WorkerApiError):
    code = 'ufs_empty_result_error'


class WorkerResult(object):
    STATUS_PENDING = 'pending'
    STATUS_SUCCESS = 'success'
    STATUS_ERROR = 'error'

    def __init__(self, query, status, error=None):
        self.query = query
        self.status = status
        self.error = error
        self.created_at = environment.now_aware()

    @property
    def cache_timeout(self):
        raise NotImplementedError()

    @property
    def expired_at(self):
        return self.created_at + timedelta(seconds=self.cache_timeout)

    def expired_with_timeout(self, expired_timeout):
        return self.created_at + timedelta(minutes=expired_timeout)

    def _change_cache(self, reset_cache, old_status):
        if self.cache_timeout <= 0:
            if old_status == self.STATUS_PENDING:
                global_cache_sync_delete(self.query.cache_key)
            return

        if reset_cache:
            global_cache_sync_set(self.query.cache_key, self, self.cache_timeout)
        else:
            global_cache_sync_add(self.query.cache_key, self, self.cache_timeout)

    def update_cache(self):
        reset_cache = False

        old_result = self.get_from_cache(self.query)
        old_status = None
        if old_result:
            old_status = old_result.status
            if old_result.expired_at <= environment.now_aware():
                reset_cache = True

        if self.status != self.STATUS_PENDING:
            reset_cache = True

        log.info('Кешируем результат status=%s cache_key=%s, сброс кеша: %s, время жизни: %s',
                 self.status, self.query.cache_key, reset_cache, self.cache_timeout)
        self._change_cache(reset_cache, old_status)

        if (self.status == self.STATUS_SUCCESS or self.is_empty_result()) and not self.query.mock_im:
            train_wizard_api_client.store_tariffs(self)

    def is_empty_result(self):
        return isinstance(self.error, WorkerEmptyResultError) or (
            isinstance(self.error, PartnerError) and self.error.is_empty_result_error()
        )

    @classmethod
    def get_from_cache(cls, query):
        return cache.get(query.cache_key)

    def __eq__(self, other):
        if not isinstance(other, WorkerResult):
            return NotImplemented
        return (self.error == other.error and
                self.query == other.query and
                self.status == other.status)


class TrainTariffsResult(WorkerResult):
    def __init__(self, query, status, segments=None, error=None):
        super(TrainTariffsResult, self).__init__(query, status, error=error)

        self.segments = segments or []

    def __eq__(self, other):
        if not isinstance(other, TrainTariffsResult):
            return NotImplemented
        return super(TrainTariffsResult, self).__eq__(other) and self.segments == other.segments

    def __str__(self):
        return "{}(status={}, segments_count={}, error={})".format(
            type(self).__name__, self.status, len(self.segments) if self.segments else None, self.error
        )


class BaseErrorResult(Exception):
    ResultClass = None

    def __init__(self, train_query, error, **kwargs):
        self.result = self.ResultClass(train_query, TrainTariffsResult.STATUS_ERROR, error=error, **kwargs)
        super(BaseErrorResult, self).__init__()


class Worker(threading.Thread):
    def run(self):
        super(Worker, self).run()

        connection.close()

    def start_daemon(self):
        self.daemon = True
        self.start()


worker_pool = ThreadPoolExecutor(settings.TRAIN_QUERY_POOL_WORKERS, thread_name_prefix='background_train_query_pool')


def _run_pool_worker(worker):
    return worker.target(*worker.args)


class PoolWorker(object):
    def __init__(self, target, args):
        self.target = target
        self.args = args

    def start(self):
        worker_pool.map(_run_pool_worker, (self,))
