# coding: utf-8

from marshmallow import fields, Schema
from marshmallow import post_load

from common.models.geo import Settlement
from common.serialization.common_schemas import MultiValueDictSchemaMixin
from common.utils.namedtuple import namedtuple_with_defaults
from travel.rasp.morda_backend.morda_backend.serialization.fields import TransportTypeField
from travel.rasp.morda_backend.morda_backend.serialization.schema_bases import PointsQuerySchema
from travel.rasp.morda_backend.morda_backend.tariffs.serialization import TariffsSchema

min_static_tariff_query_fields = [
    'point_from',
    'point_to',
    'transport_types',
    'client_settlement',
    'national_version',
]
static_tariff_query_fields = min_static_tariff_query_fields + ['dates', 'thread_uid']

MinStaticTariffQuery = namedtuple_with_defaults('MinStaticTariffQuery', min_static_tariff_query_fields,
                                                defaults={'national_version': 'ru'})
StaticTariffQuery = namedtuple_with_defaults('StaticTariffQuery', static_tariff_query_fields,
                                             defaults={'national_version': 'ru'})


class StaticTariffsQuerySchemaMixIn(object):
    static_tariff_query_class = None

    @post_load
    def make_query(self, data):
        if 'client_settlement' in data:
            data['client_settlement'] = Settlement.objects.get(id=data['client_settlement'])

        return self.static_tariff_query_class(**data)


class StaticTariffsQuerySchema(StaticTariffsQuerySchemaMixIn, PointsQuerySchema, MultiValueDictSchemaMixin):
    static_tariff_query_class = StaticTariffQuery
    dates = fields.List(fields.Date(), load_from='date')
    national_version = fields.String(required=True)
    thread_uid = fields.String(load_from='threadUid', default=None)
    transport_types = fields.List(TransportTypeField(), load_from='transportType')


class StaticTariffsResponseSchema(Schema):
    tariffs = fields.Nested(TariffsSchema, many=True)
