# -*- coding: utf-8 -*-
import logging
import time
import zlib
from collections import defaultdict

from travel.avia.library.python.common.utils import environment
from travel.avia.library.python.ticket_daemon.protobuf_converting.big_wizard.filtering import search_result_filtering
from travel.avia.library.python.ticket_daemon.protobuf_converting.big_wizard.search_result_converter import SearchResultConverter
from travel.proto.avia.wizard.search_result_pb2 import SearchResult

from travel.avia.ticket_daemon_api.environment import YANDEX_ENVIRONMENT_TYPE
from travel.avia.ticket_daemon_api.jsonrpc.lib import feature_flags
from travel.avia.ticket_daemon_api.jsonrpc.lib.flights import IATAFlight
from travel.avia.ticket_daemon_api.jsonrpc.lib.partner_store_time_provider import partner_store_time_provider
from travel.avia.ticket_daemon_api.jsonrpc.lib.result import Statuses
from travel.avia.ticket_daemon_api.jsonrpc.lib.result.collector.status import collect_statuses
from travel.avia.ticket_daemon_api.jsonrpc.lib.result.collector.variants_fabric import (
    VariantsFabric, SimpleApiVariants, CollectingModes
)
from travel.avia.ticket_daemon_api.jsonrpc.lib.result.collector import variants_fetcher as fetcherlib
from travel.avia.ticket_daemon_api.jsonrpc.lib.result.collector.test_context import parse_test_context, parse_test_context_proto
from travel.avia.ticket_daemon_api.jsonrpc.lib.ydb import cache as service_cache, wizard_cache
from travel.avia.ticket_daemon_api.jsonrpc.lib.yt_loggers.yt_logger import YtLogger


log = logging.getLogger(__name__)
json_logger = YtLogger(__name__, environment)
MINI_SAAS_MAX_VARIANTS_COUNT = 20
MAX_VARIANTS_IN_ONE_UPDATE = 1000
# Первые несколько секунд будем отвечать статусом 'querying'
# даже если в мемкэше пусто
QUERY_INIT_TIMEOUT = 9


def collect_variants(
    query,
    skip_partner_codes=None,
    result_revisions=None,
    allow_portional=False,
    prefilter=None,
    partner_codes=None,
    mode=CollectingModes.instant_search,
    max_age=None,
    prefetched_results=None,
    test_context=None,
):
    enabled_partner_codes = query.get_enabled_partner_codes()

    if partner_codes:
        enabled_partner_codes = [c for c in enabled_partner_codes if c in partner_codes]

    if skip_partner_codes:
        enabled_partner_codes = [c for c in enabled_partner_codes if c not in skip_partner_codes]

    return collect_variants_v3(
        query.queries[0], enabled_partner_codes, prefilter=prefilter,
        result_revisions=result_revisions,
        allow_portional=allow_portional,
        mode=mode,
        max_age=max_age,
        test_context=test_context,
    )


def _get_statuses(partner_code_to_status, seconds_after_init):
    """Convert status JSONs to string representation"""
    statuses = {}

    for p_code, status_data in partner_code_to_status.iteritems():
        if status_data is None:
            _elapsed = seconds_after_init < QUERY_INIT_TIMEOUT
            statuses[p_code] = Statuses.QUERYING if _elapsed else Statuses.FAIL
            continue

        statuses[p_code] = status_data['status']

        if Statuses.is_done(status_data['status']):
            statuses[p_code] = Statuses.DONE

    return statuses


