# coding: utf-8
import collections

import gevent.pool
import six
from gevent.threadpool import ThreadPool

from infra.swatlib import gutils

from infra.awacs.proto import api_pb2, internals_pb2, model_pb2, modules_pb2
from awacs import yamlparser
from awacs.lib import yp_service_discovery
from awacs.lib.rpc.exceptions import ForbiddenError, BadRequestError
from awacs.model import cache, objects
from awacs.model.balancer.generator import get_yandex_config_pb, validate_config
from awacs.model.balancer.stateholder import BalancerStateHolder
from awacs.model.balancer.vector import UpstreamVersion, DomainVersion, BackendVersion, KnobVersion, CertVersion, EndpointSetVersion
from awacs.model.validation import get_yaml
from awacs.model.util import clone_pb
from awacs.web.validation import info
from awacs.web.util import AwacsBlueprint
from awacs.wrappers import rps_limiter_settings
from awacs.wrappers.base import Holder, ValidationCtx, ANY_MODULE
from awacs.wrappers.l7upstreammacro import L7UpstreamMacro

from sepelib.core import config

info_service_bp = AwacsBlueprint('rpc_info_service', __name__, '/api')


@info_service_bp.method('GetNamespaceObjectsCount',
                        request_type=api_pb2.GetNamespaceObjectsCountRequest,
                        response_type=api_pb2.GetNamespaceObjectsCountResponse,
                        max_in_flight=5)
def get_namespace_objects_count(req_pb, _):
    """
    :type req_pb: api_pb2.GetNamespaceObjectsCountRequest
    :rtype: api_pb2.GetNamespaceObjectsCountResponse
    """
    info.validate_request(req_pb)
    c = cache.IAwacsCache.instance()

    resp_pb = api_pb2.GetNamespaceObjectsCountResponse()
    for func, field in (
            (c.count_backends, 'backends_total'),
            (c.count_balancers, 'balancers_total'),
            (c.count_certs, 'certificates_total'),
            (c.count_domains, 'domains_total'),
            (c.count_dns_records, 'dns_records_total'),
            (c.count_knobs, 'knobs_total'),
            (c.count_l3_balancers, 'l3_balancers_total'),
            (c.count_upstreams, 'upstreams_total'),
            (objects.WeightSection.cache.count, 'weight_sections_total'),
            (objects.L7HeavyConfig.cache.count, 'l7heavy_configs_total'),
    ):
        setattr(resp_pb, field, func(namespace_id=req_pb.id))
    return resp_pb


@info_service_bp.method('GetNamespaceUnresolvedEndpointSets',
                        request_type=api_pb2.GetNamespaceUnresolvedEndpointSetsRequest,
                        response_type=api_pb2.GetNamespaceUnresolvedEndpointSetsResponse,
                        max_in_flight=5)
