from __future__ import absolute_import

import json
import time
import uuid
import logging
import collections

import pymongo.errors

from sandbox.common import mds as common_mds
from sandbox.common import abc as common_abc
from sandbox.common import patterns
from sandbox.common import config as common_config
from sandbox.common import format as common_format
from sandbox.common import itertools as common_itertools
import sandbox.common.types.misc as ctm
import sandbox.common.types.database as ctd
import sandbox.common.types.resource as ctr
import sandbox.common.types.statistics as ctss
import sandbox.common.types.notification as ctn

from sandbox.services import base
from sandbox.yasandbox import controller
from sandbox.yasandbox.database import mapping
from sandbox.services.modules.statistics_processor import schemas

try:
    import logbroker.unified_agent.client.python as unified_agent
except ImportError:
    unified_agent = None

logger = logging.getLogger(__name__)


BillingMetric = collections.namedtuple("BillingMetric", "start_time finish_time usage")


class Totals(patterns.Abstract):
    __slots__ = (
        "amount",
        "size",
        "disk_size",
        "mds_amount",
        "mds_on_storages_amount",
        "mds_size",
        "mds_on_storages_size",
    )
    __defs__ = (0,) * len(__slots__)


class Signal(patterns.Abstract):
    __slots__ = schemas.clickhouse.ResourcesOnStorages.fields()
    # noinspection PyTypeChecker
    __defs__ = (None,) * 4 + (0,) * (len(__slots__) - 7) + (None,) * 3


