# -*- encoding: utf-8 -*-
import datetime
from functools import partial
from typing import List, Optional

import pytz
import tornado.escape
from tornado.httpclient import HTTPResponse
from tornado.web import HTTPError

from travel.avia.library.python.shared_dicts.cache.station_code_cache import StationCodeCache

from travel.avia.api_gateway.application.fetcher import Fetcher
from travel.avia.api_gateway.application.fetcher.flight_supplement import flight_supplement, flight_extras
from travel.avia.api_gateway.settings import SHARED_FLIGHTS_API_URL

MOSCOW_TZ = pytz.timezone('Europe/Moscow')


class FlightByDepartureDateFetcher(Fetcher):
    service = 'shared_flights_flight_by_departure_date'

    def fetch(self, fetchers=None):
        company_iata = self.params.get('company_iata')
        number = self.params.get('number')
        departure_day = self.params.get('departure_day')

        _params = {
            'from': self.params.get('from_airport_code', self.params.get('from_airport_id')),
        }

        self.request(
            '{}/flight/{}/{}/{}/'.format(
                SHARED_FLIGHTS_API_URL,
                tornado.escape.url_escape(company_iata),
                tornado.escape.url_escape(number),
                departure_day,
            ),
            params={k: v for k, v in _params.items() if v},
            callback=self.on_response,
        )

    def on_response(self, response):
        # type: (HTTPResponse) -> None
        flight = tornado.escape.json_decode(response.body)
        data = {
            'flight': map_flight(flight, self.cache_root.station_code_cache),
        }
        flight_supplement(data, self.field, self.params, self.finish_callback, self.cache_root)


class FlightByNumberFetcher(Fetcher):
    service = 'shared_flights_flight_by_number'

    def fetch(self, fetchers=None):
        company_iata = self.params.get('company_iata')
        number = self.params.get('number')
        need_map = self.params.get('need_map', True)
        params = self.params.get('params', {})
        self.request(
            '{}/flight-range/{}/{}/'.format(
                SHARED_FLIGHTS_API_URL,
                tornado.escape.url_escape(company_iata),
                tornado.escape.url_escape(number),
            ),
            callback=partial(self.on_response, need_map),
            params=params,
        )

    def on_response(self, need_map, response):
        # type: (bool, HTTPResponse) -> None
        data = tornado.escape.json_decode(response.body)
        if need_map:
            result = [map_flight(flight, self.cache_root.station_code_cache) for flight in data]
        else:
            result = data
        self.finish_callback(result, field=self.field)


class FlightFetcher(Fetcher):
    service = 'shared_flights_flight'

    def __init__(
        self, finish_callback=None, field=None, connect_timeout=None, request_timeout=None, method='GET', **kwargs
    ):
        super(FlightFetcher, self).__init__(finish_callback, field, connect_timeout, request_timeout, **kwargs)
        self._method = method
        self.request_timeout = 20

    def fetch(self, fetchers=None):
        params = self.params.get('params', {})

        self.request(
            '{}/flights/'.format(SHARED_FLIGHTS_API_URL),
            callback=self.on_response,
            params=params,
            method=self._method,
        )

    def on_response(self, response):
        # type: (HTTPResponse) -> None
        data = tornado.escape.json_decode(response.body)

        self.finish_callback(
            [map_flight(flight, self.cache_root.station_code_cache) for flight in data], field=self.field
        )


class AirportFlightListFetcher(Fetcher):
    service = 'shared_flights_flight_list'

    def fetch(self, fetchers=None):
        # type: (Optional[list]) -> None
        airport_iata = self.params.get('airport_iata')
        params = self.params.get('params', {})
        flight_day = None

        if 'departure_day' in params:
            flight_day = params.pop('departure_day')
            params['direction'] = 'departure'
        elif 'arrival_day' in params:
            flight_day = params.pop('arrival_day')
            params['direction'] = 'arrival'

        if flight_day:
            shared_flights_datetime = '%Y-%m-%dT%H:%M:%S'
            try:
                after = datetime.datetime.strptime(flight_day, '%Y-%m-%d')
            except ValueError as e:
                raise HTTPError(400, reason=str(e))
            before = after + datetime.timedelta(1)
            params['after'] = after.strftime(shared_flights_datetime)
            params['before'] = before.strftime(shared_flights_datetime)

        self.request(
            '{}/flight-station/{}/'.format(
                SHARED_FLIGHTS_API_URL,
                tornado.escape.url_escape(airport_iata),
            ),
            callback=self.on_response,
            params=params,
        )

    def on_response(self, response):
        # type: (HTTPResponse) -> None
        data = tornado.escape.json_decode(response.body)
        mapped_flights = [map_flight(flight, self.cache_root.station_code_cache) for flight in data['flights']]
        self.finish_callback(mapped_flights, field=self.field)


class FlightP2PSegmentInfo(Fetcher):
    service = 'shared_flights_flight_p2p_segment_info'

    def fetch(self, fetchers=None):
        from_stations = self.params.get('from')
        to_stations = self.params.get('to')
        _params = []
        for station in from_stations:
            _params.append(('from', station))
        for station in to_stations:
            _params.append(('to', station))

        self.request(SHARED_FLIGHTS_API_URL + '/flight-p2p-segment-info', params=_params, callback=self.on_response)

    def on_response(self, response):
        # type: (HTTPResponse) -> None
        data = tornado.escape.json_decode(response.body)
        self.finish_callback(data, field=self.field)


