from __future__ import print_function

from sandbox import sdk2
from sandbox.sdk2 import parameters
from sandbox.sdk2.helpers import subprocess, ProcessLog
from sandbox import common

from concurrent import futures

import datetime
import logging
import os


class SnapshotMeta(object):
    __slots__ = ["cluster", "segment", "snapshot_path"]

    def __init__(self, cluster, segment, snapshot_path):
        self.cluster = cluster
        self.segment = segment
        self.snapshot_path = snapshot_path


class RunDumpAllocationSnapshot(sdk2.Task):
    """
    Dump allocation snapshot and upload to YT
    """

    class Parameters(sdk2.Parameters):
        description = "Snapshot YP cluster states"

        class Requirements(sdk2.Requirements):
            cores = 5
            ram = 2048

        with parameters.Group("YT parameters") as yt_params:
            yt_snapshot_ttl = parameters.Timedelta(
                "Snapshot ttl",
                required=True,
                default=datetime.timedelta(days=180),
            )
            yt_destination_clusters = parameters.List(
                "YT destination clusters",
                required=True,
                default=["freud"]
            )
            yt_destination_path = parameters.String(
                "YT destination path",
                required=True,
                default="//home/alximik/yp_snapshots"
            )
            yt_token_secret_id = parameters.YavSecret(
                label="YT token secret id",
                required=True,
                description="secret should contain keys: yt_token",
                default="sec-01ec5jtcg3z83w2w10vgnpj27y",
            )

        with parameters.Group("YP parameters") as yp_params:
            yp_token_secret_id = parameters.YavSecret(
                label="YP token secret id",
                required=True,
                description="secret should contain keys: yp_token",
                default="sec-01ec5jtcg3z83w2w10vgnpj27y",
            )
            yp_clusters = parameters.List(
                "YP clusters",
                required=True,
                default=["man-pre", "sas-test"],
            )
            yp_segments = parameters.List(
                "YP node segments",
                required=True,
                default=["default"],
            )

    def _dump_cluster_segment(self, yp_cluster, yp_segments, simtool_path, output_dir):
        """
        :type yp_cluster: str
        :type yp_segments: list[str]
        :type simtool_path: str
        :type output_dir: str
        :rtype: list[SnapshotMeta]
        """

        yp_token = self.Parameters.yp_token_secret_id.data()["yp_token"]
        result = []
        for yp_segment in yp_segments:
            with ProcessLog(self, logger="dump_cluster_snapshot_{}_{}".format(yp_cluster, yp_segment)) as pl:
                env = os.environ.copy()
                env["YP_TOKEN"] = yp_token
                snapshot_path = os.path.join(output_dir, "{}_{}.json.gz".format(yp_cluster, yp_segment))
                cmd = [
                    simtool_path,
                    "make-snapshot",
                    "--node-segment", yp_segment,
                    "--output", snapshot_path,
                    yp_cluster,
                ]
                run = subprocess.Popen(cmd, stdout=pl.stdout, stderr=pl.stderr, env=env)
                run.communicate()

                if run.returncode != 0:
                    logging.error(
                        "Dump failed. Cluster={}, Segment={}. See logs for details.".format(yp_cluster, yp_segment)
                    )
                    result.append(SnapshotMeta(yp_cluster, yp_segment, None))
                result.append(SnapshotMeta(yp_cluster, yp_segment, snapshot_path))
        return result

    def _dump_snapshots(self, output_dir, simtool_path):
        """

        :type output_dir: str
        :type simtool_path: str
        :rtype: list[SnapshotMeta]
        """
        snapshots_dir = os.path.join(output_dir, "yp_snapshots")
        os.mkdir(snapshots_dir)
        dump_results = []
        with futures.ThreadPoolExecutor(max_workers=len(self.Parameters.yp_clusters)) as executor:
            for yp_cluster in self.Parameters.yp_clusters:
                result = executor.submit(
                    self._dump_cluster_segment,
                    yp_cluster,
                    self.Parameters.yp_segments,
                    simtool_path,
                    snapshots_dir
                )
                dump_results.append(result)

            for snapshots_future in futures.as_completed(dump_results):
                for snapshot in snapshots_future.result():
                    yield snapshot

    # noinspection PyMethodMayBeStatic
    def _locate_simtool(self):
        yp_simtool_resource_id = sdk2.Resource.find(
            sdk2.Resource["YP_SIMTOOL_BINARY"],
            attrs={"released": "stable"}
        ).first()
        logging.debug("Using simtool. Resource Id: {}".format(yp_simtool_resource_id))

        resource_data = sdk2.ResourceData(yp_simtool_resource_id)
        result = os.path.join(str(resource_data.path), "simtool")
        os.chmod(result, 0o777)
        return result

    def _upload_snapshot_to_cluster(self, snapshot, yt_cluster, now):

        # NB: https://wiki.yandex-team.ru/sandbox/faq/
        import yt.wrapper

        yt_token = self.Parameters.yt_token_secret_id.data()["yt_token"]
        client = yt.wrapper.YtClient(proxy=yt_cluster, token=yt_token)

        cluster_path = yt.wrapper.ypath_join(self.Parameters.yt_destination_path, snapshot.cluster)
        client.create("map_node", cluster_path, ignore_existing=True)
        segment_path = yt.wrapper.ypath_join(cluster_path, snapshot.segment)
        client.create("map_node", segment_path, ignore_existing=True)

        yt_snapshot_path = yt.wrapper.ypath_join(
            segment_path,
            "snapshot_{}_{}_{}.json.gz".format(
                snapshot.cluster,
                snapshot.segment,
                now.strftime("%Y-%m-%d %H:%M:%S")
            )
        )
        logging.info("Uploading \"{}\" -> \"{}\"".format(
            snapshot.snapshot_path,
            yt_snapshot_path
        ))
        with client.Transaction():
            client.create(
                "file",
                yt_snapshot_path,
            )

            with open(snapshot.snapshot_path, "rb") as f:
                client.write_file(yt_snapshot_path, f, compute_md5=True)

            client.set_attribute(
                yt_snapshot_path,
                "expiration_time",
                (now + datetime.timedelta(seconds=self.Parameters.yt_snapshot_ttl)).isoformat()
            )
            client.set_attribute(
                yt_snapshot_path,
                "dump_timestamp",
                now.isoformat()
            )

        logging.info("Uploaded {}".format(snapshot.snapshot_path))

    def _upload_snapshot(self, snapshot):
        """

        :type snapshot: SnapshotMeta
        """

        now = datetime.datetime.now()
        fail_count = 0
        for yt_cluster in self.Parameters.yt_destination_clusters:
            try:
                self._upload_snapshot_to_cluster(snapshot, yt_cluster, now)
            except Exception as err:
                logging.error("Failed uploading snapshot to cluster: {}".format(err), exc_info=True)
                fail_count += 1

        return fail_count != len(self.Parameters.yt_destination_clusters)

    def on_execute(self):
        simtool_path = self._locate_simtool()
        snapshot_iterator = self._dump_snapshots(str(self.path()), simtool_path)
        failed_snapshot_upload = []

        failed_dumps = []
        for snapshot in snapshot_iterator:
            if snapshot.snapshot_path is None:
                failed_dumps.append(snapshot)
                continue
            if not self._upload_snapshot(snapshot):
                failed_snapshot_upload.append(snapshot.cluster)

        if failed_snapshot_upload:
            raise common.errors.TaskFailure(
                "Failed uploading snapshots for clusters: {}".format(", ".join(failed_snapshot_upload))
            )

        if failed_dumps:
            raise common.errors.TaskFailure(
                "Failed dumping some segments: {}".format(
                    ", ".join("{0.cluster}@{0.segment}".format(x) for x in failed_dumps)
                )
            )
