"""Create group of bulk updaters and run them in parallel."""
import logging
import time
from contextlib import contextmanager

from gevent.event import Event
from gevent.lock import Semaphore, RLock
from gevent.pool import Group
from gevent.timeout import Timeout
from pymongo.errors import BulkWriteError

from sepelib.core.constants import MINUTE_SECONDS
from walle.stats import stats_manager

log = logging.getLogger(__name__)


def _stats_collector(key):
    return lambda key_part, value: stats_manager.add_sample(tuple(key) + tuple(key_part), value)


class AbstractBulkUpdatersGroup:
    """Group of bulk updaters exhibits customizable interface that have to be implemented for every use case.
    Thus, this is an abstract group without said interface but with scaffolding implementation.

    Concrete group needs to:
    1. Implement AbstractBulkUpdateHelper and add there any methods for your custom operations.
    2. Implement AbstractBulkUpdatersGroup and add there corresponding methods
        that pass data into BulkUpdateHelper via _defer_to method.
    """

    # empirical default values, seems optimal for the performance (when running from one instance).
    def __init__(
        self,
        stats_key,
        bulk_size_limit=2000,
        bulk_execution_timeout=MINUTE_SECONDS,
        parallel_bulks_limit=20,
        parallel_bulks_lock=None,
    ):

        self._stats_add_sample = _stats_collector(tuple(stats_key))

        self._bulk_size_limit = bulk_size_limit
        self._bulk_execution_timeout = bulk_execution_timeout

        if parallel_bulks_lock is not None:
            self._parallel_bulks_lock = parallel_bulks_lock
        else:
            # NB: this lock is not global
            self._parallel_bulks_lock = Semaphore(value=parallel_bulks_limit)

        self._lock = RLock()
        self._workers_pool = Group()
        self._stopped_event = Event()

        self._total_inserted = 0
        self._total_modified = 0

        self._operations_count = 0
        self._start_bulk()

    @contextmanager
    def current(self):
        with self._lock:
            yield self._get_current()

            self._operations_count += 1

            if self._operations_count >= self._bulk_size_limit:
                self.execute()
                self._start_bulk()

    def _execute_bulk(self, bulk, lock):
        timer = Timeout(self._bulk_execution_timeout)
        start = time.time()
        with lock, timer:
            try:
                result = bulk.execute()
            except BulkWriteError as e:
                result = e.details
                self._handle_bulk_write_error(e)
            except Timeout:
                result = {"nInserted": 0, "nModified": 0}  # unknown really
                log.error("Bulk operation timed out in %ss.", self._bulk_execution_timeout)

        execution_time = time.time() - start
        return result["nInserted"], result.get("nModified", 0), execution_time

    def _save_stats(self, greenlet):
        if greenlet.value:
            inserted, modified, execution_time = greenlet.value
            self._stats_add_sample(("bulk", "inserted"), inserted)
            self._stats_add_sample(("bulk", "modified"), modified)

            self._stats_add_sample(("bulk", "processing_time"), execution_time)

            with self._lock:
                self._total_inserted += inserted
                self._total_modified += modified
        elif greenlet.exception:
            log.debug("No stats for bulk because greenlet failed with %s", greenlet.exception)

    def _start_bulk(self):
        raise NotImplementedError

    def _get_current(self):
        raise NotImplementedError

    @staticmethod
    def _handle_bulk_write_error(error):
        log.error("Failed to execute bulk operation: %s: %s.", error, error.details)

    def execute(self):
        raise NotImplementedError

    def stop(self):
        self._stopped_event.set()
        self._workers_pool.kill()

    def finish(self):
        if self._stopped_event.is_set():
            return
        with self._lock:
            if self._operations_count:
                self.execute()
                self._operations_count = 0

        self._workers_pool.join(raise_error=True)

        self._stats_add_sample(("group", "inserted"), self._total_inserted)
        self._stats_add_sample(("group", "modified"), self._total_modified)
        self._stats_add_sample(("group", "total"), self._total_inserted + self._total_modified)


class SingleCollectionBulkGroup(AbstractBulkUpdatersGroup):
    _collection = None

    def __init__(
        self,
        collection,
        stats_key,
        bulk_size_limit=2000,
        bulk_execution_timeout=MINUTE_SECONDS,
        parallel_bulks_limit=20,
        parallel_bulks_lock=None,
    ):
        self._collection = collection
        super().__init__(
            stats_key,
            bulk_size_limit=bulk_size_limit,
            bulk_execution_timeout=bulk_execution_timeout,
            parallel_bulks_limit=parallel_bulks_limit,
            parallel_bulks_lock=parallel_bulks_lock,
        )

    def _start_bulk(self):
        self._operations_count = 0
        self._bulk = self._collection.initialize_unordered_bulk_op()

    def _get_current(self):
        return self._bulk

    def execute(self):
        if self._stopped_event.is_set():
            return
        greenlet = self._workers_pool.spawn(self._execute_bulk, self._bulk, self._parallel_bulks_lock)
        greenlet.rawlink(self._save_stats)


class MultipleCollectionsBulkGroup(AbstractBulkUpdatersGroup):
    def __init__(
        self,
        collections,
        stats_key,
        bulk_size_limit=2000,
        bulk_execution_timeout=MINUTE_SECONDS,
        parallel_bulks_limit=20,
        parallel_bulks_lock=None,
    ):
        self._collections = collections
        super().__init__(
            stats_key,
            bulk_size_limit=bulk_size_limit,
            bulk_execution_timeout=bulk_execution_timeout,
            parallel_bulks_limit=parallel_bulks_limit,
            parallel_bulks_lock=parallel_bulks_lock,
        )

    def _get_current(self):
        return self._bulks

    def _start_bulk(self):
        self._operations_count = 0
        self._bulks = tuple(c.initialize_unordered_bulk_op() for c in self._collections)

    def execute(self):
        for bulk in self._bulks:
            greenlet = self._workers_pool.spawn(self._execute_bulk, bulk, self._parallel_bulks_lock)
            greenlet.rawlink(self._save_stats)
