import bson
import dataclasses
import logging
import typing

from pymongo import ASCENDING
from infra.rtc_sla_tentacles.backend.lib.config.interface import ConfigInterface
from infra.rtc_sla_tentacles.backend.lib.metrics.types import SloType
from infra.rtc_sla_tentacles.backend.lib.mongo.client import MongoClient

from infra.rtc_sla_tentacles.backend.lib.api.proto import incidents_pb2


class IncidentStorageException(Exception):
    pass


TIncidentId = bson.ObjectId


@dataclasses.dataclass(frozen=True)
class Incident:
    _id: typing.Optional[TIncidentId]
    start_time_ts: int
    end_time_ts: int
    allocation_zone: str
    slo_type: SloType
    assignee: str
    minutes_out_of_slo: int
    startrek_ticket_id: str

    def __eq__(self, other):
        self_dict = dataclasses.asdict(self)
        del self_dict["_id"]
        other_dict = dataclasses.asdict(other)
        del other_dict["_id"]
        return self_dict == other_dict

    def to_protobuf(self) -> incidents_pb2.TIncident:
        result = incidents_pb2.TIncident()
        if self._id is not None:
            result.id = str(self._id)
        result.start_time_ts = self.start_time_ts
        result.end_time_ts = self.end_time_ts
        result.allocation_zone = self.allocation_zone
        result.slo_type = IncidentStorage.slo_type_to_protobuf(self.slo_type)
        result.assignee = self.assignee
        result.minutes_out_of_slo = self.minutes_out_of_slo
        result.startrek_ticket_id = self.startrek_ticket_id
        return result

    @staticmethod
    def from_protobuf(message: incidents_pb2.TIncident):
        if not message.id:
            incident_id = None
        else:
            incident_id = IncidentStorage.id_from_string(message.id)

        return Incident(
            incident_id,
            message.start_time_ts,
            message.end_time_ts,
            message.allocation_zone,
            IncidentStorage.slo_type_from_protobuf(message.slo_type),
            message.assignee,
            message.minutes_out_of_slo,
            message.startrek_ticket_id
        )


@dataclasses.dataclass(frozen=True)
class IncidentFilter:
    started_after_ts: int = 0
    ended_before_ts: int = 0
    allocation_zones: typing.List[str] = dataclasses.field(default_factory=list)
    slo_types: typing.List[SloType] = dataclasses.field(default_factory=list)
    assignees: typing.List[str] = dataclasses.field(default_factory=list)
    startrek_ticket_ids: typing.List[str] = dataclasses.field(default_factory=list)

    def to_protobuf(self) -> incidents_pb2.TIncidentFilter:
        result = incidents_pb2.TIncidentFilter()
        result.started_after_ts = self.started_after_ts
        result.ended_before_ts = self.ended_before_ts
        for allocation_zone in self.allocation_zones:
            result.allocation_zones.append(allocation_zone)
        for slo_type in self.slo_types:
            result.slo_types.append(IncidentStorage.slo_type_to_protobuf(slo_type))
        for assignee in self.assignees:
            result.assignees.append(assignee)
        for startrek_ticket_id in self.startrek_ticket_ids:
            result.startrek_ticket_ids.append(startrek_ticket_id)
        return result

    @staticmethod
    def from_protobuf(message: incidents_pb2.TIncidentFilter):
        return IncidentFilter(
            message.started_after_ts,
            message.ended_before_ts,
            [
                allocation_zone
                for allocation_zone in message.allocation_zones
            ],
            [
                IncidentStorage.slo_type_from_protobuf(slo_type)
                for slo_type in message.slo_types
            ],
            [
                assignee
                for assignee in message.assignees
            ],
            [
                startrek_ticket_id
                for startrek_ticket_id in message.startrek_ticket_ids
            ],
        )


@dataclasses.dataclass(frozen=True)
class IncidentFiltersList:
    allocation_zones: typing.List[str] = dataclasses.field(default_factory=list)
    slo_types: typing.List[SloType] = dataclasses.field(default_factory=list)
    assignees: typing.List[str] = dataclasses.field(default_factory=list)

    def to_protobuf(self) -> incidents_pb2.TIncidentFiltersList:
        result = incidents_pb2.TIncidentFiltersList()
        for allocation_zone in self.allocation_zones:
            result.allocation_zones.append(allocation_zone)
        for slo_type in self.slo_types:
            result.slo_types.append(IncidentStorage.slo_type_to_protobuf(slo_type))
        for assignee in self.assignees:
            result.assignees.append(assignee)
        return result

    @staticmethod
    def from_protobuf(message: incidents_pb2.TIncidentFiltersList):
        return IncidentFiltersList(
            [
                allocation_zone
                for allocation_zone in message.allocation_zones
            ],
            [
                IncidentStorage.slo_type_from_protobuf(slo_type)
                for slo_type in message.slo_types
            ],
            [
                assignee
                for assignee in message.assignees
            ],
        )


