from __future__ import absolute_import, unicode_literals

import hashlib
import json
import logging
import time
import typing
from itertools import chain

import six
from sandbox import sdk2
from sandbox.projects.common import binary_task, task_env
from sandbox.projects.Strm.common.constants import LB_CLUSTERS, YT_CLUSTERS

if typing.TYPE_CHECKING:
    from typing import Any, Iterable

logger = logging.getLogger(__name__)


class StrmCalculateBillingForLiveChannels(binary_task.LastBinaryTaskRelease, sdk2.Task):
    """
    Task to aggregate billing data from YT tables.
    """

    class Requirements(task_env.TinyRequirements):
        pass

    class Parameters(binary_task.LastBinaryReleaseParameters):
        with sdk2.parameters.Group("Quotas [abc_id:max_channels]"):
            v1_fullhd_streams = sdk2.parameters.Dict("v1 FullHD Streams")
            v1_audio_streams = sdk2.parameters.Dict("v1 Audio Streams")
            v2_fullhd_streams = sdk2.parameters.Dict("v2 FullHD Streams")

        with sdk2.parameters.Group("Logbroker parameters"):
            lb_token = sdk2.parameters.YavSecret("LB Token", required=True)
            lb_endpoint = sdk2.parameters.String(
                "LB Endpoint",
                default=LB_CLUSTERS[0],
                choices=[(c, c) for c in LB_CLUSTERS],
                required=True,
            )
            lb_topic = sdk2.parameters.String("Topic", required=True)

    def on_save(self):
        super(StrmCalculateBillingForLiveChannels, self).on_save()

        for key, value in chain(
            six.iteritems(self.Parameters.v1_fullhd_streams),
            six.iteritems(self.Parameters.v1_audio_streams),
            six.iteritems(self.Parameters.v2_fullhd_streams),
        ):
            try:
                int(key)
            except ValueError as e:
                six.raise_from(ValueError("can't parse abc_id to int"), e)

            try:
                int(value)
            except ValueError as e:
                six.raise_from(ValueError("can't parse quota as int"), e)

    def on_execute(self):
        super(StrmCalculateBillingForLiveChannels, self).on_execute()

        lines = []

        for abc_id, quota in six.iteritems(self.Parameters.v1_fullhd_streams):
            lines.append(
                json.dumps(
                    self._make_billing_event(
                        "strm.live.fullhd_streams.v1",
                        "v1",
                        int(abc_id),
                        int(quota) * 3600,
                    )
                )
            )

        for abc_id, quota in six.iteritems(self.Parameters.v1_audio_streams):
            lines.append(
                json.dumps(
                    self._make_billing_event(
                        "strm.live.audio_streams.v1",
                        "v1",
                        int(abc_id),
                        int(quota) * 3600,
                    )
                )
            )

        for abc_id, quota in six.iteritems(self.Parameters.v2_fullhd_streams):
            lines.append(
                json.dumps(
                    self._make_billing_event(
                        "strm.live.fullhd_streams.v2",
                        "v1",
                        int(abc_id),
                        int(quota) * 3600,
                    )
                )
            )

        logger.info("Data to send:\n%s", "\n".join(lines))

        self._save_to_lb(lines)

    def _make_billing_event(self, type_, version, abc_id, quantity):
        # type: (str, str, int, int) -> dict[str, Any]
        now = int(time.time()) - 3600  # send metrics for previous hour
        start = now // 3600 * 3600
        finish = start + 3599

        schema = "strm.live.streams.v1"
        id_ = hashlib.sha1("/".join((schema, type_, version, str(abc_id), str(start)))).hexdigest()

        return {
            "abc_id": abc_id,
            "id": id_,
            "schema": schema,
            "source_id": "sandbox-{}".format(self.id),
            "source_wt": now,
            "tags": {
                "type": type_,
            },
            "usage": {
                "start": start,
                "finish": finish,
                "quantity": quantity,
                "type": "delta",
                "unit": "stream*second",
            },
            "version": version,
        }

    def _save_to_lb(self, rows):
        # type: (Iterable[str]) -> None
        from sandbox.projects.Strm.common.logbroker.producer import LBProducer

        logger.info("Saving data to LB...")

        n = 0
        lb_producer = LBProducer(
            str(self.Parameters.lb_endpoint),
            str(self.Parameters.lb_token.value()),
            str(self.Parameters.lb_topic),
            self.id,
        )
        with lb_producer as lb:
            for row in rows:
                n += 1
                lb.write(row)

        logger.info("Total messages written: %d", n)
