# -*- coding: utf-8 -*-
from __future__ import unicode_literals, absolute_import, division, print_function

from django.utils.datastructures import MultiValueDict
from marshmallow import Schema, ValidationError, fields, pre_load, validates_schema

from common.models.geo import Country
from travel.rasp.train_api.serialization.fields import PointField


class MultiValueDictSchemaMixin(object):
    """
    Получает значения для полей типа List из MultiValueDict. К примеру, из request.GET.
    """
    @pre_load
    def prepare_multivaluedict(self, data):
        if not isinstance(data, MultiValueDict):
            return data

        result = data.dict()
        for field_name, field in self.fields.items():
            load_from = field.load_from or field_name
            if isinstance(field, fields.List) and load_from in data:
                result[load_from] = data.getlist(load_from)
        return result


class PointsQuerySchema(Schema):
    point_from = PointField(load_from='pointFrom', dump_to='pointFrom', required=True)
    point_to = PointField(load_from='pointTo', dump_to='pointTo', required=True)

    @validates_schema(skip_on_field_errors=False)
    def validate_points(self, data):
        schema_errors = {}

        for field in ('point_from', 'point_to'):
            if data.get(field) is None:
                schema_errors[field] = 'no_such_point'
            elif isinstance(data[field], Country):
                schema_errors[field] = 'country_point'

        if not schema_errors:
            if data['point_from'] == data['point_to']:
                schema_errors['same_points'] = 'same_points'

        if schema_errors:
            raise ValidationError(schema_errors)
