# coding: utf8
from __future__ import unicode_literals, absolute_import, division, print_function

from datetime import datetime

from django.http import QueryDict
from marshmallow import fields, pre_load, post_load, ValidationError
from pytz import timezone, UnknownTimeZoneError

from common.models.transport import TransportType
from common.utils.namedtuple import namedtuple_with_defaults

from travel.rasp.morda_backend.morda_backend.serialization.segment import INVALID_VALUE
from travel.rasp.morda_backend.morda_backend.serialization.schema_bases import PointsShowHiddenQuerySchema, ExperimentsFlagsMixin
from travel.rasp.morda_backend.morda_backend.tariffs.bus.serialization import BusSettlementKeysSchemaMixin


SearchContext = namedtuple_with_defaults('SearchContext', [
    'point_from', 'point_to', 'when', 'nearest', 'transport_type',
    'timezones', 'national_version', 'is_mobile', 'allow_change_context', 'group_trains', 'bus_settlement_keys', 'exps_flags'
], defaults={'is_mobile': False, 'allow_change_context': True})


def validate_when(when):
    if not when:
        return
    if when in {'today', 'tomorrow'}:
        return
    try:
        datetime.strptime(when, "%Y-%m-%d")
    except Exception:
        raise ValidationError("Invalid value '{}' for parameter 'when'".format(when))


def validate_transport_type(transport_type_code):
    if not transport_type_code:
        return
    if not TransportType.objects.filter(code=transport_type_code).exists():
        raise ValidationError("Invalid transport type '{}'".format(transport_type_code))


class ContextQuerySchema(PointsShowHiddenQuerySchema, BusSettlementKeysSchemaMixin, ExperimentsFlagsMixin):
    transport_type = fields.String(load_from='transportType', dump_to='transportType', validate=validate_transport_type)
    timezones = fields.List(fields.String(), default=[])
    when = fields.String(validate=validate_when)
    nearest = fields.Boolean(default=False)
    national_version = fields.String(load_from='nationalVersion', dump_to='nationalVersion', default='ru')
    is_mobile = fields.Boolean(default=False, load_from='isMobile')
    allow_change_context = fields.Boolean(default=True, load_from='allowChangeContext')
    # включить группировку сегментов поездов, https://st.yandex-team.ru/RASPFRONT-9549
    group_trains = fields.Boolean(default=False)

    @pre_load
    def prepare(self, data):
        result = data
        if isinstance(data, QueryDict):
            result = data.dict()
            timezones = data.getlist('timezones')
            result['timezones'] = timezones
        return result

    @post_load
    def make_context(self, data):
        if 'timezones' in data:
            try:
                data['timezones'] = [timezone(tzname) for tzname in data['timezones']]
            except UnknownTimeZoneError:
                raise ValidationError({'timezones': INVALID_VALUE})

        for key, field in self.declared_fields.items():
            if key not in data:
                data[key] = field.default if field.default != field.missing else None

        data['exps_flags'] = self.get_exps_flags()

        return SearchContext(**data)
