# coding: utf-8
import enum
import six

from awacs.wrappers.base import Holder, Chain, ChainableModuleWrapperBase, ModuleWrapperBase
from awacs.wrappers.main import Balancer2Backend, RegexpSection


UI_TAG = 'awacs-ui'


def clone_pb(pb):
    cloned_pb = type(pb)()
    cloned_pb.CopyFrom(pb)
    return cloned_pb


def clone_pb_dict(d):
    return {k: clone_pb(pb) for k, pb in six.iteritems(d)}


class FullId(object):
    def __init__(self, type, namespace_id, id):
        self.type = type
        self.namespace_id = namespace_id
        self.id = id

    def __repr__(self):
        return u'{} {}:{}'.format(self.type, self.namespace_id, self.id)

    def to_json(self):
        return {
            'type': self.type,
            'namespace_id': self.namespace_id,
            'id': self.id,
        }

    @classmethod
    def from_json(cls, data):
        return cls(data['type'], data['namespace_id'], data['id'])


class Warning(object):
    class Severity(enum.Enum):
        CRIT = 0

    def __init__(self, full_id, rule, path, message, severity=Severity.CRIT, tags=()):
        """
        :param full_id: affected object id, a tuple of two elements -- namespace id + object id
        :type full_id: tuple(six.text_type, six.text_type)
        :param rule: violated rule idenitifer, e.g. "ARL"
        :type rule: six.text_type
        :param path: a path to the affected module within the config tree
        :type path: list[six.text_type]
        :param message: a human-readable message
        :type message: six.text_type
        :param severity:
        :type severity: Warning.Severity.*
        :param tags: a set of tags
        :type tags: set[six.text_type] | frozenset[six.text_type]
        """
        self.full_id = full_id
        self.rule = rule
        self.path = path
        self.message = message
        self.severity = severity
        self.tags = tags

    def __repr__(self):
        return (
            u'Where: {}\n'
            u'Path: {}\n'
            u'Message: {}'
        ).format(repr(self.full_id), ' -> '.join(self.path), self.message)

    @classmethod
    def from_json(cls, data):
        return cls(FullId.from_json(data['full_id']),
                   data['rule'],
                   data['path'],
                   data['message'],
                   cls.Severity(data['severity']),
                   data['tags'])

    def to_json(self):
        return {
            'full_id': self.full_id.to_json(),
            'rule': self.rule,
            'path': list(self.path),
            'message': self.message,
            'severity': self.severity.value,
            'tags': sorted(self.tags),
        }


class BaseChecker(object):
    RULE = 'UNDEFINED'

    def __init__(self, full_id):
        """
        :type full_id: tuple(six.text_type, six.text_type)
        """
        self.full_id = full_id
        self.warnings = []

    def warn(self, path, message, tags=frozenset()):
        self.warnings.append(Warning(
            full_id=self.full_id,
            rule=self.RULE,
            path=[construct_human_readable_name(name, module) for name, module in path],
            message=message,
            tags=tags
        ))

    @staticmethod
    def find_preceding_modules(path, cls):
        rv = []
        visited_path = ()
        for name, module in path[:-1]:
            if isinstance(module, cls):
                rv.append((visited_path, module))
            visited_path = visited_path + ((name, module),)
        return rv


class BaseVisitor(BaseChecker):
    def visit(self, module, path):
        check = getattr(self, 'check_' + module.__class__.__name__, None)
        if check is not None and callable(check):
            check(module, path)


def construct_human_readable_name(name, branch):
    if isinstance(branch, Balancer2Backend):
        return u'{}[name="{}"]'.format(name, branch.pb.name)
    if isinstance(branch, RegexpSection):
        if branch.matcher and branch.matcher.match_fsm:
            match_fsm = branch.matcher.match_fsm
            if u'header' not in branch.matcher.match_fsm.list_set_fields():
                return u'{}[{}]'.format(
                    name,
                    ', '.join([u'{}="{}"'.format(k, getattr(match_fsm.pb, k)) for k in
                               branch.matcher.match_fsm.list_set_fields()]))
    return name


def visit_chainable_module(module, checker, path=()):
    checker.visit(module, path)
    if module.nested:
        visit(module.nested, checker, path)
    for name, branch in six.iteritems(module.get_named_branches()):
        visit(branch, checker, path + ((name, branch),))


def visit_module(module, checker, path=()):
    check = getattr(checker, 'check_' + module.__class__.__name__, None)
    if check is not None:
        check(module, path)
    for name, branch in six.iteritems(module.get_named_branches()):
        visit(branch, checker, path + ((name, branch),))


def visit_modules(modules, checker, path=()):
    if not modules:
        return
    fst_m = modules[0]
    visit(fst_m, checker, path)
    visit_modules(modules[1:], checker, path + ((fst_m.module_name, fst_m.module),))


def visit_holder(holder, checker, path=()):
    if holder.module:
        visit(holder.module, checker, path + ((holder.module_name, holder.module),))
    elif holder.chain:
        visit(holder.chain, checker, path)
    else:
        raise AssertionError('empty holder')


def visit(obj, checker, path=()):
    if isinstance(obj, Holder):
        visit_holder(obj, checker, path)
    elif isinstance(obj, Chain):
        visit_modules(obj.modules, checker, path)
    elif isinstance(obj, ChainableModuleWrapperBase):
        visit_chainable_module(obj, checker, path)
    elif isinstance(obj, ModuleWrapperBase):
        visit_module(obj, checker, path)
    else:
        raise AssertionError('unvisitable obj: {}'.format(type(obj)))
