# -*- coding: utf-8 -*-

from marshmallow import Schema, ValidationError, validates_schema

from common.models.geo import Country
from common.serialization.fields import PointField, PointShowHiddenField


class PointsBaseQuerySchema(Schema):
    @validates_schema(skip_on_field_errors=False)
    def validate_points(self, data):
        schema_errors = {}

        for field in ('point_from', 'point_to'):
            if data.get(field) is None:
                schema_errors[field] = 'no_such_point'
            elif isinstance(data[field], Country):
                schema_errors[field] = 'country_point'

        if not schema_errors:
            if data['point_from'] == data['point_to']:
                schema_errors['same_points'] = 'same_points'

        if schema_errors:
            raise ValidationError(schema_errors)


class PointsQuerySchema(PointsBaseQuerySchema):
    # Скрытые точки считаются не существующими
    point_from = PointField(load_from='pointFrom', dump_to='pointFrom', required=True)
    point_to = PointField(load_from='pointTo', dump_to='pointTo', required=True)


class PointsShowHiddenQuerySchema(PointsBaseQuerySchema):
    # Скрытые точки считаются существующими
    point_from = PointShowHiddenField(load_from='pointFrom', dump_to='pointFrom', required=True)
    point_to = PointShowHiddenField(load_from='pointTo', dump_to='pointTo', required=True)


class ExperimentsFlagsMixin(object):
    def get_exps_flags(self):
        flags = getattr(self.context.get('request'), 'exps_flags', None)

        if not flags:
            return set()

        return flags
