# coding: utf-8
from __future__ import unicode_literals, absolute_import, division, print_function

from django.core.exceptions import ObjectDoesNotExist
from marshmallow import ValidationError, fields, missing
from pytz import UTC

from common.models.geo import Point, Station
from common.models.transport import TransportType


class PointField(fields.Field):
    default_error_messages = {
        'not_found': 'Point {value} was not found',
        'bad_value': 'Wrong value "{value}" for PointField'
    }

    def _deserialize(self, value, attr, data):
        try:
            return Point.get_by_key(value)
        except ValueError:
            self.fail('bad_value', value=value)
        except ObjectDoesNotExist:
            self.fail('not_found', value=value)

    def _serialize(self, value, attr, obj):
        return value.point_key


class FlagField(fields.Boolean):
    def _serialize(self, value, attr, obj):
        value = super(FlagField, self)._serialize(value, attr, obj)
        if value is False:
            return missing
        return value


class AwareDateTime(fields.DateTime):
    def _serialize(self, dt, attr, obj):
        if dt is not None:
            dt = dt.astimezone(UTC)
        return super(AwareDateTime, self)._serialize(dt, attr, obj)


class TransportTypeField(fields.Field):
    def _deserialize(self, value, attr, data):
        try:
            return TransportType.objects.get(code=value)
        except TransportType.DoesNotExist:
            raise ValidationError(u'Wrong transport code "{}"'.format(value))


class StationField(fields.Integer):
    default_error_messages = {
        'station_not_found': 'Station does not exist',
    }

    def _deserialize(self, value, attr, data):
        station_id = super(StationField, self)._deserialize(value, attr, data)
        try:
            return Station.objects.get(id=station_id)
        except Station.DoesNotExist:
            self.fail('station_not_found')


class CallableStringField(fields.String):
    def _serialize(self, value, attr, obj):
        if callable(value):
            value = value()
        return super(CallableStringField, self)._serialize(value, attr, obj)


def validate_station_express_code(station):
    if not station.get_code('express'):
        raise ValidationError('No express code for the station')
