from __future__ import annotations

import pickle
import os
from typing import Optional

from infra.awacs.proto import model_pb2, modules_pb2
from infra.awacs.tools.awacstoolslib.util import AwacsClient
from awacs.model.balancer.generator import validate_config
from awacs.model.balancer.vector import (Vector, BalancerVersion, UpstreamVersion,
                                         BackendVersion, EndpointSetVersion, KnobVersion, CertVersion, DomainVersion)
from awacs.model import objects
from awacs.model.db import MongoStorage
from awacs.lib.strutils import to_full_id
from awacs.wrappers import main  # noqa
from .util import clone_pb, clone_pb_dict, get_diff


def assert_balancer_configs_are_the_same(a: AwacsClient, namespace_id: str) -> model_pb2.Balancer:
    first_pb = None
    balancer_pb = None
    for balancer_pb in a.list_balancers(namespace_id):
        config_pb = balancer_pb.spec.yandex_balancer.config
        if first_pb is None:
            first_pb = config_pb
        elif first_pb != config_pb:
            diff = get_diff(str(first_pb), str(config_pb))
            raise AssertionError(f'Not all balancer configs are the same in namespace {namespace_id}:\n{diff}')
    if balancer_pb is None:
        raise AssertionError(f'No balancers found in namespace {namespace_id}')
    return balancer_pb


