# coding: utf-8
import datetime
import os
from functools import wraps

import nanny_rpc_client
import six
import vcr
from infra.awacs.proto import api_pb2, api_stub

from awacs.model.balancer.generator import validate_config
from awacs.model.balancer.vector import (Vector, BalancerVersion, UpstreamVersion,
                                         BackendVersion, EndpointSetVersion, KnobVersion, CertVersion, DomainVersion)
from awacs.lib.strutils import to_full_id
from awacs.wrappers import main  # noqa


def vector_to_specs(d, vector):
    """
    :type d: App
    """
    balancer_spec_pb = d.get_balancer_rev(vector.balancer_version.version)
    domain_spec_pbs = {}
    upstream_spec_pbs = {}
    backend_spec_pbs = {}
    endpoint_set_spec_pbs = {}
    knob_spec_pbs = {}
    cert_spec_pbs = {}

    for full_domain_id, domain_version in six.iteritems(vector.domain_versions):
        if domain_version.deleted:
            continue
        domain_spec_pbs[domain_version] = d.get_domain_rev(domain_version.version)

    for full_upstream_id, upstream_version in six.iteritems(vector.upstream_versions):
        if upstream_version.deleted:
            continue
        upstream_spec_pbs[upstream_version] = d.get_upstream_rev(upstream_version.version)

    for full_backend_id, backend_version in six.iteritems(vector.backend_versions):
        if backend_version.deleted:
            continue
        backend_spec_pbs[backend_version] = d.get_backend_rev(backend_version.version)

    for full_cert_id, cert_version in six.iteritems(vector.cert_versions):
        if cert_version.deleted:
            continue
        cert_spec_pbs[cert_version] = d.get_cert_rev(cert_version.version)

    for full_endpoint_set_id, endpoint_set_version in six.iteritems(vector.endpoint_set_versions):
        if endpoint_set_version.deleted:
            continue
        endpoint_set_spec_pbs[endpoint_set_version] = d.get_endpoint_set_rev(endpoint_set_version.version)

    return balancer_spec_pb, domain_spec_pbs, upstream_spec_pbs, backend_spec_pbs, endpoint_set_spec_pbs, knob_spec_pbs, cert_spec_pbs


def vector_to_config_holder(vector,
                            namespace_pb,
                            balancer_spec_pb,
                            domain_spec_pbs,
                            upstream_spec_pbs,
                            backend_spec_pbs,
                            endpoint_set_spec_pbs,
                            knob_spec_pbs,
                            cert_spec_pbs):
    return validate_config(namespace_pb=namespace_pb,
                           namespace_id=vector.balancer_version.balancer_id[0],
                           balancer_version=vector.balancer_version,
                           balancer_spec_pb=balancer_spec_pb,
                           upstream_spec_pbs=upstream_spec_pbs,
                           backend_spec_pbs=backend_spec_pbs,
                           endpoint_set_spec_pbs=endpoint_set_spec_pbs,
                           knob_spec_pbs=knob_spec_pbs,
                           cert_spec_pbs=cert_spec_pbs,
                           domain_spec_pbs=domain_spec_pbs,
                           ).balancer