def get_namespace_unresolved_endpoint_sets(req_pb, auth_subject):
    """
    :type req_pb: api_pb2.GetNamespaceUnresolvedEndpointSetsRequest
    :rtype: api_pb2.GetNamespaceUnresolvedEndpointSetsResponse
    """
    def find_backend_unresolved_endpoint_sets(backend_pb):
        endpoint_sets = []
        for es_pb in gutils.idle_iter(backend_pb.spec.selector.yp_endpoint_sets):
            sd_req_pb = internals_pb2.TReqResolveEndpoints(
                cluster_name=es_pb.cluster,
                endpoint_set_id=es_pb.endpoint_set_id,
                client_name=client_name,
            )
            sd_rsp_pb = resolver.resolve_endpoints(sd_req_pb)
            if sd_rsp_pb.resolve_status not in (internals_pb2.NOT_EXISTS, internals_pb2.EMPTY):
                continue
            endpoint_sets.append(api_pb2.GetNamespaceUnresolvedEndpointSetsResponse.EndpointSet(
                status=sd_rsp_pb.resolve_status,
                cluster=es_pb.cluster,
                id=es_pb.endpoint_set_id,
                backend_id=backend_pb.meta.id
            ))
        return endpoint_sets

    info.validate_request(req_pb)
    c = cache.IAwacsCache.instance()
    resolver = yp_service_discovery.IResolver.instance()
    client_name = 'awacs:{}'.format(auth_subject.login)

    backends_used_in_l7 = set()
    if req_pb.mode == req_pb.USED_IN_L7:
        balancer_states = c.list_all_balancer_states(req_pb.id)
        for balancer_state_pb in balancer_states:
            backends_used_in_l7 |= set(balancer_state_pb.backends.keys())

    resp_pb = api_pb2.GetNamespaceUnresolvedEndpointSetsResponse()
    pool = gevent.pool.Pool(10)
    threads = []
    for backend_pb in gutils.idle_iter(c.list_all_backends(namespace_id=req_pb.id)):
        if backend_pb.spec.selector.type not in (backend_pb.spec.selector.YP_ENDPOINT_SETS,
                                                 backend_pb.spec.selector.YP_ENDPOINT_SETS_SD):
            continue
        if req_pb.mode == req_pb.USED_IN_L7 and backend_pb.meta.id not in backends_used_in_l7:
            continue
        threads.append(pool.apply_async(find_backend_unresolved_endpoint_sets, (backend_pb,)))
    pool.join(raise_error=True, timeout=5)
    for thread in gutils.idle_iter(threads):
        endpoint_sets = thread.get()
        resp_pb.endpoint_sets.extend(endpoint_sets)
    return resp_pb


@info_service_bp.method('GetEasyModeUpstreamIdsForBackend',
                        request_type=api_pb2.GetEasyModeUpstreamIdsForBackendRequest,
                        response_type=api_pb2.GetEasyModeUpstreamIdsForBackendResponse)
def get_backend_easy_mode_upstream_ids(req_pb, auth_subject):
    """
    :type req_pb: api_pb2.GetEasyModeUpstreamIdsForBackendRequest
    :rtype: api_pb2.GetEasyModeUpstreamIdsForBackendResponse
    """
    info.validate_request(req_pb)

    resp_pb = api_pb2.GetEasyModeUpstreamIdsForBackendResponse()
    c = cache.IAwacsCache.instance()
    for full_upstream_id in c.list_easy_mode_upstreams_ids(req_pb.namespace_id, req_pb.id):
        resp_pb.ids.append(full_upstream_id[1])
    return resp_pb


@info_service_bp.method('GetNormalizedInstanceMacroFromBalancerYaml',
                        request_type=api_pb2.GetNormalizedInstanceMacroFromBalancerYamlRequest,
                        response_type=api_pb2.GetNormalizedInstanceMacroFromBalancerYamlResponse,
                        max_in_flight=5)
def expand_l7_macro_to_instance_macro(req_pb, auth_subject):
    """
    :type req_pb: api_pb2.GetNormalizedInstanceMacroFromBalancerYamlRequest
    :rtype: api_pb2.GetNormalizedInstanceMacroFromBalancerYamlResponse
    """
    if auth_subject.login not in config.get_value('run.root_users', default=()):
        raise ForbiddenError('Method is only allowed for roots.')
    info.validate_request(req_pb)

    balancer_spec_pb = model_pb2.BalancerSpec()
    balancer_spec_pb.yandex_balancer.yaml = req_pb.yaml
    balancer_pb = get_yandex_config_pb(balancer_spec_pb)

    bal = Holder(balancer_pb)
    if bal.pb.HasField('l7_macro'):
        if bal.pb.l7_macro.HasField('include_domains'):
            raise BadRequestError('"l7_macro.include_domains" is not supported.')
        bal.expand_immediate_contained_macro()
    bal.to_normal_form_XXX()

    return api_pb2.GetNormalizedInstanceMacroFromBalancerYamlResponse(yaml=get_yaml(bal.pb))


