from marshmallow import Schema, fields, post_load, utils, EXCLUDE

from travel.avia.price_index.models.filters import TransferFilter, Filters, AirportsFilter, TimeFilter


class TransferFilterSchema(Schema):
    class Meta:
        unknown = EXCLUDE

    count = fields.Integer(missing=None)
    min_duration = fields.Integer(missing=None, data_key='minDuration')
    max_duration = fields.Integer(missing=None, data_key='maxDuration')
    has_airport_change = fields.Boolean(missing=None, data_key='hasAirportChange')
    has_night = fields.Boolean(missing=None, data_key='hasNight')

    @post_load
    def make_transfer_filter(self, data, *args, **kwargs):
        return TransferFilter(**data)


class EnsureList(fields.List):
    def _deserialize(self, value, attr, data, **kwargs):
        if not utils.is_collection(value):
            value = [value]
        return super(EnsureList, self)._deserialize(value, attr, data, **kwargs)


class TimeFilterSchema(Schema):
    class Meta:
        unknown = EXCLUDE

    forward_arrival = EnsureList(fields.Integer(), missing=None, data_key='forwardArrival')
    forward_departure = EnsureList(fields.Integer(), missing=None, data_key='forwardDeparture')
    backward_arrival = EnsureList(fields.Integer(), missing=None, data_key='backwardArrival')
    backward_departure = EnsureList(fields.Integer(), missing=None, data_key='backwardDeparture')

    @post_load
    def make_time_filter(self, data, *args, **kwargs):
        return TimeFilter(**data)


class AirportsFilterSchema(Schema):
    class Meta:
        unknown = EXCLUDE

    forward_departure = fields.List(fields.Integer(), missing=[], data_key='forwardDeparture')
    forward_arrival = fields.List(fields.Integer(), missing=[], data_key='forwardArrival')
    forward_transfers = fields.List(fields.Integer(), missing=[], data_key='forwardTransfers')
    backward_departure = fields.List(fields.Integer(), missing=[], data_key='backwardDeparture')
    backward_arrival = fields.List(fields.Integer(), missing=[], data_key='backwardArrival')
    backward_transfers = fields.List(fields.Integer(), missing=[], data_key='backwardTransfers')

    @post_load
    def make_airport_filter(self, data, *args, **kwargs):
        return AirportsFilter(**data)


class FiltersSchema(Schema):
    class Meta:
        unknown = EXCLUDE

    with_baggage = fields.Bool(missing=None, data_key='withBaggage')
    airlines = fields.List(fields.Integer(), missing=[])
    transfer_filters = fields.Nested(
        TransferFilterSchema, missing=lambda: TransferFilterSchema().load({}), data_key='transfer'
    )
    time_filters = fields.Nested(TimeFilterSchema, missing=lambda: TimeFilterSchema().load({}), data_key='time')
    airports_filters = fields.Nested(
        AirportsFilterSchema, missing=lambda: AirportsFilterSchema().load({}), data_key='airport'
    )

    @post_load
    def make_filters(self, data, *args, **kwargs):
        return Filters(**data)
