import dataclasses
import logging
import typing

# noinspection PyUnresolvedReferences
from pymongo import DESCENDING, MongoClient
# noinspection PyUnresolvedReferences
from pymongo.collection import Collection
from pymongo.errors import PyMongoError

from infra.rtc_sla_tentacles.backend.lib.harvesters_snapshots.snapshots import HarvesterSnapshotLabel, HarvesterSnapshot
from infra.rtc_sla_tentacles.backend.lib.funccall_stats_server import server as stat_server


class InsertSnapshotError(Exception):
    pass


class SnapshotNotFoundError(Exception):
    pass


class HarvesterSnapshotManager:
    """
        Harvesters' snapshots manager hides MongoDb when accessing
        harvesters' stored data. Data is represented by `HarvesterSnapshot`
        objects. `HarvesterSnapshotLabel` structure is a 'key' for snapshot
        object, and is a part of it.
    """
    logger = logging.getLogger("harvesters.snapshot_manager")

    def __init__(self, mongo_client: MongoClient):
        self._mongo_client = mongo_client

    def write_snapshot(self, snapshot: HarvesterSnapshot):
        collection = self._get_collection_handlers(snapshot.label.harvester_type)
        docs_for_insertion = [{
            "label": dataclasses.asdict(snapshot.label),
            "debug_info": snapshot.debug_info,
            "meta": snapshot.meta,
            "data": data,
        } for data in self._split_snapshot_on_chunks(snapshot)]
        with stat_server.mongo_timing():
            insert_result = collection.insert_many(docs_for_insertion)
        if len(insert_result.inserted_ids) != len(docs_for_insertion):
            raise InsertSnapshotError(snapshot.label)

    @staticmethod
    def _split_snapshot_on_chunks(snapshot: HarvesterSnapshot) -> typing.List[typing.Dict]:
        data_list_path = snapshot.label.data_list_path
        chunk_size = snapshot.label.chunk_size
        if not snapshot.data or not chunk_size or not snapshot.data[data_list_path]:
            return [snapshot.data]
        chunks_list = []
        ind = 0
        data_length = len(snapshot.data[data_list_path])
        while ind < data_length:
            chunks_list.append({
                data_list_path: snapshot.data[data_list_path][ind:min(ind + chunk_size, data_length)]
            })
            ind += chunk_size
        return chunks_list

    @staticmethod
    def _make_data_from_chunks(label: HarvesterSnapshotLabel,
                               chunks_cursor: list) -> typing.Dict[str, list]:
        if not label.chunk_size or chunks_cursor[0]["data"] is None:
            return chunks_cursor[0]["data"]
        data: typing.Dict[str, list] = {
            label.data_list_path: []
        }
        for chunk in chunks_cursor:
            data[label.data_list_path].extend(chunk["data"][label.data_list_path])
        return data

    def read_snapshot(self, label: HarvesterSnapshotLabel, projection=None) -> typing.Optional[HarvesterSnapshot]:
        try:
            collection = self._get_collection_handlers(label.harvester_type)
            find_one_query = {
                "label.ts": {"$eq": label.ts},
                "label.harvester_type": {"$eq": label.harvester_type},
                "label.harvester_name": {"$eq": label.harvester_name}
            }

            with stat_server.mongo_timing():
                if projection:
                    full_projection = projection.copy()
                    full_projection["debug_info"] = True
                    full_projection["meta"] = True
                    chunks = list(collection.find(find_one_query, full_projection))
                else:
                    chunks = list(collection.find(find_one_query))
            if not chunks:
                return
            return HarvesterSnapshot(label=label,
                                     debug_info=chunks[0]["debug_info"],
                                     meta=chunks[0]["meta"],
                                     data=self._make_data_from_chunks(label, chunks))
        except PyMongoError as _exc:
            self.logger.error("label='%r', read_snapshot_error='%r'" % (label, str(_exc)))
            return

    def find_labels(self,
                    harvester_type: str,
                    harvester_name: str,
                    meta_query: typing.Optional[dict] = None,
                    ts: typing.Optional[int] = None,
                    limit: int = 0) -> typing.List[HarvesterSnapshotLabel]:
        try:
            find_query = {
                "label.harvester_type": {"$eq": harvester_type},
                "label.harvester_name": {"$eq": harvester_name},
            }
            if ts is not None:  # It is Integer, may be '0', so check against 'is not None'.
                find_query["label.ts"] = {"$eq": ts}
            if meta_query:
                find_query["meta"] = meta_query
            collection = self._get_collection_handlers(harvester_type)
            result = []
            with stat_server.mongo_timing():
                for record in collection.find(find_query).sort("label.ts", DESCENDING).limit(limit):
                    result.append(HarvesterSnapshotLabel(
                        ts=record["label"]["ts"],
                        harvester_type=record["label"]["harvester_type"],
                        harvester_name=record["label"]["harvester_name"],
                        chunk_size=record["label"]["chunk_size"],
                        data_list_path=record["label"]["data_list_path"]
                    ))
            return result
        except PyMongoError as _exc:
            self.logger.error(("ts='%d', harvester_type='%r', harvester_name='%r', meta_query='%r', limit='%d'"
                               "find_labels_error='%r'") %
                              (ts, harvester_type, harvester_name, meta_query, limit, str(_exc)))
            return []

    def read_last_snapshot(self, harvester_type, harvester_name=None, projection=None, check_freshness=False):
        if not harvester_name:
            harvester_name = harvester_type
        snapshot_labels = self.find_labels(harvester_type, harvester_name, limit=1)
        if not snapshot_labels:
            raise SnapshotNotFoundError("No snapshot available")
        target_label = snapshot_labels[0]
        if check_freshness:
            pass
            # TODO(rocco66): TENTACLES-327 check for 2.5 x harvester_interval
        return self.read_snapshot(target_label, projection)

    def read_last_snapshot_data(self, *args, **kwargs):
        return self.read_last_snapshot(*args, **kwargs).data

    def get_last_snapshot_labels_from_harvester(
        self,
        harvester_type: str
    ) -> typing.Optional[typing.List[HarvesterSnapshotLabel]]:
        try:

            collection = self._get_collection_handlers(harvester_type)
            with stat_server.mongo_timing():
                harvester_names = collection.distinct(
                    "label.harvester_name",
                    {
                        "label.harvester_type": {"$eq": harvester_type},
                        "meta": {"$ne": None}
                    }
                )
            result = []
            for harvester_name in harvester_names:
                labels = self.find_labels(harvester_type=harvester_type,
                                          harvester_name=harvester_name,
                                          meta_query={"$ne": None},
                                          limit=1)
                result.extend(labels)
            return result
        except PyMongoError as _exc:
            self.logger.error("harvester_type='%r', get_last_snapshot_labels_from_harvester='%r'" % (harvester_type,
                              str(_exc)))
            return None

    def clean_old_snapshots(self, harvester_type: str, harvester_name: str, ts_border: int) -> int:
        try:
            collection = self._get_collection_handlers(harvester_type)
            with stat_server.mongo_timing():
                after_border = collection.find({
                    "label.ts": {"$qt": ts_border},
                    "label.harvester_type": {"$eq": harvester_type},
                    "label.harvester_name": {"$eq": harvester_name}
                }, {"_id": 1}, limit=1)
            if not after_border:
                return 0
            with stat_server.mongo_timing():
                result = collection.delete_many({
                    "label.ts": {"$lt": ts_border},
                    "label.harvester_type": {"$eq": harvester_type},
                    "label.harvester_name": {"$eq": harvester_name}
                })
            return result.deleted_count
        except PyMongoError as _exc:
            self.logger.error(
                "'%s/%s', border=%s, delete_snapshot_error='%r'"
                % (harvester_type, harvester_name, ts_border, str(_exc))
            )

    def _get_collection_handlers(self, harvester_name: str) -> Collection:
        mongo_database_handler = self._mongo_client.get_database(None)
        return mongo_database_handler[f"harvester_{harvester_name}"]