@info_service_bp.method('GetNormalizedConfigFromUpstreamYaml',
                        request_type=api_pb2.GetNormalizedConfigFromUpstreamYamlRequest,
                        response_type=api_pb2.GetNormalizedConfigFromUpstreamYamlResponse,
                        max_in_flight=5)
def get_normalized_config_from_upstream_yaml(req_pb, auth_subject):
    """
    :type req_pb: api_pb2.GetNormalizedConfigFromUpstreamYamlRequest
    :rtype: api_pb2.GetNormalizedConfigFromUpstreamYamlResponse
    """
    info.validate_request(req_pb)

    field = req_pb.WhichOneof('input')
    if field is None:
        raise BadRequestError('"yaml" or "config" must be set')

    if field == 'yaml':
        try:
            upstream_pb = yamlparser.parse(modules_pb2.Holder, req_pb.yaml)
        except yamlparser.Error as e:
            raise BadRequestError(e)
    else:
        assert field == 'config'
        upstream_pb = req_pb.config

    validation_ctx = ValidationCtx(
        rps_limiter_allowed_installations=set(rps_limiter_settings.get_available_installation_names()),
        config_type=ValidationCtx.CONFIG_TYPE_UPSTREAM
    )
    if req_pb.namespace_id:
        validation_ctx.namespace_id = req_pb.namespace_id
        c = cache.IAwacsCache.instance()
        for pb in gutils.idle_iter(c.list_all_upstreams(namespace_id=req_pb.namespace_id)):
            validation_ctx.upstream_spec_pbs[(pb.meta.namespace_id, pb.meta.id)] = pb.spec

    try:
        up = Holder(upstream_pb)
        up.validate(validation_ctx, preceding_modules=(ANY_MODULE,))
        if up.pb.HasField('l7_upstream_macro'):
            h = modules_pb2.Holder()
            h.regexp_section.CopyFrom(L7UpstreamMacro(up.pb.l7_upstream_macro).to_regexp_section_pb())
            up = Holder(h)
            up.expand_macroses(ctx=validation_ctx)
        up.to_normal_form_XXX()
    except Exception as e:
        raise BadRequestError(e)


    return api_pb2.GetNormalizedConfigFromUpstreamYamlResponse(yaml=get_yaml(up.pb), config=up.pb)


lua_generator_threadpool = ThreadPool(maxsize=1)


@info_service_bp.method('GenerateLua',
                        request_type=api_pb2.GenerateLuaRequest,
                        response_type=api_pb2.GenerateLuaResponse,
                        max_in_flight=3)
