"""Provides distributed locks with help of MongoDB."""

import logging
import typing as tp
from datetime import datetime, timedelta

import gevent
import mongoengine
import pytz
from apscheduler.schedulers.gevent import GeventScheduler as Scheduler
from gevent.event import Event
from mongoengine import StringField
from mongoengine.fields import DateTimeField
from pymongo.errors import ConnectionFailure, OperationFailure

from sepelib.mongo.util import register_model
from walle.models import Document
from walle.stats import stats_manager as stats
from walle.util import cloud_tools
from walle.util.misc import StopWatch
from walle.util.mongo.retry import MongoRetry

log = logging.getLogger(__name__)

MONGO_ERRORS = (ConnectionFailure, OperationFailure)


class LockError(Exception):
    """Base class for mongo locks exceptions"""

    def __init__(self, message: str):
        super().__init__(message)


class LockIsLostError(BaseException):
    """Lock is lost due to some reason.

    Derived directly from `BaseException` because we should not catch it outside this module,
    in particular, before `__exit__` call of `InterruptableLock`."""

    def __init__(self, locked_object_id: str, message: str):
        super().__init__(message)
        self.locked_object_id = locked_object_id

    def to_exception(self) -> LockError:
        return LockError(str(self))


class LockIsExpiredError(LockIsLostError):
    """Lock is lost due to timeout of our instance heartbeat"""


class LockConnectionError(LockIsLostError):
    """Connection to database with locks is broken"""


class LockDatabaseError(LockIsLostError):
    """Operation on database with locks is failed"""


class LockedObjectError(LockError):
    """Attempt to lock already locked object"""


class RepeatedAcquireError(LockError):
    """Attempt to acquire lock which is already acquired by this instance"""


class RepeatedReleaseError(LockError):
    """Attempt to release lock which is already released by this instance"""


class LockTimeoutError(LockError):
    """Failed to acquire lock after expected timeout"""


@register_model
class MongoLock(Document):
    locked_object_id = StringField(required=True, primary_key=True, help_text="Locked object id")
    instance = StringField(required=True, help_text="Identifier of instance which acquires lock")
    locked_until = DateTimeField(required=True, help_text="Timestamp until lock is currently acquired")

    meta = {"collection": "locks", "indexes": [{"fields": ["locked_until"], "expireAfterSeconds": 0}]}


class _LockWrapper:
    """Wrapper for MongoLock management."""

    def __init__(self, locked_object_id: str) -> None:
        self.locked_object_id = locked_object_id
        self.instance = cloud_tools.get_process_identifier()
        self.acquired = False
        self.locked_until: tp.Optional[datetime] = None
        self.released_event = Event()
        self.operation_lock = gevent.lock.RLock()

    def acquire(self, locking_time: timedelta) -> bool:
        """Tries to acquire lock.

        Tries to create lock object and store it to database. If it already exists raises `RepeatedAcquireError`"""
        if self.acquired:
            raise RepeatedAcquireError(
                "Lock on object {} is already acquired by this instance".format(self.locked_object_id)
            )
        with self.operation_lock:
            self.released_event.clear()
            lock = MongoLock(
                locked_object_id=self.locked_object_id,
                instance=self.instance,
                locked_until=datetime.utcnow() + locking_time,
            )

            try:
                lock.save(force_insert=True)
                self.locked_until = lock.locked_until
                self.acquired = True
            except mongoengine.NotUniqueError:
                self.acquired = False
                raise LockedObjectError(
                    "Trying to acquire lock on already locked object {}".format(self.locked_object_id)
                )

            return self.acquired

    def release(self) -> None:
        """Release current lock.

        If it does not exist in database do nothing."""
        with self.operation_lock:
            MongoLock.objects(
                locked_object_id=self.locked_object_id, instance=self.instance, locked_until=self.locked_until
            ).delete()
            self.released_event.set()
            self.acquired = False

    def extend_timeout(self, extending_time: timedelta) -> None:
        """Tries to extend `locked_until` parameter of lock if current value ended less than 3 minutes
        after this moment.

        If this lock is not found in database, raises `LockIsLostError`."""
        with self.operation_lock:
            error = LockIsExpiredError(
                self.locked_object_id,
                "Lock on object {} for instance {} is expired. It was acquired until {}".format(
                    self.locked_object_id, self.instance, self.locked_until.strftime("%Y.%m.%d %H:%M:%S %z")
                ),
            )

            if not self.acquired or self.released_event.is_set():
                return

            if datetime.utcnow() + extending_time >= self.locked_until:
                new_locked_until = max(datetime.utcnow(), self.locked_until) + extending_time
                if MongoLock.objects(
                    locked_object_id=self.locked_object_id, instance=self.instance, locked_until=self.locked_until
                ).update_one(set__locked_until=new_locked_until):
                    self.locked_until = new_locked_until
                else:
                    raise error
            else:
                try:
                    MongoLock.objects(
                        locked_object_id=self.locked_object_id, instance=self.instance, locked_until=self.locked_until
                    ).get()
                except mongoengine.DoesNotExist:
                    raise error