def collect_variants_v3(
    query,
    enabled_partner_codes,
    prefilter=None,
    result_revisions=None,
    allow_portional=False,
    mode=CollectingModes.instant_search,
    max_age=None,
    prefetched_results=None,
    test_context=None,
):
    if YANDEX_ENVIRONMENT_TYPE == 'production':
        test_context = None

    if result_revisions is None:
        result_revisions = {}

    variants_by_partner = {}
    revisions = {}

    partner_code_to_status = collect_statuses(
        query.qkey, enabled_partner_codes,
    )
    seconds_after_init = time.time() - query.created

    statuses = _get_statuses(partner_code_to_status, seconds_after_init)
    to_collect_partner_codes = enabled_partner_codes
    if allow_portional:
        done_status_partners = sorted(
            enabled_partner_codes,
            key=lambda p_code: (
                _get_status_variants_count(partner_code_to_status[p_code]) > 0,
                statuses.get(p_code) == Statuses.DONE
            ),
            reverse=True
        )

        variants_count = 0
        to_collect_partner_codes = []
        for p_code in done_status_partners:
            if variants_count > MAX_VARIANTS_IN_ONE_UPDATE:
                statuses[p_code] = Statuses.QUERYING  # Set querying status for left partners for next update request
                continue
            variants_count += _get_status_variants_count(partner_code_to_status[p_code])
            to_collect_partner_codes.append(p_code)

    results = []
    if prefetched_results:
        results = prefetched_results
    mockAviaVariants = False
    if test_context:
        tk = parse_test_context_proto(test_context)
        mockAviaVariants = tk.MockAviaVariants
    if mockAviaVariants:
        results = parse_test_context(test_context)
    elif not results:
        results = fetcherlib.fetch_and_get_variants(query, to_collect_partner_codes)
    variants_count = 0
    instant_search_variants_count = 0
    for partner_code, unpacked_variants in results:
        if allow_portional and variants_count > MAX_VARIANTS_IN_ONE_UPDATE:
            statuses[partner_code] = Statuses.QUERYING  # Set querying status for left partners for next update request
            continue
        if not unpacked_variants:
            revisions[partner_code] = result_revisions.get(partner_code)
            continue

        try:
            variants = VariantsFabric.create(
                unpacked_variants,
                status=statuses[partner_code],
                query=query,
                partner_code=partner_code,
                last_revision=result_revisions.get(partner_code, 0),
                mode=mode,
                max_age=max_age,
            )
            statuses[partner_code] = variants.status
            revisions[partner_code] = variants.revision
            if prefilter:
                variants.filter(prefilter)

            if variants:
                variants_by_partner[partner_code] = variants
                variants_count += len(variants)
                if variants.status == Statuses.OUTDATED:
                    instant_search_variants_count += len(variants)

        except Exception as e:
            log.exception('collect_partner_variants: %r', e)
            statuses[partner_code] = Statuses.FAIL

    json_logger.log({
        'qid': query.id,
        'base_qid': query.base_qid,
        'all_variants_count': variants_count,
        'instant_search_variants_count': instant_search_variants_count,
        'partners': variants_by_partner.keys(),
        'info': [
            {
                'code': p_code,
                'original_status': partner_code_to_status.get(p_code),
                'status': statuses.get(p_code),
                'variants_count': len(variants_by_partner[p_code]) if p_code in variants_by_partner else None,
            }
            for p_code in enabled_partner_codes
        ],
    })
    return variants_by_partner, statuses, revisions


def _get_status_variants_count(status):
    if status is None:
        return 0

    return status.get('all_variants_count', 0)


def collect_wizard_variants(query):
    wizard_content = None
    for row in wizard_cache.WizardCache.get(query, columns=('search_result',)):
        wizard_content = row['search_result']
        if not wizard_content:
            continue

        try:
            wizard_content = zlib.decompress(wizard_content)
        except zlib.error:
            pass

        search_result = SearchResult()
        search_result.ParseFromString(wizard_content)
        search_result_filtering.slice(search_result, limit=MINI_SAAS_MAX_VARIANTS_COUNT)
        wizard_content = SearchResultConverter().to_dictionary(search_result)
        wizard_content['fares'] = list(wizard_content['fares'])

    variants_by_partner = {}

    if not wizard_content:
        return variants_by_partner, None

    variants_by_partner = _prepare_wizard_variants(wizard_content)

    return variants_by_partner, wizard_content['qid']


def _make_route(datetime_deserializer, segment):
    return IATAFlight.make_flight_tag(
        datetime_deserializer.deserialize(segment['depDt']), segment['number']
    )


def _prepare_wizard_variants(search_result):
    """Преобразуем варианты из формата колдуна в формат тикетдемона"""
    variants_by_partner = {}
    fares_by_partner = defaultdict(list)
    for fare in search_result['fares']:
        if feature_flags.fill_fare_family_enabled():
            fare['fare_families'] = [[None] * len(fare['route'][0]), [None] * len(fare['route'][1])]
        fares_by_partner[fare['partner']].append(fare)

    for p_code, fares in fares_by_partner.iteritems():
        variants_by_partner[p_code] = SimpleApiVariants(search_result['flights'], fares)

    return variants_by_partner


def get_results_meta(query, partner_codes):
    return _get_results_meta_from_ydb(query, partner_codes)


def _get_results_meta_from_ydb(query, partner_codes):
    result = dict.fromkeys(partner_codes)

    ydb_rows = service_cache.ServiceCache.get_many(
        query, partner_codes,
        columns=('partner_code', 'expires_at', 'created_at'),
    )
    for row in ydb_rows:
        result[row['partner_code']] = dict(
            created=row['created_at'],
            expire=row['created_at'] + partner_store_time_provider.get_result_time(row['partner_code']),
            instant_search_expiration_time=row['expires_at'],
        )

    return result
