from functools import wraps

from google.protobuf import text_format
from infra.awacs.proto import modules_pb2

from awacs.model.util import clone_pb
from awacs.wrappers.main import IpdispatchSection
from .model import BaseVisitor, BaseBalancerSuggester, visit


DEFAULT_IPDISPATCH_ADMIN_SECTION_PB = modules_pb2.IpdispatchSection()
text_format.Parse('''ips {
  value: "127.0.0.1"
}
ips {
  value: "::1"
}
ports {
  f_value {
    type: GET_PORT_VAR
    get_port_var_params {
      var: "port"
    }
  }
}
nested {
  modules {
    http {
    }
  }
  modules {
    admin {
    }
  }
}''', DEFAULT_IPDISPATCH_ADMIN_SECTION_PB)

DEFAULT_STATS_STORAGE_SECTION_PB = modules_pb2.IpdispatchSection()
text_format.Parse('''ips {
  value: "127.0.0.4"
}
ports {
  f_value {
    type: GET_PORT_VAR
    get_port_var_params {
      var: "port"
    }
  }
}
nested {
  modules {
    report {
      uuid: "service_total"
      ranges: "default"
      just_storage: true
    }
  }
  modules {
    http {
    }
  }
  modules {
    errordocument {
      status: 204
    }
  }
}
''', DEFAULT_STATS_STORAGE_SECTION_PB)

DEFAULT_GET_PORT_VAR_PORT_PB = modules_pb2.IpdispatchSection.Port()
text_format.Parse('''f_value {
  type: GET_PORT_VAR
  get_port_var_params {
    var: "port"
  }
}''', DEFAULT_GET_PORT_VAR_PORT_PB)

DEFAULT_REGEXP_PB = modules_pb2.RegexpModule()
text_format.Parse('''include_upstreams {
  filter {
    any: true
  }
  order {
    label {
      name: "order"
    }
  }
}''', DEFAULT_REGEXP_PB)


def list_set_fields(pb):
    return {desc.name for desc, value in pb.ListFields()}


def h(method):
    @wraps(method)
    def f(self, *args, **kwargs):
        ok, msg = method(self, *args, **kwargs)
        self.ok &= ok
        if not ok:
            self.msg += '; ' + msg

    return f


