# -*- encoding: utf-8 -*-
import os
import re

from datetime import datetime
from functools import partial
from urllib.parse import unquote

import tornado
import tornado.escape
from tornado.web import HTTPError, RequestHandler, asynchronous
from pydantic.error_wrappers import ValidationError

from travel.avia.api_gateway.application.fetcher import Fetcher
from travel.avia.api_gateway.application.fetcher.backend import (
    AirportsFetcher,
    AirlineFetcher,
    StationFetcher,
    ReviewsFetcher,
    AirlinesSynonymsFetcher,
    TravelRecipesFetcher,
    AirportTabloSourceFetcher,
    NearDistancesFetcher,
    NearDirectionsFetcher,
    TopFlightsFetcher,
    PartnersPopularByRouteFetcher,
    CurrencyFetcher,
    GeoPointFetcher,
)
from travel.avia.api_gateway.application.fetcher.city_to_landing.fetcher import CityToLandingFetcher
from travel.avia.api_gateway.application.fetcher.city_to_landing.mapper import CityToLandingMapper
from travel.avia.api_gateway.application.fetcher.country.fetcher import CountryLandingFetcher, CountryLandingRequest
from travel.avia.api_gateway.application.fetcher.anywhere_landing.fetcher import AnywhereLandingFetcher, AnywhereLandingRequest
from travel.avia.api_gateway.application.fetcher.flight_landing.fetcher import FlightLandingFetcher
from travel.avia.api_gateway.application.fetcher.flight_landing.mapper import FlightLandingMapper
from travel.avia.api_gateway.application.fetcher.route_landing.fetcher import RouteLandingFetcher
from travel.avia.api_gateway.application.fetcher.route_landing.mapper import RouteLandingMapper
from travel.avia.api_gateway.application.fetcher.personal_search.fetcher import PersonalSearchFetcher
from travel.avia.api_gateway.application.fetcher.personal_search.mapper import PersonalSearchMapper
from travel.avia.api_gateway.application.fetcher.slugs.by_route import SlugsByRouteFetcher
from travel.avia.api_gateway.application.fetcher.is_possible_trip.fetcher import IsPossibleTripFetcher
from travel.avia.api_gateway.application.fetcher.conditional_flight_fetcher import (
    get_airport_flight_list_fetcher,
    get_flight_by_departure_date_fetcher,
    get_flight_by_number_fetcher,
)
from travel.avia.api_gateway.application.fetcher.flight_supplement import flight_supplement, flight_extras
from travel.avia.api_gateway.application.fetcher.price_index import MinPriceBatchSearch
from travel.avia.api_gateway.application.fetcher.shared_flights import (
    FlightExtrasByDepartureDateFetcher,
)
from travel.avia.api_gateway.application.mapper.factory import get_mapper
from travel.avia.api_gateway.application.stat import Stat
from travel.avia.api_gateway.application.yt_logging import yt_request_log
from travel.avia.api_gateway.lib.coding import decode

RE_STAT_HIT_KEY = re.compile(r'(.+)_hit_ammm')


class PingHandler(RequestHandler):
    def get(self):
        self.set_header('Content-Type', 'text/plain')
        if self.application.shutdown_flag.is_set():
            self.set_status(410)
            self.write('Shutdown')
        else:
            self.write('OK')


class ShutdownHandler(RequestHandler):
    def post(self):
        self.application.shutdown_flag.set()
        self.set_header('Content-Type', 'text/plain')
        self.write('OK')


class VersionHandler(RequestHandler):
    def get(self):
        self.set_header('Content-Type', 'text/plain')
        self.write('Tornado version: {}\n'.format(tornado.version))
        self.write('Package tag: {}\n'.format(os.getenv('DEPLOY_DOCKER_IMAGE')))
        self.write('Package hash: {}\n'.format(os.getenv('DEPLOY_DOCKER_HASH')))