# noinspection PyProtectedMember
class IncidentStorage:
    logger = logging.getLogger("common.incident_storage")

    def __init__(self, config_interface: ConfigInterface, mongo_client: MongoClient):
        """ Provides interface to incidents collection in MongoDb. """

        self._config = config_interface.get_incidents_config()
        mongo_database_handler = mongo_client.get_database(None)
        self._mongo_collection = mongo_database_handler.get_collection(self._config["incidents_collection_name"])

    def create_incident(self, incident: Incident) -> TIncidentId:
        incident_dict = dataclasses.asdict(incident)
        del incident_dict["_id"]
        return self._mongo_collection.insert_one(incident_dict).inserted_id

    def delete_incident(self, incident_id: TIncidentId):
        delete_result = self._mongo_collection.delete_one({"_id": incident_id})
        if not delete_result.deleted_count:
            raise IncidentStorageException("Can not delete incident ID %r: no documents deleted" % incident_id)

    def list_filters(self) -> IncidentFiltersList:
        return IncidentFiltersList(
            [
                allocation_zone
                for allocation_zone in self._config["allocation_zones"]
            ],
            [
                slo_type.name
                for slo_type in list(SloType)
            ],
            [
                assignee
                for assignee in self._config["assignees"]
            ],
        )

    def read_incidents(self,
                       incident_filter: IncidentFilter,
                       limit: int = 0,
                       offset: int = 0) -> (int, int, typing.List[Incident]):
        query = self._get_find_query(incident_filter)
        self.logger.debug("Query: %r" % query)

        total_number_of_incidents = self._mongo_collection.count_documents(query)
        if not offset and not limit:
            cursor = self._mongo_collection.find(query)
        else:
            cursor = self._mongo_collection.find(query, limit=limit, skip=offset)

        total_minutes_out_of_slo = self._get_total_minutes_out_of_slo(query)

        return total_number_of_incidents, total_minutes_out_of_slo, [
            Incident(**item)
            for item in cursor.sort("start_time_ts", ASCENDING)
        ]

    def read_incident_data(self, incident_id: TIncidentId) -> typing.Optional[Incident]:
        query = {"_id": incident_id}
        self.logger.debug("Query: %r" % query)
        find_result = self._mongo_collection.find_one(query)
        if not find_result:
            raise IncidentStorageException("Can not find data for incident ID %r" % incident_id)
        return Incident(**find_result)

    def update_incident(self, incident: Incident):
        _filter = {"_id": incident._id}
        incident_data_dict = dataclasses.asdict(incident)
        del incident_data_dict["_id"]
        update_query = {"$set": incident_data_dict}
        self.logger.debug("Update query: filter %r, data %r" % (_filter, update_query))
        update_result = self._mongo_collection.update_one(_filter, update_query)
        if not update_result.matched_count:
            raise IncidentStorageException("Can not update incident ID %r: no documents matched" % incident._id)
        if not update_result.modified_count:
            raise IncidentStorageException("Can not update incident ID %r: no documents updated" % incident._id)

    @staticmethod
    def _get_find_query(incident_filter: IncidentFilter) -> dict:
        subqueries = []
        if incident_filter.started_after_ts:
            subqueries.append({"start_time_ts": {"$gte": incident_filter.started_after_ts}})
        if incident_filter.ended_before_ts:
            subqueries.append({"end_time_ts": {"$lte": incident_filter.ended_before_ts}})
        if incident_filter.allocation_zones:
            subqueries.append({"allocation_zone": {"$in": incident_filter.allocation_zones}})
        if incident_filter.slo_types:
            subqueries.append({"slo_type": {"$in": incident_filter.slo_types}})
        if incident_filter.assignees:
            subqueries.append({"assignee": {"$in": incident_filter.assignees}})
        if incident_filter.startrek_ticket_ids:
            subqueries.append({"startrek_ticket_id": {"$in": incident_filter.startrek_ticket_ids}})

        if not subqueries:
            query = {}
        elif len(subqueries) == 1:
            query = subqueries[0]
        else:
            query = {"$and": subqueries}

        return query

    def _get_total_minutes_out_of_slo(self, query) -> int:
        aggregate_results = [
            result
            for result in self._mongo_collection.aggregate([
                {"$match": query},
                {"$group": {"_id": None, "total_minutes_out_of_slo": {"$sum": "$minutes_out_of_slo"}}},
                {"$project": {"total_minutes_out_of_slo": 1}}
            ])
        ]
        return aggregate_results[0]["total_minutes_out_of_slo"] if aggregate_results else 0

    @staticmethod
    def id_from_string(id_str: str) -> TIncidentId:
        try:
            incident_id = bson.ObjectId(id_str)
        except Exception:
            raise IncidentStorageException("Incident ID is not a BSON value")
        return incident_id

    @staticmethod
    def slo_type_to_protobuf(slo_type: SloType) -> incidents_pb2.ESloType:
        if slo_type == SloType.availability:
            return incidents_pb2.ESloType.AVAILABILITY
        elif slo_type == SloType.redeployed_on_time:
            return incidents_pb2.ESloType.REDEPLOYED_ON_TIME
        elif slo_type == SloType.reallocation:
            return incidents_pb2.ESloType.REALLOCATION
        elif slo_type == SloType.unused_nodes:
            return incidents_pb2.ESloType.UNUSED_NODES
        elif slo_type == SloType.scheduling_errors:
            return incidents_pb2.ESloType.SCHEDULING_ERRORS
        raise IncidentStorageException("Invalid SLO type value")

    @staticmethod
    def slo_type_from_protobuf(slo_type: incidents_pb2.ESloType) -> SloType:
        if slo_type == incidents_pb2.ESloType.AVAILABILITY:
            return SloType.availability
        elif slo_type == incidents_pb2.ESloType.REDEPLOYED_ON_TIME:
            return SloType.redeployed_on_time
        elif slo_type == incidents_pb2.ESloType.REALLOCATION:
            return SloType.reallocation
        elif slo_type == incidents_pb2.ESloType.UNUSED_NODES:
            return SloType.unused_nodes
        elif slo_type == incidents_pb2.ESloType.SCHEDULING_ERRORS:
            return SloType.scheduling_errors
        raise IncidentStorageException("Invalid SLO type value")
