# coding: utf-8
import re
import os.path

import six
import yaml
from google.protobuf import descriptor as pbdescriptor

from infra.awacs.proto import modules_pb2
from awacs.wrappers.base import REGISTRY
from awacs.lib import OrderedDict
from .core import Builder


BASE_WIKI_URL = 'https://wiki.yandex-team.ru/cplb/awacs/modules-reference/'
FD = pbdescriptor.FieldDescriptor

_types = {
    FD.TYPE_DOUBLE: 'number',
    FD.TYPE_FLOAT: 'number',
    FD.TYPE_INT64: 'integer',
    FD.TYPE_UINT64: 'integer',
    FD.TYPE_INT32: 'integer',
    FD.TYPE_FIXED64: 'number',
    FD.TYPE_FIXED32: 'number',
    FD.TYPE_BOOL: 'boolean',
    FD.TYPE_STRING: 'string',
    FD.TYPE_BYTES: 'string',
    FD.TYPE_UINT32: 'integer',
    FD.TYPE_SFIXED32: 'number',
    FD.TYPE_SFIXED64: 'number',
    FD.TYPE_SINT32: 'integer',
    FD.TYPE_SINT64: 'integer',
}

md_rel_link_re = re.compile('\[(.*)\]\((#[A-z-]+)\)')


def fix_md(s):
    return md_rel_link_re.sub(r'[\1]({}\2)'.format(BASE_WIKI_URL), s)