class StatHandler(RequestHandler):
    def get(self):
        self.set_header('Content-Type', 'application/json')

        stats = Stat.get_and_clear(
            [
                get_flight_by_departure_date_fetcher().service,
                StationFetcher.service,
                AirlineFetcher.service,
                AirportsFetcher.service,
                ReviewsFetcher.service,
                get_flight_by_number_fetcher().service,
                GeoPointFetcher.service,
            ]
        )

        for key in stats:
            result = RE_STAT_HIT_KEY.match(key)
            if not result:
                continue

            time_key = '{}_time_ammm'.format(*result.groups())
            time = stats.get(time_key)
            if not time:
                continue

            stats[time_key] = round(stats[time_key] / stats[key] / 1000.0, 3) if stats[key] > 0 else 0

        self.write(tornado.escape.json_encode([[key, value] for key, value in stats.items()]))


class ApiHandler(RequestHandler):
    method = ''
    service = None

    def __init__(self, application, request, **kwargs):
        super(ApiHandler, self).__init__(application, request, **kwargs)
        self.service = self.get_query_argument('service', self.service)
        self.application = application

    def prepare(self):
        yt_request_log(self.__class__.__name__, self.request)
        super(ApiHandler, self).prepare()

    def on_finish(self):
        Stat.flush()
        super(ApiHandler, self).on_finish()

    def write_mapped(self, data):
        self.set_header('Content-Type', 'application/json')
        if self.service:
            mapper = get_mapper(self.service, self.method)
            if mapper:
                self.write(tornado.escape.json_encode(mapper(data)))
                return

        self.write(tornado.escape.json_encode(data))


class FlightByDepartureDateHandler(ApiHandler):
    method = 'flight_by_departure_date'

    @asynchronous
    def get(self, company_iata, number, departure_day):
        # type: (str, str, str) -> None
        company_iata = company_iata.upper()
        lang = self.get_query_argument('lang', 'ru')
        fields = self.get_query_argument('fields').split(',') if self.get_query_argument('fields', None) else []

        fetcher = Fetcher(finish_callback=self.finish_fetch)
        fetchers = []

        fetcher.waiting_fields.add('flight')
        fetchers.append(
            get_flight_by_departure_date_fetcher(
                field='flight',
                company_iata=company_iata,
                number=number,
                departure_day=departure_day,
                lang=lang,
                fields=fields,
                from_airport_code=self.get_query_argument('from-airport-code', None),
                from_airport_id=self.get_query_argument('from-airport-id', None),
                cache_root=self.application.cache_root,
            )
        )

        fetcher.fetch(fetchers)

    def finish_fetch(self, data):
        self.write_mapped(data)
        self.finish()


class MailFlightByDepartureDateHandler(ApiHandler):
    service = 'mail'
    method = 'mail_flight_by_departure_date'

    @asynchronous
    def get(self):
        company_iata, number = self._parse_flight_number(self.get_query_argument('number'))
        departure_date = self._parse_departure_date(self.get_query_argument('date'))
        departure_day_period = self.get_query_argument('departure_day_period', '2:2')
        lang = self.get_query_argument('lang', 'ru')
        params = dict(
            departure_day_period=departure_day_period,
            time_utc=departure_date,
            lang=lang,
            fields='flight_status',
        )

        fetcher = Fetcher(
            finish_callback=partial(
                self.finish_fetch,
                params=params,
            )
        )
        fetchers = []

        fetcher.waiting_fields.add('flight')
        fetchers.append(
            get_flight_by_number_fetcher(
                field=None,
                company_iata=company_iata,
                number=number,
                params=params,
                cache_root=self.application.cache_root,
            )
        )

        fetcher.fetch(fetchers)

    def finish_fetch(self, data, params):
        def finish_get(flight_data):
            self.write_mapped(flight_data)
            self.finish()

        data = {
            'flight': self._get_segment_by_departure_date(data, params.get('time_utc')),
        }
        flight_supplement(data, 'flight', params, finish_get, self.application.cache_root)

    @staticmethod
    def _get_segment_by_departure_date(data, departure_date):
        expected_day, expected_time = departure_date.split('T')

        for route in data:
            segments = route.get('segments', [])

            # Если сегментов нет, значит это обычный рейс
            if not segments:
                if route.get('departure_day') == expected_day and route.get('departure_time') == expected_time:
                    return route
                continue

            for segment in segments:
                if segment.get('departure_day') == expected_day and segment.get('departure_time') == expected_time:
                    return segment

        raise HTTPError(404, reason='Flight with that departure date does not exist')

    @staticmethod
    def _parse_departure_date(departure_date):
        departure_date = unquote(departure_date).rstrip('Z')

        try:
            return datetime.strptime(departure_date, '%Y-%m-%dT%H:%M:%S').strftime('%Y-%m-%dT%H:%M:%S')
        except ValueError:
            raise HTTPError(400, reason='Invalid departure date format')

    @staticmethod
    def _parse_flight_number(flight_number):
        def first_digit_position(s):
            # type: (str) -> int
            m = re.search(r'\d', s)
            return m.start() if m else -1

        if ' ' in flight_number:
            return flight_number.split()[:2]

        first_digit_pos = first_digit_position(flight_number)
        if first_digit_pos < 2:
            return flight_number[:2], flight_number[2:]
        return flight_number[:first_digit_pos], flight_number[first_digit_pos:]


