from datetime import date
from typing import Optional, Union
from functools import partial

from pydantic import BaseModel, Extra, Field
from tornado.web import HTTPError
import tornado.escape

from travel.proto.dicts.rasp.settlement_pb2 import TSettlement
from travel.proto.dicts.rasp.station_pb2 import TStation
from travel.proto.dicts.rasp.country_pb2 import TCountry

from travel.avia.api_gateway.application.fetcher import Fetcher
from travel.avia.api_gateway.application.fetcher.avatars import SettlementsImagesFetcher
from travel.avia.api_gateway.application.fetcher.avia_statistics import CountrySearchPopularCitiesFetcher
from travel.avia.api_gateway.application.fetcher.country.types import (
    EKlass,
    Linguistics,
    SettlementInfo,
    PriceWithExpiration,
    Currency,
    Image,
    FromPoint,
    ToCountry,
)
from travel.avia.api_gateway.application.fetcher.price_index import MinPriceBatchSearch
from travel.avia.api_gateway.lib.model_utils import get_point_key, get_settlement_by_code, get_settlement_id
from travel.avia.library.python.enum import NationalVersion, Language
from travel.avia.library.python.urls.search import C


class CountryLandingResponse(BaseModel):
    from_point: FromPoint = Field(alias='fromPoint')
    to_country: ToCountry = Field(alias='toCountry')
    settlements: list[SettlementInfo]

    class Config:
        allow_population_by_field_name = True


class CountryLandingRequest(BaseModel):
    from_: str = Field(alias='from')
    to: str = Field(alias='to')
    date_forward: date = Field(alias='dateForward')
    date_backward: Optional[date] = Field(alias='dateBackward')
    adults: int = 1
    children: int = 0
    infants: int = 0
    klass: EKlass = EKlass.economy
    nv: Optional[NationalVersion] = NationalVersion.RU
    lang: Optional[Language] = Language.RU
    limit: int = 10

    class Config:
        extra = Extra.forbid


def linguistics_for_station(station: TStation, language: Language) -> Linguistics:
    return Linguistics(
        accusative_case=station.TitleRuAccusativeCase,
        genetive_case=station.TitleRuGenitiveCase,
        nominative_case=station.TitleRuNominativeCase,
        preposition=station.TitleRuPreposition,
        prepositional_case=station.TitleRuPrepositionalCase
    )


def linguistics_for_settlement(settlement: TSettlement, language: Language) -> Linguistics:
    return Linguistics(
        accusative_case=settlement.Title.Ru.Accusative,  # 'Москву',
        genitive_case=settlement.Title.Ru.Genitive,  # 'Москвы',
        nominative_case=settlement.Title.Ru.Nominative,  # 'Москва',
        preposition=settlement.Title.Ru.LocativePreposition,  # 'в',
        prepositional_case=settlement.Title.Ru.Prepositional,  # 'Москве',
    )


def linguistics_for_country(country: TCountry, language: Language) -> Linguistics:
    return Linguistics(
        accusative_case=country.Title.Ru.Accusative,  # 'Москву',
        genitive_case=country.Title.Ru.Genitive,  # 'Москвы',
        nominative_case=country.Title.Ru.Nominative,  # 'Москва',
        preposition=country.Title.Ru.LocativePreposition,  # 'в',
        prepositional_case=country.Title.Ru.Prepositional,  # 'Москве',
    )


def linguistics_for_point(point: Union[TStation, TSettlement, TCountry], language: Language) -> Linguistics:
    if isinstance(point, TStation):
        return linguistics_for_station(point, language)
    if isinstance(point, TSettlement):
        return linguistics_for_settlement(point, language)
    if isinstance(point, TCountry):
        return linguistics_for_country(point, language)
    raise TypeError(f'Invalid point type {type(point)}')


