import dataclasses
import logging
import time
import typing as tp

import gevent
import mongoengine
import uhashring

from sepelib.core.constants import MINUTE_SECONDS
from walle import locks
from walle.util import cloud_tools
from walle.util.cloud_tools import get_tier

log = logging.getLogger(__name__)


@dataclasses.dataclass
class MongoPartitionerShard:
    id: str
    lock: locks.PartitionerShardInterruptableLock

    def __str__(self):
        return f"DMC shard {self.id}"


class _StopException(BaseException):
    """We need special exception for supervisor precise stopping"""


class MongoPartitionerService:
    def __init__(self, partitioner_type: str):
        self._partitioner_type = partitioner_type
        self._this_node = cloud_tools.get_process_identifier()
        self._party_lock = locks.PartitionerPartyInterruptableLock(partitioner_type, self._this_node)
        self._updater = None
        self._party_supervisor = None
        self._current_hashring = None
        self._stopped = False

    def _start_party_lock_supervisor(self):
        while not self._stopped:
            try:
                while not self._party_lock.acquire():
                    gevent.idle()
                self._update_party()
                while not self._stopped:
                    time.sleep(10)
            except _StopException:
                break
            except BaseException:
                log.exception("Party lock %s was lost")
                self._party_lock = locks.PartitionerPartyInterruptableLock(self._partitioner_type, self._this_node)
        log.info(f"Party supervisor {self._partitioner_type} was stopped")

    def start(self):
        self._party_supervisor = gevent.spawn(self._start_party_lock_supervisor)
        while not self._party_lock.acquired():
            time.sleep(0.1)
        self._updater = gevent.spawn(self._update)

    def stop(self):
        self._stopped = True
        gevent.kill(self._updater)
        gevent.kill(self._party_supervisor, _StopException)
        if self._party_lock.acquired():
            self._party_lock.release()

    def _update(self):
        while not self._stopped:
            try:
                time.sleep(MINUTE_SECONDS)
                self._update_party()
            except Exception:
                self._current_hashring = None
                log.exception(f"Uncaught update {self._partitioner_type} party error")

    def _update_party(self):
        current_party = self._party_lock.get_whole_party()
        logging.info("Current %s party is %s", self._partitioner_type, current_party)
        self._current_hashring = uhashring.HashRing(nodes=current_party, hash_fn="ketama")

    def get_shard(self, shard_id, log_state=False) -> tp.Optional[MongoPartitionerShard]:
        if log_state:
            logging.debug(f"Partitioner '{self._partitioner_type}': this node '{self._this_node}' shard '{shard_id}'")
        if self._current_hashring and self._current_hashring.get_node(shard_id) == self._this_node:
            return MongoPartitionerShard(
                shard_id,
                locks.PartitionerShardInterruptableLock(self._partitioner_type, shard_id),
            )

    def get_numeric_shards(self, total_shards_count: int) -> list[MongoPartitionerShard]:
        result = []
        for shard_id in range(total_shards_count):
            shard_id = str(shard_id)
            if shard := self.get_shard(shard_id):
                result.append(shard)
        return result


def get_host_mongo_shard_query(shard: MongoPartitionerShard, shards_count: int):
    return mongoengine.Q(inv__mod=(shards_count, int(shard.id)), tier=get_tier())