def get_lua_by_balancer_spec(req_pb, auth_subject):
    """
    :type req_pb: api_pb2.GenerateLuaRequest
    :rtype: api_pb2.GenerateLuaResponse
    """
    if auth_subject.login not in config.get_value('run.root_users', default=()):
        raise ForbiddenError('Method is only allowed for roots.')
    info.validate_request(req_pb)

    c = cache.IAwacsCache.instance()
    balancer_state_pb = c.must_get_balancer_state(req_pb.namespace_id, req_pb.balancer_id)
    state_holder = BalancerStateHolder(req_pb.namespace_id, req_pb.balancer_id, balancer_state_pb)

    if req_pb.balancer_yaml:
        balancer_spec_pb = model_pb2.BalancerSpec()
        balancer_spec_pb.yandex_balancer.yaml = req_pb.balancer_yaml
    else:
        balancer_spec_pb = c.must_get_balancer(req_pb.namespace_id, req_pb.balancer_id).spec

    domain_spec_pbs = {}
    upstream_spec_pbs = {}
    backend_spec_pbs = {}
    endpoint_set_spec_pbs = {}
    knob_spec_pbs = {}
    cert_spec_pbs = {}
    weight_section_spec_pbs = {}
    for upstream_pb in gutils.idle_iter(c.list_all_upstreams(namespace_id=req_pb.namespace_id)):
        if req_pb.upstream_id and upstream_pb.meta.id == req_pb.upstream_id:
            try:
                pb = yamlparser.parse(modules_pb2.Holder, req_pb.upstream_yaml)
            except yamlparser.Error as e:
                raise BadRequestError(e)

            upstream_spec_pb = clone_pb(upstream_pb.spec)
            upstream_spec_pb.yandex_balancer.Clear()
            upstream_spec_pb.yandex_balancer.config.CopyFrom(pb)
            upstream_spec_pbs[UpstreamVersion.from_pb(upstream_pb)] = upstream_spec_pb
            continue
        upstream_spec_pbs[UpstreamVersion.from_pb(upstream_pb)] = upstream_pb.spec
    for domain_pb in gutils.idle_iter(c.list_all_domains(namespace_id=req_pb.namespace_id)):
        domain_spec_pbs[DomainVersion.from_pb(domain_pb)] = domain_pb.spec
    for knob_pb in gutils.idle_iter(c.list_all_knobs(namespace_id=req_pb.namespace_id)):
        knob_spec_pbs[KnobVersion.from_pb(knob_pb)] = knob_pb.spec
    for backend_pb in gutils.idle_iter(c.list_all_backends(namespace_id=req_pb.namespace_id)):
        backend_spec_pbs[BackendVersion.from_pb(backend_pb)] = backend_pb.spec
    for (ns_id, backend_id), backend_version in six.iteritems(state_holder.curr_vector.backend_versions):
        if ns_id == req_pb.namespace_id:
            continue
        backend_spec_pbs[backend_version] = c.must_get_backend(ns_id, backend_id).spec
    for endpoint_set_pb in gutils.idle_iter(c.list_all_endpoint_sets(namespace_id=req_pb.namespace_id)):
        endpoint_set_spec_pbs[EndpointSetVersion.from_pb(endpoint_set_pb)] = endpoint_set_pb.spec
    for cert_pb in gutils.idle_iter(c.list_all_certs(namespace_id=req_pb.namespace_id)):
        cert_spec_pbs[CertVersion.from_pb(cert_pb)] = cert_pb.spec
    for weight_section_pb in gutils.idle_iter(objects.WeightSection.cache.list(namespace_id=req_pb.namespace_id)):
        weight_section_spec_pbs[objects.WeightSection.version.from_pb(weight_section_pb)] = weight_section_pb.spec

    namespace_pb = c.must_get_namespace(req_pb.namespace_id)
    try:
        balancer = validate_config(
            namespace_pb=namespace_pb,
            namespace_id=req_pb.namespace_id,
            balancer_version=state_holder.active_vector.balancer_version,
            balancer_spec_pb=balancer_spec_pb,
            upstream_spec_pbs=upstream_spec_pbs,
            backend_spec_pbs=backend_spec_pbs,
            endpoint_set_spec_pbs=endpoint_set_spec_pbs,
            knob_spec_pbs=knob_spec_pbs,
            cert_spec_pbs=cert_spec_pbs,
            domain_spec_pbs=domain_spec_pbs,
            weight_section_spec_pbs=weight_section_spec_pbs,
            threadpool=lua_generator_threadpool,
            threadpool_interval=0.1,
        ).balancer
    except Exception as e:
        raise BadRequestError(six.text_type(e))

    assert balancer.module_name == 'main'

    validation_ctx = ValidationCtx(
        rps_limiter_allowed_installations=set(namespace_pb.spec.rps_limiter_allowed_installations.installations)
    )
    balancer.expand_macroses(ctx=validation_ctx)
    lua = balancer.module.to_config().to_top_level_lua()
    return api_pb2.GenerateLuaResponse(lua=lua)


