# coding: utf-8
import six

from awacs.wrappers.base import Holder, Chain, ChainableModuleWrapperBase, ModuleWrapperBase


def clone_pb(pb):
    cloned_pb = type(pb)()
    cloned_pb.CopyFrom(pb)
    return cloned_pb


def clone_pb_dict(d):
    return {k: clone_pb(pb) for k, pb in six.iteritems(d)}


class Warning(object):
    def __init__(self, namespace_id, balancer_id, rule, path, message):
        self.namespace_id = namespace_id
        self.balancer_id = balancer_id
        self.rule = rule
        self.path = path
        self.message = message

    def __repr__(self):
        return 'Balancer: {}/{}\nPath: {}\nMessage: {}'.format(
            self.namespace_id, self.balancer_id, ' -> '.join(self.path), self.message)


class BaseChecker(object):
    RULE = 'UNDEFINED'

    def __init__(self, namespace_id, balancer_id):
        self.namespace_id = namespace_id
        self.balancer_id = balancer_id
        self.warnings = []

    def warn(self, path, message):
        self.warnings.append(Warning(
            namespace_id=self.namespace_id,
            balancer_id=self.balancer_id,
            rule=self.RULE,
            path=[name for name, module in path],
            message=message
        ))

    @staticmethod
    def find_preceding_modules(path, cls):
        rv = []
        visited_path = ()
        for name, module in path[:-1]:
            if isinstance(module, cls):
                rv.append((visited_path, module))
            visited_path = visited_path + ((name, module),)
        return rv


class BaseVisitor(BaseChecker):
    def visit(self, module, path):
        check = getattr(self, 'check_' + module.__class__.__name__, None)
        if check is not None and callable(check):
            check(module, path)


class BaseBalancerSuggester(BaseChecker):
    def suggest(self, balancer_spec_pb):
        raise NotImplementedError()


class BaseUpstreamSuggester(BaseChecker):
    def suggest(self, upstream_id, upstream_spec_pb):
        raise NotImplementedError()


def visit_chainable_module(module, checker, path=()):
    checker.visit(module, path)
    if module.nested:
        visit(module.nested, checker, path)
    for name, branch in six.iteritems(module.get_named_branches()):
        visit(branch, checker, path + ((name, branch),))


def visit_module(module, checker, path=()):
    check = getattr(checker, 'check_' + module.__class__.__name__, None)
    if check is not None:
        check(module, path)
    for name, branch in six.iteritems(module.get_named_branches()):
        visit(branch, checker, path + ((name, branch),))


def visit_modules(modules, checker, path=()):
    if not modules:
        return
    fst_m = modules[0]
    visit(fst_m, checker, path)
    visit_modules(modules[1:], checker, path + ((fst_m.module_name, fst_m.module),))


def visit_holder(holder, checker, path=()):
    if holder.module:
        visit(holder.module, checker, path + ((holder.module_name, holder.module),))
    elif holder.chain:
        visit(holder.chain, checker, path)
    else:
        raise AssertionError('empty holder')


def visit(obj, checker, path=()):
    if isinstance(obj, Holder):
        visit_holder(obj, checker, path)
    elif isinstance(obj, Chain):
        visit_modules(obj.modules, checker, path)
    elif isinstance(obj, ChainableModuleWrapperBase):
        visit_chainable_module(obj, checker, path)
    elif isinstance(obj, ModuleWrapperBase):
        visit_module(obj, checker, path)
    else:
        raise AssertionError('unvisitable obj: {}'.format(type(obj)))
