# coding: utf-8

import json

from marshmallow import fields, post_load, Schema

from common.models.geo import Settlement
from common.serialization.common_schemas import MultiValueDictSchemaMixin
from travel.rasp.library.python.common23.date import environment
from common.utils.namedtuple import namedtuple_with_defaults
from travel.rasp.morda_backend.morda_backend.serialization.fields import TransportTypeField
from travel.rasp.morda_backend.morda_backend.serialization.schema_bases import PointsQuerySchema
from travel.rasp.morda_backend.morda_backend.tariffs.serialization import TariffsSegmentSchema, TariffsSchema


class DaemonQuerySchema(PointsQuerySchema, MultiValueDictSchemaMixin):
    dates = fields.List(fields.Date(), load_from='date')
    client_settlement = fields.Integer(load_from='clientSettlementId')
    national_version = fields.String(required=True)
    transport_types = fields.List(TransportTypeField(), load_from='transportType')
    yandexuid = fields.Integer()

    @post_load
    def make_query(self, data):
        if 'client_settlement' in data:
            data['client_settlement'] = Settlement.objects.get(id=data['client_settlement'])

        if 'dates' not in data:
            data['dates'] = [environment.today()]

        return DaemonQuery(**data)


class ResponseSchema(Schema):
    querying = fields.Boolean()
    segments = fields.Nested(TariffsSegmentSchema, many=True)
    errors = fields.Dict()


daemon_query_fields = [
    'point_from',
    'point_to',
    'transport_types',
    'dates',
    'client_settlement',
    'national_version',
    'yandexuid'
]

DaemonQuery = namedtuple_with_defaults('StaticTariffQuery', daemon_query_fields,
                                       defaults={'national_version': 'ru'})


MinDynamicTariffsQuery = namedtuple_with_defaults(
    'MinDynamicTariffsQuery',
    ['point_from', 'point_to', 'transport_types', 'national_version'],
    defaults={'national_version': 'ru'}
)


class MinDynamicTariffsQuerySchema(PointsQuerySchema, MultiValueDictSchemaMixin):
    national_version = fields.String(required=True)
    transport_types = fields.List(TransportTypeField(), load_from='transportType')

    @post_load
    def make_query(self, data):
        return MinDynamicTariffsQuery(**data)


class MinDynamicTariffsResponseSchema(Schema):
    tariffs = fields.Nested(TariffsSchema, many=True)


class TariffsInitResponseSchema(Schema):
    qids = fields.List(fields.String(), required=True)
    errors = fields.Dict()


class TariffsPollQuerySchema(Schema):
    qid = fields.String(required=True)
    skip_partners = fields.Function(deserialize=lambda obj: json.loads(obj) if obj else None, missing=list)
    yandexuid = fields.Integer()
