from collections import defaultdict, Counter

import pytz
from dataclasses import dataclass

from travel.library.python.aioapp.utils import localize_dt
from travel.rasp.pathfinder_proxy.client_tariffs.tariff_info import TariffInfo

PARTNERS_DONE_PART = 0.5


@dataclass
class TicketDaemonResult:
    tariffs: list
    querying: bool


@dataclass
class VariantInfo:
    variants: list
    segments: list


@dataclass
class TicketDaemonInterlineResult:
    tariffs: list
    querying: bool


def _flight_to_tariff(flight):
    return TariffInfo(
        flight['station_from'],
        flight['station_to'],
        flight.get('variants', {}),
        localize_dt(
            flight['departure']['local'],
            flight['departure']['tzname'],
            '%Y-%m-%d %H:%M:%S'
        ).astimezone(pytz.UTC).strftime('%Y-%m-%dT%H:%M'),
        localize_dt(
            flight['arrival']['local'],
            flight['arrival']['tzname'],
            '%Y-%m-%d %H:%M:%S'
        ).astimezone(pytz.UTC).strftime('%Y-%m-%dT%H:%M'),
        flight['number'],
    )


def calculate_interlines(data):
    itineraries = data['reference']['itineraries']
    flights_by_key = {flight['key']: flight for flight in data['reference']['flights']}

    variants_by_key = defaultdict(list)
    for variant in data['variants']:
        variants_by_key[variant['forward']].append(variant)

    return [VariantInfo(
        variants_by_key[variant_key],
        [_flight_to_tariff(flights_by_key[segment_key]) for segment_key in segment_keys]
    ) for variant_key, segment_keys in itineraries.items() if variant_key]


def calculate_tariffs(data):
    itineraries = data['reference']['itineraries']

    flights = data['reference']['flights']
    flights_by_key = {flight['key']: flight for flight in flights}
    for flight in flights:
        flight['variants'] = {}

    for variant in data['variants']:
        for key in itineraries[variant['forward']]:
            flights_by_key[key]['variants'][variant['partner']] = variant

    return [_flight_to_tariff(flight) for flight in flights]


def _check_status(result):
    partners_counter = Counter(result['data']['status'].values())
    return partners_counter['done'] / len(result['data']['status'].values()) < PARTNERS_DONE_PART


def _make(result, poll, result_class, calculate_func):
    querying = not poll
    if result.get('data') and result['data'].get('status'):
        querying = _check_status(result)

    return result_class(
        [] if querying else calculate_func(result['data']),
        querying
    )


def make_tariffs(result, poll):
    return _make(result, poll, TicketDaemonResult, calculate_tariffs)


def make_interlines(result, poll):
    return _make(result, poll, TicketDaemonInterlineResult, calculate_interlines)


def make_empty_interlines(querying):
    return TicketDaemonInterlineResult([], querying)


def make_empty_tariffs(querying):
    return TicketDaemonResult([], querying)