def on_prices_and_images(
    cache_root,
    cities: list[dict[str, int]],
    prices_and_images,
    lang,
) -> []:
    if 'prices' in prices_and_images:
        prices_by_destination = {price['to_id']: price for price in prices_and_images['prices']['data']}
    else:
        prices_by_destination = {}
    images_by_destination = {settlement_id: bundle for settlement_id, bundle in prices_and_images['images'].items()}
    if 'country_restrictions' in prices_and_images:
        country_restrictions_by_destination = prices_and_images['country_restrictions']
    else:
        country_restrictions_by_destination = {}

    def price_for_destination(destination_id: int):
        if destination_id not in prices_by_destination:
            return None
        price = prices_by_destination[destination_id]
        if not price.get('min_price'):
            return None
        expired = price['expired']
        value = price['min_price']['value']
        currency = Currency.from_str_with_correction(price['min_price']['currency'])

        return PriceWithExpiration(expired=expired, value=value, currency=currency)

    def country_restrictions_for_destination(destination_id: int):
        destination_code = str(C(destination_id))
        return country_restrictions_by_destination.get(destination_code)

    def images_for_destination(destination_id: int):
        if destination_id not in images_by_destination:
            return None
        return [[Image(**size) for size in image_data] for image_data in images_by_destination[destination_id]]

    settlements = []
    for city in cities:
        settlement: TSettlement = cache_root.settlement_cache.get_settlement_by_id(city['cityId'])
        if not settlement:
            continue
        geo_data_dict = {
            'countryId': settlement.CountryId,
            'latitude': settlement.Latitude,
            'longitude': settlement.Longitude,
        }
        popularity = city.get('popularity')
        if popularity:
            geo_data_dict['popularity'] = popularity
        images = images_for_destination(settlement.Id)
        if not images:
            continue
        settlements.append(
            SettlementInfo(
                key=get_point_key(settlement),
                title=linguistics_for_point(settlement, lang),
                images=images,
                price=price_for_destination(settlement.Id),
                country_restrictions=country_restrictions_for_destination(settlement.Id),
                geo_data=geo_data_dict,
            )
        )
    return settlements


class CountryLandingFetcher(Fetcher):
    def __init__(self, country_request: CountryLandingRequest, *args, **kwargs):
        super(CountryLandingFetcher, self).__init__(*args, **kwargs)
        self.country_request: CountryLandingRequest = country_request

    def parse_point_to(self, slug_or_key: str) -> Union[TStation, TSettlement]:
        if slug_or_key.startswith('l') and slug_or_key[1:].isdigit():
            # country
            if country := self.cache_root.country_cache.get_country_by_id(int(slug_or_key[1:])):
                return country
        if country := self.cache_root.country_cache.get_country_by_code3(slug_or_key):
            return country
        if country := self.cache_root.country_cache.get_country_by_code(slug_or_key):
            return country

    def fetch(self, fetchers=None):
        try:
            point_from = get_settlement_by_code(self.cache_root, self.country_request.from_)
        except BaseException:
            raise HTTPError(404, reason=f'Unknown point {tornado.escape.url_escape(self.country_request.from_)}'[:40])

        if point_from is None:
            raise HTTPError(404, reason=f'Unknown point {tornado.escape.url_escape(self.country_request.from_)}'[:40])

        point_to = self.parse_point_to(self.country_request.to)
        if point_to is None:
            raise HTTPError(404, reason=f'Unknown point {tornado.escape.url_escape(self.country_request.to)}'[:40])

        on_statistics = partial(self.on_statistics, point_from, point_to)
        Fetcher(finish_callback=on_statistics).fetch(
            [CountrySearchPopularCitiesFetcher(country_id=point_to.Id)],
        )

    def on_statistics(self, point_from: TSettlement, point_to: TCountry, cities: list[dict[str, int]]):
        from_id = get_settlement_id(point_from)
        cities = [city for city in cities if not city['cityId'] == from_id]
        cities = cities[: self.country_request.limit]

        if not cities:
            raise HTTPError(404, reason='No popular cities found in requested country')

        base_direction_request = {
            'forward_date': str(self.country_request.date_forward),
            'backward_date': str(self.country_request.date_backward) if self.country_request.date_backward else None,
            'from_id': from_id,
            'adults_count': self.country_request.adults,
            'children_count': self.country_request.children,
            'infants_count': self.country_request.infants,
        }
        fetcher = Fetcher(finish_callback=partial(self.on_prices_and_images, point_from, point_to, cities))
        fetchers = [
            SettlementsImagesFetcher(
                cache_root=self.cache_root, field='images', settlement_ids=[city['cityId'] for city in cities]
            ),
        ]
        fetcher.waiting_fields = {'images'}
        if self.country_request.klass == EKlass.economy:
            fetchers.append(
                MinPriceBatchSearch(
                    field='prices',
                    national_version=self.country_request.nv,
                    request={'min_requests': [dict(base_direction_request, to_id=city['cityId']) for city in cities]},
                )
            )
            fetcher.waiting_fields.add('prices')
        fetcher.fetch(fetchers)

    def on_prices_and_images(
        self,
        point_from: TSettlement,
        point_to: TCountry,
        cities: list[dict[str, int]],
        prices_and_images,
    ):
        settlements = on_prices_and_images(self.cache_root, cities, prices_and_images, self.country_request.lang)
        self.finish_callback(
            CountryLandingResponse(
                from_point=FromPoint(
                    key=get_point_key(point_from),
                    title=linguistics_for_point(point_from, self.country_request.lang),
                ),
                to_country=ToCountry(
                    id=point_to.GeoId,
                    key=get_point_key(point_to),
                    title=linguistics_for_point(point_to, self.country_request.lang),
                ),
                settlements=settlements,
            ).dict(by_alias=True)
        )
