import json
import zlib
import logging
import datetime as dt
import itertools as it
import functools as ft
import collections

import six
import setproctitle
import concurrent.futures

try:
    import kikimr.public.sdk.python.persqueue.auth as pq_auth
    import kikimr.public.sdk.python.persqueue.errors as pq_errors
    import kikimr.public.sdk.python.persqueue.grpc_pq_streaming_api as pq_api
    for module_name in (
        "ydb.public.sdk.python.ydb.connection",
        "ydb.public.sdk.python.ydb.resolver",
        "ydb.public.sdk.python.ydb.pool",
        "kikimr.public.sdk.python.persqueue._core",
        "kikimr.public.sdk.python.persqueue.grpc_pq_streaming_api",
    ):
        logging.getLogger(module_name).setLevel(logging.DEBUG)
except ImportError:
    pq_auth = None
    pq_errors = None
    pq_api = None

from sandbox.services.base import service

from sandbox import common
from sandbox.common import statistics as common_statistics
from sandbox.common.types import statistics as ctss
from sandbox.common import tvm as common_tvm
from sandbox.yasandbox.database import mapping as mp

from sandbox.services.modules.statistics_processor import processors


logger = logging.getLogger(__name__)
# noinspection PyTypeChecker
KNOWN_SIGNALS = list(ctss.SignalType) + list(ctss.YTSignalType)
DEFAULT_WORKER_INTERVAL = 10  # Time between subprocesses' loop iterations
LOGBROKER_API_TIMEOUT = 5

# Signals from Logbroker only
SIGNAL_GROUPS = {
    "serviceq": {
        ctss.SignalType.SERVICEQ_CALL,
        ctss.SignalType.SERVICEQ_TASK,
        ctss.SignalType.SERVICEQ_COUNTER,
        ctss.SignalType.SERVICEQ_SYSTEM_RESOURCES,
        ctss.SignalType.SERVICEQ_REPLICATION_DELAY,
        ctss.SignalType.QUOTA_CONSUMPTION,
        ctss.SignalType.QUOTA_CONSUMPTION_DETAILS,
        ctss.SignalType.API_QUOTA_CONSUMPTION,
    },
    "infrequent": {
        ctss.SignalType.RESOURCES_ON_STORAGES,
    },
}
SIGNALS_FROM_LOGBROKER_ONLY = (
    set(it.chain.from_iterable(SIGNAL_GROUPS.values())) - {ctss.SignalType.EXCEPTION_STATISTICS}
)
# Signals from either source

# TODO: move to LB-only SANDBOX-9127
SIGNAL_GROUPS["agentr"] = set(common_statistics.SignalHandlerForAgentR.types)
SIGNAL_GROUPS["proxy"] = set(common_statistics.UASignalHandlerForProxy.types)

SIGNAL_GROUPS["serviceapi"] = set(common_statistics.UASignalHandlerServiceApi.types)
SIGNAL_GROUPS["serviceapi_calls"] = set(common_statistics.UASignalHandlerServiceApiCalls.types)
SIGNAL_GROUPS["resource_audit"] = set(common_statistics.UASignalHandlerResourceAudit.types)


class WorkerState(common.patterns.Abstract):
    __slots__ = ("last_insert_at", "interval")
    __defs__ = (None, DEFAULT_WORKER_INTERVAL)