class SectionVisitor(BaseVisitor):
    ok, msg = True, ''

    DEFAULT_PORT = None

    def __init__(self, namespace_id, balancer_id):
        super(SectionVisitor, self).__init__(namespace_id, balancer_id)
        self.opts = {}

    def get_rv(self):
        return self.ok, self.msg, self.opts

    @h
    def check_IpdispatchSection(self, module, path):
        set_fields = list_set_fields(module.pb)
        allowed_set_fields = {'ips', 'ports' 'nested'}
        if set_fields > allowed_set_fields:
            return False, 'forbidden set fields: {}'.format(set_fields - allowed_set_fields)

        req_set_fields = {'ips', 'ports', 'nested'}
        if set_fields < req_set_fields:
            return False, 'required fields are not set: {}'.format(req_set_fields - set_fields)

        for ip_pb in module.pb.ips:
            if ip_pb.value != '*':
                return False, '{} != "*"'.format(ip_pb)

        custom_ports_present = False
        for port_pb in module.pb.ports:
            if port_pb.value == self.DEFAULT_PORT:
                continue
            if port_pb == DEFAULT_GET_PORT_VAR_PORT_PB:
                self.opts['bind_on_instance_port'] = True
            else:
                custom_ports_present = True

        custom_ports = []
        if custom_ports_present:
            for port_pb in module.pb.ports:
                if port_pb != DEFAULT_GET_PORT_VAR_PORT_PB:
                    custom_ports.append(port_pb.value)
        self.opts['custom_ports'] = custom_ports

        if module.pb.ports[0] == DEFAULT_GET_PORT_VAR_PORT_PB:
            self.opts['use_instance_port_in_section_log_name'] = True

        return True, ''

    @classmethod
    def get_header_actions(cls, module):
        actions = []
        if module.pb.delete:
            a_pb = modules_pb2.L7Macro.HeaderAction()
            a_pb.delete.target_re = module.pb.delete
            actions.append(a_pb)
        for entry_pb in sorted(module.pb.create, key=lambda pb: pb.key):
            a_pb = modules_pb2.L7Macro.HeaderAction()
            a_pb.create.target = entry_pb.key
            a_pb.create.value = entry_pb.value
            actions.append(a_pb)
        for entry_pb in sorted(module.pb.create_weak, key=lambda pb: pb.key):
            a_pb = modules_pb2.L7Macro.HeaderAction()
            a_pb.create.target = entry_pb.key
            a_pb.create.value = entry_pb.value
            a_pb.create.keep_existing = True
            actions.append(a_pb)
        for k, v in sorted(module.pb.create_func.items()):
            a_pb = modules_pb2.L7Macro.HeaderAction()
            a_pb.create.target = k
            a_pb.create.func = v
            actions.append(a_pb)
        for k, v in sorted(module.pb.create_func_weak.items()):
            a_pb = modules_pb2.L7Macro.HeaderAction()
            a_pb.create.target = k
            a_pb.create.func = v
            a_pb.create.keep_existing = True
            actions.append(a_pb)
        for entry_pb in sorted(module.pb.append, key=lambda pb: pb.key):
            a_pb = modules_pb2.L7Macro.HeaderAction()
            a_pb.append.target = entry_pb.key
            a_pb.append.value = entry_pb.value
            actions.append(a_pb)
        for entry_pb in sorted(module.pb.copy_weak, key=lambda pb: pb.key):
            a_pb = modules_pb2.L7Macro.HeaderAction()
            a_pb.copy.target = entry_pb.value
            a_pb.copy.source = entry_pb.key
            a_pb.copy.keep_existing = True
            actions.append(a_pb)

        for k, v in sorted(module.pb.append_func.items()):
            a_pb = modules_pb2.L7Macro.HeaderAction()
            a_pb.create.target = k
            a_pb.create.func = v
            actions.append(a_pb)

        assert not module.pb.append_weak
        assert not module.pb.append_func_weak
        assert not module.pb.copy
        return actions

    @h
    def check_Headers(self, module, path):
        actions = self.get_header_actions(module)
        self.opts.setdefault('header_actions', []).extend(actions)
        return True, ''

    @h
    def check_LogHeaders(self, module, path):
        a_pb = modules_pb2.L7Macro.HeaderAction()
        a_pb.log.target_re = module.pb.name_re
        a_pb.log.cookie_fields.extend(module.pb.cookie_fields)
        actions = [a_pb]
        self.opts.setdefault('header_actions', []).extend(actions)
        return True, ''

    @h
    def check_ResponseHeaders(self, module, path):
        actions = self.get_header_actions(module)
        self.opts['response_header_actions'] = actions
        return True, ''

    @h
    def check_Shared(self, module, path):
        if module.pb.uuid not in ('modules', 'upstreams'):
            return False, 'shared with wrong uuid'
        self.opts['terminates_with_shared'] = True
        self.opts['shared_uuid'] = module.pb.uuid
        return True, ''

    @h
    def check_ExpGetterMacro(self, module, path):
        self.opts['exp_getter_macro'] = {
            'service_name': module.pb.service_name
        }
        return True, ''


class HttpSectionVisitor(SectionVisitor):
    DEFAULT_PORT = 80

    def __init__(self, namespace_id, balancer_id, allow_any_include_upstreams):
        super(HttpSectionVisitor, self).__init__(namespace_id, balancer_id)
        self.allow_any_include_upstreams = allow_any_include_upstreams

    @h
    def check_ExtendedHttpMacro(self, module, path):
        module_pb = clone_pb(module.pb)
        module_pb.ClearField('nested')

        set_fields = list_set_fields(module_pb)
        allowed_set_fields = {'report_uuid', 'keepalive_drop_probability', 'maxlen', 'maxreq'}
        if not (set_fields <= allowed_set_fields):
            return False, 'forbidden set fields: {}'.format(set_fields - allowed_set_fields)

        if module_pb.report_uuid:
            if module_pb.report_uuid != 'http':
                return False, 'module_pb.report_uuid != "http"'
            self.opts['enable_total_signals'] = True

        if module_pb.HasField('keepalive_drop_probability'):
            self.opts['keepalive_drop_probability'] = module_pb.keepalive_drop_probability.value

        req_set_fields = set()
        if set_fields < req_set_fields:
            return False, 'required fields are not set: {}'.format(req_set_fields - set_fields)
        return True, ''

    @h
    def check_Regexp(self, module, path):
        name, module = path[0]
        if name != 'extended_http_macro':
            return False, 'path[0] is not extended_http_macro'

        for i, (name, module) in enumerate(path[1:-1], start=1):
            if name == 'shared':
                if module.pb.uuid not in ('modules', 'upstreams', 'rps_limiter_macro'):
                    return False, 'path[{}] is a shared with not allowed uuid: {}'.format(i, module.pb.uuid)
                self.opts['terminates_with_shared'] = False
            if name == 'rps_limiter_macro':
                self.opts['rps_limiter_record_name'] = module.pb.record_name
            if name not in {'headers', 'response_headers', 'log_headers', 'headers_forwarder', 'shared', 'rps_limiter_macro'}:
                return False, 'path[{}] is not allowed: {}'.format(i, name)

        name, module = path[-1]
        assert name == 'regexp'
        if self.allow_any_include_upstreams:
            if not module.include_upstreams:
                return False, 'path[-1] must have include_upstreams'
        elif module.pb != DEFAULT_REGEXP_PB:
            return False, 'path[-1] is not default regexp: {}'.format(module.pb)

        return True, ''


