import re
from typing import Iterable

from infra.awacs.proto import modules_pb2
from awacs.wrappers.luautil import read_string
from awacs.wrappers.main import Holder, Balancer2, ErrorDocument


def get_nested(h: Holder):
    assert not h.is_empty()
    return next(h.walk_chain())


def list_set_fields(pb):
    return {desc.name for desc, value in pb.ListFields()}


def regexp_section_matcher_to_l7um_matcher_pb(rs_matcher_pb: modules_pb2.Matcher) -> modules_pb2.L7UpstreamMacro.Matcher:
    """Based on process_match_section from uem_anttsov.py"""
    l7um_matcher_pb = modules_pb2.L7UpstreamMacro.Matcher()

    if rs_matcher_pb == modules_pb2.Matcher():
        l7um_matcher_pb.any = True
    elif rs_matcher_pb.HasField('match_fsm'):
        match_fsm_pb = rs_matcher_pb.match_fsm
        if match_fsm_pb.host:
            l7um_matcher_pb.host_re = read_string(match_fsm_pb.host)
        elif match_fsm_pb.uri:
            l7um_matcher_pb.uri_re = read_string(match_fsm_pb.uri)
        elif match_fsm_pb.url:
            l7um_matcher_pb.url_re = read_string(match_fsm_pb.url)
        elif match_fsm_pb.path:
            l7um_matcher_pb.path_re = read_string(match_fsm_pb.path)
        elif match_fsm_pb.cgi:
            cgi = match_fsm_pb.cgi
            if match_fsm_pb.surround:
                cgi = r'\\.*' + cgi + r'\\.*'
            l7um_matcher_pb.cgi_re = read_string(cgi)
        elif match_fsm_pb.match:
            methods = ('GET', 'HEAD', 'POST', 'PUT', 'PATCH', 'DELETE', 'CONNECT', 'OPTIONS')
            re_part = '(?:{})'.format('|'.join(methods))
            match_obj = re.match(r'^({0}|\((?:{0}\|?)+\))\.\*$'.format(re_part), match_fsm_pb.match)
            if not match_obj:
                raise AssertionError('match_fsm_pb {}'.format(match_fsm_pb))
            assert match_obj.group(1)
            used_methods = match_obj.group(1)
            if used_methods.startswith('('):
                used_methods = used_methods[1:]
            if used_methods.endswith(')'):
                used_methods = used_methods[:-1]
            used_methods = used_methods.split('|')
            if len(used_methods) == 1:
                l7um_matcher_pb.method = used_methods[0].upper()
            else:
                for method in used_methods:
                    l7um_matcher_pb.or_.add().method = method.upper()
        elif match_fsm_pb.HasField('header'):
            l7um_matcher_pb.header.name = match_fsm_pb.header.name
            l7um_matcher_pb.header.re = match_fsm_pb.header.value
        else:
            raise AssertionError('match_fsm_pb {}'.format(match_fsm_pb))
    elif rs_matcher_pb.HasField('match_not'):
        l7um_matcher_pb.not_.CopyFrom(regexp_section_matcher_to_l7um_matcher_pb(rs_matcher_pb.match_not))
    elif rs_matcher_pb.HasField('match_method'):
        if len(rs_matcher_pb.match_method.methods) == 1:
            l7um_matcher_pb.method = rs_matcher_pb.match_method.methods[0].upper()
        else:
            raise AssertionError('len(matcher.match_method.methods) > 1')
    elif len(rs_matcher_pb.match_and) > 0:
        for nested_matcher_pb in rs_matcher_pb.match_and:
            l7um_matcher_pb.and_.add().CopyFrom(regexp_section_matcher_to_l7um_matcher_pb(nested_matcher_pb))
    elif len(rs_matcher_pb.match_or) > 0:
        for nested_matcher_pb in rs_matcher_pb.match_or:
            l7um_matcher_pb.or_.add().CopyFrom(regexp_section_matcher_to_l7um_matcher_pb(nested_matcher_pb))
    else:
        raise AssertionError('matcher {}'.format(rs_matcher_pb))
    return l7um_matcher_pb


