# coding=utf-8
from marshmallow import Schema, fields, pre_load, ValidationError, post_load, EXCLUDE

from travel.avia.price_index.lib.national_version_provider import national_version_provider
from travel.avia.price_index.models.query import Query, IndexQuery


def _national_version_to_pk(item):
    national_version = item['national_version']
    m = national_version_provider.get_by_code(national_version)

    if m is None:
        raise ValidationError('National version [{}] is not allowed'.format(national_version), ['national_version'])
    item['national_version'] = m.pk
    return item


class IndexQuerySchema(Schema):
    class Meta:
        unknown = EXCLUDE

    national_version_id = fields.Integer(required=True, data_key='national_version')
    forward_date = fields.Date(required=True)
    backward_date = fields.Date(missing=None)
    from_id = fields.Integer(required=True)
    to_id = fields.Integer(required=True)
    adults_count = fields.Integer(required=False, missing=1, default=1)
    children_count = fields.Integer(required=False, missing=0, default=0)
    infants_count = fields.Integer(required=False, missing=0, default=0)
    raw_data = fields.Raw(required=True)

    @pre_load
    def national_version_to_pk(self, item, *args, **kwargs):
        return _national_version_to_pk(item)

    @post_load
    def make_query(self, data, *args, **kwargs):
        return IndexQuery(**data)


class QuerySchema(Schema):
    class Meta:
        unknown = EXCLUDE

    national_version_id = fields.Integer(required=True, data_key='national_version')
    forward_date = fields.Date(required=True)
    backward_date = fields.Date(missing=None)
    from_id = fields.Integer(required=True)
    to_id = fields.Integer(required=True)
    adults_count = fields.Integer(required=True)
    children_count = fields.Integer(required=True)
    infants_count = fields.Integer(required=True)
    # устаревшее поле
    is_business = fields.Bool(missing=False)

    @pre_load
    def national_version_to_pk(self, item, *args, **kwargs):
        return _national_version_to_pk(item)

    @post_load
    def make_query(self, data, *args, **kwargs):
        data.pop('is_business', None)
        return Query(**data)