class NamespaceConfig:
    def __init__(self,
                 active_vector,
                 namespace_pb,
                 balancer_spec_pb,
                 domain_spec_pbs,
                 upstream_spec_pbs,
                 backend_spec_pbs,
                 endpoint_set_spec_pbs,
                 knob_spec_pbs,
                 cert_spec_pbs,
                 weight_section_spec_pbs):
        self.active_vector = active_vector
        self.namespace_pb = namespace_pb
        self.balancer_spec_pb = balancer_spec_pb
        self.domain_spec_pbs = domain_spec_pbs
        self.upstream_spec_pbs = upstream_spec_pbs
        self.backend_spec_pbs = backend_spec_pbs
        self.endpoint_set_spec_pbs = endpoint_set_spec_pbs
        self.knob_spec_pbs = knob_spec_pbs
        self.cert_spec_pbs = cert_spec_pbs
        self.weight_section_spec_pb = weight_section_spec_pbs

    @property
    def namespace_id(self):
        return self.namespace_pb.meta.id

    def get_upstream_spec(self, upstream_id):
        upstream_spec_pb = None
        for upstream_version, upstream_spec_pb_ in self.upstream_spec_pbs.items():
            _, upstream_id_ = upstream_version.upstream_id
            if upstream_id_ == upstream_id:
                upstream_spec_pb = upstream_spec_pb_
        if upstream_spec_pb is not None:
            return clone_pb(upstream_spec_pb)
        else:
            return None

    def must_get_upstream_spec(self, upstream_id):
        upstream_spec_pb = self.get_upstream_spec(upstream_id)
        if upstream_spec_pb is None:
            raise AssertionError(f'Upstream {upstream_id} not found')
        return upstream_spec_pb

    def replace_upstream_spec(self, upstream_id, upstream_spec_pb):
        for upstream_version, upstream_spec_pb_ in self.upstream_spec_pbs.items():
            _, upstream_id_ = upstream_version.upstream_id
            if upstream_id_ == upstream_id:
                upstream_spec_pb_.CopyFrom(upstream_spec_pb)
                return True
        raise AssertionError(f'Upstream {upstream_id} not found')

    def sync_upstream_spec(self, awacs_client: AwacsClient, upstream_id: str, comment: str = ''):
        for upstream_version, upstream_spec_pb in self.upstream_spec_pbs.items():
            _, upstream_id_ = upstream_version.upstream_id
            if upstream_id_ == upstream_id:
                awacs_client.update_upstream(self.namespace_id, upstream_id, upstream_version.version,
                                             spec_pb=upstream_spec_pb, comment=comment)
                return
        raise AssertionError(f'Upstream {upstream_id} not found')

    def _to_holder(self) -> main.Holder:
        return validate_config(
            namespace_pb=clone_pb(self.namespace_pb),
            namespace_id=self.active_vector.balancer_version.balancer_id[0],
            balancer_version=self.active_vector.balancer_version,
            balancer_spec_pb=clone_pb(self.balancer_spec_pb),
            upstream_spec_pbs=clone_pb_dict(self.upstream_spec_pbs),
            backend_spec_pbs=clone_pb_dict(self.backend_spec_pbs),
            endpoint_set_spec_pbs=clone_pb_dict(self.endpoint_set_spec_pbs),
            knob_spec_pbs=clone_pb_dict(self.knob_spec_pbs),
            cert_spec_pbs=clone_pb_dict(self.cert_spec_pbs),
            domain_spec_pbs=clone_pb_dict(self.domain_spec_pbs),
            weight_section_spec_pbs=clone_pb_dict(self.weight_section_spec_pb),
        ).balancer

    def to_pb(self) -> modules_pb2.Holder:
        return self._to_holder().pb

    def to_lua(self) -> str:
        return self._to_holder().to_config().to_top_level_lua()

    def copy(self):
        return NamespaceConfig(
            self.active_vector,
            clone_pb(self.namespace_pb),
            clone_pb(self.balancer_spec_pb),
            clone_pb_dict(self.domain_spec_pbs),
            clone_pb_dict(self.upstream_spec_pbs),
            clone_pb_dict(self.backend_spec_pbs),
            clone_pb_dict(self.endpoint_set_spec_pbs),
            clone_pb_dict(self.knob_spec_pbs),
            clone_pb_dict(self.cert_spec_pbs),
            clone_pb_dict(self.weight_section_spec_pb),
        )

    @staticmethod
    def _get_cache_path(cache_dir, namespace_id):
        return os.path.join(cache_dir, f'{namespace_id}.pickle')

    def to_cache(self, cache_dir):
        if not os.path.exists(cache_dir):
            os.mkdir(cache_dir)
        path = self._get_cache_path(cache_dir, self.namespace_id)
        with open(path, 'wb') as f:
            return self.to_file(f)

    def to_file(self, f):
        pickle.dump(self, f)

    @classmethod
    def maybe_from_cache(cls, cache_dir, namespace_id) -> Optional[NamespaceConfig]:
        path = cls._get_cache_path(cache_dir, namespace_id)
        if os.path.exists(path):
            with open(path, 'rb') as f:
                return cls.from_file(f)
        else:
            return None

    @classmethod
    def from_file(cls, f) -> NamespaceConfig:
        return pickle.load(f)

    @classmethod
    def from_api(cls, awacs_client: AwacsClient, namespace_id: str, db: MongoStorage = None) -> NamespaceConfig:
        balancer_pb = assert_balancer_configs_are_the_same(awacs_client, namespace_id)
        balancer_id = balancer_pb.meta.id

        namespace_pb = awacs_client.get_namespace(namespace_id)
        balancer_state_pb = awacs_client.get_balancer_state(namespace_id, balancer_id)

        active_vector = balancer_state_to_active_vector(namespace_id, balancer_id, balancer_state_pb)
        if not active_vector.balancer_version:
            raise AssertionError(f'No active configuration for balancer {namespace_id}:{balancer_id}, nothing to check')

        if db is None:
            specs = vector_to_specs_from_api(awacs_client, active_vector)
        else:
            specs = vector_to_specs_from_db(db, active_vector)
        (balancer_spec_pb,
         domain_spec_pbs,
         upstream_spec_pbs,
         backend_spec_pbs,
         endpoint_set_spec_pbs,
         knob_spec_pbs,
         cert_spec_pbs,
         weight_section_spec_pbs) = specs

        return cls(active_vector,
                   namespace_pb,
                   balancer_spec_pb,
                   domain_spec_pbs,
                   upstream_spec_pbs,
                   backend_spec_pbs,
                   endpoint_set_spec_pbs,
                   knob_spec_pbs,
                   cert_spec_pbs,
                   weight_section_spec_pbs)