class FlightByNumberHandler(ApiHandler):
    method = 'flight_by_number'

    @asynchronous
    def get(self, company_iata, number):
        # type: (str, str) -> None
        company_iata = company_iata.upper()

        params = {p: v[-1] for p, v in self.request.query_arguments.items()}

        fetcher = Fetcher(self.finish_fetch)
        fetcher.fetch(
            [
                get_flight_by_number_fetcher(
                    field=None,
                    company_iata=company_iata,
                    number=number,
                    params=params,
                    cache_root=self.application.cache_root,
                ),
            ]
        )

    def finish_fetch(self, data):
        self.write_mapped(data)
        self.finish()


class AirlinesSynonymsHandler(ApiHandler):
    method = 'airlines_synonyms_list'

    @asynchronous
    def get(self):
        fetcher = Fetcher(self.finish_fetch)
        fetcher.fetch([AirlinesSynonymsFetcher()])

    def finish_fetch(self, data):
        self.write_mapped(data)
        self.finish()


class AirportFlightListHandler(ApiHandler):
    method = 'airport_flight_list'

    @asynchronous
    def get(self, airport_iata):
        # type: (str) -> None
        airport_iata = airport_iata.upper()

        params = {p: decode(v[-1]) for p, v in self.request.query_arguments.items()}
        fields = params.get('fields', '').split(',')
        fetcher = Fetcher(self.map_flight_extras if 'flight_extras' in fields else self.finish_fetch)
        fetcher.fetch(
            [
                get_airport_flight_list_fetcher(
                    field=None,
                    airport_iata=airport_iata,
                    params=params,
                    request_timeout=200,
                    cache_root=self.application.cache_root,
                )
            ]
        )

    def map_flight_extras(self, data):
        for flight in data:
            flight['extras'] = flight_extras(flight, self.application.cache_root)
        self.finish_fetch(data)

    def finish_fetch(self, data):
        self.write_mapped(data)
        self.finish()


class TravelRecipesHandler(ApiHandler):
    method = 'travel_recipes'

    @asynchronous
    def get(self):
        params = {p: v[-1] for p, v in self.request.query_arguments.items()}
        fetcher = Fetcher(self.finish_fetch)
        fetcher.fetch([TravelRecipesFetcher(**params)])

    def finish_fetch(self, data):
        self.write_mapped(data)
        self.finish()


class AirportTabloSourceHandler(ApiHandler):
    method = 'airport_tablo_source'

    @asynchronous
    def get(self):
        params = {p: v[-1] for p, v in self.request.query_arguments.items()}

        fetcher = Fetcher(self.finish_fetch)
        fetcher.fetch([AirportTabloSourceFetcher(trusted=params.get('trusted', None))])

    def finish_fetch(self, data):
        self.write_mapped(data)
        self.finish()


