# coding: utf-8
import json
from datetime import datetime, timedelta

import nanny_rpc_client
from awacs import yamlparser
from infra.awacs.proto import api_pb2, api_stub, modules_pb2, model_pb2


def check_balancer_pb(balancer_pb):
    """
    :type balancer_pb: model_pb2.Balancer
    :rtype: model_pb2.Balancer
    """
    if balancer_pb.spec.yandex_balancer.yaml:
        # If we get from API something that we can't parse with our
        # Arcadia proto definitions, they're probably outdated --
        # and we don't want to continue: https://st.yandex-team.ru/BALANCERSUPPORT-1369
        yamlparser.parse(modules_pb2.Holder, balancer_pb.spec.yandex_balancer.yaml)
    return balancer_pb


def check_domain_pb(domain_pb):
    """
    :type balancer_pb: model_pb2.Domain
    :rtype: model_pb2.Domain
    """
    if domain_pb.spec.yandex_balancer.yaml:
        # If we get from API something that we can't parse with our
        # Arcadia proto definitions, they're probably outdated --
        # and we don't want to continue: https://st.yandex-team.ru/BALANCERSUPPORT-1369
        yamlparser.parse(modules_pb2.Holder, domain_pb.spec.yandex_balancer.yaml)
    return domain_pb


def check_upstream_pb(upstream_pb):
    """
    :type upstream_pb: model_pb2.Upstream
    :rtype: model_pb2.Upstream
    """
    if upstream_pb.spec.yandex_balancer.yaml:
        # If we get from API something that we can't parse with our
        # Arcadia proto definitions, they're probably outdated --
        # and we don't want to continue: https://st.yandex-team.ru/BALANCERSUPPORT-1369
        yamlparser.parse(modules_pb2.Holder, upstream_pb.spec.yandex_balancer.yaml)
    return upstream_pb