def balancer_state_to_active_vector(namespace_id, balancer_id, balancer_state_pb):
    balancer_active_version = None
    upstream_active_versions = {}
    backend_active_versions = {}
    endpoint_set_active_versions = {}
    knob_active_versions = {}
    cert_active_versions = {}

    for rev_pb in balancer_state_pb.balancer.statuses:
        v = BalancerVersion.from_rev_status_pb((namespace_id, balancer_id), rev_pb)
        if rev_pb.active.status == 'True':
            balancer_active_version = v

    for maybe_full_upstream_id, upstream_state_pb in six.iteritems(balancer_state_pb.upstreams):
        upstream_full_id = to_full_id(namespace_id, maybe_full_upstream_id)
        for rev_pb in upstream_state_pb.statuses:
            v = UpstreamVersion.from_rev_status_pb(upstream_full_id, rev_pb)
            if rev_pb.active.status == 'True':
                upstream_active_versions[upstream_full_id] = v

    for maybe_full_backend_id, backend_state_pb in six.iteritems(balancer_state_pb.backends):
        backend_full_id = to_full_id(namespace_id, maybe_full_backend_id)
        for rev_pb in backend_state_pb.statuses:
            v = BackendVersion.from_rev_status_pb(backend_full_id, rev_pb)
            if rev_pb.active.status == 'True':
                backend_active_versions[backend_full_id] = v

    for maybe_full_endpoint_set_id, endpoint_set_state_pb in six.iteritems(balancer_state_pb.endpoint_sets):
        for rev_pb in endpoint_set_state_pb.statuses:
            endpoint_set_full_id = to_full_id(namespace_id, maybe_full_endpoint_set_id)
            v = EndpointSetVersion.from_rev_status_pb(endpoint_set_full_id, rev_pb)
            if rev_pb.active.status == 'True':
                endpoint_set_active_versions[endpoint_set_full_id] = v

    for knob_id, knob_state_pb in six.iteritems(balancer_state_pb.knobs):
        for rev_pb in knob_state_pb.statuses:
            knob_full_id = to_full_id(namespace_id, knob_id)
            v = KnobVersion.from_rev_status_pb(knob_full_id, rev_pb)
            if rev_pb.active.status == 'True':
                knob_active_versions[knob_full_id] = v

    for cert_id, cert_state_pb in six.iteritems(balancer_state_pb.certificates):
        for rev_pb in cert_state_pb.statuses:
            cert_full_id = to_full_id(namespace_id, cert_id)
            v = CertVersion.from_rev_status_pb(cert_full_id, rev_pb)
            if rev_pb.active.status == 'True':
                cert_active_versions[cert_full_id] = v

    domain_active_versions = {}
    for domain_id, domain_state_pb in six.iteritems(balancer_state_pb.domains):
        for rev_pb in domain_state_pb.statuses:
            domain_full_id = to_full_id(namespace_id, domain_id)
            v = DomainVersion.from_rev_status_pb(domain_full_id, rev_pb)
            if rev_pb.active.status == 'True':
                domain_active_versions[domain_full_id] = v

    active_vector = Vector(balancer_active_version,
                           upstream_active_versions,
                           domain_active_versions,
                           backend_active_versions,
                           endpoint_set_active_versions,
                           knob_active_versions,
                           cert_active_versions)

    return active_vector


def _cls_decorator(cls):
    if not (os.getenv('VCR_ENABLE', default=False) or os.getenv('VCR_RECORD_MODE', default=False)):
        return cls

    CASSETTES_DIR = './cassettes'
    CASSETTE = os.getenv('VCR_CASSETE', default=str(datetime.date.today()))

    stat_vcr = vcr.VCR(
        cassette_library_dir=CASSETTES_DIR,
        record_mode=os.getenv('VCR_RECORD_MODE', default='new_episodes'),
        filter_headers=['authorization'],
        match_on=['method', 'scheme', 'host', 'port', 'path', 'query', 'raw_body'],
    )

    for attr in cls.__dict__:
        func = getattr(cls, attr)
        if callable(func) and (attr.startswith('list_') or attr.startswith('get_')):
            def make_wrapper(func=func):
                @wraps(func)
                def wrapper(obj_self, *args, **kw):
                    with stat_vcr.use_cassette(CASSETTE + '_' + str(hash(str(args) + str(kw))) + '.yml'):
                        return func(obj_self, *args, **kw)

                return wrapper

            setattr(cls, attr, make_wrapper())
    return cls