def vector_to_specs_from_api(awacs_client: AwacsClient, vector):
    balancer_spec_pb = awacs_client.get_balancer_rev(vector.balancer_version.version)
    domain_spec_pbs = {}
    upstream_spec_pbs = {}
    backend_spec_pbs = {}
    endpoint_set_spec_pbs = {}
    knob_spec_pbs = {}
    cert_spec_pbs = {}
    weight_section_pbs = {}

    for full_domain_id, domain_version in vector.domain_versions.items():
        if domain_version.deleted:
            continue
        domain_spec_pbs[domain_version] = awacs_client.get_domain_rev(domain_version.version)

    for full_upstream_id, upstream_version in vector.upstream_versions.items():
        if upstream_version.deleted:
            continue
        upstream_spec_pbs[upstream_version] = awacs_client.get_upstream_rev(upstream_version.version)

    for full_backend_id, backend_version in vector.backend_versions.items():
        if backend_version.deleted:
            continue
        backend_spec_pbs[backend_version] = awacs_client.get_backend_rev(backend_version.version)

    for full_cert_id, cert_version in vector.cert_versions.items():
        if cert_version.deleted:
            continue
        cert_spec_pbs[cert_version] = awacs_client.get_cert_rev(cert_version.version)

    for full_endpoint_set_id, endpoint_set_version in vector.endpoint_set_versions.items():
        if endpoint_set_version.deleted:
            continue
        endpoint_set_spec_pbs[endpoint_set_version] = awacs_client.get_endpoint_set_rev(endpoint_set_version.version)

    for full_endpoint_set_id, weight_section_version in vector.weight_section_versions.items():
        if weight_section_version.deleted:
            continue
        weight_section_pbs[weight_section_version] = awacs_client.get_weight_section_rev(weight_section_version.version)

    return balancer_spec_pb, domain_spec_pbs, upstream_spec_pbs, backend_spec_pbs, endpoint_set_spec_pbs, knob_spec_pbs, cert_spec_pbs, weight_section_pbs


def vector_to_specs_from_db(db: MongoStorage, vector):
    balancer_spec_pb = db.must_get_balancer_rev(vector.balancer_version.version).spec
    domain_spec_pbs = {}
    upstream_spec_pbs = {}
    backend_spec_pbs = {}
    endpoint_set_spec_pbs = {}
    knob_spec_pbs = {}
    cert_spec_pbs = {}
    weight_section_pbs = {}

    for full_domain_id, domain_version in vector.domain_versions.items():
        if domain_version.deleted:
            continue
        domain_spec_pbs[domain_version] = db.must_get_domain_rev(domain_version.version).spec

    for full_upstream_id, upstream_version in vector.upstream_versions.items():
        if upstream_version.deleted:
            continue
        upstream_spec_pbs[upstream_version] = db.must_get_upstream_rev(upstream_version.version).spec

    for full_backend_id, backend_version in vector.backend_versions.items():
        if backend_version.deleted:
            continue
        backend_spec_pbs[backend_version] = db.must_get_backend_rev(backend_version.version).spec

    for full_cert_id, cert_version in vector.cert_versions.items():
        if cert_version.deleted:
            continue
        cert_spec_pbs[cert_version] = db.must_get_cert_rev(cert_version.version).spec

    for full_endpoint_set_id, endpoint_set_version in vector.endpoint_set_version.items():
        if endpoint_set_version.deleted:
            continue
        endpoint_set_spec_pbs[endpoint_set_version] = db.must_get_endpoint_set_rev(endpoint_set_version.version).spec

    # TODO(romanovich@):
    # for full_endpoint_set_id, weight_section_version in vector.weight_section_versions.items():
    #    if weight_section_version.deleted:
    #        continue
    #    weight_section_pbs[weight_section_version] = db.must_get_weight_section_rev(weight_section_version.version).spec

    return balancer_spec_pb, domain_spec_pbs, upstream_spec_pbs, backend_spec_pbs, endpoint_set_spec_pbs, knob_spec_pbs, cert_spec_pbs, weight_section_pbs


