import six
import inject

from awacs.lib.vectors.state_handler import StateHandler, Rev
from awacs.lib.vectors.vector import Vector, ValidationError
from awacs.lib.vectors.vector_discovered import DiscoveredVector
from awacs.lib.vectors.vector_mutable import MutableVector
from awacs.lib.vectors.version import L7HeavyConfigVersion, WeightSectionVersion
from awacs.model import cache, objects, util
from awacs.lib.strutils import to_full_id
from infra.awacs.proto import model_pb2


class L7HeavyConfigVector(Vector):
    __main_version_class__ = L7HeavyConfigVersion
    __version_classes__ = (WeightSectionVersion,)

    def __init__(self, l7heavy_config_version, weight_section_versions):
        self.l7heavy_config_version = l7heavy_config_version
        self.weight_section_versions = weight_section_versions
        super(L7HeavyConfigVector, self).__init__()


class MutableL7HeavyConfigVector(MutableVector):
    __main_version_class__ = L7HeavyConfigVersion
    __version_classes__ = (WeightSectionVersion,)

    COMMON_REQUIRED_LOCATIONS = frozenset(['SAS', 'MAN', 'VLA'])

    def __init__(self, l7heavy_config_version, weight_section_versions, validated_pbs):
        self.l7heavy_config_version = l7heavy_config_version
        self.weight_section_versions = weight_section_versions

        # super() resolves to Vector instead of MutableVector here, this is a workaround
        MutableVector.__init__(self, validated_pbs=validated_pbs)

    def validate(self, ctx):
        """
        :type ctx: context.OpCtx
        :raises: ValidationError
        """
        l7heavy_spec_pb = objects.L7HeavyConfig.find_rev_spec(self.l7heavy_config_version)
        if l7heavy_spec_pb.group_id == util.COMMON_L7HEAVY_GROUP_ID:
            ctx.log.debug(u'start common group validation')
            if len(self.weight_section_versions) == 0:
                raise ValidationError(u'L7-heavy config located in group "common" must have at least one section')

            for ws_version in six.itervalues(self.weight_section_versions):
                spec_pb = objects.WeightSection.find_rev_spec(ws_version)
                locations = {location_pb.name for location_pb in spec_pb.locations}
                if not self.COMMON_REQUIRED_LOCATIONS.issubset(locations):
                    raise ValidationError(
                        u'All weights sections must contains locations ({}) if L7Heavy config located in group'
                        u' "common"'.format(', '.join(sorted(self.COMMON_REQUIRED_LOCATIONS))), cause=ws_version)

    def has_anything_to_validate(self):
        if self.main_version is None:
            return False
        return True

    def remove_orphan_versions(self):
        pass


class DiscoveredL7HeavyConfigVector(DiscoveredVector):
    __main_version_class__ = L7HeavyConfigVersion
    __version_classes__ = (WeightSectionVersion,)

    def __init__(self, l7heavy_config_version, weight_section_versions):
        self.l7heavy_config_version = l7heavy_config_version
        self.weight_section_versions = weight_section_versions
        super(DiscoveredL7HeavyConfigVector, self).__init__()

    @classmethod
    def from_cache(cls, namespace_id, main_id):
        l7heavy_config_version = objects.L7HeavyConfig.version.from_pb(
            objects.L7HeavyConfig.cache.must_get(namespace_id, main_id)
        )
        weight_section_versions = {(pb.meta.namespace_id, pb.meta.id): objects.WeightSection.version.from_pb(pb)
                                   for pb in objects.WeightSection.cache.list(namespace_id)}
        return cls(l7heavy_config_version, weight_section_versions)

    def find_versions_to_update_in_state(self, vectors):
        """
        :type vectors: state_handler.vectors
        :rtype: set[Versions], set[Versions]
        """
        versions_to_add, versions_to_delete = super(DiscoveredL7HeavyConfigVector, self).find_versions_to_update_in_state(vectors)

        ns_id, l7heavy_config_id = self.l7heavy_config_version.id
        for weight_section_pb in objects.WeightSection.cache.list(ns_id):
            latest_ws_version = WeightSectionVersion.from_pb(weight_section_pb)
            full_ws_id = (ns_id, weight_section_pb.meta.id)
            current_ws_version = vectors.current.weight_section_versions.get(full_ws_id)
            if current_ws_version is None:
                if not latest_ws_version.deleted:
                    versions_to_add.add(latest_ws_version)
            elif latest_ws_version > current_ws_version:
                versions_to_add.add(latest_ws_version)
            elif latest_ws_version.deleted:
                versions_to_delete.add(latest_ws_version)
        return versions_to_add, versions_to_delete


class L7HeavyConfigStateHandler(StateHandler):
    __slots__ = ()
    __protobuf__ = model_pb2.L7HeavyConfigState
    __vector_class__ = L7HeavyConfigVector
    __mutable_vector_class__ = MutableL7HeavyConfigVector
    __zk_update_method__ = objects.L7HeavyConfig.state.zk.update

    @property
    def full_id(self):
        return to_full_id(self._pb.namespace_id, self._pb.l7heavy_config_id)

    def mark_version_as_active(self, versions):
        """
        :type version: Version
        """
        updated = False
        for state_pb in self.update_zk():
            self._pb = state_pb
            for version in versions:
                rev_pb = self._get_rev_pb(version)
                updated |= Rev.set_active(rev_pb, status=u'True')
        return updated