class NearDirectionsHandler(ApiHandler):
    method = 'near_directions'

    def get_query_params(self):
        return {p: v[-1] for p, v in self.request.query_arguments.items()}

    def _not_found(self):
        self.set_status(404)
        self.finish()

    @asynchronous
    def get(self):
        params = self.get_query_params()
        fetcher = Fetcher(finish_callback=self.on_near_distances_response)
        fetcher.fetch(
            [
                NearDistancesFetcher(
                    from_id=params['from_id'],
                    to_id=params['to_id'],
                ),
            ]
        )

    def on_near_distances_response(self, data):
        if not data:
            self._not_found()
            return
        params = self.get_query_params()
        distance = self.get_query_argument('distance', default=data['defaultDistance'])
        fetcher = Fetcher(finish_callback=partial(self.on_near_directions_response, distances=data))

        fetcher.fetch(
            [
                NearDirectionsFetcher(
                    from_id=params['from_id'],
                    to_id=params['to_id'],
                    distance=distance,
                    forward_date=params['forward_date'],
                    backward_date=params.get('backward_date'),
                    adult_seats=params.get('adult_seats', 1),
                    children_seats=params.get('children_seats', 0),
                    infant_seats=params.get('infant_seats', 0),
                    klass=params.get('klass', 'economy'),
                    lang=params['lang'],
                )
            ]
        )

    def on_near_directions_response(self, directions, distances=None):
        if not directions:
            self._not_found()
            return
        params = self.get_query_params()
        fetcher = Fetcher(finish_callback=partial(self.finish_fetch, distances=distances, directions=directions))
        base_direction_request = {
            'forward_date': params['forward_date'],
            'backward_date': params.get('backward_date'),
            'from_id': params['from_id'],
            'adults_count': params.get('adult_seats', 1),
            'children_count': params.get('children_seats', 0),
            'infants_count': params.get('infant_seats', 0),
        }
        request = {
            'min_requests': [dict(base_direction_request, to_id=direction['toCity']['id']) for direction in directions]
        }
        fetcher.fetch(
            [
                MinPriceBatchSearch(
                    national_version=params['national_version'],
                    request=request,
                )
            ]
        )

    def finish_fetch(self, prices, distances=None, directions=None):
        self.write_mapped(
            {
                'directions': directions,
                'distances': distances,
                'prices': prices['data'] if prices.get('status') == 'ok' else [],
            }
        )
        self.finish()


class TopFlightsHandler(ApiHandler):
    method = 'top_flights'

    @asynchronous
    def get(self):
        params = {p: v[-1] for p, v in self.request.query_arguments.items()}

        fetcher = Fetcher(self.finish_fetch)
        fetcher.fetch(
            [
                TopFlightsFetcher(
                    national_version=params.get('national_version'),
                    point_from=params.get('point_from'),
                    point_to=params.get('point_to'),
                    date=params.get('date'),
                    limit=params.get('limit'),
                )
            ]
        )

    def finish_fetch(self, data):
        self.write_mapped(data)
        self.finish()


class PartnersPopularByRouteHandler(ApiHandler):
    method = 'partners_popular_by_route'

    @asynchronous
    def get(self):
        params = {p: v[-1] for p, v in self.request.query_arguments.items()}

        fetcher = Fetcher(self.finish_fetch)
        fetcher.fetch([PartnersPopularByRouteFetcher(**params)])

    def finish_fetch(self, data):
        self.write_mapped(data)
        self.finish()


class CurrencyHandler(ApiHandler):
    method = 'currency'

    @asynchronous
    def get(self):
        params = {p: v[-1] for p, v in self.request.query_arguments.items()}

        fetcher = Fetcher(self.finish_fetch)
        fetcher.fetch(
            [
                CurrencyFetcher(
                    national_version=params.get('national_version'),
                )
            ]
        )

    def finish_fetch(self, data):
        self.write_mapped(data)
        self.finish()


class GeoPointHandler(ApiHandler):
    method = 'point'

    @asynchronous
    def get(self):
        params = {p: v[-1] for p, v in self.request.query_arguments.items()}

        fetcher = Fetcher(self.finish_fetch)
        fetcher.fetch(
            [
                GeoPointFetcher(
                    key=params.get('key'),
                    lang=params.get('lang'),
                ),
            ]
        )

    def finish_fetch(self, data):
        self.write_mapped(data)
        self.finish()