def balancer2_to_l7um_balancer_settings_pb(balancer2: Balancer2) -> modules_pb2.L7UpstreamMacro.BalancerSettings:
    l7um_balancer_settings_pb = modules_pb2.L7UpstreamMacro.BalancerSettings()

    balancer2_pb = balancer2.pb
    po_pb = balancer2_pb.generated_proxy_backends.proxy_options

    if balancer2_pb.attempts:
        l7um_balancer_settings_pb.attempts = balancer2_pb.attempts
    elif balancer2_pb.HasField('f_attempts'):
        l7um_balancer_settings_pb.attempt_all_endpoints = True
    else:
        raise AssertionError('no attempts specified')

    if balancer2_pb.connection_attempts:
        l7um_balancer_settings_pb.fast_attempts = balancer2_pb.connection_attempts
        if balancer2_pb.fast_attempts:
            raise AssertionError()
    elif balancer2_pb.HasField('f_connection_attempts'):
        l7um_balancer_settings_pb.fast_attempt_all_endpoints = True
    elif balancer2_pb.fast_attempts:
        l7um_balancer_settings_pb.fast_attempts = balancer2_pb.fast_attempts

    if po_pb.connect_timeout:
        l7um_balancer_settings_pb.connect_timeout = po_pb.connect_timeout
    l7um_balancer_settings_pb.backend_timeout = po_pb.backend_timeout or '10s'
    l7um_balancer_settings_pb.keepalive_count = po_pb.keepalive_count
    l7um_balancer_settings_pb.backend_read_timeout = po_pb.backend_read_timeout
    l7um_balancer_settings_pb.client_read_timeout = po_pb.client_read_timeout

    if balancer2_pb.fast_503.value:
        l7um_balancer_settings_pb.fast_attempts_type = l7um_balancer_settings_pb.CONNECT_FAILURE_AND_503
        if balancer2_pb.fast_attempts:
            l7um_balancer_settings_pb.fast_attempts = balancer2_pb.fast_attempts
        elif balancer2_pb.HasField('f_fast_attempts'):
            l7um_balancer_settings_pb.fast_attempt_all_endpoints = True

    if po_pb.HasField('fail_on_5xx') and not po_pb.fail_on_5xx.value:
        l7um_balancer_settings_pb.do_not_retry_http_responses = True
    elif po_pb.HasField('fail_on_5xx') and po_pb.fail_on_5xx.value:
        l7um_balancer_settings_pb.retry_http_responses.codes.append('5xx')
    else:
        l7um_balancer_settings_pb.retry_http_responses.codes.append('5xx')

    if po_pb.status_code_blacklist_exceptions:
        l7um_balancer_settings_pb.retry_http_responses.exceptions.extend(po_pb.status_code_blacklist_exceptions)

    if po_pb.allow_connection_upgrade:
        l7um_balancer_settings_pb.allow_connection_upgrade = po_pb.allow_connection_upgrade

    if po_pb.status_code_blacklist:
        l7um_balancer_settings_pb.retry_http_responses.codes.extend(po_pb.status_code_blacklist)

    l7um_balancer_settings_pb.buffering = po_pb.buffering

    if balancer2_pb.HasField('weighted2'):
        l7um_balancer_settings_pb.compat.method = l7um_balancer_settings_pb.compat.WEIGHTED2
    elif balancer2_pb.HasField('active'):
        l7um_balancer_settings_pb.compat.method = l7um_balancer_settings_pb.compat.ACTIVE
        l7um_balancer_settings_pb.health_check.delay = balancer2_pb.active.delay
        l7um_balancer_settings_pb.health_check.request = balancer2_pb.active.request
        if balancer2_pb.active.HasField('steady') and not balancer2_pb.active.steady.value:
            l7um_balancer_settings_pb.health_check.compat.not_steady = True
    elif balancer2_pb.HasField('rr'):
        l7um_balancer_settings_pb.compat.method = l7um_balancer_settings_pb.compat.RR
    elif balancer2_pb.HasField('dynamic'):
        l7um_balancer_settings_pb.max_pessimized_endpoints_share = balancer2_pb.dynamic.max_pessimized_share
    else:
        raise AssertionError

    wm_policy = False
    policy_kinds = (balancer2.balancing_policy and
                    balancer2.balancing_policy.list_policy_kinds() or [])
    if policy_kinds == ['watermark_policy', 'unique_policy']:
        assert balancer2_pb.balancing_policy.HasField('watermark_policy')
        wm_pb = balancer2_pb.balancing_policy.watermark_policy
        l7um_balancer_settings_pb.compat.watermark_policy.lo = wm_pb.lo
        l7um_balancer_settings_pb.compat.watermark_policy.hi = wm_pb.hi
        wm_policy = True
    elif policy_kinds == ['unique_policy'] or policy_kinds == []:
        pass
    else:
        raise AssertionError('unexpected policies {}'.format(policy_kinds))

    if balancer2.attempts_rate_limiter:
        assert not wm_policy
        l7um_balancer_settings_pb.max_reattempts_share = balancer2.attempts_rate_limiter.pb.limit
    elif not wm_policy:
        l7um_balancer_settings_pb.do_not_limit_reattempts = True
    return l7um_balancer_settings_pb