class FlightExtrasByDepartureDateFetcher(Fetcher):
    service = 'shared_flights_flight_by_departure_date'

    def fetch(self, fetchers=None):
        company_iata = self.params.get('company_iata')
        number = self.params.get('number')
        departure_day = self.params.get('departure_day')

        self.request(
            '{}/flight-range/{}/{}?departure_day_period=2:2&time_utc={}T00:00:00'.format(
                SHARED_FLIGHTS_API_URL,
                tornado.escape.url_escape(company_iata),
                tornado.escape.url_escape(number),
                departure_day,
            ),
            callback=self.on_response,
        )

    def on_response(self, response):
        # type: (HTTPResponse) -> None
        departure_day = self.params.get('departure_day')
        flights = tornado.escape.json_decode(response.body)

        flight = self._find_flight_by_departure_day(departure_day, flights)

        if not flight:
            raise HTTPError(404, 'Flight not found')

        self.finish_callback(
            flight_extras(map_flight(flight, self.cache_root.station_code_cache), self.cache_root),
            field=self.field,
        )

    def _find_flight_by_departure_day(self, departure_day, flights):
        # type: (str, List[dict]) -> Optional[dict]
        for flight in flights:
            if flight['departureDay'] == departure_day:
                return flight

            if flight.get('segments'):
                segment = self._find_flight_by_departure_day(departure_day, flight['segments'])
                if segment:
                    return segment

        return None


def map_flight(flight, station_code_cache):
    # type: (dict, StationCodeCache) -> dict
    """
    map_flight serves three purposes:
    1) To map names from camelCase to snake_case. It is not true camel case though, so automatic mapping is problematic
    2) To hide possible new fields from api consumers until they are explicitly allowed via adding to this function
    3) To document data format
    :param flight:
    :param station_code_cache:
    :return:
    """
    mapped = {
        'departure_timezone': flight['departureTimezone'],
        'airport_from_code': station_code_cache.get_station_code_by_id(flight['airportFromID']),
        'arrival_time': flight['arrivalTime'],
        'arrival_timezone': flight['arrivalTimezone'],
        'number': flight['title'],
        'arrival_utc': flight['arrivalUtc'],
        'airport_to_id': flight['airportToID'],
        'departure_time': flight['departureTime'],
        'departure_day': flight['departureDay'],
        'airport_from_id': flight['airportFromID'],
        'airline_id': flight['airlineID'],
        'airport_to_code': station_code_cache.get_station_code_by_id(flight['airportToID']),
        'departure_utc': flight['departureUtc'],
        'airline_iata': flight['airlineCode'],
        'arrival_day': flight['arrivalDay'],
        'status': map_status(flight['status'], station_code_cache),
        'created_at': _convert_utc_to_tz(flight['createdAtUtc'], MOSCOW_TZ),
        'created_at_utc': flight['createdAtUtc'],
        'updated_at': _convert_utc_to_tz(flight['updatedAtUtc'], MOSCOW_TZ),
        'updated_at_utc': flight['updatedAtUtc'],
        'transport_model_id': flight['transportModelID'],
    }

    if 'segments' in flight:
        mapped['segments'] = [map_flight(f, station_code_cache) for f in flight['segments']]
    else:
        mapped['segments'] = []

    return mapped


def map_status(status, station_code_cache):
    # type: (dict, StationCodeCache) -> dict
    """
    map_data serves three purposes:
    1) To map names from camelCase to snake_case. It is not true camel case though, so automatic mapping is problematic
    2) To hide possible new fields from api consumers until they are explicitly allowed via adding to this function
    3) To document data format
    :param status:
    :param station_code_cache:
    :return:
    """
    return {
        'arrival': status['arrival'],
        'departure_gate': status['departureGate'],
        'diverted_airport_id': status['divertedAirportID'] or None,
        'diverted_airport_code': station_code_cache.get_station_code_by_id(status['divertedAirportID']),
        'baggage_carousels': status['baggageCarousels'],
        'arrival_terminal': status['arrivalTerminal'],
        'arrival_status': status['arrivalStatus'],
        'departure_updated_at_utc': status['departureUpdatedAtUtc'],
        'departure_source': status['departureSource'],
        'departure_terminal': status['departureTerminal'],
        'arrival_source': status['arrivalSource'],
        'status': status['status'],
        'arrival_updated_at': status['arrivalUpdatedAtUtc'],
        'arrival_gate': status['arrivalGate'],
        'diverted': status['diverted'],
        'departure_updated_at': status['departureUpdatedAtUtc'],
        'check_in_desks': status['checkInDesks'],
        'departure_status': status['departureStatus'],
        'arrival_updated_at_utc': status['arrivalUpdatedAtUtc'],
        'departure': status['departure'],
        'created_at': _convert_utc_to_tz(status['createdAtUtc'], MOSCOW_TZ),
        'created_at_utc': status['createdAtUtc'],
        'updated_at': _convert_utc_to_tz(status['updatedAtUtc'], MOSCOW_TZ),
        'updated_at_utc': status['updatedAtUtc'],
    }


def _convert_utc_to_tz(utc_time, tz):
    # type: (str, datetime.tzinfo) -> str
    f = '%Y-%m-%d %H:%M:%S'
    try:
        return datetime.datetime.strptime(utc_time, f).replace(tzinfo=pytz.utc).astimezone(tz).strftime(f)
    except ValueError:
        return utc_time