class RouteLandingHandler(ApiHandler):
    method = 'route_landing'

    @asynchronous
    def get(self, from_slug, to_slug, national_version, lang):
        fetcher = Fetcher(self.finish_fetch)
        route_landing_mapper = RouteLandingMapper(self.application.cache_root, self.application.route_landing_templater)
        fetcher.fetch(
            [
                RouteLandingFetcher(
                    route_landing_mapper=route_landing_mapper,
                    request_headers=self.request.headers,
                    cache_root=self.application.cache_root,
                    from_slug=from_slug,
                    to_slug=to_slug,
                    national_version=national_version,
                )
            ]
        )

    def finish_fetch(self, data):
        self.write_mapped(data)
        self.finish()


class CityToLandingHandler(ApiHandler):
    method = 'city_to_landing'

    @asynchronous
    def get(self, to_slug, national_version, lang):
        fetcher = Fetcher(self.finish_fetch)
        city_to_landing_mapper = CityToLandingMapper(
            self.application.cache_root, self.application.city_to_landing_templater
        )
        fetcher.fetch(
            [
                CityToLandingFetcher(
                    city_to_landing_mapper=city_to_landing_mapper,
                    cache_root=self.application.cache_root,
                    to_slug=to_slug,
                    national_version=national_version,
                )
            ]
        )

    def finish_fetch(self, data):
        self.write_mapped(data)
        self.finish()


class PersonalSearchHandler(ApiHandler):
    method = 'personal_search'

    @asynchronous
    def get(self):
        fetcher = Fetcher(self.finish_fetch)
        fetcher.fetch(
            [
                PersonalSearchFetcher(
                    personal_search_mapper=PersonalSearchMapper(self.application.cache_root),
                    cache_root=self.application.cache_root,
                    geo_id=self.get_argument('geoId', None),
                    yandex_uid=self.get_argument('yandexUid', None),
                )
            ]
        )

    def finish_fetch(self, data):
        self.write_mapped(data)
        self.finish()


class SettlementSlugByGeoIdHandler(ApiHandler):
    method = 'settlement_slugs_by_geo_ids'

    @asynchronous
    def get(self):
        slug_by_geo_id = {}
        geo_ids = self.get_arguments('geo_id')
        for raw_geo_id in geo_ids:
            try:
                geo_id = int(raw_geo_id)
            except:
                continue
            settlement = self.application.cache_root.settlement_cache.get_settlement_by_geo_id(geo_id)
            if not settlement:
                continue
            slug = self.application.cache_root.settlement_cache.get_slug_by_id(settlement.Id)
            if slug:
                slug_by_geo_id[geo_id] = slug
        self.finish_fetch(dict(data=slug_by_geo_id))

    def finish_fetch(self, data):
        self.write_mapped(data)
        self.finish()


class SettlementSlugByRouteHandler(ApiHandler):
    method = 'settlement_slugs_by_geo_ids'

    DEFAULT_NATIONAL_VERSION = 'ru'

    @asynchronous
    def get(self):
        from_settlement_code = self.get_argument('from', None)
        to_settlement_code = self.get_argument('to', None)
        national_version = self.get_argument('national_version', self.DEFAULT_NATIONAL_VERSION)

        if not (from_settlement_code and to_settlement_code):
            raise HTTPError(400, reason='Both arguments "from" and "to" must be specified')

        fetcher = Fetcher(self.finish_fetch)
        fetcher.fetch(
            [
                SlugsByRouteFetcher(
                    cache_root=self.application.cache_root,
                    from_settlement_code=from_settlement_code,
                    to_settlement_code=to_settlement_code,
                    national_version=national_version,
                    field='data',
                )
            ]
        )

    def finish_fetch(self, data):
        self.write_mapped(data)
        self.finish()