class AwacsClient(object):
    def __init__(self, awacs_token, awacs_rpc_url):
        self.awacs_rpc = nanny_rpc_client.SessionedRpcClient(rpc_url=awacs_rpc_url, oauth_token=awacs_token)
        self.namespace_service_stub = api_stub.NamespaceServiceStub(self.awacs_rpc)
        self.l3_balancer_service_stub = api_stub.L3BalancerServiceStub(self.awacs_rpc)
        self.balancer_service_stub = api_stub.BalancerServiceStub(self.awacs_rpc)
        self.upstream_service_stub = api_stub.UpstreamServiceStub(self.awacs_rpc)
        self.cert_service_stub = api_stub.CertificateServiceStub(self.awacs_rpc)
        self.backend_service_stub = api_stub.BackendServiceStub(self.awacs_rpc)
        self.domain_service_stub = api_stub.DomainServiceStub(self.awacs_rpc)
        self.endpoint_set_service_stub = api_stub.EndpointSetServiceStub(self.awacs_rpc)
        self.knob_service_stub = api_stub.KnobServiceStub(self.awacs_rpc)
        self.weight_section_service_stub = api_stub.WeightSectionServiceStub(self.awacs_rpc)
        self.statistics_stub = api_stub.AwacsStatisticsServiceStub(self.awacs_rpc)

    def get_inclusion_graph(self, namespace_id):
        req_pb = api_pb2.GetNamespaceAspectsSetRequest(id=namespace_id)
        resp_pb = self.namespace_service_stub.get_namespace_aspects_set(req_pb)
        graph_pb = resp_pb.aspects_set.content.graph
        last_attempt_pb = graph_pb.status.last_attempt
        ok = True
        if last_attempt_pb.succeeded.status != 'True':
            ok = False
        if datetime.utcnow() - last_attempt_pb.finished_at.ToDatetime() > timedelta(hours=3):
            ok = False
        return ok, json.loads(graph_pb.content.inclusion_graph_json)

    def get_load_statistics_entry(self, dt):
        req_pb = api_pb2.GetLoadStatisticsEntryRequest()
        req_pb.start.FromDatetime(dt)
        resp_pb = self.statistics_stub.get_load_statistics_entry(req_pb)
        return resp_pb.entry

    def get_usage_statistics_entry(self, dt):
        req_pb = api_pb2.GetUsageStatisticsEntryRequest()
        req_pb.start.FromDatetime(dt)
        resp_pb = self.statistics_stub.get_usage_statistics_entry(req_pb)
        return resp_pb.entry

    def get_yesterday_max_rps_stats_by_namespace(self):
        """
        :rtype: Dict[six.text_type, int]
        """
        dt = datetime.utcnow().replace(hour=0, minute=0, second=0, microsecond=0)
        data_pb = self.get_load_statistics_entry(dt).date_statistics
        rv = {}
        for namespace_id, entry_pb in sorted(data_pb.by_namespace.items()):
            rv[namespace_id] = entry_pb.max
        return rv

    def get_yesterday_max_rps_stats_by_balancer(self):
        """
        :rtype: Dict[(six.text_type, six.text_type), int]
        """
        dt = datetime.utcnow().replace(hour=0, minute=0, second=0, microsecond=0)
        data_pb = self.get_load_statistics_entry(dt).date_statistics
        rv = {}
        for flat_balancer_id, entry_pb in sorted(data_pb.by_balancer.items()):
            full_balancer_id = tuple(flat_balancer_id.split('/'))
            rv[full_balancer_id] = int(entry_pb.max)
        return rv

    def iter_all_namespaces(self, field_mask_pb=None):
        """
        :type field_mask_pb: google.protobuf.field_mask_pb2.FieldMask
        """
        skip = 0
        limit = 500
        while 1:
            req_pb = api_pb2.ListNamespacesRequest(skip=skip, limit=limit)
            if field_mask_pb is not None:
                req_pb.field_mask.CopyFrom(field_mask_pb)
            resp_pb = self.namespace_service_stub.list_namespaces(req_pb)
            if not resp_pb.namespaces:
                break
            for pb in resp_pb.namespaces:
                yield pb
            skip += limit

    def list_namespace_ids(self):
        req_pb = api_pb2.ListNamespaceSummariesRequest()
        resp_pb = self.namespace_service_stub.list_namespace_summaries(req_pb)
        for pb in resp_pb.summaries:
            yield pb.id

    def list_balancer_revs(self, namespace_id, balancer_id, limit):
        req_pb = api_pb2.ListBalancerRevisionsRequest(namespace_id=namespace_id, id=balancer_id, limit=limit)
        resp_pb = self.balancer_service_stub.list_balancer_revisions(req_pb)
        return resp_pb.revisions

    def iter_all_balancers(self,
                           skip_incomplete=False,
                           yp_cluster_in=(),
                           location_in=(),
                           namespace_id_in=None,
                           full_balancer_id_in=None,
                           yp_lite_only=False):
        for yp_cluster in yp_cluster_in:
            if yp_cluster not in ('SAS', 'MAN', 'VLA', 'MYT', 'IVA'):
                raise ValueError('Unknown YP cluster "{!r}"'.format(yp_cluster))
        for location in location_in:
            if location not in ('SAS', 'MAN', 'VLA', 'MYT', 'IVA', 'XDC'):
                raise ValueError('Unknown location "{!r}"'.format(location))
        if full_balancer_id_in is not None and namespace_id_in is not None:
            namespace_ids = sorted({namespace_id for namespace_id, _ in full_balancer_id_in} & set(namespace_id_in))
        elif full_balancer_id_in is not None:
            namespace_ids = sorted({namespace_id for namespace_id, _ in full_balancer_id_in})
        elif namespace_id_in is not None:
            namespace_ids = sorted(namespace_id_in)
        else:
            namespace_ids = self.list_namespace_ids()
        for namespace_id in namespace_ids:
            for balancer_pb in self.list_balancers(namespace_id):
                if balancer_pb.spec.incomplete and skip_incomplete:
                    continue
                if yp_lite_only and not balancer_pb.meta.location.yp_cluster:
                    continue
                balancer_yp_cluster = balancer_pb.meta.location.yp_cluster.upper()
                balancer_gencfg_dc = balancer_pb.meta.location.gencfg_dc.upper()
                if yp_cluster_in and balancer_yp_cluster not in yp_cluster_in:
                    continue
                if location_in and (balancer_yp_cluster or balancer_gencfg_dc) not in location_in:
                    continue
                full_id = (balancer_pb.meta.namespace_id, balancer_pb.meta.id)
                if full_balancer_id_in and full_id not in full_balancer_id_in:
                    continue
                yield check_balancer_pb(balancer_pb)

    def iter_all_upstreams(self):
        for namespace_id in self.list_namespace_ids():
            for upstream_pb in self.list_upstreams(namespace_id):
                yield check_upstream_pb(upstream_pb)

    # certs
    def iter_all_certs(self):
        for namespace_id in self.list_namespace_ids():
            for cert_pb in self.list_certs(namespace_id):
                yield cert_pb

    def get_cert_rev(self, id):
        req_pb = api_pb2.GetCertificateRevisionRequest(id=id)
        resp_pb = self.cert_service_stub.get_certificate_revision(req_pb)
        return resp_pb.revision.spec

    def list_certs(self, namespace_id):
        req_pb = api_pb2.ListCertificatesRequest(namespace_id=namespace_id)
        resp_pb = self.cert_service_stub.list_certificates(req_pb)
        return resp_pb.certificates

    def update_cert(self, namespace_id, cert_id, version, auth_pb, comment=''):
        req_pb = api_pb2.UpdateCertificateRequest()
        req_pb.meta.id = cert_id
        req_pb.meta.namespace_id = namespace_id
        req_pb.meta.comment = comment
        req_pb.meta.auth.CopyFrom(auth_pb)
        req_pb.meta.version = version
        resp_pb = self.cert_service_stub.update_certificate(req_pb)
        return resp_pb.certificate

    # /

    def list_balancers(self, namespace_id) -> list[model_pb2.Balancer]:
        req_pb = api_pb2.ListBalancersRequest(namespace_id=namespace_id)
        resp_pb = self.balancer_service_stub.list_balancers(req_pb)
        assert resp_pb.total == len(resp_pb.balancers)
        for balancer_pb in resp_pb.balancers:
            check_balancer_pb(balancer_pb)
        return resp_pb.balancers

    def list_upstreams(self, namespace_id):
        upstream_pbs = []
        req_pb = api_pb2.ListUpstreamsRequest(namespace_id=namespace_id)
        req_pb.limit = 1000
        while 1:
            resp_pb = self.upstream_service_stub.list_upstreams(req_pb)
            upstream_pbs.extend(resp_pb.upstreams)
            if len(upstream_pbs) < resp_pb.total:
                req_pb.skip += 1000
            else:
                assert len(upstream_pbs) == resp_pb.total
                break
        return upstream_pbs

    def list_upstream_ids(self, namespace_id):
        req_pb = api_pb2.ListUpstreamsRequest(namespace_id=namespace_id)
        req_pb.field_mask.paths.append('meta.id')
        req_pb.limit = 1000
        ids = []
        while True:
            resp_pb = self.upstream_service_stub.list_upstreams(req_pb)
            for pb in resp_pb.upstreams:
                ids.append(pb.meta.id)
            if len(ids) < resp_pb.total:
                req_pb.skip = len(ids)
            else:
                assert len(ids) == resp_pb.total
                break

        return sorted(ids)

    def list_balancer_ids(self, namespace_id):
        req_pb = api_pb2.ListBalancersRequest(namespace_id=namespace_id)
        req_pb.field_mask.paths.append('meta.id')
        resp_pb = self.balancer_service_stub.list_balancers(req_pb)
        assert resp_pb.total == len(resp_pb.balancers)
        return sorted([pb.meta.id for pb in resp_pb.balancers])

    def get_namespace(self, namespace_id):
        req_pb = api_pb2.GetNamespaceRequest(id=namespace_id)
        resp_pb = self.namespace_service_stub.get_namespace(req_pb)
        return resp_pb.namespace

    def get_balancer(self, namespace_id, balancer_id):
        req_pb = api_pb2.GetBalancerRequest(namespace_id=namespace_id, id=balancer_id)
        resp_pb = self.balancer_service_stub.get_balancer(req_pb)
        return check_balancer_pb(resp_pb.balancer)

    def get_balancer_state(self, namespace_id, balancer_id):
        req_pb = api_pb2.GetBalancerStateRequest(namespace_id=namespace_id, id=balancer_id)
        resp_pb = self.balancer_service_stub.get_balancer_state(req_pb)
        return resp_pb.state

    def get_balancer_rev(self, id):
        req_pb = api_pb2.GetBalancerRevisionRequest(id=id)
        resp_pb = self.balancer_service_stub.get_balancer_revision(req_pb)
        return resp_pb.revision.spec

    # upstream
    def get_upstream(self, namespace_id, upstream_id):
        req_pb = api_pb2.GetUpstreamRequest(namespace_id=namespace_id, id=upstream_id)
        resp_pb = self.upstream_service_stub.get_upstream(req_pb)
        return resp_pb.upstream

    def update_namespace(self, namespace_id, auth_pb=None, spec_pb=None, meta_version=None, annotations=None, comment=''):
        req_pb = api_pb2.UpdateNamespaceRequest()
        req_pb.meta.id = namespace_id
        req_pb.meta.comment = comment
        if auth_pb is not None:
            req_pb.meta.auth.CopyFrom(auth_pb)
        if spec_pb is not None:
            req_pb.spec.CopyFrom(spec_pb)
        if meta_version is not None:
            req_pb.meta.version = meta_version
        if annotations is not None:
            req_pb.meta.annotations.update(annotations)
        resp_pb = self.namespace_service_stub.update_namespace(req_pb)
        return resp_pb.namespace

    def update_upstream(self, namespace_id, upstream_id, version, auth_pb=None, spec_pb=None, comment=''):
        req_pb = api_pb2.UpdateUpstreamRequest()
        req_pb.meta.id = upstream_id
        req_pb.meta.namespace_id = namespace_id
        req_pb.meta.comment = comment
        if auth_pb is not None:
            req_pb.meta.auth.CopyFrom(auth_pb)
        if spec_pb is not None:
            req_pb.spec.CopyFrom(spec_pb)
        req_pb.meta.version = version
        resp_pb = self.upstream_service_stub.update_upstream(req_pb)
        return resp_pb.upstream

    # /upstream

    def get_upstream_rev(self, id):
        req_pb = api_pb2.GetUpstreamRevisionRequest(id=id)
        resp_pb = self.upstream_service_stub.get_upstream_revision(req_pb)
        return resp_pb.revision.spec

    def remove_upstream(self, namespace_id, upstream_id, version):
        req_pb = api_pb2.RemoveUpstreamRequest()
        req_pb.namespace_id = namespace_id
        req_pb.id = upstream_id
        req_pb.version = version
        self.upstream_service_stub.remove_upstream(req_pb)

    def update_balancer(self, namespace_id, balancer_id, version, spec_pb=None, auth_pb=None, location_pb=None,
                        comment=''):
        req_pb = api_pb2.UpdateBalancerRequest()
        req_pb.meta.namespace_id = namespace_id
        req_pb.meta.id = balancer_id
        req_pb.meta.version = version
        req_pb.meta.comment = comment
        if spec_pb is not None:
            req_pb.spec.CopyFrom(spec_pb)
        if location_pb is not None:
            req_pb.meta.location.CopyFrom(location_pb)
        if auth_pb is not None:
            req_pb.meta.auth.CopyFrom(auth_pb)
        resp_pb = self.balancer_service_stub.update_balancer(req_pb)
        return resp_pb.balancer

    def list_l3_balancers(self, namespace_id):
        req_pb = api_pb2.ListL3BalancersRequest(namespace_id=namespace_id)
        resp_pb = self.l3_balancer_service_stub.list_l3_balancers(req_pb)
        assert resp_pb.total == len(resp_pb.l3_balancers)
        return resp_pb.l3_balancers

    def list_l3_balancer_ids(self, namespace_id):
        req_pb = api_pb2.ListL3BalancersRequest(namespace_id=namespace_id)
        req_pb.field_mask.paths.append('meta.id')
        resp_pb = self.l3_balancer_service_stub.list_l3_balancers(req_pb)
        assert resp_pb.total == len(resp_pb.l3_balancers)
        return sorted([pb.meta.id for pb in resp_pb.l3_balancers])

    def iter_all_l3_balancers(self):
        for namespace_id in self.list_namespace_ids():
            for l3_balancer_pb in self.list_l3_balancers(namespace_id):
                yield l3_balancer_pb

    def get_l3_balancer(self, namespace_id, l3_balancer_id):
        req_pb = api_pb2.GetL3BalancerRequest(namespace_id=namespace_id, id=l3_balancer_id)
        resp_pb = self.l3_balancer_service_stub.get_l3_balancer(req_pb)
        return resp_pb.l3_balancer

    def update_l3_balancer(self, namespace_id, l3_balancer_id, version, spec_pb, comment=''):
        req_pb = api_pb2.UpdateL3BalancerRequest()
        req_pb.meta.namespace_id = namespace_id
        req_pb.meta.id = l3_balancer_id
        req_pb.meta.version = version
        req_pb.meta.comment = comment
        req_pb.spec.CopyFrom(spec_pb)
        resp_pb = self.l3_balancer_service_stub.update_l3_balancer(req_pb)
        return resp_pb.l3_balancer

    # backends
    def get_backend(self, namespace_id, backend_id):
        req_pb = api_pb2.GetBackendRequest(namespace_id=namespace_id, id=backend_id)
        resp_pb = self.backend_service_stub.get_backend(req_pb)
        return resp_pb.backend

    def get_backend_rev(self, id):
        req_pb = api_pb2.GetBackendRevisionRequest(id=id)
        resp_pb = self.backend_service_stub.get_backend_revision(req_pb)
        return resp_pb.revision.spec

    def iter_all_backends(self):
        for namespace_id in self.list_namespace_ids():
            for backend_pb in self.list_backends(namespace_id):
                yield backend_pb

    def list_backend_revs(self, namespace_id, backend_id):
        req_pb = api_pb2.ListBackendRevisionsRequest(namespace_id=namespace_id, id=backend_id)
        resp_pb = self.backend_service_stub.list_backend_revisions(req_pb)
        return resp_pb.revisions

    def list_backends(self, namespace_id):
        req_pb = api_pb2.ListBackendsRequest(namespace_id=namespace_id)
        resp_pb = self.backend_service_stub.list_backends(req_pb)
        assert resp_pb.total == len(resp_pb.backends)
        return resp_pb.backends

    def update_backend(self, namespace_id, backend_id, version, auth_pb=None, spec_pb=None, comment=''):
        req_pb = api_pb2.UpdateBackendRequest()
        req_pb.meta.id = backend_id
        req_pb.meta.namespace_id = namespace_id
        req_pb.meta.comment = comment
        if auth_pb is not None:
            req_pb.meta.auth.CopyFrom(auth_pb)
        if spec_pb is not None:
            req_pb.spec.CopyFrom(spec_pb)
        req_pb.meta.version = version
        resp_pb = self.backend_service_stub.update_backend(req_pb)
        return resp_pb.backend

    def remove_backend(self, namespace_id, backend_id, version):
        req_pb = api_pb2.RemoveBackendRequest()
        req_pb.namespace_id = namespace_id
        req_pb.id = backend_id
        req_pb.version = version
        self.backend_service_stub.remove_backend(req_pb)

    # endpoint sets
    def list_endpoint_sets(self, namespace_id):
        req_pb = api_pb2.ListEndpointSetsRequest(namespace_id=namespace_id)
        resp_pb = self.endpoint_set_service_stub.list_endpoint_sets(req_pb)
        assert resp_pb.total == len(resp_pb.endpoint_sets)
        return resp_pb.endpoint_sets

    def iter_all_endpoint_sets(self):
        for namespace_id in self.list_namespace_ids():
            for endpoint_set_pb in self.list_endpoint_sets(namespace_id):
                yield endpoint_set_pb

    def list_endpoint_set_revs(self, namespace_id, backend_id):
        req_pb = api_pb2.ListEndpointSetRevisionsRequest(namespace_id=namespace_id, id=backend_id)
        resp_pb = self.endpoint_set_service_stub.list_endpoint_set_revisions(req_pb)
        return resp_pb.revisions

    def get_endpoint_set(self, namespace_id, backend_id):
        req_pb = api_pb2.GetEndpointSetRequest(namespace_id=namespace_id, id=backend_id)
        resp_pb = self.endpoint_set_service_stub.get_endpoint_set(req_pb)
        return resp_pb.endpoint_set

    def get_endpoint_set_rev(self, id):
        req_pb = api_pb2.GetEndpointSetRevisionRequest(id=id)
        resp_pb = self.endpoint_set_service_stub.get_endpoint_set_revision(req_pb)
        return resp_pb.revision.spec

    # aspect sets
    def list_balancer_aspects_sets(self, namespace_id):
        req_pb = api_pb2.ListBalancerAspectsSetsRequest(namespace_id=namespace_id)
        resp_pb = self.balancer_service_stub.list_balancer_aspects_sets(req_pb)
        return resp_pb.aspects_sets

    def iter_all_balancer_aspects_sets(self, namespace_id_in=()):
        if namespace_id_in:
            namespace_ids = namespace_id_in
        else:
            namespace_ids = self.list_namespace_ids()
        for namespace_id in namespace_ids:
            for aspects_set_pb in self.list_balancer_aspects_sets(namespace_id):
                yield aspects_set_pb

    # domains
    def list_domains(self, namespace_id):
        req_pb = api_pb2.ListDomainsRequest(namespace_id=namespace_id)
        resp_pb = self.domain_service_stub.list_domains(req_pb)
        assert resp_pb.total == len(resp_pb.domains)
        for domain_pb in resp_pb.domains:
            check_domain_pb(domain_pb)
        return resp_pb.domains

    def get_domain_rev(self, id):
        req_pb = api_pb2.GetDomainRevisionRequest(id=id)
        resp_pb = self.domain_service_stub.get_domain_revision(req_pb)
        return resp_pb.revision.spec

    # weight section
    def get_weight_section_rev(self, id):
        req_pb = api_pb2.GetWeightSectionRevisionRequest(id=id)
        resp_pb = self.weight_section_service_stub.get_weight_section_revision(req_pb)
        return resp_pb.revision.spec
