# coding: utf-8

import os
import logging
import requests

from sandbox import sdk2
from sandbox import common
from sandbox.sandboxsdk import environments
import sandbox.common.types.task as ctt
from sandbox.common.types import client as ctc
from sandbox.common.types.resource import State
from sandbox.projects.porto import BuildPortoLayer
from sandbox.projects.common.nanny import nanny
from sandbox.common.share import skynet_get


class YtUpdateUserLayers(nanny.ReleaseToNannyTask2, sdk2.Task):
    # type = 'YT_UPDATE_USER_LAYERS'
    class Parameters(sdk2.Task.Parameters):
        def resource_types_with_prefixes(*args):
            return [r for r in sdk2.Resource if any(r.name.startswith(arg) for arg in args)]

        layer_type = sdk2.parameters.String('Layer type', required=True, default='PORTO_LAYER_YT')
        with sdk2.parameters.String("Compress", multiline=True, required=True) as compress:
            for choice in ['tar.gz', 'tar.xz', 'tar']:
                if choice == 'tar.gz':
                    compress.values[choice] = compress.Value(default=True)
                else:
                    compress.values[choice] = None

        parent_layer = sdk2.parameters.Resource(
            "Parent layer", required=True, resource_type=resource_types_with_prefixes('PORTO_LAYER'))
        script_url = sdk2.parameters.ArcadiaUrl('Setup script URL', default_value='arcadia:/arc/trunk/arcadia/yt/scripts/sandbox/build_yt_layer.sh')
        space_limit = sdk2.parameters.Integer('Max disk usage (MB)', default=8192)
        memory_limit = sdk2.parameters.Integer('Max memory usage (MB)', default=4096)
        debug_build = sdk2.parameters.Bool("Debug build", default=False)
        yt_token_vault_name = sdk2.parameters.String('YT token vault name', required=True)

    class Requirements(sdk2.Task.Requirements):
        environments = (environments.PipEnvironment("yandex-yt"),)
        client_tags = ctc.Tag.GENERIC

    class Context(sdk2.Task.Context):
        build_tasks = dict()

    description = "User layer updater: YTADMIN-9927"
    META_CLUSTER = 'locke'
    GEOBASE_META_PATH = '//home/geobase'
    PACKAGES_META_PATH = '//sys/admin/user_packages'
    ENVIRONMENTS = ['prod', 'testing']
    PORTO_LAYERS_ROOT = '//porto_layers'

    def _get_yt_client(self, cluster):
        if cluster not in self.yt_clients:
            if self.wrapper is None:
                from yt import wrapper as yt_wrapper
                self.wrapper = yt_wrapper
            self.yt_clients[cluster] = self.wrapper.YtClient(cluster, token=sdk2.Vault.data(self.Parameters.yt_token_vault_name))
        return self.yt_clients[cluster]

    def _discover_clusters(self):
        return [cluster for cluster, values in self.meta_client.get("//sys/clusters_config").items()]

    def _get_layer_path_for_environment(self, environment):
        return "{}/{}.{}".format(self.PORTO_LAYERS_ROOT, environment, self.Parameters.compress)

    def _fetch_cluster_packages(self):
        res = {}
        for cluster in self.Context.clusters:
            for environment in self.ENVIRONMENTS:
                layer_node = None
                try:
                    layer_node = self._get_yt_client(cluster).get(
                        self._get_layer_path_for_environment(environment),
                        attributes=["packages", "geobases", "parent_layer_resource_id"]
                    )
                except self.wrapper.YtHttpResponseError as e:
                    if e.is_request_queue_size_limit_exceeded() or \
                       e.is_resolve_error() or \
                       e.is_request_timed_out():
                        logging.exception("Error getting layer attributes from \"{}\"".format(cluster))
                        pass
                    else:
                        raise
                except self.wrapper.YtProxyUnavailable:
                    logging.exception("Error getting layer attributes from \"{}\"".format(cluster))
                    pass
                except requests.ConnectionError:
                    logging.exception("Skipping unavailable cluster \"{}\"".format(cluster))
                    continue
                if cluster not in res:
                    res[cluster] = {}
                res[cluster][environment] = {
                    k: (layer_node.attributes[k] if layer_node is not None and k in layer_node.attributes else None)
                    for k in ("packages", "geobases", "parent_layer_resource_id")
                }
        return res

    def _fetch_defined_packages(self):
        res = {}
        for environment in self.ENVIRONMENTS:
            res[environment] = self._get_yt_client(self.META_CLUSTER).get("{}/{}".format(self.PACKAGES_META_PATH, environment))
        return res

    def layer_needs_update(self, cluster, environment):
        def packages_need_update():
            if self.Context.cluster_layer_packages[cluster][environment]["packages"] is None:
                return True
            return dict(self.Context.cluster_layer_packages[cluster][environment]["packages"]) != dict(self.Context.defined_packages[environment])

        def geobases_need_update():
            if self.Context.cluster_layer_packages[cluster][environment]["geobases"] is None:
                return True
            return set(self.Context.cluster_layer_packages[cluster][environment]["geobases"]) != set(self.Context.defined_geobases)

        def parent_layer_needs_update():
            if self.Context.cluster_layer_packages[cluster][environment]["parent_layer_resource_id"] is None:
                return True
            return self.Context.cluster_layer_packages[cluster][environment]["parent_layer_resource_id"] != self.Parameters.parent_layer.id
        if environment not in self.Context.cluster_layer_packages[cluster] or \
                packages_need_update() or \
                geobases_need_update() or \
                parent_layer_needs_update():
            return True
        return False

    def _fetch_defined_geobases(self):
        return [
            r[1] for r in self.meta_client.get(self.GEOBASE_META_PATH).items()
            if r[0].split("-stable")[0] in ["geodata6", "geodata5", "geodata4", "geodata-treeling", "tzdata"]
        ]

    def on_execute(self):
        self.wrapper = None
        self.yt_clients = {}
        self.meta_client = self._get_yt_client(self.META_CLUSTER)
        self.Context.clusters = self._discover_clusters()
        self.Context.cluster_layer_packages = self._fetch_cluster_packages()
        self.Context.defined_packages = self._fetch_defined_packages()
        self.Context.defined_geobases = self._fetch_defined_geobases()

        with self.memoize_stage.create_children:
            for environment in self.ENVIRONMENTS:
                # Skip environment if there was no changes.
                if not any(self.layer_needs_update(cluster, environment) for cluster in self.Context.cluster_layer_packages):
                    continue
                layer_build_params = {
                    BuildPortoLayer.ParentLayer.name: self.Parameters.parent_layer,
                    BuildPortoLayer.LayerType.name: self.Parameters.layer_type,
                    BuildPortoLayer.LayerName.name: "{}.{}".format(environment, self.Parameters.compress),
                    BuildPortoLayer.Compress.name: self.Parameters.compress,
                    BuildPortoLayer.ScriptUrl.name: self.Parameters.script_url,
                    # TODO: cluster -> environment
                    BuildPortoLayer.ScriptEnv.name: {'YT_ENVIRONMENT': environment},
                    BuildPortoLayer.MergeLayers.name: True,
                    BuildPortoLayer.DebugBuild.name: self.Parameters.debug_build,
                    BuildPortoLayer.SpaceLimit.name: self.Parameters.space_limit,
                    BuildPortoLayer.MemoryLimit.name: self.Parameters.memory_limit,
                }

                logging.debug("BUILD_PORTO_LAYER build params: %s", str(layer_build_params))
                task_class = sdk2.Task['BUILD_PORTO_LAYER']
                layer_build_task = task_class(
                    self,
                    description="User Layer for {} environment.".format(environment),
                    **{
                        key: value.id if isinstance(value, sdk2.Resource) else value
                        for key, value in layer_build_params.iteritems()
                    }
                ).enqueue()
                self.Context.build_tasks[environment] = layer_build_task.id

                logging.debug("BUILD_PORTO_LAYER build task: %d", layer_build_task.id)

                raise sdk2.WaitTask(self.Context.build_tasks.values(), ctt.Status.Group.FINISH | ctt.Status.Group.BREAK, wait_all=True)

        for environment, task_id in self.Context.build_tasks.items():
            if not all(task.status in ctt.Status.Group.SUCCEED for task in self.find(id=task_id)):
                raise common.errors.TaskFailure("Build porto layer failed")

            # Release task.
            parent_layer = sdk2.Resource.find(id=self.Parameters.parent_layer.id).first()
            if not parent_layer.state == State.READY:
                raise common.errors.TaskFailure("Parent layer: {} is not ready.".format(parent_layer.id))
            if parent_layer.released not in [ctt.ReleaseStatus.STABLE]:
                raise common.errors.TaskFailure("Parent layer: {} is not released stable.".format(parent_layer.id))
            with self.memoize_stage.release_task(commit_on_entrance=False):
                self.server.release(task_id=task_id, type=ctt.ReleaseStatus.STABLE, subject="Porto Layer YT [{}]".format(environment))

            layer_resource = sdk2.Resource.find(task_id=task_id, type=self.Parameters.layer_type).first()
            skynet_get(layer_resource.skynet_id, '.')

            for cluster in self.Context.cluster_layer_packages:
                if not self.layer_needs_update(cluster, environment):
                    logging.debug("Skipping up to date cluster \"{}\".".format(cluster))
                    continue

                yt_client = self._get_yt_client(cluster)
                yt_path = self._get_layer_path_for_environment(environment)
                with open(os.path.join(".", os.path.basename("{}.{}".format(environment, self.Parameters.compress)))) as f:
                    yt_client.write_file(yt_path, f)
                yt_client.set("{}/@sandbox_resource_id".format(yt_path), layer_resource.id)
                yt_client.set("{}/@sandbox_task_id".format(yt_path), task_id)
                yt_client.set("{}/@parent_layer_resource_id".format(parent_layer.id))
                yt_client.set("{}/@packages".format(yt_path), self.Context.defined_packages[environment])
                yt_client.set("{}/@geobases".format(yt_path), self.Context.defined_geobases)
