from __future__ import absolute_import, unicode_literals

import logging
import typing
from sandbox import sdk2
from sandbox.projects.common import binary_task, task_env
from sandbox.projects.Strm.common.constants import LB_CLUSTERS, YT_CLUSTERS

if typing.TYPE_CHECKING:
    from typing import Generator, Iterable

logger = logging.getLogger(__name__)


class StrmCalculateBillingFromYT(binary_task.LastBinaryTaskRelease, sdk2.Task):
    """
    Task to aggregate billing data from YT tables.
    """

    class Requirements(task_env.TinyRequirements):
        pass

    class Parameters(binary_task.LastBinaryReleaseParameters):

        with sdk2.parameters.Group("YT parameters"):
            yql_token = sdk2.parameters.YavSecret("YQL Token", required=True)
            yt_token = sdk2.parameters.YavSecret("YT Token", required=True)
            yt_cluster = sdk2.parameters.String(
                "YT Cluster", default=YT_CLUSTERS[0], choices=[(c, c) for c in YT_CLUSTERS], required=True
            )
            yt_src_folder = sdk2.parameters.String(
                "Source folder on YT cluster", default="//logs/strm-access-log/1h", required=True
            )
            yt_dst_folder = sdk2.parameters.String(
                "Destination folder on YT cluster", default="//home/strm/billing/access-log/1h", required=True
            )

        with sdk2.parameters.Group("Logbroker parameters"):
            lb_token = sdk2.parameters.YavSecret("LB Token", required=True)
            lb_endpoint = sdk2.parameters.String(
                "LB Endpoint",
                default=LB_CLUSTERS[0],
                choices=[(c, c) for c in LB_CLUSTERS],
                required=True,
            )
            lb_topic = sdk2.parameters.String("Topic", required=True)

    def on_execute(self):
        super(StrmCalculateBillingFromYT, self).on_execute()

        tables = self._execute_query()
        if not tables:
            logger.info("No tables to sync.")
            return

        logger.info("Tables to sync: %s", ", ".join(tables))

        lines = self._read_tables(tables)

        self._save_to_lb(lines)

    def _execute_query(self):
        # type: () -> list[str]

        logger.info("Starting YQL query...")

        from yql.api.v1.client import YqlClient
        from yql.client.parameter_value_builder import YqlParameterValueBuilder
        from library.python import resource
        import sandbox.projects.Strm.common.yql_helpers as yql_helpers

        client = YqlClient(db=str(self.Parameters.yt_cluster), token=str(self.Parameters.yql_token.value()))

        query = resource.find('sandbox/projects/Strm/StrmCalculateBillingFromYT/query.sql')

        params = YqlParameterValueBuilder.build_json_map(
            {
                "$src_folder": YqlParameterValueBuilder.make_string(self.Parameters.yt_src_folder),
                "$dst_folder": YqlParameterValueBuilder.make_string(self.Parameters.yt_dst_folder),
            }
        )

        result_table = yql_helpers.execute_yql_query(
            client, query, syntax_version=1, title="STRM Billing", parameters=params
        )

        logger.info("Done executing YQL query.")

        return result_table.rows[0][0]

    def _read_tables(self, tables):
        # type: (list[str]) -> Generator[str]

        logger.info("Reading result tables...")

        import yt.wrapper as yt
        import yt.logger as yt_logger

        yt_logger.LOGGER = logging.getLogger("yt")

        client = yt.YtClient(proxy=self.Parameters.yt_cluster, token=self.Parameters.yt_token.value())

        for table in tables:
            table_path = yt.ypath_join(self.Parameters.yt_dst_folder, table)
            logger.info("Reading table %s...", table_path)

            rows = client.read_table(table_path, format=yt.YsonFormat())
            for row in rows:
                yield row["value"]

    def _save_to_lb(self, rows):
        # type: (Iterable[str]) -> None
        from sandbox.projects.Strm.common.logbroker.producer import LBProducer

        logger.info("Saving data to LB...")

        n = 0
        lb_producer = LBProducer(
            str(self.Parameters.lb_endpoint),
            str(self.Parameters.lb_token.value()),
            str(self.Parameters.lb_topic),
            self.id,
        )
        with lb_producer as lb:
            for row in rows:
                n += 1
                lb.write(row)

        logger.info("Total messages written: %d", n)

