from datetime import datetime

import json
import pytz
from marshmallow import fields, post_load, Schema, validates_schema, ValidationError

from travel.rasp.pathfinder_proxy.const import MSK_TZ, TTransport


class PointsQuerySchema(Schema):
    point_from = fields.String(load_from='pointFrom', dump_to='pointFrom', required=True)
    point_to = fields.String(load_from='pointTo', dump_to='pointTo', required=True)

    @validates_schema
    def validate_points(self, data):
        if 'point_from' not in data:
            raise ValidationError('no point_from')

        if 'point_to' not in data:
            raise ValidationError('no point_to')

        if data['point_from'] == data['point_to']:
            raise ValidationError('same points')


class SERPTransfersQuery(PointsQuerySchema):
    tld = fields.String(required=True, missing='ru')
    departure_dt = fields.Date(load_from="departureDt")

    @post_load
    def ensure_departure_dt(self, data):
        departure_dt = data.get('departure_dt')
        if departure_dt is None:
            data['departure_dt'] = pytz.UTC.localize(datetime.utcnow()).astimezone(MSK_TZ)
        return data


class TransportTypesList(fields.Field):
    def _deserialize(self, value, attr, data):
        transport_types = []
        for value in data.getall(attr, []):
            transport_type = TTransport.get_by_name(value)
            if transport_type is None:
                raise ValidationError('Wrong transport code "{}"'.format(value))
            transport_types.append(transport_type)
        return sorted(transport_types)


class TransferVariantsWithPricesQuerySchema(PointsQuerySchema):
    tld = fields.String(required=True, missing='ru')
    language = fields.String(required=True, missing='ru')
    when = fields.Date()
    transport_types = TransportTypesList(load_from='transportType', missing=[])
    is_bot = fields.Boolean(load_from='isBot', missing=False)
    include_price_fee = fields.Boolean(load_from='includePriceFee', missing=False)


def bandit_type_from_headers(headers):
    bandit_type = None
    try:
        uaas_experiments = headers.get('X-Ya-Uaas-Experiments')
        if uaas_experiments:
            uaas_map = json.loads(uaas_experiments)
            bandit_type = uaas_map.get('TRAINS_bandit_type')
    except Exception:
        pass

    return bandit_type


def icookie_from_headers(headers):
    return headers.get('X-Ya-ICookie')


def req_id_from_headers(headers):
    return headers.get('X-Request-Id')


def device_from_headers(headers):
    return headers.get('X-Ya-User-Device')


def yandex_uid_from_headers(headers):
    return headers.get('X-Ya-YandexUid')
