# -*- coding: utf-8 -*-
from collections import defaultdict
from logging import getLogger

from typing import List  # noqa

from travel.avia.library.python.price_prediction import PriceCategory
from travel.avia.price_prediction.api.v1 import check_price_pb2
from travel.avia.ticket_daemon.ticket_daemon.api.flights import Variant  # noqa
from travel.avia.ticket_daemon.ticket_daemon.api.query import Query  # noqa
from travel.avia.ticket_daemon.ticket_daemon.lib import feature_flags
from travel.avia.library.python.ticket_daemon.date import naive_datetime_to_timestamp
from travel.avia.ticket_daemon.ticket_daemon.lib.tvm_provider import tvm_provider
from travel.avia.ticket_daemon.ticket_daemon.settings.price_prediction import (
    build_price_prediction_client, PRICE_PREDICTION_TIMEOUT
)

logger = getLogger(__name__)

price_prediction_client = build_price_prediction_client()


def safe_add_price_prediction_category(variants, query):
    # type: (List[Variant], Query) -> None
    try:
        add_price_prediction_category(variants, query)
    except Exception:
        logger.exception('Error in add_price_prediction_category')


def add_price_prediction_category(variants, query):
    # type: (List[Variant], Query) -> None
    if not feature_flags.use_price_prediction():
        return

    # В данных для прогноза сейчас только варианты в одну сторону, поэтому для туда-обратно не оцениваем
    if query.date_backward:
        logger.debug('has backward, skip apprising')
        return

    # В данных для прогноза только рубли,
    # для русской версии для других валют можем брать рублевые цены из national_tariff
    if query.national_version != 'ru':
        logger.debug('national version is not ru - %s', query.national_version)
        return

    point_from_key = (query.point_from.get_related_settlement() or query.point_from).point_key
    point_to_key = (query.point_to.get_related_settlement() or query.point_to).point_key

    # 1. В данных для прогноза минимальные цены.
    # 2. Поэтому для цены без багажа и с багажем может быть разная категория цены (good, bad, unknown)
    # Поэтому решили для одного партнера для всех цен одного варианта брать оценку
    # по минимальной цене для этого партнера для этого варианта
    variants_by_apprising_key = defaultdict(list)
    for variant in variants:
        local_departure = naive_datetime_to_timestamp(variant.forward.segments[0].local_departure)
        route = ';'.join(s.number for s in variant.forward.segments)
        variants_by_apprising_key[(local_departure, route, variant.partner)].append(variant)

    variants_by_id = {}
    sub_requests = {}
    for i, (key, gathered_variants) in enumerate(variants_by_apprising_key.iteritems()):
        local_departure, route, _partner = key
        price = min(
            v.tariff.value if v.tariff.currency == 'RUR' else v.national_tariff.value
            for v in gathered_variants
        )
        sub_requests[i] = check_price_pb2.TCheckPriceReq(
            PointFromKey=point_from_key,
            PointToKey=point_to_key,
            Routes=route,
            LocalDeparture=local_departure,
            AdultSeats=query.adults,
            ChildrenSeats=query.children,
            InfantSeats=query.infants,
            Price=price,
        )
        variants_by_id[i] = gathered_variants
    request = check_price_pb2.TCheckPricesReq(CheckPricesReq=sub_requests)

    tvm_service_ticket = tvm_provider.get_ticket('price-prediction')
    categories = price_prediction_client.check_prices(
        request, timeout=PRICE_PREDICTION_TIMEOUT, tvm_service_ticket=tvm_service_ticket
    )
    for i, gathered_variants in variants_by_id.iteritems():
        price_category = PriceCategory.from_proto(categories[i])
        for variant in gathered_variants:
            variant.price_category = price_category
