# -*- coding: utf-8 -*-

import aniso8601
import datetime
import json
import logging
import os

from sandbox import sdk2

from sandbox.common.errors import TaskFailure, VaultNotFound
from sandbox.common.types.resource import State
from sandbox.common.types.task import Status

from sandbox.projects.common.nanny import nanny
from sandbox.projects.mssngr.common.util import fetch_package
from sandbox.projects.mssngr.runtime import resources
from sandbox.projects.mssngr.runtime import MssngrDeployStateCache as mdsc

from sandbox.sandboxsdk import environments, process


def postproc_state_cacher(parent_task):
    cacher_resource = sdk2.Resource.find(
        task=parent_task,
        resource_type=resources.MssngrRouterStateCacher,
    ).first()

    if cacher_resource.version:
        return

    try:
        cacher_bin_path = os.path.join(fetch_package(parent_task, cacher_resource), "mssngr-state-cacher")
        out, _ = process.run_process([cacher_bin_path, "--version"], outs_to_pipe=True).communicate()
    except Exception as e:
        logging.warn("Failed to determine state version of {}: {}".format(cacher_resource.id, e))
    else:
        data = json.loads(out)

        # XXX compat
        if "common" in data and "sharded" in data:
            versions = []
            for val in data.values():
                versions += ["{}_{}".format(k, v) for k, v in val.iteritems()]
            cacher_resource.version = json.dumps(sorted(versions))
        else:
            cacher_resource.version = json.dumps(data)

        logging.info("Set version for {}".format(cacher_resource.id))


def release_state_cacher(parent_task, release_status):
    env = None

    for k, v in mdsc.ENVS.iteritems():
        if v.release_status == release_status:
            env = k

    if not env:
        logging.info("Unsupported release status: {}".format(release_status))
        return

    cacher_resource = sdk2.Resource.find(
        task=parent_task,
        resource_type=resources.MssngrRouterStateCacher,
        state=State.READY
    ).first()

    MssngrAutoDeployStateCache(
        parent_task,
        create_sub_task=False,
        description=parent_task.Parameters.description,
        env=env,
        cacher=cacher_resource
    ).enqueue()