def balancer2_to_l7um_dc_balancer_settings_pb(balancer2: Balancer2) -> modules_pb2.L7UpstreamMacro.DcBalancerSettings:
    balancer2_pb = balancer2.pb

    dc_balancer_settings_pb = modules_pb2.L7UpstreamMacro.DcBalancerSettings()
    if balancer2_pb.attempts:
        dc_balancer_settings_pb.attempts = balancer2.pb.attempts
    elif balancer2_pb.HasField('f_attempts'):
        dc_balancer_settings_pb.attempt_all_dcs = True
    else:
        raise AssertionError()
    assert balancer2_pb.HasField('rr')
    if not balancer2_pb.rr.weights_file:
        dc_balancer_settings_pb.compat.disable_dynamic_weights = True

    policy_kinds = (balancer2.balancing_policy and
                    balancer2.balancing_policy.list_policy_kinds() or [])
    if policy_kinds == ['by_name_policy', 'unique_policy']:
        dc_balancer_settings_pb.method = dc_balancer_settings_pb.LOCAL_THEN_BY_DC_WEIGHT
    elif policy_kinds == ['unique_policy'] or policy_kinds == []:
        dc_balancer_settings_pb.method = dc_balancer_settings_pb.BY_DC_WEIGHT
    else:
        raise AssertionError(
            "it is not possible to determine which policy should be used: "
            "balancer2.balancing_policy.list_policy_kinds() == {} :(".format(policy_kinds))
    return dc_balancer_settings_pb


def balancer2_to_l7um_flat_scheme_pb(balancer2: Balancer2) -> modules_pb2.L7UpstreamMacro.FlatScheme:
    return modules_pb2.L7UpstreamMacro.FlatScheme(
        balancer=balancer2_to_l7um_balancer_settings_pb(balancer2),
        on_error=on_error_to_l7um_on_error_pb(balancer2)
    )


def on_error_to_l7um_on_error_pb(balancer2: Balancer2) -> modules_pb2.L7UpstreamMacro.OnError:
    on_error_pb = modules_pb2.L7UpstreamMacro.OnError()
    if balancer2.on_error:
        m = get_nested(balancer2.on_error)
        if isinstance(m, ErrorDocument):
            on_error_pb.static.status = m.pb.status
            on_error_pb.static.content = m.pb.content
        else:
            raise AssertionError("only ErrorDocument supported in balancer2.on_error")
    else:
        on_error_pb.rst = True
    return on_error_pb


def balancer2_to_by_dc_scheme_pb(balancer2: Balancer2) -> modules_pb2.L7UpstreamMacro.ByDcScheme:
    by_dc_scheme_pb = modules_pb2.L7UpstreamMacro.ByDcScheme(
        dc_balancer=balancer2_to_l7um_dc_balancer_settings_pb(balancer2),
        on_error=on_error_to_l7um_on_error_pb(balancer2)
    )
    return by_dc_scheme_pb


def errordocument_to_l7um_static_response_pb(errordocument: ErrorDocument) -> modules_pb2.L7UpstreamMacro.StaticResponse:
    static_response_pb = modules_pb2.L7UpstreamMacro.StaticResponse(
        status=errordocument.pb.status,
        content=errordocument.pb.content
    )
    return static_response_pb


def to_rewrite_actions(actions: Iterable[modules_pb2.RewriteAction]) -> Iterable[modules_pb2.L7Macro.RewriteAction]:
    rv = []
    for action in actions:
        rewrite_a_pb = modules_pb2.L7Macro.RewriteAction()
        if action.split in (u'url', u''):
            rewrite_a_pb.target = rewrite_a_pb.URL
        elif action.split == u'path':
            rewrite_a_pb.target = rewrite_a_pb.PATH
        elif action.split == u'cgi':
            rewrite_a_pb.target = rewrite_a_pb.CGI
        else:
            rewrite_a_pb.target = rewrite_a_pb.NONE
        rewrite_a_pb.replacement = action.rewrite
        rewrite_a_pb.pattern.re = action.regexp
        rewrite_a_pb.pattern.literal.value = action.literal

        getattr(rewrite_a_pb.pattern, 'global').value = getattr(action, 'global')
        rewrite_a_pb.pattern.case_sensitive.value = not action.case_insensitive.value
        rv.append(rewrite_a_pb)
    return rv