class HttpsSectionVisitor(SectionVisitor):
    DEFAULT_PORT = 443

    def __init__(self, namespace_id, balancer_id, allow_any_include_upstreams):
        super(HttpsSectionVisitor, self).__init__(namespace_id, balancer_id)
        self.allow_any_include_upstreams = allow_any_include_upstreams

    @h
    def check_ExtendedHttpMacro(self, module, path):
        module_pb = clone_pb(module.pb)
        module_pb.ClearField('nested')

        set_fields = list_set_fields(module_pb)
        allowed_set_fields = {'report_uuid', 'enable_ssl', 'force_ssl', 'ssl_sni_contexts', 'disable_sslv3',
                              'ssl_sni_ja3_enabled', 'enable_http2', 'disable_tlsv1_3'
                              }
        if not (set_fields <= allowed_set_fields):
            return False, 'forbidden set fields: {}'.format(set_fields - allowed_set_fields)

        if module_pb.HasField('disable_tlsv1_3') and not module_pb.disable_tlsv1_3.value:
            self.opts['enable_tlsv1_3'] = True

        if module_pb.report_uuid:
            if module_pb.report_uuid != 'https':
                return False, 'module_pb.report_uuid != "https"'
            self.opts['enable_total_signals'] = True

        if module_pb.disable_sslv3:
            self.opts['disable_sslv3'] = True

        if module_pb.HasField('force_ssl') and not module_pb.force_ssl.value:
            return False, 'force_ssl: false'

        req_set_fields = {'enable_ssl', 'ssl_sni_contexts'}
        if set_fields < req_set_fields:
            return False, 'required fields are not set: {}'.format(req_set_fields - set_fields)

        if not module_pb.enable_ssl:
            return False, 'enable_ssl is false'
        if len(module_pb.ssl_sni_contexts) != 1:
            return False, 'len(module_pb.ssl_sni_contexts) != 1'
        if module_pb.ssl_sni_contexts[0].value.servername_regexp != 'default':
            return False, 'module_pb.ssl_sni_contexts[0].value.servername_regexp != "default"'

        self.opts['enable_http2'] = module_pb.enable_http2

        if module_pb.ssl_sni_contexts[0].value.HasField('c_cert'):
            self.opts['cert_id'] = module_pb.ssl_sni_contexts[0].value.c_cert.id
        else:
            self.opts['cert_id'] = module_pb.ssl_sni_contexts[0].key

        if module_pb.ssl_sni_contexts[0].value.HasField('c_secondary_cert'):
            self.opts['secondary_cert_id'] = module_pb.ssl_sni_contexts[0].value.c_secondary_cert.id
        elif module_pb.ssl_sni_contexts[0].value.HasField('secondary_cert'):
            self.opts['secondary_cert_id'] = module_pb.ssl_sni_contexts[0].key
        return True, ''

    @h
    def check_Regexp(self, module, path):
        name, module = path[0]
        if name != 'extended_http_macro':
            return False, 'path[0] is not extended_http_macro'

        for i, (name, module) in enumerate(path[1:-1], start=1):
            if name == 'shared':
                if module.pb.uuid not in ('modules', 'upstreams'):
                    return False, 'path[{}] is a shared with not allowed uuid: {}'.format(i, module.pb.uuid)
                self.opts['terminates_with_shared'] = False
            if name == 'rps_limiter_macro':
                if not module.pb.use_sd_backends:
                    return False, 'rps_limiter_macro.use_sd_backends: false is not allowed'
                self.opts['rps_limiter_record_name'] = module.pb.record_name
            if name not in {'headers', 'response_headers', 'log_headers', 'shared', 'rps_limiter_macro'}:
                return False, 'path[{}] is not allowed: {}'.format(i, name)

        name, module = path[-1]
        assert name == 'regexp'
        if self.allow_any_include_upstreams:
            if not module.include_upstreams:
                return False, 'path[-1] must have include_upstreams'
        elif module.pb != DEFAULT_REGEXP_PB:
            return False, 'path[-1] is not default regexp: {}'.format(module.pb)
        return True, ''