def balancer_state_to_active_vector(namespace_id, balancer_id, balancer_state_pb):
    balancer_active_version = None
    upstream_active_versions = {}
    backend_active_versions = {}
    endpoint_set_active_versions = {}
    knob_active_versions = {}
    cert_active_versions = {}
    weight_section_active_versions = {}

    for rev_pb in balancer_state_pb.balancer.statuses:
        v = BalancerVersion.from_rev_status_pb((namespace_id, balancer_id), rev_pb)
        if rev_pb.active.status == 'True':
            balancer_active_version = v

    for maybe_full_upstream_id, upstream_state_pb in balancer_state_pb.upstreams.items():
        upstream_full_id = to_full_id(namespace_id, maybe_full_upstream_id)
        for rev_pb in upstream_state_pb.statuses:
            v = UpstreamVersion.from_rev_status_pb(upstream_full_id, rev_pb)
            if rev_pb.active.status == 'True':
                upstream_active_versions[upstream_full_id] = v

    for maybe_full_backend_id, backend_state_pb in balancer_state_pb.backends.items():
        backend_full_id = to_full_id(namespace_id, maybe_full_backend_id)
        for rev_pb in backend_state_pb.statuses:
            v = BackendVersion.from_rev_status_pb(backend_full_id, rev_pb)
            if rev_pb.active.status == 'True':
                backend_active_versions[backend_full_id] = v

    for maybe_full_endpoint_set_id, endpoint_set_state_pb in balancer_state_pb.endpoint_sets.items():
        for rev_pb in endpoint_set_state_pb.statuses:
            endpoint_set_full_id = to_full_id(namespace_id, maybe_full_endpoint_set_id)
            v = EndpointSetVersion.from_rev_status_pb(endpoint_set_full_id, rev_pb)
            if rev_pb.active.status == 'True':
                endpoint_set_active_versions[endpoint_set_full_id] = v

    for knob_id, knob_state_pb in balancer_state_pb.knobs.items():
        for rev_pb in knob_state_pb.statuses:
            knob_full_id = to_full_id(namespace_id, knob_id)
            v = KnobVersion.from_rev_status_pb(knob_full_id, rev_pb)
            if rev_pb.active.status == 'True':
                knob_active_versions[knob_full_id] = v

    for cert_id, cert_state_pb in balancer_state_pb.certificates.items():
        for rev_pb in cert_state_pb.statuses:
            cert_full_id = to_full_id(namespace_id, cert_id)
            v = CertVersion.from_rev_status_pb(cert_full_id, rev_pb)
            if rev_pb.active.status == 'True':
                cert_active_versions[cert_full_id] = v

    domain_active_versions = {}
    for domain_id, domain_state_pb in balancer_state_pb.domains.items():
        for rev_pb in domain_state_pb.statuses:
            domain_full_id = to_full_id(namespace_id, domain_id)
            v = DomainVersion.from_rev_status_pb(domain_full_id, rev_pb)
            if rev_pb.active.status == 'True':
                domain_active_versions[domain_full_id] = v

    for weight_section_id, weight_section_state_pb in balancer_state_pb.weight_sections.items():
        for rev_pb in weight_section_state_pb.statuses:
            weight_section_full_id = to_full_id(namespace_id, weight_section_id)
            v = objects.WeightSection.version.from_rev_status_pb(weight_section_full_id, rev_pb)
            if rev_pb.active.status == 'True':
                weight_section_active_versions[weight_section_full_id] = v

    return Vector(balancer_active_version,
                  upstream_active_versions,
                  domain_active_versions,
                  backend_active_versions,
                  endpoint_set_active_versions,
                  knob_active_versions,
                  cert_active_versions,
                  weight_section_active_versions)
