import collections

import six
from awacs.wrappers.main import Pinger, Threshold, Rewrite, ErrorDocument

from infra.awacs.tools.awacstoolslib.util import clone_pb
from .model import BaseVisitor


def to_names(path):
    return [item[0] for item in path]


THRESHOLDS = collections.defaultdict(list)
OUTER_BALANCER2S = collections.defaultdict(list)
INNER_BALANCER2S = collections.defaultdict(list)
OUTER_BY_NAME_CALLS = collections.defaultdict(list)
REGEXP_MATCHERS = collections.defaultdict(list)
SLB_PING_MATCHERS = collections.defaultdict(list)
PLATFORM_PING_MATCHERS = collections.defaultdict(list)


def list_set_field_names(pb):
    return {desc.name for desc, value in pb.ListFields()}


class UpstreamChecker(BaseVisitor):
    """
    Very quick and dirty implementation, to be refined and refactored.
    """
    RULE = 'UEM'

    ok = False
    already_ok = False

    def fail(self, path, message='Does not look easy'):
        self.warn(path, message, severity=1, tags={'total'})
        self.ok = False
        return False, [], None

    def check_Balancer2(self, module, path):
        original_path = list(path)

        def ignore_all(ms, ignored_name):
            return [(name, m) for name, m in ms if name != ignored_name]

        def require(ms, name):
            head_name, m = ms[0]
            if head_name != name:
                return self.fail(original_path)
            return True, ms[1:], m

        def require_prefix(ms, prefix):
            head_name, m = ms[0]
            if not head_name.startswith(prefix):
                return self.fail(original_path)
            return True, ms[1:], m

        def allow(ms, name):
            head_name, m = ms[0]
            if head_name == name:
                return True, ms[1:], m
            return True, ms, None

        def allow_allof(ms, names):
            found_ms = set()
            names = set(names)
            head_name, m = ms[0]
            while head_name in names:
                found_ms.add(m)
                ms = ms[1:]
                names.remove(head_name)
                head_name, m = ms[0]
            return True, ms, found_ms

        is_terminal = module.generated_proxy_backends
        if not is_terminal:
            return

        print('watchme', self.full_id.namespace_id, ' -> '.join([n for n, _ in path]))

        # we skip balancer2 inside sink or on_error
        for name, _ in original_path:
            if name == 'sink' or name == 'on_error':
                return

        path = ignore_all(path, 'stats_eater')

        ok, path, regexp_section = require(path, 'regexp_section')
        if not ok: return

        if not regexp_section.matcher.is_empty() and (
            not regexp_section.matcher.match_fsm or
            regexp_section.matcher.match_fsm.list_set_fields()[0] not in ('host', 'path', 'url') or
            regexp_section.matcher.match_fsm.pb.surround
        ):
            self.fail(path=original_path, message='matcher is too complex')
            return

        ok, path, m = allow(path, 'shared')
        if m and m.pb.uuid != 'backends':
            self.fail(path=original_path, message='shared uuid is not backends')
            return

        ok, path, m = allow(path, 'report')
        if not ok: return

        ok, path, ms = allow_allof(path, {
            # 'report',
            'threshold',
            # 'rewrite',
            # 'pinger',
            'headers',
            'response_headers',
            'log_headers',
            # 'request_replier',
            # 'antirobot_macro'
        })
        if not ok: return

        threshold = None
        for m in ms:
            if isinstance(m, Pinger):
                if m.pb.admin_request_uri:
                    self.fail(path=original_path, message='pinger uses admin_request_uri')
                    return
            if isinstance(m, Threshold):
                threshold = m
                if (threshold.pb.lo_bytes == 734003 and
                    threshold.pb.hi_bytes == 838860 and
                    threshold.pb.pass_timeout == '10s' and
                    threshold.pb.recv_timeout == '1s'
                ):
                    # usually maps
                    pass
                elif (threshold.pb.lo_bytes == 30720 and
                      threshold.pb.hi_bytes == 71680 and
                      threshold.pb.pass_timeout == '4s' and
                      threshold.pb.recv_timeout == '1s'
                ):
                    # used in translate-internal.yandex.net
                    self.fail(path=original_path, message='non-standard threshold: {}'.format(threshold.pb))
                    return
                else:
                    self.fail(path=original_path, message='non-standard threshold: {}'.format(threshold.pb))
                    return
            if isinstance(m, Rewrite):
                for a in m.actions:
                    if a.pb.header_name:
                        self.fail(path=original_path, message='rewrite works w/ headers')
                        return

        ok, path, m = allow(path, 'shared')
        if m and m.pb.uuid != 'backends':
            self.fail(path=original_path, message='shared uuid is not backends')
            return

        ok, path, m = require(path, 'balancer2')
        if not ok: return

        outer_balancer2 = m
        inner_balancer2 = None

        if path:
            ok, path, m = require_prefix(path, 'backend')
            if not m.pb.name.endswith((u'sas', u'man', u'vla', u'iva', u'myt',
                                       u'to_upstream'  # temporary, for slbpings
                                       )):
                self.fail(path=original_path, message='backend is not dc: ' + m.pb.name)
                return

            if not ok: return

            ok, path, m = allow(path, 'report')
            if not ok: return

            ok, path, m = require(path, 'balancer2')
            if not ok: return

            inner_balancer2 = m
        else:
            outer_balancer2 = None
            inner_balancer2 = m

        outer_balancer2_by_name_call_pb = None
        if outer_balancer2 is not None:
            outer_balancer2_pb = clone_pb(outer_balancer2.pb)
            outer_balancer2_pb.ClearField('backends')

            fields = list_set_field_names(outer_balancer2_pb)
            allowed_fields = {'attempts', 'f_attempts',
                              'rr', 'balancing_policy', 'on_error'}
            unexpected_fields = fields - allowed_fields
            if unexpected_fields:
                self.fail(path=[], message='outer balancer contains more than expected: {}'.format(unexpected_fields))
                return

            attempts = outer_balancer2.get('attempts')
            if attempts.is_func():
                if attempts.value.func_name != 'count_backends':
                    self.fail(path=[], message='outer balancer !f-attempts is not count_backends')
                    return

            policy_kinds = (outer_balancer2.balancing_policy and
                            outer_balancer2.balancing_policy.list_policy_kinds() or [])
            if (policy_kinds not in (['by_name_policy', 'unique_policy'],
                                     ['by_name_policy', 'simple_policy'],
                                     ['unique_policy'],
                                     ['simple_policy'],
                                     [])):
                self.fail(path=[], message='outer balancer policies are strange: {}'.format(policy_kinds))
                return

            if policy_kinds and policy_kinds[0] == 'by_name_policy':
                bn_pb = outer_balancer2.balancing_policy.by_name_policy.pb
                fields = list_set_field_names(bn_pb)
                unexpected_fields = fields - {'f_name', 'balancing_policy'}
                if unexpected_fields:
                    self.fail(path=[],
                              message='outer by_name policy has unexpected fields: {}'.format(unexpected_fields))
                    return
                if not (bn_pb.f_name.HasField('suffix_with_dc_params') or
                        bn_pb.f_name.HasField('prefix_with_dc_params') or
                        bn_pb.f_name.HasField('get_geo_params')):
                    self.fail(path=[],
                              message='outer by_name policy has unexpected call: {}'.format(bn_pb.f_name))
                    return
                outer_balancer2_by_name_call_pb = bn_pb.f_name

            if not outer_balancer2_pb.HasField('rr'):
                self.fail(path=[], message='outer balancer is not rr')
                return
            rr_fields = list_set_field_names(outer_balancer2_pb.rr)
            if rr_fields - {'weights_file'}:
                self.fail(path=[], message='outer balancer rr contains more than weights file: {}'.format(rr_fields))
                return

            if outer_balancer2.on_error and not isinstance(next(outer_balancer2.on_error.walk_chain()), ErrorDocument):
                self.fail(path=[], message='outer balancer on_error '
                                           'is not errordocument')
                return

        if inner_balancer2 is not None:
            inner_balancer2_pb = clone_pb(inner_balancer2.pb)

            fields = list_set_field_names(inner_balancer2_pb)
            allowed_fields = {'attempts', 'f_attempts',
                              'connection_attempts', 'f_connection_attempts',
                              'fast_attempts', 'f_fast_attempts',
                              'fast_503',
                              'weighted2', 'rr', 'active', 'balancing_policy',
                              'generated_proxy_backends', 'attempts_rate_limiter'}
            if outer_balancer2 is None:
                allowed_fields.add('on_error')
            unexpected_fields = fields - allowed_fields
            if unexpected_fields:
                self.fail(path=[], message='inner balancer contains more than expected: {}'.format(unexpected_fields))
                return

            if inner_balancer2.on_error and not isinstance(next(inner_balancer2.on_error.walk_chain()), ErrorDocument):
                self.fail(path=[], message='simple (outer == inner) balancer on_error '
                                           'is not errordocument: {}'.format(inner_balancer2.on_error.module_name))
                return

            policy_kinds = (inner_balancer2.balancing_policy and
                            inner_balancer2.balancing_policy.list_policy_kinds() or [])
            if (policy_kinds not in (['watermark_policy', 'unique_policy'],
                                     ['watermark_policy', 'simple_policy'],
                                     ['unique_policy'],
                                     ['simple_policy'],
                                     [])):
                self.fail(path=[], message='inner balancer policies are strange: {}'.format(policy_kinds))
                return

            if policy_kinds and policy_kinds[0] == 'watermark_policy':
                wm_pb = inner_balancer2.balancing_policy.watermark_policy.pb
                fields = list_set_field_names(wm_pb)
                unexpected_fields = fields - {'lo', 'hi', 'balancing_policy'}
                if unexpected_fields:
                    self.fail(path=[], message='inner wm policy has unexpected fields: {}'.format(unexpected_fields))
                    return

            if inner_balancer2_pb.HasField('generated_proxy_backends'):
                fields = list_set_field_names(inner_balancer2_pb.generated_proxy_backends)
                unexpected_fields = fields - {'include_backends', 'proxy_options'}
                if unexpected_fields:
                    self.fail(path=[], message='inner balancer generated_proxy_backends '
                                               'contains more than expected: {}'.format(unexpected_fields))
                    return

                if inner_balancer2_pb.generated_proxy_backends.HasField('proxy_options'):
                    fields = list_set_field_names(inner_balancer2_pb.generated_proxy_backends.proxy_options)
                    unexpected_fields = fields - {'connect_timeout',
                                                  'backend_timeout',
                                                  'keepalive_count',
                                                  'keepalive_timeout',
                                                  'fail_on_5xx',
                                                  'client_write_timeout',
                                                  'client_read_timeout',
                                                  'backend_write_timeout',
                                                  'backend_read_timeout',
                                                  'allow_connection_upgrade',
                                                  'status_code_blacklist',
                                                  }
                    if unexpected_fields:
                        self.fail(path=[], message='inner balancer generated_proxy_backends proxy_options '
                                                   'contains more than expected: {}'.format(unexpected_fields))
                        return
            if inner_balancer2_pb.HasField('weighted2'):
                set_fields = list_set_field_names(inner_balancer2_pb.weighted2)
                if set_fields:
                    self.fail(path=[], message='inner balancer weighted2 has unexpected options: {}'.format(set_fields))
                    return

            if inner_balancer2_pb.HasField('rr'):
                set_fields = list_set_field_names(inner_balancer2_pb.rr)
                if set_fields:
                    self.fail(path=[], message='inner balancer rr has unexpected options: {}'.format(set_fields))
                    return

            if inner_balancer2_pb.HasField('active'):
                unexpected_fields = list_set_field_names(inner_balancer2_pb.active) - {
                    'delay', 'request', 'steady'
                }
                # if inner_balancer2_pb.active.HasField('steady') and not inner_balancer2_pb.active.steady.value:
                #    self.fail(path=[],
                #              message='inner balancer active steady is false')
                #    return

                if unexpected_fields:
                    self.fail(path=[],
                              message='inner balancer active has unexpected options: {}'.format(unexpected_fields))
                    return

            inner_balancer2_pb.generated_proxy_backends.ClearField('include_backends')

        self.ok = True
        self.two_level = outer_balancer2 is not None
        if threshold is not None:
            from google.protobuf.json_format import MessageToJson
            THRESHOLDS[MessageToJson(threshold.pb)].append(repr(self.full_id))
        if inner_balancer2 is not None:
            INNER_BALANCER2S[six.text_type(inner_balancer2_pb)].append(repr(self.full_id))
        if outer_balancer2 is not None:
            OUTER_BALANCER2S[six.text_type(outer_balancer2_pb)].append(repr(self.full_id))
        if outer_balancer2_by_name_call_pb is not None:
            OUTER_BY_NAME_CALLS[six.text_type(outer_balancer2_by_name_call_pb)].append(repr(self.full_id))
        REGEXP_MATCHERS[six.text_type(regexp_section.pb.matcher)].append(repr(self.full_id))

    def check_SlbPingMacro(self, module, path):
        id = self.full_id.id
        if id in ('slbping', 'slb_ping'):
            ns = to_names(path)
            if ns != ['regexp_section', 'slb_ping_macro'] and ns != ['regexp_section', 'report', 'slb_ping_macro']:
                self.fail(path)
                return
            regexp_section = path[0][1]
            SLB_PING_MATCHERS[six.text_type(regexp_section.pb.matcher)].append(repr(self.full_id))
            self.ok = True

    def check_ErrorDocument(self, module, path):
        id = self.full_id.id
        if id in ('awacs-balancer-health-check',):
            ns = to_names(path)
            if ns != ['regexp_section', 'errordocument'] and ns != ['regexp_section', 'report', 'errordocument']:
                self.fail(path)
                return
            regexp_section = path[0][1]
            PLATFORM_PING_MATCHERS[six.text_type(regexp_section.pb.matcher)].append(repr(self.full_id))
            self.ok = True
        elif id == 'default':
            ns = to_names(path)
            if ns == ['regexp_section', 'errordocument'] or ns == ['regexp_section', 'report', 'errordocument']:
                self.ok = True
        else:
            ns = to_names(path)

    def check_L7UpstreamMacro(self, module, path):
        self.ok = True
        self.already_ok = True

    def check_EasyModeUpstreamMacro(self, module, path):
        self.ok = True
        self.already_ok = True
