# -*- coding: utf-8 -*-

from __future__ import absolute_import

from datetime import (
    date,
    datetime,
)
import enum
import types

from passport.backend.utils.common import filter_none
from passport.backend.vault.api.utils.ulid import ULID
import simplejson as json


class JsonSerializator(object):
    def __init__(self, max_serialization_depth, custom_field_processors=None, custom_type_processors=None):
        self._max_serialization_depth = max_serialization_depth
        self._custom_field_processors = custom_field_processors or {}
        self._custom_type_processors = custom_type_processors or {}

    def _process_list(self, field, depth, include=None, exclude=None, *args, **kwargs):
        depth += 1
        if exclude is None:
            exclude = []
        if depth > self._max_serialization_depth:
            return
        r = [self.process_field(x, depth, include=include, exclude=exclude, *args, **kwargs) for x in field]
        return filter_none(r)

    def _process_dict(self, field, depth, include=None, exclude=None, *args, **kwargs):
        depth += 1
        result = {}
        if exclude is None:
            exclude = []
        if depth > self._max_serialization_depth:
            return result
        for key in field:
            if key in exclude:
                continue
            if key in self._custom_field_processors:
                value = self._custom_field_processors[key](field[key], depth, include=include, exclude=exclude, *args, **kwargs)
            else:
                value = self.process_field(field[key], depth, include=include, exclude=exclude, *args, **kwargs)
            if value is not None:
                key = self.process_field(key, depth, include=include, exclude=exclude)
                result[key] = value
        return result

    def process_field(self, field, depth, include=None, exclude=None, *args, **kwargs):
        for type_, processor in self._custom_type_processors.items():
            if isinstance(field, type_):
                return processor(field, depth, include=include, exclude=exclude, *args, **kwargs)
        if isinstance(field, (date, datetime)):
            return field.strftime('%Y-%m-%d %H:%M:%S')
        if isinstance(field, (enum.Enum, enum.IntEnum)):
            return field.name
        if isinstance(field, (list, tuple)) or isinstance(field, types.GeneratorType):
            return self._process_list(field, depth, include=include, exclude=exclude, *args, **kwargs)
        if isinstance(field, dict):
            return self._process_dict(field, depth, include=include, exclude=exclude, *args, **kwargs)
        if isinstance(field, ULID):
            return str(field)
        return field


class JsonSerializable(object):
    serializator = JsonSerializator
    max_serialization_depth = 1

    def serialize(self, max_depth=None, include=None, exclude=None, *args, **kwargs):
        serializator = self.serializator(max_serialization_depth=max_depth or self.max_serialization_depth)
        return serializator.process_field(self, depth=0, include=include, exclude=exclude, *args, **kwargs)


class ExploitableJSONEncoder(json.JSONEncoder):
    def default(self, o):
        if isinstance(o, JsonSerializable):
            return o.serialize()
        elif isinstance(o, (bytes, bytearray)):
            return o.decode('utf-8')
        elif isinstance(o, ULID):
            return str(o)
        else:
            return super(ExploitableJSONEncoder, self).default(o)