class Checker(BaseBalancerSuggester):
    RULE = 'TLEM'

    def __init__(self, namespace_id, balancer_id, l7_macro_version='0.0.1'):
        super(Checker, self).__init__(namespace_id, balancer_id)
        self.l7_macro_version = l7_macro_version
        self.opts = {}

    def _validate_admin_section(self, ipdispatch_section_pb):
        ok = ipdispatch_section_pb == DEFAULT_IPDISPATCH_ADMIN_SECTION_PB
        return ok, ''

    def _validate_stats_storage(self, ipdispatch_section_pb):
        ok = ipdispatch_section_pb == DEFAULT_STATS_STORAGE_SECTION_PB
        return ok, ''

    def _validate_http_section(self, ipdispatch_section_pb, allow_any_include_upstreams):
        v = HttpSectionVisitor(self.namespace_id, self.balancer_id, allow_any_include_upstreams)
        visit(IpdispatchSection(ipdispatch_section_pb), v)
        return v.get_rv()

    def _validate_https_section(self, ipdispatch_section_pb, allow_any_include_upstreams):
        v = HttpsSectionVisitor(self.namespace_id, self.balancer_id, allow_any_include_upstreams)
        visit(IpdispatchSection(ipdispatch_section_pb), v)
        return v.get_rv()

    def suggest(self, balancer_spec_pb):
        config_pb = balancer_spec_pb.yandex_balancer.config
        if not config_pb.HasField('instance_macro'):
            return False, 'not an instance_macro', None

        im_pb = config_pb.instance_macro
        if im_pb.maxconn and im_pb.maxconn != 5000:
            self.opts['maxconn'] = im_pb.maxconn

        if not (im_pb.HasField('f_workers') and im_pb.f_workers.HasField('get_workers_params')):
            return False, 'hard-corded workers count', None

        set_fields = list_set_fields(im_pb)
        allowed_set_fields = {'tcp_listen_queue', 'maxconn', 'sections', 'sd', 'unistat', 'f_workers'}
        allowed_set_fields.add('thread_mode')
        allowed_set_fields.add('state_directory')
        allowed_set_fields.add('version')

        if set_fields > allowed_set_fields:
            return False, 'forbidden set fields: {}'.format(set_fields - allowed_set_fields), None

        if not im_pb.HasField('sd'):
            self.opts['disable_sd'] = True
        if not im_pb.HasField('unistat'):
            self.opts['disable_unistat'] = True
        if not im_pb.tcp_listen_queue:
            self.opts['disable_tcp_listen_queue'] = True
        else:
            if im_pb.tcp_listen_queue != 128:
                return False, 'im_pb.tcp_listen_queue != 128', None
        if im_pb.version == '0.0.2':
            compat = self.opts.setdefault('compat', {})
            compat['enable_persistent_sd_cache'] = True

        req_set_fields = {'sections', 'f_workers'}
        if set_fields < req_set_fields:
            return False, 'required fields are not set: {}'.format(req_set_fields - set_fields), None

        sections = {entry_pb.key: entry_pb.value for entry_pb in im_pb.sections}
        expected_section_ids = {'admin', 'stats_storage', 'http_section', 'https_section'}
        unexpected_section_ids = set(sections) - expected_section_ids
        if len(unexpected_section_ids) > 0:
            return False, 'unexpected ipdispatch-sections: {}'.format(unexpected_section_ids), None

        admin_section = sections.get('admin')
        stats_storage = sections.get('stats_storage')
        http_section = sections.get('http_section')
        https_section = sections.get('https_section')

        if admin_section:
            ok, msg = self._validate_admin_section(admin_section)
            if not ok:
                return ok, 'admin: {}'.format(msg), None

        if stats_storage:
            ok, msg = self._validate_stats_storage(stats_storage)
            if not ok:
                return ok, 'stats_storage: {}'.format(msg), None

        http_opts = None
        if http_section:
            ok, msg, http_opts = self._validate_http_section(http_section,
                                                             allow_any_include_upstreams=not https_section)
            if not ok:
                return ok, 'http_section: {}'.format(msg), None

        https_opts = None
        if https_section:
            ok, msg, https_opts = self._validate_https_section(https_section,
                                                               allow_any_include_upstreams=not http_section)
            if not ok:
                return ok, 'https_section: {}'.format(msg), None

        opts = self.opts

        # some sanity checks
        if stats_storage and not ((not http_section or http_opts.get('enable_total_signals')) and
                                  (not https_section or https_opts.get('enable_total_signals'))):
            return False, 'smth wrong with stats_storage', None

        assert not (http_section and http_opts.get('bind_on_instance_port') and
                    https_section and https_opts.get('bind_on_instance_port'))

        pb = modules_pb2.L7Macro()
        pb.version = self.l7_macro_version
        if opts.get('disable_sd'):
            pb.compat.disable_sd = True
        if opts.get('disable_unistat'):
            pb.compat.disable_unistat = True
        if opts.get('maxconn'):
            pb.compat.maxconn.value = opts.get('maxconn')
        if opts.get('disable_tcp_listen_queue'):
            pb.compat.disable_tcp_listen_queue_limit = True
        if opts.get('compat', {}).get('enable_persistent_sd_cache'):
            pb.compat.enable_persistent_sd_cache = 'true'

        if stats_storage:
            pb.monitoring.enable_total_signals = True

        if http_section:
            pb.http.SetInParent()
            if http_opts.get('use_instance_port_in_section_log_name'):
                pb.http.compat.use_instance_port_in_section_log_name = True
            if http_opts.get('bind_on_instance_port'):
                pb.http.compat.bind_on_instance_port = True
            if http_opts['custom_ports']:
                pb.http.ports.extend(http_opts['custom_ports'])

        if https_section:
            pb.https.SetInParent()
            if https_opts.get('use_instance_port_in_section_log_name'):
                pb.https.compat.use_instance_port_in_section_log_name = True
            if https_opts.get('bind_on_instance_port'):
                pb.https.compat.bind_on_instance_port = True
            if https_opts.get('enable_tlsv1_3'):
                pb.https.enable_tlsv1_3 = True
            if not https_opts.get('disable_sslv3'):
                pb.https.compat.enable_sslv3 = True
            pb.https.certs.add(
                id=https_opts['cert_id'],
                secondary_id=https_opts.get('secondary_cert_id', ''),
            )
            if https_opts['custom_ports']:
                pb.https.ports.extend(https_opts['custom_ports'])
            if https_opts['enable_http2']:
                pb.https.enable_http2 = True
                pb.http2.SetInParent()

        if http_section and https_section:
            if not http_opts.get('terminates_with_shared') and not https_opts.get('terminates_with_shared'):
                http_header_actions = http_opts.get('header_actions', [])
                https_header_actions = https_opts.get('header_actions', [])
                if http_header_actions != https_header_actions:
                    return False, 'different header actions in http and https', None
                if http_opts.get('response_header_actions', []) != https_opts.get('response_header_actions', []):
                    return False, 'different response header actions in http and https', None

            place_first = [entry_pb.key for entry_pb in im_pb.sections if
                           entry_pb.key in ['https_section', 'http_section']] \
                          == ['https_section', 'http_section']
            pb.https.compat.place_first = place_first

            # TODO:
            if http_opts.get('terminates_with_shared'):
                pb.https.compat.assign_shared_uuid = http_opts['shared_uuid']
                pb.http.compat.refer_shared_uuid = http_opts['shared_uuid']
            if https_opts.get('terminates_with_shared'):
                pb.http.compat.assign_shared_uuid = https_opts['shared_uuid']
                pb.https.compat.refer_shared_uuid = https_opts['shared_uuid']

        def process_section(opts):
            header_actions = opts.get('header_actions', [])
            if repr(pb.headers) != repr(header_actions):
                for apb in header_actions:
                    if apb not in pb.headers:
                        pb.headers.add().CopyFrom(apb)

            response_header_actions = opts.get('response_header_actions', [])
            if repr(pb.response_headers) != repr(response_header_actions):
                for apb in response_header_actions:
                    if apb not in pb.response_headers:
                        pb.response_headers.add().CopyFrom(apb)

            exp_getter_macro_opts = opts.get('exp_getter_macro')
            if exp_getter_macro_opts:
                pb.headers.add().uaas.service_name = exp_getter_macro_opts['service_name']

            if opts.get('keepalive_drop_probability', None):
                pb.core.compat.keepalive_drop_probability.SetInParent()
                pb.core.compat.keepalive_drop_probability.value = opts['keepalive_drop_probability']

            if opts.get('rps_limiter_record_name', False):
                pb.rps_limiter.external.record_name = opts['rps_limiter_record_name']

            # pb.announce_check_reply.compat.replaced_upstream_id = 'slbping'
            # pb.announce_check_reply.url_re = '/ping'
            #
            # pb.health_check_reply.SetInParent()
            # pb.health_check_reply.compat.replaced_upstream_id = 'awacs-balancer-health-check'

        if http_section:
            process_section(http_opts)

        if https_section:
            process_section(https_opts)

        return True, '', pb
