from __future__ import unicode_literals

import csv
import json
import logging
from collections import namedtuple
from datetime import datetime

import requests
from requests.adapters import HTTPAdapter
from six import StringIO
from urllib3.util.retry import Retry

from sandbox import sdk2
from sandbox.projects.common import binary_task, task_env

YT_TABLE_SCHEMA = [
    {"name": "timestamp", "type": "timestamp", "required": True},
    {"name": "channel", "type": "string", "required": True},
    {"name": "abc_id", "type": "int32", "required": True},
    {"name": "unit", "type": "string", "required": True},
    {"name": "weight", "type": "double", "required": True},
]

YT_TABLE_ATTRIBUTES = {
    "compression_codec": "zstd_6",
    "schema": YT_TABLE_SCHEMA,
}

CH_QUERY = """INSERT INTO `{table}` FORMAT CSVWithNames"""


ChannelNormalizedWeight = namedtuple("ChannelNormalizedWeight", ("timestamp", "channel", "abc_id", "unit", "weight"))


logger = logging.getLogger(__name__)


class StrmBillingChannelsWeights(sdk2.Resource):
    b"""
    Channels Weights dump form pult
    """

    releasable = False


class UploadError(Exception):
    pass


def weights_yt_timestamp_format(weights):
    for weight in weights:
        weight_dict = weight._asdict()
        timestamp = int(weight_dict.pop("timestamp") * 10 ** 6)
        yield ChannelNormalizedWeight(timestamp=timestamp, **weight_dict)._asdict()


class StrmBillingSaveChannelsWeights(binary_task.LastBinaryTaskRelease, sdk2.Task):
    """
    Task to save normalized channel weights to YT and ClickHouse from pult.
    """

    class Requirements(task_env.TinyRequirements):
        pass

    class Parameters(sdk2.Parameters):

        dump_endpoint_url = sdk2.parameters.Url(
            "Dump endpoint URL", default="https://cfg.strm.yandex.net/scheduler/v1/dump", required=True
        )

        with sdk2.parameters.Group("YT output settings") as yt_settings:
            upload_to_yt = sdk2.parameters.Bool("Upload to YT")
            with upload_to_yt.value[True]:
                yt_cluster = sdk2.parameters.String("YT cluster", default="hahn")
                yt_token_secret = sdk2.parameters.YavSecret("Secret with YT token")
                yt_token_key = sdk2.parameters.String("Key in secret with token", default="YT_TOKEN")
                yt_table = sdk2.parameters.String(
                    "Path to table in YT", default="//home/strm/billing/channels_normalized_weights"
                )

        with sdk2.parameters.Group("ClickHouse output settings") as ch_settings:
            upload_to_ch = sdk2.parameters.Bool("Upload to ClickHouse")
            with upload_to_ch.value[True]:
                ch_endpoint = sdk2.parameters.Url("CH HTTP endpoint", default="https://clickhouse.strm.yandex.net")
                ch_credentials = sdk2.parameters.YavSecret("Secret with clickhouse username and password")
                ch_credentials_user_key = sdk2.parameters.String(
                    "Key in secret with username", default="clickhouse_user"
                )
                ch_credentials_pass_key = sdk2.parameters.String(
                    "Key in secret with password", default="clickhouse_password"
                )
                ch_database = sdk2.parameters.String("Database in ClickHouse", default="logs")
                ch_table = sdk2.parameters.String("Table in ClickHouse", default="trns_channels_normalized_weights")

        with sdk2.parameters.Group("Other settings") as other_settings:
            bin_params = binary_task.LastBinaryReleaseParameters()

    def _upload_to_yt(self, weights):
        import yt.logger
        from yt.wrapper import YtClient

        yt.logger.LOGGER = logging.getLogger("yt")

        cluster = self.Parameters.yt_cluster
        table = self.Parameters.yt_table
        token = self.Parameters.yt_token_secret.data()[self.Parameters.yt_token_key]

        client = YtClient(proxy=cluster, token=token)

        if not client.exists(table):
            client.create("table", table, attributes=YT_TABLE_ATTRIBUTES)

        weights_yt = list(weights_yt_timestamp_format(weights))
        client.write_table(client.TablePath(table, append=True), weights_yt)

    def _upload_to_ch(self, session, weights):
        credentials = self.Parameters.ch_credentials.data()
        endpoint = self.Parameters.ch_endpoint

        auth = (
            credentials[self.Parameters.ch_credentials_user_key],
            credentials[self.Parameters.ch_credentials_pass_key],
        )

        params = {
            "database": self.Parameters.ch_database,
            "query": CH_QUERY.format(table=self.Parameters.ch_table),
        }

        insert_data = StringIO(newline="")
        writer = csv.DictWriter(insert_data, ChannelNormalizedWeight._fields)
        writer.writeheader()
        writer.writerows((w._asdict() for w in weights))
        insert_data.seek(0)

        req = session.post(endpoint, auth=auth, params=params, data=insert_data)

        logger.info("Dumping ClickHouse response headers")
        for header, value in req.headers.items():
            logger.info("%s: %s", header, value)

        req.raise_for_status()

    def _create_resource(self, weights):
        logger.info("Saving weights to file...")
        with open("weights.json", "w") as f:
            for weight in weights:
                f.write(json.dumps(weight._asdict()))
                f.write("\n")

        logger.info("Creating resource...")
        resource = StrmBillingChannelsWeights(self, "Channels Weights dump", "weights.json", ttl=7)
        sdk2.ResourceData(resource).ready()

    def on_execute(self):
        session = requests.Session()

        retries = Retry(total=5, backoff_factor=1, status_forcelist=(502, 503, 504))
        session.mount("http://", HTTPAdapter(max_retries=retries))
        session.mount("https://", HTTPAdapter(max_retries=retries))

        req = session.get(self.Parameters.dump_endpoint_url)
        req.raise_for_status()

        dump = req.json()
        dump_timestamp = datetime.fromisoformat(dump["updated_at"])

        weights = list()

        for channel, data in dump["channels_normalized_weights"].items():
            weights.append(
                ChannelNormalizedWeight(
                    timestamp=dump_timestamp.timestamp(),
                    channel=channel,
                    abc_id=data["abc_id"],
                    unit=data["unit"],
                    weight=data["weight"],
                )
            )

        weights.sort(key=lambda weight: (weight.timestamp, weight.channel))

        logger.info("Weights were updated at %s", dump_timestamp.isoformat())
        logger.info("Got weights for %d channels", len(weights))

        self._create_resource(weights)

        upload_failed = False

        if self.Parameters.upload_to_ch:
            logger.info("Starting upload to ClickHouse")
            try:
                self._upload_to_ch(session, weights)
            except Exception:
                upload_failed = True
                logger.exception("Got exception while uploading data to ClickHouse")

        if self.Parameters.upload_to_yt:
            logger.info("Starting upload to YT")
            try:
                self._upload_to_yt(weights)
            except Exception:
                upload_failed = True
                logger.exception("Got exception while uploading data to YT")

        logger.info("Task complete!")

        if upload_failed:
            raise UploadError("One or more upload was failed")
