from rest_framework import serializers

from ..domain import domain_objects
from ..utils import subfields


class AchieveryBaseSerializer(serializers.Serializer):
    domain_object_class = None
    field_name = None
    parent = None

    @property
    def user(self):
        user = self.context.get('user')
        if not user:
            try:
                user = self.context['request'].user.get_profile()
            except (AttributeError, KeyError):
                pass
        return user

    @property
    def role_registry(self):
        return self.context.get('role_registry')

    def get_allowed_fields(self):
        if self.field_name and hasattr(self.parent, 'get_allowed_fields'):
            return subfields(
                self.parent.get_allowed_fields(), self.field_name
            )

        try:
            return self.context['fields']
        except KeyError:
            pass

        try:
            return self.context['request'].PARAMS.fields
        except (AttributeError, KeyError):
            pass

    def get_domain_object_class(self):
        return domain_objects[self.domain_object_class]

    def restore_object(self, attrs, instance=None):
        if not instance and 'id' in attrs:
            return (
                self.get_domain_object_class()
                .objects(self.user, self.role_registry)
                .get(id=attrs['id'])
            )

        return super(AchieveryBaseSerializer, self).restore_object(
            attrs, instance)


class DynamicFieldsSerializer(AchieveryBaseSerializer):
    DEFAULT_FIELDS = None

    def __init__(self, *args, **kwargs):
        self.allowed_fields = kwargs.pop('fields', self.DEFAULT_FIELDS)
        super(DynamicFieldsSerializer, self).__init__(*args, **kwargs)

    def get_fields(self):
        fields = super(DynamicFieldsSerializer, self).get_fields()
        allowed_fields = self.get_allowed_fields()

        if allowed_fields:
            allowed = set(f.replace('__', '.').split('.')[0]
                          for f in allowed_fields)
            existing = set(fields.keys())
            for field_name in existing - allowed:
                fields.pop(field_name)
        return fields

    def initialize(self, parent, field_name):
        super(DynamicFieldsSerializer, self).initialize(parent, field_name)
        self.field_name = field_name
        self.fields = self.get_fields()
