# -*- coding: utf-8 -*-
from copy import copy
from logging import getLogger, Logger
from datetime import datetime
from typing import Optional, Union, Tuple, cast, List

from travel.avia.library.python.common.models.partner import Partner, DohopVendor
from travel.avia.library.python.common.models.transport import TransportType
from travel.avia.library.python.common.utils import environment

from travel.avia.ticket_daemon.ticket_daemon.api.flights import Flights, Variant
from travel.avia.ticket_daemon.ticket_daemon.api.models_utils.cabin_class import (
    make_cabin_class_id_by_code
)
from travel.avia.ticket_daemon.ticket_daemon.api.models_utils.interfaces import (
    POINT_TYPE_SETTLEMENT, POINT_TYPE_STATION,
    PointInterface, SettlementInterface, StationInterface
)
from travel.avia.ticket_daemon.ticket_daemon.api.models_utils.modern_currency import (
    make_currency_id_by_code
)
from travel.avia.ticket_daemon.ticket_daemon.api.models_utils.national_version import (
    national_version_by_code
)
from travel.avia.library.python.ticket_daemon.caches.services import get_service_by_code
from travel.avia.ticket_daemon.ticket_daemon.api.models_utils.partners import get_partner_by_code
from travel.avia.ticket_daemon.ticket_daemon.api.query import Query
from travel.avia.ticket_daemon.ticket_daemon.lib.yt_loggers.abstract_variant_logger import IVariantsLogger
from travel.avia.ticket_daemon.ticket_daemon.lib.yt_loggers.yt_logger import YtLogger
from travel.avia.library.python.ticket_daemon.date import CachingDateUtils
from travel.avia.library.python.avia_data.models.amadeus_merchant import AmadeusMerchant


LOGGER = getLogger(__name__)
YT_LOGGER_NAME = 'yt.avia_variants'
YT_LOGGER = YtLogger(YT_LOGGER_NAME, environment)

AMADEUS_CODE = 'amadeus'
DOHOP_CODE = 'dohop'