@info_service_bp.method('GetL7MacroUsage',
                        request_type=api_pb2.GetL7MacroUsageRequest,
                        response_type=api_pb2.GetL7MacroUsageResponse,
                        max_in_flight=5)
def get_l7_macro_usage(req_pb, auth_subject):
    """
    :type req_pb: api_pb2.GetL7MacroUsageRequest
    :rtype: api_pb2.GetL7MacroUsageResponse
    """
    c = cache.IAwacsCache.instance()
    by_version = c.count_balancer_usage_by_l7_macro_version()
    return api_pb2.GetComponentUsageResponse(balancer_counts_by_version=by_version)


@info_service_bp.method('ListEasyModeUpstreamWeightSections',
                        request_type=api_pb2.ListEasyModeUpstreamWeightSectionsRequest,
                        response_type=api_pb2.ListEasyModeUpstreamWeightSectionsResponse,
                        max_in_flight=5)
def list_upstream_weight_sections(req_pb, auth_subject):
    """
    :type req_pb: api_pb2.ListEasyModeUpstreamWeightSectionsRequest
    :rtype: api_pb2.ListEasyModeUpstreamWeightSectionsResponse
    """
    info.validate_request(req_pb)
    c = cache.IAwacsCache.instance()

    section_infos = {}

    for upstream_pb in c.list_all_upstreams(req_pb.namespace_id):
        if upstream_pb.spec.yandex_balancer.mode != model_pb2.YandexBalancerUpstreamSpec.EASY_MODE2:
            continue
        l7_upstream_macro_pb = upstream_pb.spec.yandex_balancer.config.l7_upstream_macro
        weights_section_ids = []
        if l7_upstream_macro_pb.HasField('by_dc_scheme'):
            if l7_upstream_macro_pb.by_dc_scheme.dc_balancer.weights_section_id:
                weights_section_ids.append(l7_upstream_macro_pb.by_dc_scheme.dc_balancer.weights_section_id)
        elif l7_upstream_macro_pb.HasField('traffic_split'):
            if l7_upstream_macro_pb.traffic_split.weights_section_id:
                weights_section_ids.append(l7_upstream_macro_pb.traffic_split.weights_section_id)

        for weights_section_id in weights_section_ids:
            ws_pb = model_pb2.WeightSection()
            ws_pb.meta.namespace_id = req_pb.namespace_id
            ws_pb.meta.id = weights_section_id
            ws_pb.meta.type = ws_pb.meta.ST_DC_WEIGHTS

            dcs_count = len(l7_upstream_macro_pb.by_dc_scheme.dcs)
            for i, dc in enumerate(sorted(l7_upstream_macro_pb.by_dc_scheme.dcs, key=lambda x: x.name)):
                loc_pb = ws_pb.spec.locations.add(name=dc.name.upper())
                loc_pb.default_weight = (100 // dcs_count) + (i < (100 % dcs_count))
            if not l7_upstream_macro_pb.by_dc_scheme.compat.disable_devnull:
                ws_pb.spec.locations.add(name='DEVNULL', default_weight=0, is_fallback=True)

            if weights_section_id not in section_infos:
                info_pb = api_pb2.ListEasyModeUpstreamWeightSectionsResponse.WeightSectionInfo(weight_section=ws_pb)
                info_pb.upstream_ids.append(upstream_pb.meta.id)
                section_infos[weights_section_id] = info_pb
            else:
                section_infos[weights_section_id].upstream_ids.append(upstream_pb.meta.id)
                if ws_pb != section_infos[weights_section_id].weight_section:
                    section_infos[weights_section_id].upstreams_are_incompatible = True

    resp_pb = api_pb2.ListEasyModeUpstreamWeightSectionsResponse()
    resp_pb.weight_sections.extend([info_pb for info_pb in sorted(six.itervalues(section_infos),
                                                                  key=lambda x: (x.upstreams_are_incompatible,
                                                                                 x.weight_section.meta.id))])
    return resp_pb