class MssngrAutoDeployStateCache(sdk2.Task):
    """Builds and deploys actual versions of mssngr-worker state cache"""

    class Requirements(sdk2.Task.Requirements):
        environments = [
            environments.PipEnvironment("yandex-yt", use_wheel=True),
        ]
        cores = 1
        ram = 2048
        disk_space = 2048

        class Caches(sdk2.Requirements.Caches):
            pass

    class Parameters(mdsc.BaseParameters):
        cacher = sdk2.parameters.LastReleasedResource(
            "State cacher binary to check",
            resource_type=resources.MssngrRouterStateCacher,
            state=(State.READY),
            required=False,
        )

        support_prev_versions = sdk2.parameters.Bool(
            "Build previous cache versions for quick rollback",
            default=True,
        )

        version_ttl = sdk2.parameters.Integer(
            "Previous versions support time (days)",
            default=1,
        )

    def _fetch_releases(self, statuses, sort_by, limit):
        releases = self.nanny_client.find_releases({
            "limit": limit,
            "sort": {"field": sort_by},
            "query": {
                "sandboxResourceType": [resources.MssngrRouterStateCacher.name],
                "sandboxReleaseType": [self.env.release_status],
                "status": statuses,
            }
        })

        result = []
        for rel in releases:
            for res in rel["spec"]["sandboxRelease"]["resources"]:
                if res["type"] == resources.MssngrRouterStateCacher.name:
                    result.append((rel["id"], res["id"], rel["status"].get("endTime")))
        return result

    def _filter_deployed(self, releases):
        result = []
        for rel in releases:
            resp = self.nanny_client.filter_tickets({
                "query": {
                    "release_id": rel[0],
                    "status": ["DEPLOY_SUCCESS"],
                }
            })
            if len(resp) >= 3:  # 3 locations
                result.append(rel)
        return result

    def _get_cacher_version(self, resource):
        if resource.version:
            data = json.loads(resource.version)
            if isinstance(data, list):
                return set(data), None
            else:
                modes = {}
                versions = []
                for k, v in data.iteritems():
                    versions.append("{}_{}".format(k, v["version"]))
                    modes[k] = v["mode"]
                return set(versions), modes
        return None, None

    def _build_exec_plan(self, prev, current, new, auto_mode):
        type2mode = dict()
        versions = dict()

        actual_versions = set()
        cacher2mode = dict()

        cachers = filter(None, (prev, current, new))
        for c in cachers:
            version, modes = self._get_cacher_version(c)
            if version:
                actual_versions.update(version)
                versions[c] = version
            if modes:
                type2mode.update(modes)

        if auto_mode:
            # XXX compat
            if type2mode:
                def remap(x):
                    t, v = x.rsplit("_", 1)
                    return type2mode[t], v

                for k, v in versions.iteritems():
                    versions[k] = set(map(remap, v))

            base = new or current or prev
            for c in cachers:
                if c == base:
                    cacher2mode[c] = None  # collect all types
                    logging.debug("Collecting cache for base version: {}".format(c.id))
                else:
                    diff = versions[c].difference(versions[base])
                    if diff:
                        mode = sum(map(lambda x: x[0], diff)) if type2mode else None
                        cacher2mode[c] = mode
                        versions[base].update(diff)
                        logging.debug("Collecting cache for aux version: {}, mode: {}".format(c.id, mode))
        else:
            cacher2mode[new] = None

        return actual_versions, cacher2mode

    def _run_subtask(self, cacher_resource, mode=None):
        description = "{}, run with cacher {}, mode {}".format(self.Parameters.description, cacher_resource.id, mode)
        return mdsc.MssngrDeployStateCache(
            self,
            create_sub_task=False,

            description=description,
            notifications=self.Parameters.notifications,
            kill_timeout=self.Parameters.kill_timeout,

            env=self.Parameters.env,
            resource_ttl=self.Parameters.resource_ttl,
            yt_proxy=self.Parameters.yt_proxy,
            yt_path=self.Parameters.yt_path,
            bstr=self.Parameters.bstr,
            pusher=self.Parameters.pusher,
            custom_mode=self.Parameters.custom_mode or mode,

            cacher=cacher_resource
        ).enqueue()

    def _get_from_vault(self, key):
        try:
            value = sdk2.Vault.data(self.owner, key)
        except VaultNotFound:
            value = sdk2.Vault.data("MSSNGR", key)
        return str(value)

    def _cleanup_old_versions(self, versions=None):
        import yt.wrapper as yt

        yt.config["proxy"]["url"] = self.Parameters.yt_proxy
        yt.config["token"] = self._get_from_vault("yt-token-{}".format(self.Parameters.yt_proxy))

        path = "{}/{}".format(self.Parameters.yt_path, self.Parameters.env)
        groups = {}

        for t in yt.list(path):
            has_ts_in_name = str(t).endswith('.lz4')
            name = str(t).rsplit("_", 1)[0] if has_ts_in_name else str(t).split(".")[0]
            if versions is not None:
                if name not in versions:
                    table = "{}/{}".format(path, t)
                    logging.info("Gonna drop {}".format(table))
                    yt.remove(table)
            elif has_ts_in_name:
                groups.setdefault(name, []).append(t)

        for _, files in groups.iteritems():
            files.sort()
            for t in files[:-1]:
                table = "{}/{}".format(path, t)
                logging.info("Gonna drop {}".format(table))
                yt.remove(table, recursive=True)

    def on_execute(self):
        subtasks = list(self.find(task_type=mdsc.MssngrDeployStateCache))
        if subtasks:
            for task in subtasks:
                if task.status not in Status.Group.SUCCEED:
                    raise TaskFailure("Subtask failed with status {}".format(task.status))
            self._cleanup_old_versions()
            return

        self.env = mdsc.ENVS[self.Parameters.env]
        self.nanny_client = nanny.NannyClient(
            api_url="http://nanny.yandex-team.ru/",
            oauth_token=self._get_from_vault("nanny_oauth_token")
        )

        subtasks = []

        closed_releases = self._fetch_releases(statuses=["CLOSED"], sort_by=["-status.end_time"], limit=10)
        deployed = self._filter_deployed(closed_releases)

        logging.debug("Cacher deployed releases: {}".format(deployed))

        current_cacher = None
        prev_cacher = None
        new_cacher = None

        if deployed:
            _, res_id, deploy_time = deployed[0]
            deploy_time = aniso8601.parse_datetime(deploy_time).replace(tzinfo=None)
            time_since_deploy = datetime.datetime.utcnow() - deploy_time
            logging.debug("Time since deploy: {}".format(time_since_deploy))

            current_cacher = sdk2.Resource[res_id]
            if len(deployed) > 1 and time_since_deploy < datetime.timedelta(days=self.Parameters.version_ttl):
                prev_cacher = sdk2.Resource[deployed[1][1]]

        if self.Parameters.cacher:
            new_cacher = self.Parameters.cacher
        else:
            last_released_cacher_id = self._fetch_releases(statuses=["OPEN", "CLOSED"], sort_by=["-meta.creation_time"], limit=1)[0][1]
            new_cacher = sdk2.Resource[last_released_cacher_id]

        auto_mode = self.Parameters.support_prev_versions and self.Parameters.cacher is None
        actual_versions, cacher2mode = self._build_exec_plan(prev_cacher, current_cacher, new_cacher, auto_mode)

        for c, m in cacher2mode.iteritems():
            subtasks.append(self._run_subtask(c, m))

        try:
            self._cleanup_old_versions(actual_versions)
        except Exception as e:
            logging.warning("Unable to clean up bstr tables: {}".format(e))

        if subtasks:
            raise sdk2.WaitTask(subtasks, Status.Group.FINISH | Status.Group.BREAK, wait_all=True)