class VariantsLogger2(IVariantsLogger):
    """
    Логгирование ответа тикет демона по партнеру
    """

    def __init__(self, yt_logger, environment, logger):
        # type: (YtLogger, any, Logger) -> None
        self._yt_logger = yt_logger
        self._environment = environment
        self._logger = logger
        self._date_utils = CachingDateUtils()

    @classmethod
    def create(cls):
        return cls(
            yt_logger=YT_LOGGER,
            environment=environment,
            logger=LOGGER,
        )

    def log(self, query, partner, partner_variants):
        # type: (Query, Union[Partner,DohopVendor], List[Variant]) -> None
        try:
            self._log(query, partner, partner_variants)
        except Exception:
            self._logger.exception('Can not log variants')

    def _log(self, query, partner, partner_variants):
        # type: (Query, Union[Partner,DohopVendor], List[Variant]) -> None
        """
        :param query: query
        :param partner: partner для совместимости по итерфейсу
        :param partner_variants: partner_variants
        :returns Void
        """

        q = query

        from_settlement, from_airport = self._point_to_settlement_and_airport(
            q.point_from
        )
        to_settlement, to_airport = self._point_to_settlement_and_airport(
            q.point_to
        )
        national_version_id = national_version_by_code()[q.national_version]

        currency_id_by_code = make_currency_id_by_code()
        cabin_class_id_by_code = make_cabin_class_id_by_code()
        service_id = get_service_by_code(q.service).id

        common_data = {
            'query_id': q.meta.get('base_qid') or q.id,
            'init_id': q.id,
            'adults': q.adults,
            'children': q.children,
            'infants': q.infants,
            'class_id': cabin_class_id_by_code[q.klass],
            'forward_date': self._date_utils.naive_date_to_timestamp(q.date_forward),
            'backward_date': self._date_utils.naive_date_to_timestamp(
                q.date_backward
            ) if q.date_backward else None,
            'national_version_id': national_version_id,
            'service_id': service_id,
            'from_settlement_id': getattr(from_settlement, 'id', None),
            'from_airport_id': getattr(from_airport, 'id', None),
            'to_settlement_id': getattr(to_settlement, 'id', None),
            'to_airport_id': getattr(to_airport, 'id', None),
            'unixtime': self._environment.unixtime()
        }

        for v in partner_variants:
            variant_data = copy(common_data)

            forward_segments = self._serialize_segments(v.forward)
            backward_segments = self._serialize_segments(v.backward)

            partner_id, vendor_id = self._get_partner_and_vendor_id(v.partner)

            variant_data.update({
                'partner_id': partner_id,
                'vendor_id': vendor_id,
                'original_price': int(v.tariff.value * 100),
                'original_currency_id': currency_id_by_code[v.tariff.currency],
                'national_price': int(v.national_tariff.value * 100),
                'national_currency_id': currency_id_by_code[v.national_tariff.currency],
                'forward_segments': forward_segments,
                'backward_segments': backward_segments,
                'only_planes': (
                    self._is_only_planes(forward_segments)
                    and self._is_only_planes(backward_segments)
                ),
                'with_baggage': v.with_baggage,
                'selfconnect': v.selfconnect,
                'forward_count_transfers': self._get_transfers(
                    forward_segments
                ),
                'backward_count_transfers': self._get_transfers(
                    backward_segments
                ),
                'forward_duration': self._calc_duration(
                    forward_segments
                ),
                'backward_duration': self._calc_duration(
                    backward_segments
                ),
                'is_charter': v.charter,
                'tag': v.tag,
            })

            self._yt_logger.log(variant_data)

    @staticmethod
    def _is_only_planes(segments):
        # type: (List[dict]) -> bool
        if not segments:
            return True

        return bool(all(
            s['departure_station_transport_type_id'] == TransportType.PLANE_ID and
            s['arrival_station_transport_type_id'] == TransportType.PLANE_ID
            for s in segments
        ))

    @staticmethod
    def _get_transfers(segments):
        # type: (List[any]) -> Optional[int]
        if not segments:
            return None
        return len(segments) - 1

    @staticmethod
    def _calc_duration(serialized_segments):
        # type: (List[dict]) -> Optional[int]
        if not serialized_segments:
            return None
        return int(
            (serialized_segments[-1]['arrival_time'] - serialized_segments[0]['departure_time'])
        )

    def _serialize_segments(self, segments_data):
        # type: (Flights) -> Optional[List[dict]]
        if not segments_data:
            return None

        results = []
        for f in segments_data.segments:
            results.append({
                'route': f.number,
                'operating_route': f.operating.number if f.operating else None,
                'company_id': getattr(f.company, 'id', None),
                'airline_id': getattr(f.avia_company, 'rasp_company_id', None),
                'departure_station_id': f.station_from.id,
                'departure_station_transport_type_id': f.station_from.t_type_id,
                'arrival_station_id': f.station_to.id,
                'arrival_station_transport_type_id': f.station_to.t_type_id,
                'departure_time': self._date_utils.aware_to_timestamp(
                    f.departure
                ) if f.departure else None,
                'arrival_time': self._date_utils.aware_to_timestamp(
                    f.arrival
                ) if f.arrival else None,
                'departure_offset': self._date_utils.aware_utc_offset(
                    f.departure
                ) if f.departure else None,
                'arrival_offset': self._date_utils.aware_utc_offset(
                    f.arrival
                ) if f.arrival else None,
                'baggage': f.baggage.key() if f.baggage is not None else None,
                'fare_code': f.fare_code,
                'fare_family': f.fare_family,
                'selfconnect': f.selfconnect,
            })
        return results

    @staticmethod
    def _point_to_settlement_and_airport(point):
        # type: (PointInterface) -> Tuple[SettlementInterface, Optional[StationInterface]]
        point_type = point.get_point_type()

        if point_type == POINT_TYPE_SETTLEMENT:
            return cast(SettlementInterface, point), None
        if point_type == POINT_TYPE_STATION:
            return point.get_related_settlement(), cast(StationInterface, point)

        raise RuntimeError('Unexpected argument', point)

    @staticmethod
    def _get_partner_and_vendor_id(partner):
        """
        partner_id - id партнера в таблице order_partner
        vendor_id -  id продавца либо в таблице  order_dohopvendor,
                                 либо в таблице avia_amadeusmerchant
        """
        # type: (Union[Partner, DohopVendor, AmadeusMerchant]) -> Tuple[int, Optional[int]]
        if isinstance(partner, Partner):
            return partner.id, None
        if isinstance(partner, DohopVendor):
            return get_partner_by_code(DOHOP_CODE).id, partner.id
        if isinstance(partner, AmadeusMerchant):
            return get_partner_by_code(AMADEUS_CODE).id, partner.id

        raise RuntimeError('Unexpected argument', partner)

    @staticmethod
    def _get_utc_offset(d):
        # type: (datetime) -> Optional[int]
        if d is None:
            return None
        return int(d.utcoffset().total_seconds())
