# coding: utf-8
from __future__ import unicode_literals, absolute_import, division, print_function

from marshmallow import ValidationError, fields, missing
from pytz import UTC

from common.models.geo import Country, Station
from common.models.transport import TransportType


class FlagField(fields.Boolean):
    def _serialize(self, value, attr, obj):
        value = super(FlagField, self)._serialize(value, attr, obj)
        if value is False:
            return missing
        return value


class AwareDateTime(fields.DateTime):
    def _serialize(self, dt, attr, obj):
        if dt is not None:
            dt = dt.astimezone(UTC)
        return super(AwareDateTime, self)._serialize(dt, attr, obj)


class TransportTypeField(fields.Field):
    def _deserialize(self, value, attr, data):
        try:
            return TransportType.objects.get(code=value)
        except TransportType.DoesNotExist:
            raise ValidationError(u'Wrong transport code "{}"'.format(value))


class StationField(fields.Integer):
    default_error_messages = {
        'station_not_found': 'Station does not exist',
    }

    def _deserialize(self, value, attr, data):
        station_id = super(StationField, self)._deserialize(value, attr, data)
        try:
            return Station.objects.get(id=station_id)
        except Station.DoesNotExist:
            self.fail('station_not_found')


def validate_station_express_code(station):
    if not station.get_code('express'):
        raise ValidationError('No express code for the station')


class CountryTypeField(fields.Field):
    def _deserialize(self, value, attr, data):
        try:
            return Country.objects.get(code=value)
        except Country.DoesNotExist:
            raise ValidationError(u'Wrong country code "{}"'.format(value))
