# coding: utf8
from __future__ import unicode_literals, absolute_import, division, print_function

import logging
import time
from datetime import timedelta
from typing import List, Dict, Set, AnyStr, Iterable
from collections import defaultdict

from django.conf import settings

from travel.rasp.library.python.common23.date import environment
from common.models.tariffs import SuburbanTariffProvider
from common.settings.utils import define_setting
from common.settings.configuration import Configuration
from travel.library.python.tracing.instrumentation import traced_function

from travel.rasp.suburban_selling.selling.tariffs.tariffs_configuration import TariffsConfiguration
from travel.rasp.suburban_selling.selling.tariffs.interfaces import (
    TariffsProvider, TariffKey, SellingTariff, carrier_to_json, TariffKeyData, TariffKeyDataStatus
)
from travel.rasp.suburban_selling.selling.tariffs.selling_companies import SuburbanCarrierCompany, SuburbanCarrierCode

log = logging.getLogger(__name__)

define_setting('GET_TARIFFS_TIMEOUT', {Configuration.PRODUCTION: 0.1}, default=5)
define_setting('GET_TARIFFS_POLL_FREQUENCY', default=0.02)


class TariffsGetter(object):
    """Получение тарифов продажи электричек по всем провайдерам"""

    def __init__(self, providers, selling_flows, barcode_presets, tariffs_configuration):
        # type: (Dict[AnyStr, TariffsProvider], Iterable[AnyStr], Iterable[AnyStr], TariffsConfiguration) -> None

        self.tariffs_configuration = tariffs_configuration
        self.selling_flows = set(selling_flows)  # type: Set[AnyStr]
        self.barcode_presets = set(barcode_presets)  # type: Set[AnyStr]

        self.providers = {}                      # type: Dict[AnyStr, TariffsProvider]
        for provider_code, provider in providers.items():
            if provider.get_selling_flows() & self.selling_flows:
                self.providers[provider_code] = provider

        provider_codes = list(self.providers.keys())
        self.provider_codes_by_priority = (  # type: List[AnyStr]
            sorted(provider_codes, key=lambda p: SuburbanTariffProvider.PRIORITY.index(p))
        )
        self.priority_by_provider_codes = {
            code: priority for priority, code in enumerate(self.provider_codes_by_priority)
        }

    def prepare_start_tariff_keys_data(self, tariff_keys):
        # type: (List[TariffKey]) -> Dict[AnyStr, Dict[TariffKey, TariffKeyData]]
        # собираем ключи сегментов для каждого провайдера
        data_by_providers_and_keys = defaultdict(defaultdict)
        for tariff_key in tariff_keys:
            provider_codes = self.tariffs_configuration.get_company_provider_codes(tariff_key.company)
            for provider_code in provider_codes:
                key_data = TariffKeyData([], TariffKeyDataStatus.NO_DATA)
                data_by_providers_and_keys[provider_code][tariff_key] = key_data

        return data_by_providers_and_keys

    def _filter_tariffs_by_barcode_presets(self, data_by_providers_and_keys):
        # type: (Dict[AnyStr, Dict[TariffKey, TariffKeyData]]) -> None
        # Оставляем только тарифы с поддерживаемыми пресетами для штрихкодов
        # Пока для каждого перевозчика однозначно задан код пресета, но в будущем возможна более сложная логика
        for provider_code, keys_data in data_by_providers_and_keys.items():
            for key_data in keys_data.values():
                key_data.tariffs = [
                    tariff for tariff in key_data.tariffs if (
                        self.tariffs_configuration.carriers_by_codes.get(tariff.partner).barcode_preset
                        in self.barcode_presets
                        or provider_code == SuburbanCarrierCode.AEROEXPRESS
                    )
                ]

    def get_tariffs_from_providers(self, data_by_providers_and_keys):
        # type: (Dict[AnyStr, Dict[TariffKey, TariffKeyData]]) -> None
        """
        Получение тарифов из провайдеров. Если не получается сразу получить тарифы, то делается поллинг
        В процессе заполняются данные в data_by_providers_and_keys
        """

        cur_dt = environment.now()
        need_request = True
        # Продолжаем попытки получить тарифы, пока они не будут получены, но не больше заданного таймаута
        while need_request and environment.now() - cur_dt < timedelta(seconds=settings.GET_TARIFFS_TIMEOUT):
            need_request = False
            for provider_code, getter_keys_data in data_by_providers_and_keys.items():
                keys_to_process = [
                    key
                    for key, key_data in getter_keys_data.items()
                    if key_data.status != TariffKeyDataStatus.ACTUAL
                ]
                if not keys_to_process:
                    continue
                if provider_code not in self.providers:
                    continue

                keys_data = self.providers[provider_code].get_tariffs(keys_to_process)

                for key, key_data in keys_data.items():
                    if key in getter_keys_data:
                        if (
                            key_data.status == TariffKeyDataStatus.ACTUAL or
                            (
                                key_data.status == TariffKeyDataStatus.OLD and
                                getter_keys_data[key].status == TariffKeyDataStatus.NO_DATA
                            )
                        ):
                            getter_keys_data[key] = key_data
                        if getter_keys_data[key].status != TariffKeyDataStatus.ACTUAL:
                            need_request = True

            # Если данные не удалось получить из кэша, то мы асинхронно запустили получение их от провайдера
            # Затем мы все же пытаемся дождаться результата, недолго, чтобы не перегрузить воркеры
            if need_request:
                time.sleep(settings.GET_TARIFFS_POLL_FREQUENCY)

        self._filter_tariffs_by_barcode_presets(data_by_providers_and_keys)

    def _get_tariffs_json(self, tariffs_by_provider, tariff_ids_by_keys, providers_by_tariffs_keys, carriers):
        # type: (Dict[AnyStr, Set[SellingTariff]], Dict[TariffKey, List[int]], Dict[TariffKey, AnyStr], Iterable[SuburbanCarrierCompany]) -> Dict
        return {
            'selling_tariffs': [
                {
                    'provider': provider,
                    'tariffs': [tariff.to_json() for tariff in tariffs_by_provider[provider]]
                }
                for provider in tariffs_by_provider
            ],
            'keys': [
                {
                    'key': key,
                    'provider': providers_by_tariffs_keys.get(key),
                    'tariff_ids': [tid for tid in tariff_ids_by_keys[key]]
                } for key in tariff_ids_by_keys
            ],
            'selling_partners': [carrier_to_json(carrier) for carrier in carriers]
        }

    def make_tariffs_response(self, data_by_providers_and_keys):
        # type: (Dict[AnyStr, Dict[TariffKey, TariffKeyData]]) -> Dict[str]

        tariffs_by_provider = defaultdict(set)
        tariffs_ids_by_keys = defaultdict(list)
        providers_by_tariffs_keys = {}
        partners_by_code = {}

        # Заполняем id тарифов, сквозной для всех провайдеров
        tariff_id = 1
        for provider_code in self.provider_codes_by_priority:
            if provider_code in data_by_providers_and_keys:
                keys_data = data_by_providers_and_keys[provider_code]

                for tariff_key, key_data in keys_data.items():
                    if tariff_key not in providers_by_tariffs_keys or not tariffs_ids_by_keys[tariff_key]:
                        providers_by_tariffs_keys[tariff_key] = provider_code
                        tariffs_ids_by_keys.setdefault(tariff_key, [])
                    if key_data.status == TariffKeyDataStatus.NO_DATA:
                        continue

                    for tariff in key_data.tariffs:
                        partner = self.tariffs_configuration.carriers_by_codes.get(tariff.partner)
                        if not partner:
                            log.error('Unknown tariff partner code: {}'.format(tariff.partner))
                            continue

                        if not tariff.tariff_id:
                            tariff.set_tariff_id(tariff_id)
                            tariff_id += 1
                            tariffs_by_provider[provider_code].add(tariff)

                        if tariff.tariff_id not in tariffs_ids_by_keys[tariff_key]:
                            tariffs_ids_by_keys[tariff_key].append(tariff.tariff_id)

                        partners_by_code[partner.code] = partner

        log.debug(
            'get_tariffs_for_tariff_keys: end, found %s tariff keys with tariffs',
            sum(1 for tariff_ids in tariffs_ids_by_keys.values() if len(tariff_ids) > 0)
        )

        return self._get_tariffs_json(
            tariffs_by_provider, tariffs_ids_by_keys, providers_by_tariffs_keys, partners_by_code.values()
        )

    @traced_function
    def get_tariffs(self, tariff_keys):
        # type: (List[TariffKey]) -> Dict[str]
        """
        Для каждого сегмента получаем список тарифов, разбитых по провайдерам,
        и список ключей со ссылками на тарифы
        """
        log.debug('get_tariffs_for_tariff_keys: start: %s tariff keys', len(tariff_keys))

        data_by_providers_and_keys = self.prepare_start_tariff_keys_data(tariff_keys)
        self.get_tariffs_from_providers(data_by_providers_and_keys)
        return self.make_tariffs_response(data_by_providers_and_keys)
