# coding: utf8
from __future__ import absolute_import, division, print_function, unicode_literals

from rest_framework import serializers

from common.utils.date import MSK_TZ, timedelta2minutes
from travel.library.python.tracing.instrumentation import traced_function
from travel.rasp.wizards.train_wizard_api.direction.filters import Filters
from travel.rasp.wizards.train_wizard_api.direction.sorting import DEFAULT_SORTING, SORTINGS
from travel.rasp.wizards.train_wizard_api.lib.facility_provider import facility_provider
from travel.rasp.wizards.train_wizard_api.serialization.route import find_route_with_train_stations
from travel.rasp.wizards.wizard_lib.experiment_flags import ExperimentFlag
from travel.rasp.wizards.wizard_lib.serialization.date import dump_datetime, parse_date
from travel.rasp.wizards.wizard_lib.serialization.direction import DirectionQuery
from travel.rasp.wizards.wizard_lib.serialization.experiment_flags import parse_experiment_flags
from travel.rasp.wizards.wizard_lib.serialization.language import parse_language
from travel.rasp.wizards.wizard_lib.serialization.point import dump_point, parse_point
from travel.rasp.wizards.wizard_lib.serialization.place import dump_places


ABKHAZIA_ID = 21619
ARMENIA_ID = 168
AUSTRIA_ID = 113
BELARUS_ID = 149
CZECH_ID = 125
ESTONIA_ID = 179
FRANCE_ID = 124
GERMANY_ID = 96
ITALY_ID = 205
KAZAKHSTAN_ID = 159
KYRGYZSTAN_ID = 207
LITHUANIA_ID = 117
LATVIA_ID = 206
MOLDOVA_ID = 208
MONACO_ID = 10070
POLAND_ID = 120
RUSSIA_ID = 225
TAJIKISTAN_ID = 209
TURKMENISTAN_ID = 170
UKRAINE_ID = 187
UZBEKISTAN_ID = 171

COUNTRIES_WHERE_SALE_TICKETS_IDS = {
    ABKHAZIA_ID,
    ARMENIA_ID,
    AUSTRIA_ID,
    BELARUS_ID,
    CZECH_ID,
    ESTONIA_ID,
    FRANCE_ID,
    GERMANY_ID,
    ITALY_ID,
    KAZAKHSTAN_ID,
    KYRGYZSTAN_ID,
    LITHUANIA_ID,
    LATVIA_ID,
    MOLDOVA_ID,
    MONACO_ID,
    POLAND_ID,
    RUSSIA_ID,
    TAJIKISTAN_ID,
    TURKMENISTAN_ID,
    UKRAINE_ID,
    UZBEKISTAN_ID,
}


def _parse_sorting(value):
    if value:
        try:
            return SORTINGS[value]
        except KeyError:
            raise serializers.ValidationError('invalid order_by value: it should be one of {}'.format(sorted(SORTINGS)))


def _dump_train_brand(train_brand):
    return {
        'id': train_brand.id,
        'title': train_brand.L_title(),
        'short_title': train_brand.L_title_short(),
        'is_deluxe': train_brand.deluxe,
        'is_high_speed': train_brand.high_speed
    } if train_brand else None


def dump_segment(segment, lang):
    return {
        'train': {
            'number': segment.train_number,
            'display_number': segment.display_number,
            'has_dynamic_pricing': segment.has_dynamic_pricing,
            'two_storey': segment.two_storey,
            'is_suburban': segment.is_suburban,
            'coach_owners': segment.coach_owners,
            'title': getattr(segment.train_title, lang),
            'brand': _dump_train_brand(segment.train_brand),
            'thread_type': segment.thread_type,
            'first_country_code': segment.first_country_code,
            'last_country_code': segment.last_country_code,
            'provider': segment.provider,
            'raw_train_name': segment.raw_train_name,
            't_subtype_id': segment.t_subtype_id,
        },
        'departure': {
            'station': dump_point(segment.departure_station),
            'settlement': dump_point(segment.departure_station.settlement),
            'local_datetime': dump_datetime(segment.departure_local_dt),
        },
        'arrival': {
            'station': dump_point(segment.arrival_station),
            'settlement': dump_point(segment.arrival_station.settlement),
            'local_datetime': dump_datetime(segment.arrival_local_dt),
        },
        'duration': timedelta2minutes(segment.arrival_local_dt - segment.departure_local_dt),
        'places': dump_places(segment),
        'broken_classes': segment.broken_classes,
        'facilities': [
            facility_provider.get_code_by(pk)
            for pk in segment.facilities_ids
        ] if segment.facilities_ids is not None else None
    }


@traced_function
def load_query(query_params):
    experiment_flags = parse_experiment_flags(query_params.get('exp_flags'))
    departure_point, arrival_point, original_departure_point, original_arrival_point = \
        find_route_with_train_stations(query_params, experiment_flags)

    if departure_point is None or arrival_point is None:
        is_pointless_queries_enabled = ExperimentFlag.TRAIN_ANSWER_POINTLESS_QUERY in experiment_flags

        departure_point = parse_point(
            query_params=query_params,
            point_key_name='departure_point_key',
            settlement_geoid_name='departure_settlement_geoid',
            allow_blank=is_pointless_queries_enabled
        )

        arrival_point = parse_point(
            query_params=query_params,
            point_key_name='arrival_point_key',
            settlement_geoid_name='arrival_settlement_geoid',
            allow_blank=is_pointless_queries_enabled
        )

    if departure_point is not None and arrival_point is not None:
        if departure_point == arrival_point:
            raise serializers.ValidationError('arrival point should be different from the departure point')

        if departure_point.country_id not in COUNTRIES_WHERE_SALE_TICKETS_IDS or \
                arrival_point.country_id not in COUNTRIES_WHERE_SALE_TICKETS_IDS:
            raise serializers.ValidationError('sale is prohibited for these direction points')

    main_reqid = query_params.get('main_reqid') or None

    return DirectionQuery(
        departure_point=departure_point,
        arrival_point=arrival_point,
        original_departure_point=original_departure_point,
        original_arrival_point=original_arrival_point,
        departure_date=parse_date(
            value=query_params.get('departure_date'),
            local_tz=departure_point.pytz if departure_point else MSK_TZ,
            ignore_past=False,
        ),
        language=parse_language(query_params.get('language')),
        experiment_flags=experiment_flags,
        sorting=_parse_sorting(query_params.get('order_by')) or DEFAULT_SORTING,
        filters=Filters.load(query_params),
        tld=query_params.get('tld'),
        limit=None,
        main_reqid=main_reqid,
    )
