from marshmallow import fields, post_load, Schema

from travel.rasp.pathfinder_proxy.client_tariffs.train_api_result import TrainApiResult, TrainTariffInfo
from travel.rasp.pathfinder_proxy.const import CacheType
from travel.rasp.pathfinder_proxy.tariff_storages.base_tariff_storage import BaseTariffStorage, TariffInfoSchema


class TrainTariffInfoSchema(TariffInfoSchema):
    provider = fields.String(default=None, allow_none=True, missing=None)

    @post_load
    def load_object(self, data):
        return TrainTariffInfo(**data)


class TrainApiResultSchema(Schema):
    querying = fields.Boolean()
    tariffs = fields.Nested(TrainTariffInfoSchema, many=True)

    @post_load
    def load_train_api_result(self, data):
        return TrainApiResult(**data)


class TrainApiStorage(BaseTariffStorage):
    _RESULT_SCHEMA = TrainApiResultSchema(strict=True)
    _CACHE_TYPE = CacheType.TRAIN_API
