# -*- coding: utf-8 -*-
import threading
from itertools import chain
import time
from intranet.yandex_directory.src.yandex_directory.directory_logging.logger import log

from intranet.yandex_directory.src.yandex_directory.core.monitoring.stats import (
    build_closed_service_stats,
    build_is_analytics_data_saved_to_yt,
    build_task_queue_stats,
    build_event_stats,
    build_sso_delay,
    build_cron_results,
)
from intranet.yandex_directory.src.yandex_directory.common.utils import (
    import_from_string,
    stopit, utcnow,
)
from intranet.yandex_directory.src.yandex_directory.core.utils import thread_log

common_metrics = [
    build_task_queue_stats,
    build_event_stats,
    build_is_analytics_data_saved_to_yt,
    build_sso_delay,
    build_cron_results,
]

common_metrics_data = {}
last_calc = None


def build_common_metrics_calc_delay():
    if last_calc is None:
        return []
    return [["common-metrics-calc-delay_axxx", utcnow().timestamp() - last_calc.timestamp()]]


def get_common_metric():
    return list(
        chain(
            *list(common_metrics_data.values())
        )
    )


class ThreadWithReturn(threading.Thread):
    def __init__(self, func, logger):
        super(ThreadWithReturn, self).__init__()
        self.func = func
        self.log = logger
        self._return = None

    def run(self):
        self.log.debug('Start metric calc')
        try:
            self._return = self.func()
        except Exception:
            self.log.trace().error('Error during calc metric')
            raise
        else:
            self.log.debug('End metric calc')

    def join(self, timeout=None):
        super(ThreadWithReturn, self).join(timeout)
        return self._return


def calc_common_metrics(app):
    for func in common_metrics:
        with log.fields(metric_func=func.__name__):
            try:
                wrapped = thread_log(func)
                t = ThreadWithReturn(wrapped, log)
                t.start()
                result = t.join()
                if result is not None:
                    common_metrics_data[func.__name__] = result
            except Exception:
                log.trace().error('Unable to collect metric')


def common_stats_updater(app):
    global last_calc
    while True:
        with log.name_and_fields('common_stats_calculator'):
            try:
                calc_common_metrics(app)
                last_calc = utcnow()
            except Exception:
                log.trace().error('Unable to calc common stats')
            finally:
                time.sleep(app.config['GOLOVAN_STATS_AGGREGATOR_PERIOD'])


def setup_stats_aggregator(app):
    try:
        StatsAggregator = import_from_string(app.config['GOLOVAN_STATS_AGGREGATOR_CLASS'], 'GOLOVAN_STATS_AGGREGATOR_CLASS')
    except ImportError:
        from golovan_stats_aggregator import MemoryStatsAggregator
        StatsAggregator = MemoryStatsAggregator
        log.warning(
            'Can\'t import %s for monitoring. Fallback to MemoryStatsAggregator' % app.config['GOLOVAN_STATS_AGGREGATOR_CLASS']
        )

    aggregator = StatsAggregator()

    metric_functions = [
        build_closed_service_stats,
        build_common_metrics_calc_delay,
    ]

    if app.config['COMMON_METRIC_CALC_ENABLED']:
        if app.testing:
            metric_functions += common_metrics
        else:
            aggregator.add_metric_func(get_common_metric)
            common_stats_updater_thread = threading.Thread(
                target=thread_log(common_stats_updater),
                args=(app,)
            )
            common_stats_updater_thread.daemon = True
            common_stats_updater_thread.start()

    for func in metric_functions:
        aggregator.add_metric_func(
            stopit(func, timeout=app.config['GOLOVAN_STATS_AGGREGATOR_FUNCTION_TIMEOUT'], default=[], raise_timeout=False)
        )

    app.stats_aggregator = aggregator
