# -*- coding: utf-8 -*-

from copy import copy
from typing import Any, Dict, List
from datetime import datetime

import pytz
from ciso8601 import parse_datetime_as_naive

from travel.avia.library.python.common.utils.iterrecipes import pairwise
from travel.avia.ticket_daemon.ticket_daemon.api.models_utils.partners import get_partner_by_code


class FareExtender(object):
    def __init__(self):
        self._timepoint_conversion_cache = {}

    def _difference_in_minutes(self, first_time, second_time):
        # type: (datetime, datetime) -> int

        return int((first_time - second_time).total_seconds() / 60)

    def _time_point_to_datetime(self, time_point_info):
        # type: (Dict[str, Any]) -> datetime

        key = (time_point_info['tzname'], time_point_info['local'])
        value = self._timepoint_conversion_cache.get(key)
        if value:
            return value

        value = (
            pytz.timezone(time_point_info['tzname'])
                .localize(parse_datetime_as_naive(time_point_info['local']))
        )
        self._timepoint_conversion_cache[key] = value
        return value

    def _direction_duration(self, route, flights):
        # type: (List[str], Any) -> int

        if not route:
            return 0

        start = self._time_point_to_datetime(flights[route[0]]['departure'])
        finish = self._time_point_to_datetime(flights[route[-1]]['arrival'])

        return self._difference_in_minutes(finish, start)

    def _duration(self, fare, flights):
        # type: (Any, Any) -> int

        return self._direction_duration(fare['route'][0], flights) + self._direction_duration(fare['route'][1], flights)

    def _direction_transfers_count(self, route):
        # type: (List[str]) -> int

        return max(0, len(route) - 1)

    def _transfers_count(self, fare):
        # type: (Any) -> int

        return self._direction_transfers_count(fare['route'][0]) + self._direction_transfers_count(fare['route'][1])

    def _from_aviacompany(self, fare):
        # type: (Any) -> bool

        partner = get_partner_by_code(fare['partner'])
        if not partner:
            return False

        return partner.is_aviacompany

    def _direction_airport_changes(self, route, flights):
        # type: (List[str], Any) -> int

        airport_changes = 0

        for first_flight, second_flight in pairwise(flights[f] for f in route):
            if first_flight['to'] != second_flight['from']:
                airport_changes += 1

        return airport_changes

    def _aiport_changes(self, fare, flights):
        # type: (Any, Any) -> int

        return self._direction_airport_changes(fare['route'][0], flights) + self._direction_airport_changes(fare['route'][1], flights)

    def _discomfort_level(self, fare, flights):
        # type: (Any, Any) -> int

        return 2 * self._aiport_changes(fare, flights) - self._from_aviacompany(fare)

    def extend(self, fare, flights):
        # type: (Any, Any) -> Any

        fare_with_additional_info = copy(fare)

        fare_with_additional_info['start_time'] = self._time_point_to_datetime(flights[fare['route'][0][0]]['departure'])
        fare_with_additional_info['end_time'] = (
            self._time_point_to_datetime(flights[fare['route'][1][-1]]['arrival'])
            if fare['route'][1]
            else self._time_point_to_datetime(flights[fare['route'][0][-1]]['arrival'])
        )
        fare_with_additional_info['duration'] = self._duration(fare, flights)
        fare_with_additional_info['transfers_count'] = self._transfers_count(fare)
        fare_with_additional_info['discomfort_level'] = self._discomfort_level(fare, flights)

        return fare_with_additional_info
