# coding=utf-8
from marshmallow import Schema, fields, pre_load, ValidationError, validate, EXCLUDE

from travel.avia.price_index.lib.national_version_provider import national_version_provider


class DirectionRequest(Schema):
    class Meta:
        unknown = EXCLUDE

    from_id = fields.Integer(required=True)
    to_id = fields.Integer(required=True)


class TopDirectionsByDateWindowQuery(Schema):
    class Meta:
        unknown = EXCLUDE

    national_version_id = fields.Integer(required=True, data_key='national_version')
    directions = fields.Nested(DirectionRequest, many=True, required=True, validate=validate.Length(min=1))
    forward_date = fields.Date(required=True)
    backward_date = fields.Date(missing=None)
    window_size = fields.Integer(required=True)
    results_per_direction = fields.Integer(required=True)

    @pre_load
    def national_version_to_pk(self, item, *args, **kwargs):
        national_version = item['national_version']
        m = national_version_provider.get_by_code(national_version)

        if m is None:
            raise ValidationError('National version [{}] is not allowed'.format(national_version), ['national_version'])
        item['national_version'] = m.pk
        return item