@_cls_decorator
class App(object):
    def __init__(self, awacs_rpc_url, awacs_token):
        self.awacs_rpc = nanny_rpc_client.RetryingRpcClient(rpc_url=awacs_rpc_url, oauth_token=awacs_token)
        self.namespace_service_stub = api_stub.NamespaceServiceStub(self.awacs_rpc)
        self.balancer_service_stub = api_stub.BalancerServiceStub(self.awacs_rpc)
        self.domain_service_stub = api_stub.DomainServiceStub(self.awacs_rpc)
        self.upstream_service_stub = api_stub.UpstreamServiceStub(self.awacs_rpc)
        self.backend_service_stub = api_stub.BackendServiceStub(self.awacs_rpc)
        self.endpoint_set_service_stub = api_stub.EndpointSetServiceStub(self.awacs_rpc)
        self.cert_service_stub = api_stub.CertificateServiceStub(self.awacs_rpc)
        self.knob_service_stub = api_stub.KnobServiceStub(self.awacs_rpc)

    def get_namespace(self, namespace_id):
        req_pb = api_pb2.GetNamespaceRequest(id=namespace_id)
        resp_pb = self.namespace_service_stub.get_namespace(req_pb)
        return resp_pb.namespace

    def list_namespace_ids(self):
        req_pb = api_pb2.ListNamespacesRequest()
        req_pb.field_mask.paths.append('meta.id')
        resp_pb = self.namespace_service_stub.list_namespaces(req_pb)
        assert resp_pb.total == len(resp_pb.namespaces)
        return sorted([pb.meta.id for pb in resp_pb.namespaces])

    def get_balancer_state(self, namespace_id, balancer_id):
        req_pb = api_pb2.GetBalancerStateRequest(namespace_id=namespace_id, id=balancer_id)
        resp_pb = self.balancer_service_stub.get_balancer_state(req_pb)
        return resp_pb.state

    def list_balancers(self, namespace_id):
        req_pb = api_pb2.ListBalancersRequest(namespace_id=namespace_id)
        resp_pb = self.balancer_service_stub.list_balancers(req_pb)
        assert resp_pb.total == len(resp_pb.balancers)
        return resp_pb.balancers

    def list_balancer_ids(self, namespace_id):
        req_pb = api_pb2.ListBalancersRequest(namespace_id=namespace_id)
        req_pb.field_mask.paths.append('meta.id')
        resp_pb = self.balancer_service_stub.list_balancers(req_pb)
        assert resp_pb.total == len(resp_pb.balancers)
        return sorted([pb.meta.id for pb in resp_pb.balancers])

    def list_upstream_ids(self, namespace_id):
        req_pb = api_pb2.ListUpstreamsRequest(namespace_id=namespace_id)
        req_pb.field_mask.paths.append('meta.id')
        req_pb.limit = 1000
        ids = []
        while True:
            resp_pb = self.upstream_service_stub.list_upstreams(req_pb)
            for pb in resp_pb.upstreams:
                ids.append(pb.meta.id)
            if len(ids) < resp_pb.total:
                req_pb.skip = len(ids)
            else:
                assert len(ids) == resp_pb.total
                break
        return sorted(ids)

    def get_balancer(self, namespace_id, balancer_id):
        req_pb = api_pb2.GetBalancerRequest(namespace_id=namespace_id, id=balancer_id)
        resp_pb = self.balancer_service_stub.get_balancer(req_pb)
        return resp_pb.balancer

    def update_balancer(self, namespace_id, balancer_id, version, spec_pb, comment=''):
        req_pb = api_pb2.UpdateBalancerRequest()
        req_pb.meta.namespace_id = namespace_id
        req_pb.meta.id = balancer_id
        req_pb.meta.version = version
        req_pb.meta.comment = comment
        req_pb.spec.CopyFrom(spec_pb)
        resp_pb = self.balancer_service_stub.update_balancer(req_pb)
        return resp_pb.balancer

    def get_upstream(self, namespace_id, upstream_id):
        req_pb = api_pb2.GetUpstreamRequest(namespace_id=namespace_id, id=upstream_id)
        resp_pb = self.upstream_service_stub.get_upstream(req_pb)
        return resp_pb.upstream

    def update_upstream(self, namespace_id, upstream_id, version, spec_pb, comment=''):
        req_pb = api_pb2.UpdateUpstreamRequest()
        req_pb.meta.namespace_id = namespace_id
        req_pb.meta.id = upstream_id
        req_pb.meta.version = version
        req_pb.meta.comment = comment
        req_pb.spec.CopyFrom(spec_pb)
        resp_pb = self.upstream_service_stub.update_upstream(req_pb)
        return resp_pb.upstream

    def get_balancer_rev(self, id):
        req_pb = api_pb2.GetBalancerRevisionRequest(id=id)
        resp_pb = self.balancer_service_stub.get_balancer_revision(req_pb)
        return resp_pb.revision.spec

    def list_upstreams(self, namespace_id):
        req_pb = api_pb2.ListUpstreamsRequest(namespace_id=namespace_id)
        resp_pb = self.upstream_service_stub.list_upstreams(req_pb)
        assert resp_pb.total == len(resp_pb.upstreams)
        return resp_pb.upstreams

    def get_upstream_rev(self, id):
        req_pb = api_pb2.GetUpstreamRevisionRequest(id=id)
        resp_pb = self.upstream_service_stub.get_upstream_revision(req_pb)
        return resp_pb.revision.spec

    def list_domains(self, namespace_id):
        req_pb = api_pb2.ListDomainsRequest(namespace_id=namespace_id)
        resp_pb = self.domain_service_stub.list_domains(req_pb)
        assert resp_pb.total == len(resp_pb.domains)
        return resp_pb.domains

    def get_domain_rev(self, id):
        req_pb = api_pb2.GetDomainRevisionRequest(id=id)
        resp_pb = self.domain_service_stub.get_domain_revision(req_pb)
        return resp_pb.revision.spec

    def list_certs(self, namespace_id):
        req_pb = api_pb2.ListCertificatesRequest(namespace_id=namespace_id)
        resp_pb = self.cert_service_stub.list_certificates(req_pb)
        assert resp_pb.total == len(resp_pb.certificates)
        return resp_pb.certificates

    def get_cert_rev(self, id):
        req_pb = api_pb2.GetCertificateRevisionRequest(id=id)
        resp_pb = self.cert_service_stub.get_certificate_revision(req_pb)
        return resp_pb.revision.spec

    def list_backends(self, namespace_id):
        req_pb = api_pb2.ListBackendsRequest(namespace_id=namespace_id)
        resp_pb = self.backend_service_stub.list_backends(req_pb)
        assert resp_pb.total == len(resp_pb.backends)
        return resp_pb.backends

    def get_backend_rev(self, id):
        req_pb = api_pb2.GetBackendRevisionRequest(id=id)
        resp_pb = self.backend_service_stub.get_backend_revision(req_pb)
        return resp_pb.revision.spec

    def list_endpoint_sets(self, namespace_id):
        req_pb = api_pb2.ListEndpointSetsRequest(namespace_id=namespace_id)
        resp_pb = self.endpoint_set_service_stub.list_endpoint_sets(req_pb)
        assert resp_pb.total == len(resp_pb.endpoint_sets)
        return resp_pb.endpoint_sets

    def get_endpoint_set_rev(self, id):
        req_pb = api_pb2.GetEndpointSetRevisionRequest(id=id)
        resp_pb = self.endpoint_set_service_stub.get_endpoint_set_revision(req_pb)
        return resp_pb.revision.spec

    # def get_cert_rev(self, id):
    #    req_pb = api_pb2.GetCertRevisionRequest(id=id)
    #    resp_pb = self.cert_service_stub.get_cert_revision(req_pb)
    #    return resp_pb.revision.spec

    def list_knobs(self, namespace_id):
        req_pb = api_pb2.ListKnobsRequest(namespace_id=namespace_id)
        resp_pb = self.knob_service_stub.list_knobs(req_pb)
        assert resp_pb.total == len(resp_pb.knobs)
        return resp_pb.knobs