class JsonSchemaBuilder(object):
    NUMBER_FD_TYPES = frozenset({FD.TYPE_DOUBLE, FD.TYPE_FLOAT, FD.TYPE_FIXED64, FD.TYPE_FIXED32,
                                 FD.TYPE_SFIXED32, FD.TYPE_SFIXED64})
    INT_FD_TYPES = frozenset({FD.TYPE_INT64, FD.TYPE_UINT64, FD.TYPE_INT32, FD.TYPE_UINT32,
                              FD.TYPE_SINT32, FD.TYPE_SINT64})
    TOP_LEVEL_HOLDER_PROPERTIES = [
        'instance_macro', 'main', 'ipdispatch_section', 'regexp_section',
        'regexp_path_section', 'prefix_path_router_section'
    ]

    def __init__(self):
        self._being_processed = set()
        self._definitions = {}

    @staticmethod
    def _camel_case_to_snake_case(name):
        result = []
        for c in name:
            if c.isupper():
                result.append('_')
                result.append(c.lower())
            else:
                result.append(c)
        return ''.join(result).lstrip('_')

    @staticmethod
    def _is_map_entry(message_desc):
        return Builder.is_map_entry(message_desc) or Builder.is_ordered_map_entry(message_desc)

    @staticmethod
    def _is_squashed(field_desc):
        assert field_desc.message_type
        field_awacs_options = Builder.get_awacs_field_options(field_desc)
        if field_awacs_options and field_awacs_options.squashed:
            assert field_desc.label != FD.LABEL_REPEATED
            return True
        else:
            return False

    @classmethod
    def _get_squashed_fields(cls, message_desc):
        """
        :type message_desc: pbdescriptor.Descriptor
        :rtype: dict[str, str]
        """
        rv = {}
        for field_name, field_desc in six.iteritems(message_desc.fields_by_name):
            if field_desc.type == FD.TYPE_MESSAGE:
                if cls._is_squashed(field_desc):
                    rv[field_name] = field_desc.message_type.full_name
        return rv

    def _field_desc_to_schema(self, field_desc, preceding_messages,
                              is_repeated=False, is_dynamic=False, allowed_calls=(), is_embedded=False):
        """
        :type field_desc: pbdescriptor.FieldDescriptor
        :type preceding_messages: tuple[str]
        :param bool is_repeated: whether this field is repeated (even despite not having REPEATED label)
        :param bool is_dynamic: whether this field allows function calls
        :rtype: dict
        """
        assert isinstance(field_desc, pbdescriptor.FieldDescriptor)
        assert (is_dynamic and allowed_calls) or (not is_dynamic and not allowed_calls)

        is_repeated |= field_desc.label == FD.LABEL_REPEATED
        if field_desc.type == FD.TYPE_MESSAGE:
            field_message_desc = field_desc.message_type
            if is_repeated:
                if self._is_map_entry(field_message_desc):
                    value_field_desc = field_message_desc.fields_by_name['value']
                    s = {
                        'type': 'object',
                        'additionalProperties': self._field_desc_to_schema(value_field_desc, preceding_messages),
                    }
                else:
                    s = {
                        'type': 'array',
                        'items': self.message_desc_to_schema(field_message_desc, preceding_messages),
                    }
            else:
                s = self.message_desc_to_schema(field_message_desc, preceding_messages)
        else:
            s = {}
            if not is_embedded:
                s['title'] = field_desc.name
            if field_desc.type in self.NUMBER_FD_TYPES:
                s['type'] = 'number'
            elif field_desc.type in self.INT_FD_TYPES:
                s['type'] = 'integer'
            elif field_desc.type == FD.TYPE_STRING:
                s['type'] = 'string'
            elif field_desc.type == FD.TYPE_BOOL:
                s['type'] = 'boolean'
            elif field_desc.type == FD.TYPE_ENUM:
                s['type'] = 'string'
                s['enum'] = list(field_desc.enum_type.values_by_name.keys())
            else:
                raise AssertionError('unknown field type: {}'.format(field_desc.type))
            if is_dynamic and s['type'] != 'string':
                # hack to allow scalar values like "!f func(1,2,3)" in yaml
                s['type'] = [s['type'], 'string']
            if is_repeated:
                s = {
                    'type': 'array',
                    'items': s,
                }
        return s

    @classmethod
    def _list_possible_doc_filenames(cls, message_desc):
        filename = None
        for field_name, field_desc in six.iteritems(modules_pb2.Holder.DESCRIPTOR.fields_by_name):
            if field_desc.message_type and field_desc.message_type.full_name == message_desc.full_name:
                filename = field_name
        rv = []
        if filename:
            rv.append(filename + '.yml')

        prefix = 'awacs.modules.'
        full_name = message_desc.full_name
        if full_name.startswith(prefix):
            full_name = full_name[len(prefix):]

        parts = full_name.split('.')
        while parts:
            rv.extend(['.'.join(parts) + '.yml'])
            parts.pop(0)

        return rv

    @classmethod
    def get_wrapper_cls_by_full_name(cls, full_name):
        return REGISTRY.get(full_name)

    @classmethod
    def get_message_docs(cls, message_desc):
        """
        :type message_desc: pbdescriptor.Descriptor
        :rtype: dict
        """
        docs = {}
        for docs_fn in cls._list_possible_doc_filenames(message_desc):
            import awacs
            docs_fp = os.path.join(os.path.dirname(awacs.__file__), '../../docs/src/', docs_fn)
            if os.path.exists(docs_fp):
                with open(docs_fp) as f:
                    docs = yaml.safe_load(f.read().decode('utf-8'))
                    break
        wrapper_cls = cls.get_wrapper_cls_by_full_name(message_desc.full_name)

        if not wrapper_cls:
            return {}

        rv = {
            'defaults': getattr(wrapper_cls, 'DEFAULTS', {}),
            'allowed_calls': getattr(wrapper_cls, 'ALLOWED_CALLS', {}),
            'required_props': wrapper_cls.REQUIRED,
            'required_oneofs': wrapper_cls.get_required_oneofs(),
            'required_anyofs': wrapper_cls.REQUIRED_ANYOFS,
            'wrapper_cls': wrapper_cls,
        }
        if not docs:
            return rv

        field_docs = docs.get('fields', {})
        rv.update({
            'docs': docs,
            'field_docs': field_docs,
        })
        return rv

    def _message_desc_to_schema_properties(self, message_desc, preceding_messages):
        """
        :type message_desc: pbdescriptor.Descriptor
        :type preceding_messages: tuple[str]
        :rtype: dict, set[str], dict
        """
        assert isinstance(message_desc, pbdescriptor.Descriptor)

        docs = self.get_message_docs(message_desc)

        description = fix_md(docs.get('docs', {}).get('desc', ''))
        field_docs = docs.get('field_docs', {})
        required_props = docs.get('required_props', [])
        required_oneofs = docs.get('required_oneofs', [])
        required_anyofs = docs.get('required_anyofs', [])
        allowed_calls = docs.get('allowed_calls', {})

        props = OrderedDict()

        def get_field_schema(field_name, field_desc):
            allowed_field_calls = allowed_calls.get(field_name, ())
            is_dynamic = bool(allowed_calls.get(field_name))
            if field_desc.type == FD.TYPE_MESSAGE:
                embedded_field_name = Builder.get_embedded_field_name(field_desc.message_type)
                embedded_field_docs = self.get_message_docs(field_desc.message_type)
                is_embedded_field_dynamic = embedded_field_name in embedded_field_docs.get('allowed_calls', {})
                allowed_field_calls = embedded_field_docs.get('allowed_calls', {}).get(embedded_field_name, ())
                if embedded_field_name:
                    field_schema = self._field_desc_to_schema(
                        field_desc.message_type.fields_by_name[embedded_field_name],
                        preceding_messages,
                        is_repeated=field_desc.label == FD.LABEL_REPEATED,
                        is_dynamic=is_embedded_field_dynamic,
                        allowed_calls=allowed_field_calls,
                        is_embedded=True)
                else:
                    field_schema = self._field_desc_to_schema(
                        field_desc,
                        preceding_messages,
                        is_dynamic=is_dynamic,
                        allowed_calls=allowed_field_calls)
            else:
                field_schema = self._field_desc_to_schema(
                    field_desc,
                    preceding_messages,
                    is_dynamic=is_dynamic,
                    allowed_calls=allowed_field_calls)

            if field_name in field_docs:
                field_doc = field_docs[field_name]
                if 'desc' in field_doc:
                    field_schema['description'] = field_doc['desc']
                if allowed_field_calls:
                    field_schema['description'] += (
                        u'  \n'
                        u'Допускает использование следующих функций: `{}`.'
                    ).format('`, `'.join(allowed_field_calls))
                if 'default' in field_doc:
                    default = field_doc['default']
                    if default not in (None, 'None', 'none'):
                        field_schema['default'] = default
            return field_schema

        for field_name, field_desc in six.iteritems(message_desc.fields_by_name):
            if field_name.startswith('deprecated_'):
                # skip deprecated fields
                continue

            if field_desc.type == FD.TYPE_MESSAGE:
                if field_name.startswith('f_') and field_desc.message_type.name == 'Call':
                    # skip !f-fields for now
                    continue

            props[field_name] = get_field_schema(field_name, field_desc)

        return props, sorted(required_props), sorted(required_oneofs), sorted(required_anyofs), description

    @staticmethod
    def _get_schema_definition_id(full_name, preceding_messages):
        """
        :type full_name: str
        :type preceding_messages: tuple[str]
        :rtype: str
        """
        if full_name == 'awacs.modules.Holder' and not preceding_messages:
            return full_name + '-0'
        else:
            return full_name

    @staticmethod
    def _is_message_module_or_macro(message_desc):
        """
        :type message_desc: pbdescriptor.Descriptor
        :rtype: bool
        """
        return (message_desc.name.endswith(('Module', 'Macro', 'Holder')) or
                message_desc.name in ('IpdispatchSection', 'RegexpSection',
                                      'RegexpPathSection', 'PrefixPathRouterSection'))

    def _customize_schema(self, s, message_desc, preceding_messages):
        """
        :type s: dict
        :type message_desc: pbdescriptor.Descriptor
        :type preceding_messages: tuple[str]
        :rtype: dict
        """
        if message_desc.name == 'Holder':
            if not preceding_messages:
                props = {}
                for prop in self.TOP_LEVEL_HOLDER_PROPERTIES:
                    props[prop] = s['properties'][prop]
                s['properties'] = props
            else:
                for prop in self.TOP_LEVEL_HOLDER_PROPERTIES:
                    del s['properties'][prop]
        return s

    def message_desc_to_schema(self, message_desc, preceding_messages):
        """
        :type message_desc: pbdescriptor.Descriptor
        :type preceding_messages: tuple[str]
        :rtype: dict
        """
        assert isinstance(message_desc, pbdescriptor.Descriptor)

        definition_id = self._get_schema_definition_id(message_desc.full_name, preceding_messages)
        if definition_id in self._definitions or definition_id in self._being_processed:
            return {'$ref': '#/definitions/{}'.format(definition_id)}

        self._being_processed.add(definition_id)

        updated_preceding_messages = preceding_messages + (message_desc.full_name,)
        properties, required, required_oneofs, required_anyofs, description = self._message_desc_to_schema_properties(
            message_desc, preceding_messages=updated_preceding_messages)

        if message_desc.full_name == 'awacs.modules.Balancer2Module':
            p, _, _, _, _ = self._message_desc_to_schema_properties(
                message_desc.fields_by_name['balancing_policy'].message_type,
                preceding_messages=preceding_messages)
            properties.update(p)

        additional_schemas = []
        for req_oneof in required_oneofs:
            items = []
            for field_name in sorted(req_oneof):
                items.append({'required': [field_name]})
            one_of_schema = {
                'oneOf': items,
            }
            additional_schemas.append(one_of_schema)
        for req_anyof in required_anyofs:
            items = []
            for field_name in sorted(req_anyof):
                items.append({'required': [field_name]})
            any_of_schema = {
                'anyOf': items,
            }
            additional_schemas.append(any_of_schema)

        if len(additional_schemas) == 0:
            mixin = {}
        elif len(additional_schemas) == 1:
            mixin = additional_schemas[0]
        else:
            mixin = {
                'allOf': additional_schemas,
            }

        for item in required_oneofs:
            description += (
                u'  \n'
                u'В точности одно из следующих полей должно быть указано: `{}`.'
            ).format('`, `'.join(item))
        for item in required_anyofs:
            description += (
                u'  \n'
                u'Хотя бы одно из следующих полей должно быть указано: `{}`.'
            ).format('`, `'.join(item))

        s = {
            'type': 'object',
            'title': message_desc.name,
            'id': message_desc.full_name,
            'additionalProperties': True,
            'properties': properties,
            'required': sorted(required),
            'description': fix_md(description),
        }
        s.update(mixin)
        module_properties = set()
        for name, field_desc in six.iteritems(message_desc.fields_by_name):
            if field_desc.type == FD.TYPE_MESSAGE and self._is_message_module_or_macro(field_desc.message_type):
                module_properties.add(name)
        if module_properties:
            s['moduleProperties'] = sorted(module_properties)

        squashed_fields = self._get_squashed_fields(message_desc)

        if message_desc.full_name == 'awacs.modules.Balancer2Module':
            del squashed_fields['balancing_policy']

        if squashed_fields:
            all_of = [s]
            for squashed_field_name, squashed_message_full_name in six.iteritems(squashed_fields):
                definition_id = self._get_schema_definition_id(squashed_message_full_name,
                                                               updated_preceding_messages)
                ref = {
                    '$ref': '#/definitions/{}'.format(definition_id),
                }
                if squashed_field_name in required:
                    s['required'].remove(squashed_field_name)
                all_of.append(ref)
            s = {'allOf': all_of}

        s = self._customize_schema(s, message_desc, preceding_messages)

        definition_id = self._get_schema_definition_id(message_desc.full_name, preceding_messages)
        self._definitions[definition_id] = s
        return {'$ref': "#/definitions/{}".format(definition_id), 'description': description}

    def build_schema(self, message_desc):
        """
        :type message_desc: pbdescriptor.Descriptor
        :rtype: dict
        """
        rv = self.message_desc_to_schema(message_desc, preceding_messages=())
        rv['definitions'] = self._definitions
        return rv
