# coding: utf8
from __future__ import unicode_literals, absolute_import, division, print_function

from importlib import import_module
from inspect import isclass

import six
from apispec import APISpec
from apispec.ext.marshmallow import MarshmallowPlugin, openapi
from marshmallow import Schema

from common.dev_tools.swagger.endpoint import EndpointEnumerator, path_from_endpoint

field2property_original = openapi.OpenAPIConverter.field2property


# TODO: удалить после заезда в аркадию, там это уже есть
@staticmethod
def _observed_name_hack(field, name):
    dump_to = getattr(field, 'dump_to', None)
    load_from = getattr(field, 'load_from', None)
    return dump_to or load_from or name


def field2property_hack(self, field, use_refs=True, dump=True, name=None):
    res = field2property_original(self, field, use_refs=use_refs, dump=dump, name=name) or {}
    if 'default' in res and res['default'] is None:
        res.pop('default')
    return res


openapi.OpenAPIConverter._observed_name = _observed_name_hack
openapi.OpenAPIConverter.field2property = field2property_hack


def build_spec(spec_info, urlpatterns, schema_modules):
    """
    :param spec_info: dict, описывающий спецификацию (см. конструктор apispec.APISpec)
    :param urlpatterns: urlpatterns
    :param schema_modules: список модулей (или путей до модулей), в которых лежат схемы
    :return: apispec.APISpec
    """
    if 'plugins' in spec_info and isinstance(spec_info['plugins'], list):
        if not any(isinstance(plugin, MarshmallowPlugin) for plugin in spec_info['plugins']):
            spec_info['plugins'].append(MarshmallowPlugin())
    else:
        spec_info['plugins'] = [MarshmallowPlugin()]

    spec = APISpec(**spec_info)

    for schema in get_marshmellow_schemas(schema_modules):
        spec.definition(schema.__name__, schema=schema)

    enumerator = EndpointEnumerator(patterns=urlpatterns)
    for endpoint in enumerator.get_api_endpoints():
        path = path_from_endpoint(endpoint)
        spec.add_path(path)
    return spec


def get_marshmellow_schemas(schema_modules):
    for schema_module in schema_modules:
        if isinstance(schema_module, six.string_types):
            schema_module = import_module(schema_module)

        for obj in schema_module.__dict__.values():
            if isclass(obj) and issubclass(obj, Schema) and obj.__name__ != 'Schema':
                yield obj