class InterruptableLock:
    """
    This lock manages `MongoLock` for specified object and provide the following functionality:

    If mongodb connection breaks or lock is lost between acquire() and release() the greenlet will be killed with
    LockIsLostError. This exception will be raised only between acquire() and release() calls,
    so any connection error after release() won't raise anything.
    Please notice that LockIsLostError is derived from BaseException. This is done to get more
    guarantees that the code will be actually interrupted.

    __exit__() translates `LockIsLostError` into `LockError` which is derived from `Exception`, so you can catch
    any errors outside of context manager with `except Exception` clause.

    Attention: this lock must be used only in gevent:
    * Interruption works only in gevent.
    * The code is not thread-safe.
    * gevent.sleep for background job is used
    """

    EXTENDING_DIVISOR = 3

    def __init__(
        self,
        locked_object: str,
        locking_time: tp.Union[int, timedelta] = timedelta(minutes=10),
        prefix: tp.Optional[str] = None,
        blocking: bool = True,
        timeout: tp.Optional[int] = None,
    ) -> None:
        """Construct new `InterruptableLock` object.

        :param locked_object: id of locked object.
        :param locking_time: expected time for lock acquisition
        :param prefix: optional prefix for locked object id
        :param blocking: does this lock blocking or not, see `acquire`
        :param timeout: default timeout for acquiring lock, see `acquire`
        """
        locked_object_id = prefix + "/" + locked_object if prefix is not None else locked_object
        self.__lock = _LockWrapper(locked_object_id)
        self.__locked_object_id = locked_object_id
        self.__locking_time = max(
            timedelta(minutes=3),
            (locking_time if isinstance(locking_time, timedelta) else timedelta(seconds=locking_time)),
        )
        self.__greenlet: tp.Optional[gevent.Greenlet] = None
        self.__blocking = blocking
        self.__timeout = timeout
        self.__acquired = False
        self.__interrupted = False
        self.__release_retry = MongoRetry(log, retry_attempts=3)

    def acquired(self) -> bool:
        return self.__acquired

    def acquire(self, blocking: tp.Optional[bool] = None, timeout: tp.Optional[int] = None) -> bool:
        """Tries to acquire lock.

        Spawns a helper greenlet to periodically increase time unit which lock is acquired. If lock is lost
        because of lock timeout it breaks main greenlet with `LockIsExpiredError` and release lock.

        Because of possible interruptions it is not recommended to use this method directly, use provided
        context manager instead.

        :param blocking: if true, try to acquire lock many times until timeout, otherwise make only one attempt
        :param timeout: optional, limits time to acquire lock if blocking is true"""
        if self.__acquired:
            raise RepeatedAcquireError(
                "Lock on object {} is already acquired by this instance".format(self.__locked_object_id)
            )

        if blocking is None:
            blocking = self.__blocking

        if timeout is None:
            timeout = self.__timeout

        try:
            self.__acquired = self.__acquire(blocking, timeout)
        except BaseException:
            raise

        if not self.__acquired:
            return False
        try:
            self.__greenlet = gevent.getcurrent()
            _HB.add_lock(self.__lock, self.__heartbeat_callback, self.__locking_time // self.EXTENDING_DIVISOR)
        except Exception:
            _HB.delete_lock(self.__lock)
            self.__lock.release()
            self.__acquired = False
            raise

        return True

    def __acquire(self, blocking: tp.Optional[bool] = None, timeout: tp.Optional[int] = None) -> bool:
        start_time = datetime.utcnow()
        acquiring_interval = 2
        error = None
        while True:
            if timeout is not None and start_time + timedelta(seconds=timeout) < datetime.utcnow():
                raise LockTimeoutError(
                    "Failed to acquire lock on {} after {} seconds because of {}".format(
                        self.__locked_object_id, timeout, error
                    )
                )
            try:
                self.__lock.acquire(self.__locking_time)
                return True
            except (LockError, Exception) as e:
                error = e
                if not blocking:
                    log.info("Failed to acquire non-blocking mongo lock on %s: %s", self.__locked_object_id, e)
                    return False
                log.warning("Failed to acquire lock on %s because of %s", self.__locked_object_id, e)
                gevent.sleep(acquiring_interval)

    def __heartbeat_callback(self, e: BaseException):
        if isinstance(e, LockIsExpiredError):
            error = e
            log.error("Mongo lock is lost due to TTL. Aborting lock on %s with error %s...", self.__locked_object_id, e)
        elif isinstance(e, ConnectionFailure):
            error = LockConnectionError(
                self.__locked_object_id, "Lock is lost because of problems with connection: {}".format(e)
            )
            log.error("Connection to mongo is broken. Aborting lock on %s...", self.__locked_object_id)
        elif isinstance(e, OperationFailure):
            error = LockDatabaseError(self.__locked_object_id, "Operation on lock's database failed: {}".format(e))
            log.error("Extending timeout in database failed. Aborting lock on %s...", self.__locked_object_id)
        else:
            error = LockIsLostError(self.__locked_object_id, "Lock is lost because of exception: {}".format(e))
            log.error(
                "Failed to check lock state with exception %s. Aborting lock on %s...", e, self.__locked_object_id
            )
        if not self.__acquired or self.__interrupted or self.__lock.released_event.is_set():
            log.error(
                "Cancelling abortion of lock on %s: %s.",
                self.__locked_object_id,
                "it's already interrupted" if self.__interrupted else "it's not acquired already",
            )
            return

        self.__interrupted = True
        gevent.kill(self.__greenlet, error)

    def release(self) -> None:
        if not self.__acquired:
            raise RepeatedReleaseError("Lock is not acquired")

        _HB.delete_lock(self.__lock)
        try:
            self.__release_retry(self.__lock.release)
        finally:
            self.__acquired = False

    def __enter__(self):
        if not self.acquire():
            raise LockError(f"Can't acquire lock {self.__locked_object_id}")
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        if self.acquired():
            try:
                self.release()
            except Exception as e:
                log.error("Failed to release lock on %s: %s", self.__locked_object_id, e)
                raise

        if isinstance(exc_val, LockIsLostError):
            if not (exc_val.locked_object_id == self.__locked_object_id):
                raise exc_val
            else:
                raise exc_val.to_exception()


# TODO(rocco66): it does not used, remove?
def lost_mongo_lock_retry(func):
    """Wraps a callable object to retry calls that raise LockIsLostError."""

    def wrapper(*args, **kwargs):
        error = None

        while True:
            try:
                result = func(*args, **kwargs)
            except LockIsLostError as e:
                error = e
                log.error("Retry %s call due to error: %s", func, error)
            else:
                if error is None:
                    return result
                else:
                    raise error

    return wrapper


def start_heartbeat(scheduler: Scheduler) -> None:
    name = "Heartbeat for Mongo locks"
    start_date = datetime.utcnow()
    scheduler.add_job(
        _HB.run,
        "interval",
        misfire_grace_time=10,
        name=name,
        start_date=start_date,
        seconds=20,
        timezone=pytz.timezone("Europe/Moscow"),
    )


def stop_heartbeat() -> None:
    log.info("Waiting for releasing all mongo locks...")
    _HB.stop()


class _Heartbeat:
    """
    Provides `heartbeat` for Mongo locks. If lock downtime is closing but lock itself is still acquired
    this instance increase timeout. Running periodically, this instance also checks lock actuality in database,
    and tries to interrupt lock if necessary.
    """

    def __init__(self) -> None:
        self.__locks: tp.Dict[_LockWrapper, tp.Tuple[tp.Callable[[BaseException], None], timedelta]] = dict()
        self.__data_lock = gevent.lock.RLock()
        self.__stopping = False

    def add_lock(
        self, lock: _LockWrapper, callback: tp.Callable[[BaseException], None], extending_time: timedelta
    ) -> None:
        with self.__data_lock:
            if self.__stopping:
                raise LockError("Mongo lock heartbeat is stopping")
            self.__locks[lock] = (callback, extending_time)

    def delete_lock(self, lock: _LockWrapper):
        with self.__data_lock:
            self.__locks.pop(lock, None)

    def run(self) -> None:
        stopwatch = StopWatch()
        log.info("Heartbeat job started")
        try:
            with self.__data_lock:
                if self.__stopping:
                    return
                locks = self.__locks.copy()

            log.info("Locked objects in heartbeat %s", ", ".join(lock.locked_object_id for lock in locks))

            lost_locks: tp.List[_LockWrapper] = list()
            for lock, (callback, extending_time) in locks.items():
                extend_type = (
                    "Extending timeout"
                    if lock.locked_until <= datetime.utcnow() + extending_time
                    else "Checking availability"
                )
                log.info(
                    "%s of lock for object %s with extending time %s",
                    extend_type,
                    lock.locked_object_id,
                    extending_time,
                )
                retry = MongoRetry(log, retry_attempts=3, interrupt_event=lock.released_event)
                try:
                    retry(lock.extend_timeout, extending_time)
                except (LockIsLostError, Exception) as e:
                    callback(e)
                    lost_locks.append(lock)
                else:
                    log.info("%s of lock for object %s finished successfully", extend_type, lock.locked_object_id)

            if lost_locks:
                with self.__data_lock:
                    for lock in lost_locks:
                        self.__locks.pop(lock, None)
        except Exception as e:
            log.exception("Heartbeat job failed with exception %s. Execution time %s seconds", e, stopwatch.get())
        else:
            log.info("Heartbeat job finished successfully. Execution time %s seconds", stopwatch.get())
        finally:
            stats.add_sample(("locks", "heartbeat_job_iteration_time"), stopwatch.get())

    def stop(self):
        with self.__data_lock:
            self.__stopping = True
            locks = self.__locks.copy()

        for lock in locks.keys():
            if not lock.released_event.wait(timeout=1):
                log.error(
                    "Failed to wait releasing mongo lock on %s. It will be released by TTL after %s",
                    lock.locked_object_id,
                    lock.locked_until.strftime("%Y.%m.%d %H:%M:%S %z"),
                )

        with self.__data_lock:
            self.__locks = dict()
            self.__stopping = False


_HB = _Heartbeat()
