from functools import wraps

from infra.awacs.proto import modules_pb2

from awacs.wrappers.luautil import read_string
from awacs.wrappers.main import StatsEater, LogHeaders, Threshold, Headers, Shared, Holder, Balancer2, ErrorDocument, \
    Report, RegexpSection, Rewrite, ResponseHeaders, Compressor, RpsLimiterMacro
from .model import BaseUpstreamSuggester
from .tlem import SectionVisitor
import six


def list_set_fields(pb):
    return {desc.name for desc, value in pb.ListFields()}


def h(method):
    @wraps(method)
    def f(self, *args, **kwargs):
        ok, msg = method(self, *args, **kwargs)
        self.ok &= ok
        if not ok:
            self.msg += '; ' + msg

    return f


class Checker(BaseUpstreamSuggester):
    RULE = 'UEM'
    SKIP_RpsLimiterMacro = 'SKIP_RpsLimiterMacro'

    def __init__(self, namespace_id, upstream_id):
        super(Checker, self).__init__(namespace_id, upstream_id)
        self.opts = {}

    @classmethod
    def get_nested(cls, holder_pb):
        h = Holder(holder_pb)
        assert not h.is_empty()
        return next(h.walk_chain())

    @classmethod
    def _fill_balancer_settings(cls, balancer_settings_pb, balancer2):
        balancer2_pb = balancer2.pb
        po_pb = balancer2.generated_proxy_backends.proxy_options.pb

        if balancer2_pb.attempts:
            balancer_settings_pb.attempts = balancer2_pb.attempts
        elif balancer2_pb.HasField('f_attempts'):
            balancer_settings_pb.attempt_all_endpoints = True
        else:
            raise AssertionError('no attempts!')

        if balancer2_pb.connection_attempts:
            balancer_settings_pb.fast_attempts = balancer2_pb.connection_attempts
            if balancer2_pb.fast_attempts:
                raise AssertionError()
        elif balancer2_pb.HasField('f_connection_attempts'):
            balancer_settings_pb.fast_attempt_all_endpoints = True
        elif balancer2_pb.fast_attempts:
            balancer_settings_pb.fast_attempts = balancer2_pb.fast_attempts

        if po_pb.connect_timeout:
            balancer_settings_pb.connect_timeout = po_pb.connect_timeout
        balancer_settings_pb.backend_timeout = po_pb.backend_timeout or '10s'
        balancer_settings_pb.keepalive_count = po_pb.keepalive_count
        balancer_settings_pb.backend_read_timeout = po_pb.backend_read_timeout
        balancer_settings_pb.client_read_timeout = po_pb.client_read_timeout
        # balancer_settings_pb.keepalive_timeout = po_pb.keepalive_timeout

        if balancer2_pb.fast_503.value:
            balancer_settings_pb.fast_attempts_type = balancer_settings_pb.CONNECT_FAILURE_AND_503
            if balancer2_pb.fast_attempts:
                balancer_settings_pb.fast_attempts = balancer2_pb.fast_attempts
            elif balancer2_pb.HasField('f_fast_attempts'):
                balancer_settings_pb.fast_attempt_all_endpoints = True

        if po_pb.HasField('fail_on_5xx') and not po_pb.fail_on_5xx.value:
            balancer_settings_pb.do_not_retry_http_responses = True
        elif po_pb.HasField('fail_on_5xx') and po_pb.fail_on_5xx.value:
            balancer_settings_pb.retry_http_responses.codes.append('5xx')
        else:
            balancer_settings_pb.retry_http_responses.codes.append('5xx')

        if po_pb.status_code_blacklist_exceptions:
            balancer_settings_pb.retry_http_responses.exceptions.extend(po_pb.status_code_blacklist_exceptions)

        if po_pb.allow_connection_upgrade:
            balancer_settings_pb.allow_connection_upgrade = po_pb.allow_connection_upgrade

        if po_pb.status_code_blacklist:
            balancer_settings_pb.retry_http_responses.codes.extend(po_pb.status_code_blacklist)

        balancer_settings_pb.buffering = po_pb.buffering

        if balancer2_pb.HasField('weighted2'):
            balancer_settings_pb.compat.method = balancer_settings_pb.compat.WEIGHTED2
        elif balancer2_pb.HasField('active'):
            balancer_settings_pb.compat.method = balancer_settings_pb.compat.ACTIVE
            balancer_settings_pb.health_check.delay = balancer2_pb.active.delay
            balancer_settings_pb.health_check.request = balancer2_pb.active.request
            if balancer2_pb.active.HasField('steady') and not balancer2_pb.active.steady.value:
                balancer_settings_pb.health_check.compat.not_steady = True
        elif balancer2_pb.HasField('rr'):
            balancer_settings_pb.compat.method = balancer_settings_pb.compat.RR
        elif balancer2_pb.HasField('dynamic'):
            balancer_settings_pb.max_pessimized_endpoints_share = balancer2_pb.dynamic.max_pessimized_share
        else:
            raise AssertionError

        wm_policy = False
        policy_kinds = (balancer2.balancing_policy and
                        balancer2.balancing_policy.list_policy_kinds() or [])
        if policy_kinds == ['watermark_policy', 'unique_policy']:
            assert balancer2_pb.balancing_policy.HasField('watermark_policy')
            wm_pb = balancer2_pb.balancing_policy.watermark_policy
            balancer_settings_pb.compat.watermark_policy.lo = wm_pb.lo
            balancer_settings_pb.compat.watermark_policy.hi = wm_pb.hi
            wm_policy = True
        elif policy_kinds == ['unique_policy'] or policy_kinds == []:
            pass
        else:
            raise AssertionError('unexpected policies {}'.format(policy_kinds))

        if balancer2.attempts_rate_limiter:
            assert not wm_policy
            balancer_settings_pb.max_reattempts_share = balancer2.attempts_rate_limiter.pb.limit
        elif not wm_policy:
            balancer_settings_pb.do_not_limit_reattempts = True

    @classmethod
    def _fill_flat_scheme(cls, flat_scheme_pb, balancer2, can_handle_announce_checks=False):
        """
        :type balancer2: Balancer2
        """
        flat_scheme_pb.can_handle_announce_checks = can_handle_announce_checks

        balancer_settings_pb = flat_scheme_pb.balancer
        cls._fill_balancer_settings(balancer_settings_pb, balancer2)

        if balancer2.on_error:
            on_error_module = cls.get_nested(balancer2.on_error.pb)
            assert isinstance(on_error_module, ErrorDocument)

            flat_scheme_pb.on_error.static.status = on_error_module.pb.status
            flat_scheme_pb.on_error.static.content = on_error_module.pb.content
        else:
            flat_scheme_pb.on_error.rst = True

    @classmethod
    def _fill_dc(cls, ctx, dc_pb, modules):
        m = modules.pop(0)
        if isinstance(m, Shared):
            assert not m.nested
            m = modules.pop(0)
        if isinstance(m, RpsLimiterMacro) and ctx.get(cls.SKIP_RpsLimiterMacro, False):
            m = modules.pop(0)
        if isinstance(m, Report):
            dc_pb.monitoring.uuid = m.pb.uuid
            if m.pb.ranges != 'default':
                dc_pb.monitoring.ranges = m.pb.ranges
            m = modules.pop(0)
        else:
            dc_pb.compat.disable_monitoring = True

        if isinstance(m, StatsEater):
            m = modules.pop(0)
        if isinstance(m, Shared):
            assert not m.nested
            m = modules.pop(0)
        assert isinstance(m, Balancer2), repr(type(m))
        assert m.generated_proxy_backends

        dc_pb.backend_ids.extend(m.generated_proxy_backends.include_backends.pb.ids)

    @classmethod
    def _fill_by_dc_scheme(cls, ctx, upstream_spec_pb, upstream_spec_pbs, by_dc_scheme_pb, balancer2, can_handle_announce_checks=False):
        by_dc_scheme_pb.can_handle_announce_checks = can_handle_announce_checks

        balancer2_pb = balancer2.pb
        if balancer2_pb.attempts:
            by_dc_scheme_pb.dc_balancer.attempts = balancer2.pb.attempts
        elif balancer2_pb.HasField('f_attempts'):
            by_dc_scheme_pb.dc_balancer.attempt_all_dcs = True
        else:
            raise AssertionError()
        assert balancer2_pb.HasField('rr')
        if not balancer2_pb.rr.weights_file:
            by_dc_scheme_pb.dc_balancer.compat.disable_dynamic_weights = True

        policy_kinds = (balancer2.balancing_policy and
                        balancer2.balancing_policy.list_policy_kinds() or [])
        if policy_kinds == ['by_name_policy', 'unique_policy']:
            by_dc_scheme_pb.dc_balancer.method = by_dc_scheme_pb.dc_balancer.LOCAL_THEN_BY_DC_WEIGHT
        elif policy_kinds == ['unique_policy'] or policy_kinds == []:
            by_dc_scheme_pb.dc_balancer.method = by_dc_scheme_pb.dc_balancer.BY_DC_WEIGHT
        else:
            raise AssertionError("it is not possible to determine which policy should be used: balancer2.balancing_policy.list_policy_kinds() == {} :(".format(policy_kinds))

        prefixes = set()

        seen_devnull = False
        first_dc = True
        for backend in balancer2.backends:
            parts = backend.pb.name.split('_')
            ok = False

            if len(parts) >= 2:
                prefix, dc = '_'.join(parts[:-1]), parts[-1]
                ok = dc in ('sas', 'man', 'vla', 'iva', 'myt', 'devnull')
            if not ok:
                raise AssertionError('it is not possible to determine location: balancer2.backends[*].name == "{}" :('.format(backend.pb.name))

            prefixes.add(prefix)

            if dc != 'devnull':
                dc_pb = by_dc_scheme_pb.dcs.add(name=dc)
                dc_modules = list(backend.walk_chain())[1:]

                cls.inject_shared(dc_modules, upstream_spec_pb, upstream_spec_pbs)

                cls._fill_dc(ctx, dc_pb, list(dc_modules))
                if first_dc:
                    cls._fill_balancer_settings(by_dc_scheme_pb.balancer, dc_modules[-1])
                    first_dc = False
            else:
                seen_devnull = True

        if not seen_devnull:
            by_dc_scheme_pb.compat.disable_devnull = True
        else:
            by_dc_scheme_pb.devnull.monitoring.uuid = 'requests_hollywood_to_devnull'
            by_dc_scheme_pb.devnull.static.status = 200
            by_dc_scheme_pb.devnull.static.content = 'OK'

        assert len(prefixes) == 1
        by_dc_scheme_pb.dc_balancer.weights_section_id = prefixes.pop()

        if balancer2.on_error:
            on_error_module = cls.get_nested(balancer2.on_error.pb)
            assert isinstance(on_error_module, ErrorDocument)

            by_dc_scheme_pb.on_error.static.status = on_error_module.pb.status
            by_dc_scheme_pb.on_error.static.content = on_error_module.pb.content
        else:
            by_dc_scheme_pb.on_error.rst = True

    @classmethod
    def inject_shared(cls, modules, upstream_spec_pb, upstream_spec_pbs):
        upstreams = [upstream_spec_pb]
        upstreams.extend(six.itervalues(upstream_spec_pbs))

        def check_and_inject(modules, other_up_modules):
            for i, other_up_m in enumerate(other_up_modules):
                if isinstance(other_up_m, Balancer2):
                    for b in other_up_m.backends:
                        other_up_modules2 = list(b.walk_chain())
                        if check_and_inject(modules, other_up_modules2):
                            return True
                if (isinstance(other_up_m, Shared)
                        and other_up_m.pb.uuid == modules[-1].pb.uuid
                        and i < len(other_up_modules) - 1):
                    injected_modules = other_up_modules[i + 1:]  # skip shared
                    del modules[-1]
                    modules += injected_modules
                    return True
            return False

        if isinstance(modules[-1], Shared) and modules[-1].pb.uuid != 'backends':
            for other_up_spec_pb in upstreams:
                other_up_config_pb = other_up_spec_pb.yandex_balancer.config
                other_up_modules = list(Holder(other_up_config_pb).walk_chain())
                if check_and_inject(modules, other_up_modules):
                    break

    def suggest(self, upstream_id, upstream_spec_pb, upstream_spec_pbs):
        config_pb = upstream_spec_pb.yandex_balancer.config
        if not config_pb.HasField('regexp_section'):
            return False, 'not an regexp_section', None

        modules = list(Holder(config_pb).walk_chain())
        modules = [m for m in modules if not isinstance(m, StatsEater)]

        self.inject_shared(modules, upstream_spec_pb, upstream_spec_pbs)

        rv_pb = modules_pb2.L7UpstreamMacro()
        rv_pb.version = '0.0.1'
        rv_pb.id = upstream_id

        m = modules.pop(0)

        def process_match_section(matcher, rv_matcher):
            if matcher == modules_pb2.Matcher():
                rv_matcher.any = True
            elif matcher.HasField('match_fsm'):
                match_fsm_pb = matcher.match_fsm
                if match_fsm_pb.host:
                    rv_matcher.host_re = read_string(match_fsm_pb.host)
                elif match_fsm_pb.uri:
                    rv_matcher.uri_re = read_string(match_fsm_pb.uri)
                elif match_fsm_pb.url:
                    rv_matcher.url_re = read_string(match_fsm_pb.url)
                elif match_fsm_pb.path:
                    rv_matcher.path_re = read_string(match_fsm_pb.path)
                elif match_fsm_pb.cgi:
                    cgi = match_fsm_pb.cgi
                    if match_fsm_pb.surround:
                        cgi = r'\\.*' + cgi + r'\\.*'
                    rv_matcher.cgi_re = read_string(cgi)
                elif match_fsm_pb.match:
                    import re
                    methods = ('GET', 'HEAD', 'POST', 'PUT', 'PATCH', 'DELETE', 'CONNECT', 'OPTIONS')
                    re_part = '(?:{})'.format('|'.join(methods))
                    match_obj = re.match(r'^({0}|\((?:{0}\|?)+\))\.\*$'.format(re_part), match_fsm_pb.match)
                    if not match_obj:
                        raise AssertionError('match_fsm_pb {}'.format(match_fsm_pb))
                    assert match_obj.group(1)
                    used_methods = match_obj.group(1)
                    if used_methods.startswith('('):
                        used_methods = used_methods[1:]
                    if used_methods.endswith(')'):
                        used_methods = used_methods[:-1]
                    used_methods = used_methods.split('|')
                    if len(used_methods) == 1:
                        rv_matcher.method = used_methods[0].upper()
                    else:
                        for method in used_methods:
                            rv_matcher.or_.add().method = method.upper()
                elif match_fsm_pb.HasField('header'):
                    rv_matcher.header.name = match_fsm_pb.header.name
                    rv_matcher.header.re = match_fsm_pb.header.value
                else:
                    raise AssertionError('match_fsm_pb {}'.format(match_fsm_pb))
            elif matcher.HasField('match_not'):
                inner_matcher = matcher.match_not
                rv_inner_matcher = rv_matcher.not_
                process_match_section(inner_matcher, rv_inner_matcher)
            elif matcher.HasField('match_method'):
                if len(matcher.match_method.methods) == 1:
                    rv_matcher.method = matcher.match_method.methods[0].upper()
                else:
                    raise AssertionError('len(matcher.match_method.methods) > 1')
            elif len(matcher.match_and) > 0:
                for inner_matcher in matcher.match_and:
                    rv_inner_matcher = rv_matcher.and_.add()
                    process_match_section(inner_matcher, rv_inner_matcher)
            elif len(matcher.match_or) > 0:
                for inner_matcher in matcher.match_or:
                    rv_inner_matcher = rv_matcher.or_.add()
                    process_match_section(inner_matcher, rv_inner_matcher)
            else:
                raise AssertionError('matcher {}'.format(matcher))

        if isinstance(m, RegexpSection):
            rs_pb = m.pb
            process_match_section(rs_pb.matcher, rv_pb.matcher)
        else:
            raise AssertionError()

        m = modules.pop(0)

        opts = dict(
            scheme_can_handle_announce_checks=False,
            report_found=False,
        )

        def process_rewrite(m):
            if isinstance(m, Rewrite):
                def to_rewrite_actions(actions):
                    """
                    :type actions: Iterable[modules_pb2.RewriteAction]
                    :rtype: Iterable[modules_pb2.L7Macro.RewriteAction]
                    """
                    rv = []
                    for action in actions:
                        rewrite_a_pb = modules_pb2.L7Macro.RewriteAction()
                        if action.split in (u'url', u''):
                            rewrite_a_pb.target = rewrite_a_pb.URL
                        elif action.split == u'path':
                            rewrite_a_pb.target = rewrite_a_pb.PATH
                        elif action.split == u'cgi':
                            rewrite_a_pb.target = rewrite_a_pb.CGI
                        else:
                            rewrite_a_pb.target = rewrite_a_pb.NONE
                        rewrite_a_pb.replacement = action.rewrite
                        rewrite_a_pb.pattern.re = action.regexp
                        rewrite_a_pb.pattern.literal.value = action.literal

                        getattr(rewrite_a_pb.pattern, 'global').value = getattr(action, 'global')
                        rewrite_a_pb.pattern.case_sensitive.value = not action.case_insensitive.value
                        rv.append(rewrite_a_pb)
                    return rv

                rv_pb.rewrite.extend(to_rewrite_actions(m.pb.actions))
                m = modules.pop(0) if len(modules) > 0 else None
            return m, True

        def process_rps_limiter_macro(m, mutate_modules=True):
            if isinstance(m, RpsLimiterMacro):
                rv_pb.rps_limiter.external.record_name = m.pb.record_name
                rv_pb.rps_limiter.external.installation = m.pb.installation
                if mutate_modules:
                    m = modules.pop(0) if len(modules) > 0 else None
            return m, True

        def process_threshold(m):
            if isinstance(m, Threshold):
                rv_pb.compat.threshold_profile = rv_pb.compat.THRESHOLD_PROFILE_CORE_MAPS
                m = modules.pop(0) if len(modules) > 0 else None
            return m, True

        def process_shared(m):
            if isinstance(m, Shared):
                if m.pb.uuid == 'backends':
                    rv_pb.can_handle_announce_checks = True
                m = modules.pop(0) if len(modules) > 0 else None
            return m, True

        def process_report(m):
            if isinstance(m, Report):
                opts['report_found'] = True
                # if m.pb.uuid == upstream_id and (not m.pb.ranges or m.pb.ranges == 'default'):
                # #     pass
                # # else:
                rv_pb.monitoring.uuid = m.pb.uuid
                if m.pb.ranges != 'default':
                    rv_pb.monitoring.ranges = m.pb.ranges

                m = modules.pop(0) if len(modules) > 0 else None

            return m, True

        def process_treshold(m):
            if isinstance(m, Threshold):
                m = modules.pop(0) if len(modules) > 0 else None
            return m, True

        def process_headers(m):
            if isinstance(m, Headers):
                for action_pb in SectionVisitor.get_header_actions(m):
                    rv_pb.headers.add().CopyFrom(action_pb)
                m = modules.pop(0) if len(modules) > 0 else None
            return m, True

        def process_response_headers(m):
            if isinstance(m, ResponseHeaders):
                for action_pb in SectionVisitor.get_header_actions(m):
                    rv_pb.response_headers.add().CopyFrom(action_pb)
                m = modules.pop(0) if len(modules) > 0 else None
            return m, True

        def process_log_headers(m):
            if isinstance(m, LogHeaders):
                rv_pb.headers.add().log.target_re = m.pb.name_re
                m = modules.pop(0) if len(modules) > 0 else None
            return m, True

        def process_compressor(m):
            if isinstance(m, Compressor):
                if m.pb.enable_compression and not m.pb.enable_decompression:
                    rv_pb.compression.codecs.CopyFrom(m.pb.compression_codecs)
                m = modules.pop(0) if len(modules) > 0 else None
            return m, True

        def process_shared2(m):
            if isinstance(m, Shared):
                if m.pb.uuid == 'backends':
                    opts['scheme_can_handle_announce_checks'] = True
                m = modules.pop(0) if len(modules) > 0 else None
            return m, True

        def process_balancer2(m):
            if isinstance(m, Balancer2):
                if m.generated_proxy_backends:
                    if m.generated_proxy_backends.include_backends:
                        self._fill_flat_scheme(rv_pb.flat_scheme, m,
                                               can_handle_announce_checks=opts['scheme_can_handle_announce_checks'])
                        rv_pb.flat_scheme.backend_ids.extend(
                            m.generated_proxy_backends.include_backends.pb.ids)
                    else:
                        for a in ['instances', 'nanny_snapshots', 'gencfg_groups', 'endpoint_sets']:
                            if len(getattr(m.generated_proxy_backends, a)) > 0:
                                raise AssertionError('generated_proxy_backends: expect include_backends, found %s' % a)

                elif m.backends:
                    ctx = {}
                    ok = True
                    dc_modules = None
                    for i, backend in enumerate(m.backends):
                        dc_modules = tuple(backend.walk_chain())[1:]
                        if isinstance(dc_modules[0], RpsLimiterMacro):
                            if i > 0 and dc_modules[0].pb != prev_dc_modules[0].pb:
                                ok = False
                                break
                        prev_dc_modules = dc_modules
                    if ok:
                        ctx[self.SKIP_RpsLimiterMacro] = True
                        if dc_modules:
                            process_rps_limiter_macro(dc_modules[0], mutate_modules=False)

                    self._fill_by_dc_scheme(ctx, upstream_spec_pb, upstream_spec_pbs, rv_pb.by_dc_scheme, m,
                                            can_handle_announce_checks=opts['scheme_can_handle_announce_checks'])
                else:
                    raise AssertionError()
                m = modules.pop(0) if len(modules) > 0 else None
                return m, m is not None
            return m, False

        def process_error_document(m):
            if isinstance(m, ErrorDocument):
                rv_pb.static_response.status = m.pb.status
                rv_pb.static_response.content = m.pb.content
                m = modules.pop(0) if len(modules) > 0 else None
            return m, False

        prev_m = None
        cont = True
        while cont and m is not None and prev_m != m:
            prev_m = m
            for process in [
                process_rewrite,
                process_rps_limiter_macro,
                process_threshold,
                process_shared,
                process_report,
                process_threshold,
                process_headers,
                process_response_headers,
                process_compressor,
                process_treshold,
                process_log_headers,
                process_shared2,
            ]:
                m, cont = process(m)
                if not cont or m is None:
                    break

        if not opts['report_found']:
            rv_pb.compat.disable_monitoring = True

        if m is not None:
            m, cont = process_balancer2(m)

        if m is not None:
            m, cont = process_error_document(m)

        if m is not None:
            raise AssertionError('is a {}'.format(m.__class__))

        return True, '', rv_pb