class StatisticsProcessor(service.MultiprocessedService):
    """
    Service process to push signals from MongoDB to ClickHouse/Solomon/... (SANDBOX-4457).
    Spawns as many threads as there are database collections and keeps an eye on them
    """

    tick_interval = 10

    #: Mapping (signal type -> list of processors to handle with)
    PROCESSOR_CLASSES = collections.defaultdict(
        lambda: [processors.clickhouse.ClickHouseProcessor],
        **{
            ctss.SignalType.AUDIT: [processors.clickhouse.AuditProcessor],
            ctss.SignalType.TASK_SESSION_COMPLETION: [processors.clickhouse.SessionCompletionProcessor],
            ctss.SignalType.API_QUOTA_CONSUMPTION: [processors.clickhouse.ApiQuotaConsumption],
            ctss.SignalType.RESOURCES_ON_STORAGES: [processors.clickhouse.MdsConsumption]
        }
    )

    PROCESSOR_CLASSES.update(**{
        signal_type: [processors.yt_processor.YTProcessor]
        for signal_type in ctss.YTSignalType
    })

    #: Mapping (signal type -> amount of signals to fetch from MongoDB at once)
    CHUNK_SZ = collections.defaultdict(
        lambda: processors.clickhouse.ClickHouseProcessor.BATCH_SZ,
        **{
            ctss.SignalType.TASK_SESSION_COMPLETION: 1000,
        }
    )
    CHUNK_SZ.update(**{
        signal_type: processors.yt_processor.YTProcessor.BATCH_SZ
        for signal_type in ctss.YTSignalType
    })

    MIN_CHUNK_SZ = 1000  # minimum amount of signals to be sent
    LOGBROKER_MAX_COUNT = 10000  # maximal logbroker messages count per request
    LOGBROKER_MAX_SIZE = 10 << 20  # maximal logbroker messages size per request in bytes
    MAX_WORKER_IDLE_TIME = 40  # maximal amount of time to spend without sending signals

    PROC_TITLE_DELIMITER = " :: "

    worker_state_class = WorkerState

    def __init__(self, *a, **kw):
        super(StatisticsProcessor, self).__init__(*a, **kw)
        self.__initialized = set()

        # hide spam about connections being dropped from urllib3's cached pool
        logging.getLogger("requests").setLevel(logging.ERROR)

    @property
    def targets(self):
        blacklist = processors.clickhouse.ClickHouseProcessor.blacklist
        targets = [
            self.Target(
                function=ft.partial(self.process_collection, signal_type=signal_type),
                interval=DEFAULT_WORKER_INTERVAL,
                name=signal_type,
                log_execution=False,
                stateful=True,
            )
            for signal_type in filter(
                lambda _: _ not in blacklist and _ not in SIGNALS_FROM_LOGBROKER_ONLY,
                KNOWN_SIGNALS
            )
        ]
        for signal_group, signal_types in SIGNAL_GROUPS.items():
            for endpoint_index in range(len(self.service_config["logbroker"]["endpoints"])):
                targets.append(
                    self.Target(
                        function=ft.partial(
                            self.process_logbroker,
                            signal_group=signal_group,
                            signal_types=signal_types,
                            service_config=self.service_config,
                            endpoint_index=endpoint_index,
                        ),
                        interval=DEFAULT_WORKER_INTERVAL,
                        name=signal_group,
                        log_execution=False,
                        stateful=True,
                    )
                )
        return targets

    def tick(self):
        for signal_type in KNOWN_SIGNALS:
            for processor_class in self.PROCESSOR_CLASSES[signal_type]:
                if processor_class.__name__ in self.__initialized:
                    continue

                ctx = processor_class.initialize(
                    self.context.setdefault(processor_class.__name__, {})
                )
                self.context[processor_class.__name__] = ctx
                self.__initialized.add(processor_class.__name__)

        super(StatisticsProcessor, self).tick()

    @staticmethod
    def make_db_connection():
        # always read from primary to avoid duplicates on consecutive queries
        db = mp.ensure_connection().rw.connection[
            common.config.Registry().server.services.statistics_processor.database
        ]
        db.logger = logger
        return db

    @staticmethod
    def patch_in_place(signal_):
        timestamp = signal_.get(processors.base.Processor.TIMESTAMP_FIELD)
        if not timestamp:
            timestamp = dt.datetime.utcnow()
        elif isinstance(timestamp, six.string_types):
            timestamp = processors.base.DATETIME_CONVERTER.decode(timestamp)
        elif isinstance(timestamp, int):
            timestamp = dt.datetime.utcfromtimestamp(timestamp)
        else:
            return
        signal_[processors.base.Processor.TIMESTAMP_FIELD] = timestamp

    @classmethod
    def _set_proc_title(cls, *parts):
        title = setproctitle.getproctitle()
        if cls.PROC_TITLE_DELIMITER not in title:
            setproctitle.setproctitle(cls.PROC_TITLE_DELIMITER.join((title,) + parts))

    @classmethod
    def process_collection(cls, state, signal_type):
        cls._set_proc_title("process_collection", signal_type)
        collection = cls.make_db_connection().get_collection(signal_type)
        procs = [class_() for class_ in cls.PROCESSOR_CLASSES[signal_type]]

        state.last_insert_at = state.last_insert_at or dt.datetime.utcnow()
        with common.utils.Timer() as timer:
            with timer["load from MongoDB"]:
                chunk = list(collection.find(limit=cls.CHUNK_SZ[signal_type]))
                state.interval = DEFAULT_WORKER_INTERVAL if len(chunk) < cls.MIN_CHUNK_SZ else 0
                if (
                    len(chunk) < cls.MIN_CHUNK_SZ and
                    (dt.datetime.utcnow() - state.last_insert_at).total_seconds() < cls.MAX_WORKER_IDLE_TIME
                ):
                    return None, [], state

            with timer["patch signals"]:
                map(cls.patch_in_place, chunk)
                ids = [signal.pop("_id", None) for signal in chunk]

            with timer["remove from MongoDB"]:
                delete_result = collection.delete_many({"_id": {"$in": ids}})
                logger.debug(
                    "(%s): read %d, deleted %s", signal_type, len(chunk),
                    delete_result.deleted_count if delete_result.acknowledged else "unknown amount of items"
                )

            for p in procs:
                processor_name = type(p).__name__
                try:
                    # noinspection PyUnresolvedReferences
                    p.process(chunk, signal_type, logger, timer)
                    state.last_insert_at = dt.datetime.utcnow()
                except (Exception, BaseException):
                    logger.exception(
                        "(from thread) %s: failed to process with %s", signal_type, processor_name
                    )
                    break

        logger.debug("%s: %s", signal_type, timer)
        return None, [], state

    @classmethod
    def _create_logbroker_consumers(cls, endpoints, tvm_id, topic, consumer_name, proc_logger):
        apis = []
        consumers = []
        api_start_futures = {}
        consumer_start_futures = {}
        for endpoint in endpoints:
            host, port = endpoint.split(":")
            api = pq_api.PQStreamingAPI(host, int(port))
            apis.append(api)
            api_start_futures[api.start()] = api
        done, not_done = concurrent.futures.wait(api_start_futures, timeout=LOGBROKER_API_TIMEOUT)
        for future in done:
            result = future.result()
            if isinstance(result, Exception):
                proc_logger.error(result)
                continue
            api = api_start_futures[future]
            credentials_provider = pq_auth.TVMCredentialsProvider(common_tvm.TVMClient, tvm_id)
            configurator = pq_api.ConsumerConfigurator(
                topic, consumer_name, read_only_local=True,
                max_count=cls.LOGBROKER_MAX_COUNT, max_size=cls.LOGBROKER_MAX_SIZE
            )
            consumer = api.create_consumer(configurator, credentials_provider=credentials_provider)
            consumer_start_futures[consumer.start()] = consumer
        done, not_done = concurrent.futures.wait(consumer_start_futures, timeout=LOGBROKER_API_TIMEOUT)
        for future in done:
            result = future.result()
            if isinstance(result, pq_errors.SessionFailureResult):
                proc_logger.error("Error occurred on start of consumer: %s", result)
                continue
            elif not result.HasField("init"):
                proc_logger.error("Bad consumer start result from server: %s", result)
                continue
            consumers.append(consumer_start_futures[future])
        proc_logger.debug("%s consumer(s) started", len(consumers))
        return apis, consumers

    @classmethod
    def _get_signals_from_logbroker(cls, consumers, signal_group, proc_logger):
        chunks = collections.defaultdict(list)
        chunks_total = 0
        cookies_to_commit = collections.defaultdict(list)
        start_time = dt.datetime.utcnow()
        while (
            chunks_total < cls.CHUNK_SZ[signal_group] and
            (dt.datetime.utcnow() - start_time).total_seconds() < cls.MAX_WORKER_IDLE_TIME
        ):
            next_event_futures = {}
            for consumer in consumers:
                future = consumer.next_event()
                next_event_futures[future] = consumer
            done = concurrent.futures.wait(next_event_futures, timeout=LOGBROKER_API_TIMEOUT)[0]
            if not done:
                break
            for future in done:
                consumer = next_event_futures[future]
                result = future.result()
                if result.type == pq_api.ConsumerMessageType.MSG_DATA:
                    for batch in result.message.data.message_batch:
                        for message in batch.message:
                            if message.meta.codec == pq_api.WriterCodec.RAW.value:
                                data = message.data
                            elif message.meta.codec == pq_api.WriterCodec.GZIP.value:
                                data = zlib.decompress(message.data, 16 + zlib.MAX_WBITS)
                            else:
                                proc_logger.error("Unsupported codec: %s", message.meta.codec)
                                continue
                            for data in data.split("\n"):
                                if not data:
                                    continue
                                data = json.loads(data)
                                for signal_type, chunk in data.items():
                                    chunks[signal_type].extend(chunk)
                                    chunks_total += len(chunk)
                    cookies_to_commit[consumer].append(result.message.data.cookie)
                elif result.type == pq_api.ConsumerMessageType.MSG_ERROR:
                    proc_logger.error("Got error message: %s", result.message)
                    return [], 0, cookies_to_commit
                else:
                    proc_logger.error("Got unknown message: %s", result.message)

            for consumer in consumers:
                consumer.reads_done()

        return chunks, chunks_total, cookies_to_commit

    @classmethod
    def process_logbroker(cls, state, signal_group, signal_types, service_config, endpoint_index):
        logbroker_config = service_config["logbroker"]
        endpoint = logbroker_config["endpoints"][endpoint_index]
        cls._set_proc_title("process_logbroker", " ".join(signal_types), endpoint)
        tvm_id = logbroker_config["tvm_id"]
        topic = "{}/{}".format(logbroker_config["topic_dir"], signal_group)
        proc_logger = common.log.MessageAdapter(
            logger.getChild(signal_group), fmt="[{}] %(message)s".format(signal_group)
        )
        apis, consumers = cls._create_logbroker_consumers(
            [endpoint], tvm_id, topic, logbroker_config["consumer"], proc_logger
        )
        try:
            if not consumers:
                proc_logger.warning("No consumers created")
                return None, [], state
            signals_procs = {
                signal_type: [class_() for class_ in cls.PROCESSOR_CLASSES[signal_type]]
                for signal_type in signal_types
            }
            state.last_insert_at = state.last_insert_at or dt.datetime.utcnow()
            with common.utils.Timer() as timer:
                with timer["get_signals"]:
                    chunks, chunks_total, cookies_to_commit = cls._get_signals_from_logbroker(
                        consumers, signal_group, proc_logger
                    )
                    proc_logger.debug("%d signal(s) received", chunks_total)
                    state.interval = DEFAULT_WORKER_INTERVAL if chunks_total < cls.MIN_CHUNK_SZ else 0
                    if (
                        chunks_total < cls.MIN_CHUNK_SZ and
                        (dt.datetime.utcnow() - state.last_insert_at).total_seconds() < cls.MAX_WORKER_IDLE_TIME
                    ):
                        return None, [], state

                if not chunks_total:
                    return None, [], state

                with timer["patch_signals"]:
                    for chunk in chunks.values():
                        for signal in chunk:
                            cls.patch_in_place(signal)

                with timer["commit_signals"]:
                    last_received_cookie = {}
                    consumers_to_commit = set()
                    for consumer, cookies in cookies_to_commit.items():
                        last_received_cookie[consumer] = cookies[-1]
                        consumers_to_commit.add(consumer)
                        consumer.commit(cookies)
                    while consumers_to_commit:
                        next_event_futures = {}
                        for consumer in consumers:
                            future = consumer.next_event()
                            next_event_futures[future] = consumer
                        done = concurrent.futures.wait(next_event_futures, timeout=LOGBROKER_API_TIMEOUT)[0]
                        if not done:
                            proc_logger.warning("Commit timed out")
                            break
                        for future in done:
                            consumer = next_event_futures[future]
                            if consumer not in consumers_to_commit:
                                continue
                            result = future.result()
                            if result.type == pq_api.ConsumerMessageType.MSG_COMMIT:
                                if last_received_cookie[consumer] == result.message.commit.cookie[-1]:
                                    consumers_to_commit.remove(consumer)
                            elif result.type == pq_api.ConsumerMessageType.MSG_ERROR:
                                proc_logger.error("Got error message: %s", result.message)
                    proc_logger.debug("%d signal(s) committed", chunks_total)

            signal_types = set(chunks.keys())
            for signal_type, procs in signals_procs.items():
                chunk = chunks[signal_type]
                if not chunk:
                    continue
                signal_types.discard(signal_type)
                for proc in procs:
                    processor_name = type(proc).__name__
                    proc_logger.debug(
                        "Processing %s signal(s) of type %s by %s", len(chunk), signal_type, processor_name
                    )
                    # noinspection PyBroadException
                    try:
                        # noinspection PyUnresolvedReferences
                        proc.process(chunk, signal_type, proc_logger, timer)
                        state.last_insert_at = dt.datetime.utcnow()
                    except BaseException:
                        proc_logger.exception(
                            "failed to process with %s", processor_name
                        )
                        break
            if signal_types:
                proc_logger.warning(
                    "%s are not processed",
                    ", ".join(
                        "{} signal(s) of type {}".format(len(chunks[signal_type]), signal_type)
                        for signal_type in signal_types
                    )
                )
            proc_logger.debug("%s", timer)
            return None, [], state
        finally:
            watchdog = common.threading.KamikadzeThread(LOGBROKER_API_TIMEOUT * 3, logger=proc_logger)
            watchdog.start()
            proc_logger.debug("Stopping Logbroker consumers")
            for consumer in consumers:
                consumer.stop()
            proc_logger.debug("Stopping Logbroker API")
            for api in apis:
                api.stop()
            proc_logger.debug("Logbroker API stopped")
            watchdog.stop()