class FlightLandingHandler(ApiHandler):
    method = 'flight_landing'

    @asynchronous
    def get(self, company_code, flight_number, national_version, lang):
        departure_date = self.get_argument('departureDate', None)
        from_code = self.get_argument('fromCode', None)
        user_geo_id = self._int_or_none(self.get_query_argument('userGeoId', None))
        lang = self.get_query_argument('lang', None)
        fetcher = Fetcher(self.finish_fetch)
        flight_landing_mapper = FlightLandingMapper(
            self.application.cache_root,
            self.application.flight_landing_templater,
        )
        fetcher.fetch(
            [
                FlightLandingFetcher(
                    company_code=company_code,
                    flight_number=flight_number,
                    departure_date=departure_date,
                    from_code=from_code,
                    national_version=national_version,
                    user_geo_id=user_geo_id,
                    lang=lang,
                    cache_root=self.application.cache_root,
                    flight_landing_mapper=flight_landing_mapper,
                    iata_corrector=self.application.iata_corrector,
                )
            ]
        )

    def _int_or_none(self, raw_int):
        try:
            return int(raw_int)
        except:
            return None

    def finish_fetch(self, data):
        self.write_mapped(data)
        self.finish()


class IsPossibleTripHandler(ApiHandler):
    method = 'is_possible_trip'

    DEFAULT_WINDOW_SIZE = 0

    @asynchronous
    def get(self, from_settlement_code, to_settlement_code, departure_date, national_version):
        try:
            window_size = int(self.get_argument('window_size', self.DEFAULT_WINDOW_SIZE))
        except:
            window_size = self.DEFAULT_WINDOW_SIZE

        fetcher = Fetcher(self.finish_fetch)
        fetcher.fetch(
            [
                IsPossibleTripFetcher(
                    from_settlement_code=from_settlement_code,
                    to_settlement_code=to_settlement_code,
                    departure_date=departure_date,
                    national_version=national_version,
                    window_size=window_size,
                    cache_root=self.application.cache_root,
                )
            ]
        )

    def finish_fetch(self, data):
        self.write_mapped(data)
        self.finish()


class FlightExtrasByDepartureDateHandler(ApiHandler):
    method = 'flight_extras_by_departure_date'

    @asynchronous
    def get(self, company_iata, number, departure_day):
        # type: (str, str, str) -> None
        company_iata = company_iata.upper()

        fetcher = Fetcher(self.finish_fetch)
        fetcher.fetch(
            [
                FlightExtrasByDepartureDateFetcher(
                    company_iata=company_iata,
                    number=number,
                    departure_day=departure_day,
                    from_airport_code=self.get_query_argument('from-airport-code', None),
                    from_airport_id=self.get_query_argument('from-airport-id', None),
                    cache_root=self.application.cache_root,
                ),
            ]
        )

    def finish_fetch(self, data):
        self.write_mapped(data)
        self.finish()


class CountryLandingHandler(ApiHandler):
    method = 'country_to_landing'

    @asynchronous
    def get(self):
        try:
            country_landing_request = CountryLandingRequest(
                **{p: decode(v[-1]) for p, v in self.request.query_arguments.items()}
            )
        except ValidationError as e:
            reason = f'Invalid request: {e}'.replace('\n', ' ')
            raise HTTPError(400, reason=reason)
        fetcher = Fetcher(self.finish_fetch)
        fetcher.fetch(
            [
                CountryLandingFetcher(
                    country_request=country_landing_request,
                    cache_root=self.application.cache_root,
                )
            ]
        )

    def finish_fetch(self, data):
        self.write_mapped(data)
        self.finish()


class AnywhereLandingHandler(ApiHandler):
    method = 'anywhere_to_landing'

    @asynchronous
    def get(self):
        try:
            anywhere_landing_request = AnywhereLandingRequest(
                **{p: decode(v[-1]) for p, v in self.request.query_arguments.items()}
            )
        except ValidationError as e:
            reason = f'Invalid request: {e}'.replace('\n', ' ')
            raise HTTPError(400, reason=reason)
        fetcher = Fetcher(self.finish_fetch)
        fetcher.fetch(
            [
                AnywhereLandingFetcher(
                    anywhere_request=anywhere_landing_request,
                    cache_root=self.application.cache_root,
                ),
            ]
        )

    def finish_fetch(self, data):
        self.write_mapped(data)
        self.finish()
