# coding=utf-8
from marshmallow import Schema, fields, pre_load, ValidationError, validate, post_load, EXCLUDE

from travel.avia.price_index.lib.national_version_provider import national_version_provider
from travel.avia.price_index.models.batch_prices_form import BatchPricesForm
from travel.avia.price_index.models.query import Query


class MinRequest(Schema):
    class Meta:
        unknown = EXCLUDE

    forward_date = fields.Date(required=True)
    backward_date = fields.Date(missing=None)
    from_id = fields.Integer(required=True)
    to_id = fields.Integer(required=True)
    adults_count = fields.Integer(missing=1)
    children_count = fields.Integer(missing=0)
    infants_count = fields.Integer(missing=0)


class BatchMinRequestsQuery(Schema):
    class Meta:
        unknown = EXCLUDE

    national_version_id = fields.Integer(required=True, data_key='national_version')
    query_source = fields.Integer(
        required=False,
        load_from='query_source',
    )
    min_requests = fields.Nested(MinRequest, many=True, required=True, validate=validate.Length(min=1))

    @pre_load
    def national_version_to_pk(self, item, *args, **kwargs):
        national_version = item['national_version']
        m = national_version_provider.get_by_code(national_version)

        if m is None:
            raise ValidationError('National version [{}] is not allowed'.format(national_version), ['national_version'])
        item['national_version'] = m.pk
        return item

    @post_load
    def parse_form(self, data, *args, **kwargs):
        national_version_id = data['national_version_id']
        return BatchPricesForm(
            national_version_id=national_version_id,
            query_source=data.get('query_source'),
            queries=[
                Query(
                    national_version_id=national_version_id,
                    from_id=r['from_id'],
                    to_id=r['to_id'],
                    forward_date=r['forward_date'],
                    backward_date=r['backward_date'],
                    adults_count=r.get('adults_count', 1),
                    children_count=r.get('children_count', 0),
                    infants_count=r.get('infants_count', 0),
                )
                for r in data['min_requests']
            ],
        )
