import time
import operator
import itertools as it
import datetime as dt
import functools as ft
import collections

import sandbox.common.types.task as ctt
import sandbox.common.types.database as ctd
import sandbox.serviceq.errors

from sandbox.common import log
from sandbox.common import utils
from sandbox.common import config
from sandbox.yasandbox import manager
from sandbox.yasandbox import controller
from sandbox.yasandbox.services import zk
from sandbox.yasandbox.database import mapping
from sandbox.yasandbox.proxy import task as task_proxy


class StatisticsUpdater(zk.ThreadWithZK):
    """Service thread for statistics update"""

    Keys = mapping.Statistics.Keys
    SERVICE_DATA_TTL = 3600
    _Clients = collections.defaultdict(set)
    _last_update_time = 0
    _SERVICE_TASKS = []
    ERROR_STATUSES = [
        ctt.Status.FAILURE,
        ctt.Status.EXCEPTION,
        ctt.Status.NO_RES,
        ctt.Status.TIMEOUT,
        ctt.Status.NOT_RELEASED,
    ]
    SUCCESS_STATUSES = [
        ctt.Status.SUCCESS,
        ctt.Status.RELEASED,
    ]

    @utils.classproperty
    def service_tasks(self):
        if not self._SERVICE_TASKS or self._last_update_time + self.SERVICE_DATA_TTL < time.time():
            self._SERVICE_TASKS = [
                cls.type
                for cls in it.imap(task_proxy.getTaskClass, task_proxy.listTaskTypes())
                if controller.TaskWrapper.service(cls)
            ]
            self._last_update_time = time.time()
        return self._SERVICE_TASKS

    @utils.classproperty
    def client_groups(self):
        if not self._Clients or self._last_update_time + self.SERVICE_DATA_TTL < time.time():
            for cl in mapping.Client.objects:
                self._Clients[cl.platform].add(cl.hostname)
                self._Clients[cl.hardware.cpu.model].add(cl.hostname)
            self._last_update_time = time.time()
        return self._Clients

    def __init__(self, *args, **kwargs):
        super(StatisticsUpdater, self).__init__(*args, **kwargs)
        self.run_interval = 60
        self.logger = log.get_core_log("statistics_updater")

    @classmethod
    def collect_enqueued_tasks_statistics(cls):
        for i in xrange(3):
            try:
                cnt = collections.Counter()
                for queue_item in controller.TaskQueue.qclient.queue(secondary=True):
                    hosts = set(host_item.host for host_item in queue_item.hosts)
                    for key, group in cls.client_groups.iteritems():
                        if group.intersection(hosts):
                            cnt[key] += 1
                return cnt
            except sandbox.serviceq.errors.QRetry:
                pass
            except sandbox.serviceq.errors.QNeedValidation:
                time.sleep(1)
        return None

    def _date_bounds(self, field_name):
        now = dt.datetime.utcnow()
        last_date = self.model.context.get(field_name)
        date_gt = last_date if last_date is not None else (now - dt.timedelta(seconds=64))
        date_lt = min(now - dt.timedelta(seconds=5), date_gt + dt.timedelta(hours=12))
        return date_gt, date_lt

    def collect_tasks_weather(self):
        settings = config.Registry().server.services.statistics_updater
        date_gt, date_lt = self._date_bounds("last_weather_date")
        last_processed_date = date_gt
        ro_audit_mapping = mapping.switch_db(mapping.Audit, ctd.ReadPreference.SECONDARY)
        ro_task_mapping = mapping.switch_db(mapping.Task, ctd.ReadPreference.SECONDARY)
        with ro_audit_mapping as Audit, ro_task_mapping as Task:
            finish_history = Audit.objects(
                date__gt=date_gt,
                date__lte=date_lt,
                status__in=self.ERROR_STATUSES + self.SUCCESS_STATUSES
            ).order_by("+task_id", "+date").scalar("task_id", "status", "date")

            if not finish_history:
                self.model.context["last_weather_date"] = last_processed_date
                return

            task_ids = list(set(map(operator.itemgetter(0), finish_history)))
            first_preparing_history = Audit.first_preparing_for_tasks(task_ids)
            id2type = dict(Task.objects(id__in=task_ids).scalar("id", "type"))

        finish_dict = collections.defaultdict(list)
        timings = collections.defaultdict(list)
        for id_, data in it.groupby(finish_history, key=operator.itemgetter(0)):
            data = list(data)
            task_type = id2type.get(id_)
            if not task_type:  # Task already has been deleted from DB
                continue
            last_finish_status = data[-1][1]
            last_finish_date = data[-1][2]
            last_processed_date = max(last_processed_date, last_finish_date)

            finish_dict[task_type].append(
                int(last_finish_status in self.SUCCESS_STATUSES)
            )

            first_preparing_date = first_preparing_history.get(id_)
            if first_preparing_date:
                timings[task_type].append(
                    (last_finish_date, (last_finish_date - first_preparing_date).total_seconds())
                )
            else:
                self.logger.warning("Audit is inconsistent. Absent PREPARING for task: #%s", id_)

        for task_type, data in timings.iteritems():
            # sort by task finish time
            data.sort(key=operator.itemgetter(0))
            # [(time1, delta1), (time2, delta2)] -> [delta1, delta2]
            timings[task_type] = zip(*data)[1]

        types = list(set(id2type.itervalues()))
        stored_weather = mapping.Weather.objects(type__in=types)
        stored_weather_dict = {doc.type: doc for doc in stored_weather}

        new_weather = []
        for task_type, data in finish_dict.iteritems():
            doc = stored_weather_dict.get(task_type) or mapping.Weather(type=task_type, data=mapping.Weather.Data())
            doc.update(
                finish_dict.get(task_type) or [],
                timings.get(task_type) or [],
                settings.history_len,
                settings.percentile,
            )
            new_weather.append(doc)

        self.model.context["last_weather_date"] = last_processed_date

        return new_weather

    def collect_tasks_enqueue_time(self):
        date_gt, date_lt = self._date_bounds("last_date")
        ro_audit_mapping = mapping.switch_db(mapping.Audit, ctd.ReadPreference.SECONDARY)
        ro_task_mapping = mapping.switch_db(mapping.Task, ctd.ReadPreference.SECONDARY)
        with ro_audit_mapping as Audit, ro_task_mapping as Task:
            # get events of switching tasks to EXECUTING during last tick
            tasks_execution_dates = Audit.objects(
                date__gt=date_gt,
                date__lte=date_lt,
                status=mapping.Task.Execution.Status.EXECUTING
            ).fast_scalar("task_id", "date")
            # get types for these tasks, excluding service types
            tasks_types = dict(
                Task.objects(
                    id__in=map(operator.itemgetter(0), tasks_execution_dates),
                    type__not__in=self.service_tasks
                ).fast_scalar("id", "type")
            )
            # leave only those events which are not of service type
            tasks_execution_dates = [
                (task_id, date) for task_id, date in tasks_execution_dates
                if task_id in tasks_types
            ]
            # get events of switching tasks to ENQUEUED for these tasks
            tasks_enqueued_dates = Audit.objects(
                status=mapping.Task.Execution.Status.ENQUEUED,
                task_id__in=list(tasks_types)
            ).fast_scalar("task_id", "date")

        if not tasks_enqueued_dates:
            return {}

        def group_task_events_dates(task_events_dates):
            """Group by task_id, sort values by date DESC"""
            return {
                task_id_: sorted(map(operator.itemgetter(1), dates), reverse=True)
                for task_id_, dates in it.groupby(
                    task_events_dates,
                    key=operator.itemgetter(0)
                )
            }

        tasks_execution_history = group_task_events_dates(tasks_execution_dates)
        tasks_enqueued_history = group_task_events_dates(tasks_enqueued_dates)

        enqueue_time = collections.Counter()
        inconsistent_tasks = []
        for task_id, execution_dates in tasks_execution_history.items():
            if task_id not in tasks_enqueued_history or len(execution_dates) > len(tasks_enqueued_history[task_id]):
                self.logger.warning("It looks like `audit` collection is inconsistent for task id: %d", task_id)
                inconsistent_tasks.append(task_id)
                continue
            task_type = tasks_types[task_id]
            # iterate through execution dates from the newest to the oldest ones (there may be several for one task_id)
            # look for corresponding enqueue time for each execution date
            # and sum these values
            for date_index, execution_date in enumerate(execution_dates):
                enqueue_date = tasks_enqueued_history[task_id][date_index]
                enqueue_time[task_type] += (execution_date - enqueue_date).total_seconds()
        if inconsistent_tasks:
            controller.Notification.report_service_error(
                comment=(
                    "It looks like `audit` collection is inconsistent\nTask ids with incorrect audit:\n{}"
                ).format("\n".join("  {}".format(task_id) for task_id in sorted(inconsistent_tasks))),
                thread_name=self.__class__.__name__,
                lock_name="jobs/{}".format(self.__class__.__name__)
            )
        self.model.context["last_date"] = max(map(operator.itemgetter(-1), tasks_execution_history.values()))
        return dict(enqueue_time)

    def save_tasks_weather(self, tasks_weather):
        if not tasks_weather:
            return

        bulk = mapping.Weather._get_collection().initialize_unordered_bulk_op()
        for doc in tasks_weather:
            doc.validate()
            bulk.find({"_id": doc.type}).upsert().replace_one(doc.to_mongo())
        bulk.execute()

    def _proc(self):
        start_ts = time.time()
        collect_storages_info = self.model.context.get("storage_stat_timing", {}).get("last_run", 0) + 600 < start_ts
        ro_resource_mapping = mapping.switch_db(mapping.Resource, ctd.ReadPreference.SECONDARY)
        ro_task_mapping = mapping.switch_db(mapping.Task, ctd.ReadPreference.SECONDARY)
        with ro_resource_mapping as Resource, ro_task_mapping as Task:
            statuses = {}
            for k, f in (
                ("tasks", Task.tasks_per_status),
                ("resources", Resource.resources_per_state),
                ("delta", ft.partial(Task.tasks_per_status, exceptional=False)),
                ("absolute", ft.partial(Task.tasks_per_status, delta=None, exceptional=False)),
            ):
                self.logger.debug("Collecting %s per state statistics", k)
                statuses[k] = f()
        self.logger.info("Collecting enqueued tasks statistics")
        enqueued_tasks = self.collect_enqueued_tasks_statistics()
        self.logger.info("Collecting tasks enqueue time")
        enqueue_time = self.collect_tasks_enqueue_time()
        storage_info = collections.Counter()
        storage_hosts = config.Registry().server.storage_hosts
        if collect_storages_info:
            self.logger.info("Collecting storage resources statistics")
            for host in storage_hosts:
                storage_info += manager.resource_manager.storage_statistics(host=host)
            ready_resources_size = mapping.Resource.ready_resources_size()
            for item in ready_resources_size:
                size = item["total_size"] * 1024  # convert to Bytes
                storage_info["ready_resources_size"] += size
                if item["_id"] in storage_hosts:
                    storage_info["storage_ready_resources_size"] += size
                else:
                    storage_info["client_side_ready_resources_size"] += size
        self.logger.info("Collecting tasks weather statistics")
        tasks_weather = self.collect_tasks_weather()
        self.logger.info("Data aggregation taken: %.2fs", time.time() - start_ts)
        self.logger.info("Starting to insert data")
        start_ts = time.time()
        mapping.Statistics.objects(key=self.Keys.STATUS).update(upsert=True, set__data=statuses)
        if enqueued_tasks:
            mapping.Statistics.objects(key=self.Keys.ENQUEUED_TASKS).update(upsert=True, set__data=enqueued_tasks)
        if collect_storages_info:
            if not storage_info:
                self.logger.warning("Empty storage resources statistics")
            mapping.Statistics.objects(key=self.Keys.STORAGE).update(
                upsert=True,
                set__data=dict(storage_info)
            )
            self.model.context["storage_stat_timing"] = {"last_run": start_ts}
        doc = mapping.Statistics.objects(key=self.Keys.ENQUEUE_TIME).first()
        if doc:
            doc.data["last_hour_timings"].append(enqueue_time)
        else:
            doc = mapping.Statistics(key=self.Keys.ENQUEUE_TIME, data={"last_hour_timings": [enqueue_time]})
        if len(doc.data["last_hour_timings"]) > 60:
            doc.data["last_hour_timings"].pop(0)
        doc.save()

        self.save_tasks_weather(tasks_weather)
        self.model.save()
        self.logger.info("Data insertion taken: %.2fs", time.time() - start_ts)
        start_ts = time.time()
        # Return an amount of time to wait till next run.
        return dt.timedelta(seconds=(start_ts // self.run_interval + 1) * self.run_interval - start_ts)