class ResourcesStatistics(base.SingletonService):
    """
    Service process for pushing of resources statistics on storages to Clickhouse
    """

    SIGNALS_CHUNK_SIZE = 1000
    BILLING_METRICS_CHUNK_SIZE = 1000
    BILLING_METRICS_FRAME_SIZE = 3600  # in seconds, used for hour discounts in billing

    tick_interval = 1800

    @staticmethod
    def _get_cursor(resource_mapping, pipeline):
        for i in range(5):
            try:
                return resource_mapping.aggregate(pipeline, allowDiskUse=True)
            except (pymongo.errors.CursorNotFound, pymongo.errors.OperationFailure):
                if i == 4:
                    raise

    @staticmethod
    def update_mds_stats_in_signal(res_owner, groups, signal, owner_stats):
        if res_owner not in owner_stats:
            group_abc, group_bucket = groups.get(res_owner, (None, None))
            # TODO: remove condition after SANDBOX-8766
            if group_bucket:
                bucket_name, bucket_info = common_mds.S3.check_bucket(bucket=group_bucket)
            else:
                bucket_name, bucket_info = common_mds.S3.check_bucket(common_abc.abc_service_id(group_abc))
            owner_stats[res_owner] = {
                "mds_bucket_name": bucket_name,
                "mds_bucket_used": bucket_info["used_space"],
                "mds_bucket_max_size": bucket_info["max_size"]
            }

        stats = owner_stats[res_owner]
        signal["mds_bucket_name"] = stats["mds_bucket_name"]
        signal["mds_bucket_used"] = stats["mds_bucket_used"]
        signal["mds_bucket_max_size"] = stats["mds_bucket_max_size"]

    def metric_intervals(self, prev_timestamp, timestamp):
        intervals = []
        prev_frame = prev_timestamp // self.BILLING_METRICS_FRAME_SIZE * self.BILLING_METRICS_FRAME_SIZE
        current_frame = timestamp // self.BILLING_METRICS_FRAME_SIZE * self.BILLING_METRICS_FRAME_SIZE
        while prev_frame < current_frame:
            prev_frame += self.BILLING_METRICS_FRAME_SIZE
            intervals.append((prev_timestamp, prev_frame - 1))
            prev_timestamp = prev_frame
        if prev_timestamp < timestamp - 1:
            intervals.append((prev_timestamp, timestamp - 1))
        return intervals

    def send_billing_metrics(self, config):
        if config.common.installation != ctm.Installation.PRODUCTION:
            return

        timestamp = int(time.time())
        prev_timestamp = self.context.setdefault("billing_metrics_timestamp", timestamp)
        if prev_timestamp == timestamp:
            return

        billing_metrics = []
        for bucket in mapping.Bucket.objects():
            abc_id = common_abc.abc_service_id(bucket.abc)
            _, bucket_info = common_mds.S3.check_bucket(bucket=bucket.name)
            quota = bucket_info["max_size"] >> 20
            metric_base = dict(
                abc_id=str(abc_id),
                schema="sandbox.storage.v1",
                source_wt=int(time.time()),
                source_id=config.this.fqdn,
                version="v1alpha1",
                labels=dict(account=bucket.name)
            )
            for start_time, finish_time in self.metric_intervals(prev_timestamp, timestamp):
                interval = finish_time - start_time
                billing_metrics.append(dict(
                    metric_base,
                    id=str(uuid.uuid1()),
                    tags={},
                    usage=dict(
                        quantity=str(quota * interval),
                        unit="mbyte*second",
                        start=start_time,
                        finish=finish_time,
                        type="delta",
                    ),
                ))

        if billing_metrics:
            logger.info("Sending %s billing metrics", len(billing_metrics))
            self.context["billing_metrics_timestamp"] = timestamp
            uri = config.common.unified_agent["billing_metrics"]["uri"]
            ua_client = unified_agent.Client(uri, log_level=logging.DEBUG)
            ua_session = ua_client.create_session()
            for metric in common_itertools.chunker(billing_metrics, self.BILLING_METRICS_CHUNK_SIZE):
                ua_session.send(json.dumps(metric) + "\n", time.time())
            ua_session.close()

    def tick(self):
        config = common_config.Registry()
        storages = config.server.storage_hosts

        groups = {
            name: (abc, bucket)
            for name, abc, bucket in mapping.Group.objects().fast_scalar("name", "abc", "bucket")
        }

        pipeline_on_storages = [
            {"$match": {
                "state": {"$in": [ctr.State.READY, ctr.State.BROKEN, ctr.State.DELETED]},
                "hosts.h": {"$in": storages},
            }},
            {"$unwind": "$hosts"},
            {"$match": {
                "hosts.st": ctr.HostState.OK,
                "hosts.h": {"$in": storages},
            }},
            {"$project": {
                "_id": 1,
                "owner": 1,
                "type": 1,
                "state": 1,
                "size": 1,
                "mds_amount": {"$cond": [{"$gt": ["$mds", None]}, 1, 0]},
                "mds_size": {"$cond": [{"$gt": ["$mds", None]}, "$size", 0]},
            }},
            {"$group": {
                "_id": "$_id",
                "owner": {"$last": "$owner"},
                "type": {"$last": "$type"},
                "state": {"$last": "$state"},
                "disk_size": {"$sum": "$size"},
                "size": {"$last": "$size"},
                "mds_amount": {"$last": "$mds_amount"},
                "mds_size": {"$last": "$mds_size"},
            }},
            {"$group": {
                "_id": {"owner": "$owner", "type": "$type", "state": "$state"},
                "amount": {"$sum": 1},
                "size": {"$sum": "$size"},
                "disk_size": {"$sum": "$disk_size"},
                "mds_amount": {"$sum": "$mds_amount"},
                "mds_size": {"$sum": "$mds_size"},
            }},
        ]

        pipeline_in_mds = [
            {"$match": {
                "state": {"$in": [ctr.State.READY, ctr.State.BROKEN, ctr.State.DELETED]},
            }},
            {"$project": {
                "owner": 1,
                "type": 1,
                "state": 1,
                "size": {"$cond": [{"$gt": ["$mds", None]}, "$size", 0]},
                "_id": 1
            }},
            {"$group": {
                "_id": {"owner": "$owner", "type": "$type", "state": "$state"},
                "amount": {"$sum": 1},
                "size": {"$sum": "$size"},
            }},
        ]

        timestamp = int(time.time())
        signals = {}
        all_owners = set()
        all_types = set()
        totals_ready = Totals()
        totals_deleted = Totals()
        owners_stats = {}
        with mapping.switch_db(mapping.Resource, ctd.ReadPreference.SECONDARY) as Resource:
            logger.info("Collecting resources on storages")
            for item in self._get_cursor(Resource, pipeline_on_storages):
                res_owner = item["_id"]["owner"]
                res_type = item["_id"]["type"]
                res_state = item["_id"]["state"]
                size = int(item["size"]) << 10
                disk_size = int(item["disk_size"]) << 10
                mds_size = int(item["mds_size"]) << 10
                signal = signals.setdefault((res_owner, res_type), dict(Signal(
                    date=timestamp,
                    timestamp=timestamp,
                    owner=res_owner,
                    type=res_type,
                )))
                if res_state == ctr.State.DELETED:
                    totals = totals_deleted
                    suffix = "_deleted"
                else:
                    totals = totals_ready
                    suffix = ""
                signal["amount" + suffix] += item["amount"]
                signal["size" + suffix] += size
                signal["disk_size" + suffix] += disk_size
                signal["mds_on_storages_amount" + suffix] += item["mds_amount"]
                signal["mds_on_storages_size" + suffix] += mds_size
                self.update_mds_stats_in_signal(res_owner, groups, signal, owners_stats)

                all_owners.add(res_owner)
                all_types.add(res_type)
                totals.amount += item["amount"]
                totals.size += size
                totals.disk_size += disk_size
                totals.mds_on_storages_amount += item["mds_amount"]
                totals.mds_on_storages_size += mds_size

            logger.info("Collecting resources in MDS")
            for item in self._get_cursor(Resource, pipeline_in_mds):
                res_owner = item["_id"]["owner"]
                res_type = item["_id"]["type"]
                res_state = item["_id"]["state"]
                size = int(item["size"]) << 10
                signal = signals.setdefault((res_owner, res_type), dict(Signal(
                    date=timestamp,
                    timestamp=timestamp,
                    owner=res_owner,
                    type=res_type,
                )))
                if res_state == ctr.State.DELETED:
                    totals = totals_deleted
                    suffix = "_deleted"
                else:
                    totals = totals_ready
                    suffix = ""
                signal["mds_amount" + suffix] += item["amount"]
                signal["mds_size" + suffix] += size
                self.update_mds_stats_in_signal(res_owner, groups, signal, owners_stats)

                all_owners.add(res_owner)
                all_types.add(res_type)
                totals.mds_amount += item["amount"]
                totals.mds_size += size

        juggler_checks = collections.defaultdict(list)

        for owner, stats in owners_stats.iteritems():
            if stats["mds_bucket_name"] and stats["mds_bucket_name"] != ctr.DEFAULT_S3_BUCKET:
                free_space = stats["mds_bucket_max_size"] - stats["mds_bucket_used"]
                if free_space <= 0:
                    juggler_checks[ctn.JugglerStatus.CRIT].append(owner)
                elif stats["mds_bucket_max_size"] and free_space / stats["mds_bucket_max_size"] < 0.1:
                    juggler_checks[ctn.JugglerStatus.WARN].append(owner)
                else:
                    juggler_checks[ctn.JugglerStatus.OK].append(owner)

        for status, check_owners in juggler_checks.iteritems():
            recipients = controller.Notification.juggler_expanded_recipients(
                check_owners, ctn.JugglerCheck.MDS_QUOTA_EXCEEDED
            )
            if recipients:
                controller.Notification.save(
                    transport=ctn.Transport.JUGGLER, send_to=recipients, send_cc=[], subject=None, body="Mds quota",
                    check_status=status
                )

        total_amount_ready = totals_ready.amount + totals_ready.mds_amount - totals_ready.mds_on_storages_amount
        total_amount_deleted = totals_deleted.amount + totals_deleted.mds_amount - totals_deleted.mds_on_storages_amount
        total_size_ready = totals_ready.size + totals_ready.mds_size - totals_ready.mds_on_storages_size
        total_size_deleted = totals_deleted.size + totals_deleted.mds_size - totals_deleted.mds_on_storages_size
        logger.info(
            "Collected %s (%s / %s) resources of %s types of %s owners "
            "with total size %s (%s / %s) and %s (%s / %s) on disk and %s (%s / %s) in MDS",
            total_amount_ready + total_amount_deleted,
            total_amount_ready,
            total_amount_deleted,
            len(all_types), len(all_owners),
            common_format.size2str(total_size_ready + total_size_deleted),
            common_format.size2str(total_size_ready),
            common_format.size2str(total_size_deleted),
            common_format.size2str(totals_ready.disk_size + totals_deleted.disk_size),
            common_format.size2str(totals_ready.disk_size),
            common_format.size2str(totals_deleted.disk_size),
            common_format.size2str(totals_ready.mds_size + totals_deleted.mds_size),
            common_format.size2str(totals_ready.mds_size),
            common_format.size2str(totals_deleted.mds_size),
        )

        logger.info("Sending statistics")
        uri = config.common.unified_agent["infrequent_statistics"]["uri"]
        ua_client = unified_agent.Client(uri, log_level=logging.DEBUG)
        ua_session = ua_client.create_session()
        for chunk in common_itertools.chunker(signals.values(), self.SIGNALS_CHUNK_SIZE):
            logger.info("Sending %s of %s signal(s)", len(chunk), len(signals))
            ua_session.send(json.dumps({ctss.SignalType.RESOURCES_ON_STORAGES: chunk}) + "\n", time.time())
        ua_session.close()

        self.send_billing_metrics(config)
