from sandbox import sdk2
from sandbox.sdk2.vcs.svn import Arcadia
from sandbox.sandboxsdk.process import run_process  # maybe sdk2.helpers.ProcessLog?
from sandbox.projects.common import error_handlers as eh
from sandbox.projects.common.arcadia import sdk as arcadia_sdk

import logging
import time

from os.path import join as pj, dirname
import os
import sys

from sandbox.sandboxsdk import environments

import datetime
import random
import string
from string import Template


def RunProcess(cmd, env, log_prefix=None, exception_if_nonzero_code=True):
    cmd_str = ' '.join([str(cmd_elem) for cmd_elem in cmd])
    process = run_process(
        cmd_str,
        outs_to_pipe=True, check=False, shell=True, wait=True,
        environment=env,
        log_prefix=log_prefix,
    )
    result, error = process.communicate()
    if exception_if_nonzero_code and process.returncode != 0:
        raise Exception(error)
    return result, error


class RemoveUsYtNodes(sdk2.Task):
    class Requirements(sdk2.Requirements):
        environments = [
            environments.PipEnvironment('yandex-yt', version='0.10.8'),
            environments.PipEnvironment('requests'),
            environments.PipEnvironment('networkx', version='2.2', use_wheel=True),
        ]

        cores = 1
        ram = 4096
        disk_space = 4096

        class Caches(sdk2.Requirements.Caches):
            pass

    class Parameters(sdk2.Parameters):
        labels = sdk2.parameters.String("Labels, ','-separated. For example !antifraud", default="")
        filter_clusters = sdk2.parameters.String("Filter YT clusters, ','-separated or empty. For example: hahn,arnold", default="")
        custom_today_datetime_str = sdk2.parameters.String("Custom datetime in format YYYY-MM-DD:HH:MM")

        with sdk2.parameters.Group("Fast sessions parameters") as fast_group:
            period_fast = sdk2.parameters.Bool("Enable time period: 30min", default=True)
            backlog_fast = sdk2.parameters.Integer("Backlog to watch, days", default=2)

        with sdk2.parameters.Group("Daily sessions parameters") as daily_group:
            period_daily = sdk2.parameters.Bool("Enable time period: 1d", default=True)
            backlog_daily = sdk2.parameters.Integer("Backlog to watch, days", default=7)

        with sdk2.parameters.Group("Reactor token") as reactor_token_block:
            reactor_token_secret_owner = sdk2.parameters.String("Owner of sb-vault-secret with reactor token", required=True)  # on practice: USERSESSIONSTOOLS. But do not specify default here for safety
            reactor_token_secret_name = sdk2.parameters.String("Name of sb-vault-secret with reactor token", required=True)  # on practice: robot-make-sessions-reactor-token (TODO: create secret). But do not specify default here for safety
            # TODO: maybe add defaults at least to description?

        with sdk2.parameters.Group("Yt token") as yt_token_block:
            yt_token_secret_owner = sdk2.parameters.String("Owner of sb-vault-secret with yt token", required=True)  # on practice: USERSESSIONSTOOLS. But do not specify default here for safety
            yt_token_secret_name = sdk2.parameters.String("Name of sb-vault-secret with yt token", required=True)  # on practice: robot-make-sessions-yt-token (TODO: create secret). But do not specify default here for safety
            # TODO: maybe add defaults at least to description?

        with sdk2.parameters.Group("Debug") as debug:
            dry_run = sdk2.parameters.Bool("Dry run", default=False)
            debug_yt_prefix = sdk2.parameters.String("Debug yt prefix", default="")

    class Context(sdk2.Context):
        first_time = True

    def GetReactorTokenPath(self):
        secret_content = self.GetReactorToken()
        reactor_token_path = pj(str(self.path()), 'reactor_token_file')
        os.system('echo {} > {}'.format(secret_content, reactor_token_path))

        return reactor_token_path

    def GetEnv(self, arcadia_user_sessions_local_abs_path):
        env = dict(os.environ)
        paths = env['PYTHONPATH'].split(':') if 'PYTHONPATH' in env else []
        paths.insert(0, arcadia_user_sessions_local_abs_path)
        env['PYTHONPATH'] = ':'.join(paths)

        return env

    def CheckoutArcadiaSubfolder(self, arcadia_subfolder, arcadia_src_dir, svn_url, use_cache=True):
        pos = svn_url.rfind('@')

        if pos != -1:
            dir_url = pj(svn_url[:pos], arcadia_subfolder) + svn_url[pos:]
        else:
            dir_url = pj(svn_url, arcadia_subfolder)

        arcadia_subfolder_local_abs_path = pj(arcadia_src_dir, arcadia_subfolder)

        if use_cache:
            with arcadia_sdk.mount_arc_path(dir_url, use_arc_instead_of_aapi=True) as p:
                sdk2.paths.copy_path(str(p), arcadia_subfolder_local_abs_path)
        else:
            sdk2.svn.Arcadia.checkout(dir_url, arcadia_subfolder_local_abs_path)

        return arcadia_subfolder_local_abs_path

    def PrepareArcadia(self, svn_url):
        if svn_url in self.arcadia_key_to_reactor_dir:
            return

        arcadia_src_dir = pj(str(self.path()), 'local_arcadia_' + \
            ''.join(random.SystemRandom(time.time()).choice(
                string.ascii_lowercase + string.ascii_uppercase + string.digits
            ) for _ in range(20)))

        self.arcadia_key_to_reactor_dir[svn_url] = self.CheckoutArcadiaSubfolder('quality/user_sessions/reactor', arcadia_src_dir, svn_url=svn_url)

    class ArtifactInstanceInfoAccumulator(object):
        def __init__(self, reactor_client):
            self.reactor_client = reactor_client
            self.user_timestamp_namespace_path_to_instances = {}
            self.err_accumulator = []

        def has_instance_with_target_usertime(self, namespace_path, user_timestamp):
            if user_timestamp not in self.user_timestamp_namespace_path_to_instances:
                self.user_timestamp_namespace_path_to_instances[user_timestamp] = {}

            if namespace_path not in self.user_timestamp_namespace_path_to_instances[user_timestamp]:
                try:
                    art_insts = self.reactor_client.get_artifact_range(namespace=namespace_path, limit=100, statuses=["CREATED", "ACTIVE"],
                                                                       from_user_ts=user_timestamp, to_user_ts=user_timestamp + 1)["range"]
                except Exception as exc:
                    self.err_accumulator.append(str(exc)[-100:])
                    logging.info('get_artifact_range FAILED on {art}'.format(art=namespace_path))
                    return False

                if len(art_insts) == 0:
                    self.user_timestamp_namespace_path_to_instances[user_timestamp][namespace_path] = None
                else:
                    self.user_timestamp_namespace_path_to_instances[user_timestamp][namespace_path] = art_insts

            if not self.user_timestamp_namespace_path_to_instances[user_timestamp][namespace_path]:
                return False
            else:
                return True

        def raise_for_errors(self):
            fail_msg = ""
            for error in self.err_accumulator:
                fail_msg += error + "\n"
            if fail_msg:
                raise Exception(fail_msg)


    def CheckArtifact(self, artifact_path, datetime_dt):
        from us_processes.time_util import _convert_to_unixtime
        usertime_ts = _convert_to_unixtime(datetime_dt)
        return self.artifact_instance_info_accumulator.has_instance_with_target_usertime(artifact_path, usertime_ts)

    def GetAllTemplateArgs(self, value):
        res = []
        for elem in string.Template.pattern.findall(value):
            if elem[1] != "":
                res.append(elem[1])
            else:
                res.append(elem[2])

        return res

    def DoRemove(self, cluster, path, lifetime, period, yt_err_accumulator):
        import yt.wrapper as yt
        from us_processes import reactor_datetime
        from us_processes.time_util import _convert_to_unixtime, period_name
        from us_processes.time_periods import Periods, DATETIME_FORMAT

        if lifetime.days_to_store == "inf":
            return

        all_template_args = self.GetAllTemplateArgs(path)

        today_dt = datetime.datetime.strptime(self.Context.today_dt[period_name(period)], DATETIME_FORMAT)

        backlog_days = self.Parameters.backlog_fast if period == Periods.FAST else self.Parameters.backlog_daily
        excluded_end_dt = today_dt - datetime.timedelta(days=lifetime.days_to_store)
        included_beg_dt = excluded_end_dt - datetime.timedelta(days=backlog_days)

        all_dts = []
        cur_dt = included_beg_dt
        while cur_dt < excluded_end_dt:
            all_dts.append(cur_dt)
            cur_dt += period

        for remove_dt in all_dts:
            args = {}
            for elem in all_template_args:
                # we are interested only in stripped identifier
                # (from possible add / sub additions)
                original_elem = elem
                elem = reactor_datetime.GetDatetimeAndDeltaFormat(elem)[0]

                if elem == reactor_datetime.UNIXTIME:
                    datetimeValue = _convert_to_unixtime(remove_dt)
                else:
                    dt_format = reactor_datetime.DATETIME_ALIAS_TO_REACTOR_AND_PYTHON_FORMAT[elem][1]
                    datetimeValue = datetime.datetime.strftime(remove_dt, dt_format)

                args[original_elem] = datetimeValue

            template = Template(path)
            result_path = template.safe_substitute(**args)

            if not result_path.startswith("//"):
                prefix = self.Parameters.debug_yt_prefix if self.Parameters.debug_yt_prefix else "//"
                result_path = pj(prefix, result_path)

            yt.config["proxy"]["url"] = cluster

            try:
                if not yt.exists(result_path) and not yt.exists(result_path + '&'):
                    continue
            except Exception as ex:
                yt_err_accumulator.append("YT error on table {}:{}: {}".format(cluster, result_path, str(ex)))
                continue

            logging.debug('{}:{} to be removed'.format(cluster, result_path))

            if lifetime.drop_cluster_wait_nonversioned_artifact:
                art_template = Template(lifetime.drop_cluster_wait_nonversioned_artifact)
                result_art = art_template.safe_substitute({"cluster": cluster})
                if not self.CheckArtifact(result_art, remove_dt):
                    logging.info("Reactor artifact '{}' is not instantiated. {}:{} skipped".format(result_art, cluster, result_path))
                    continue

            if self.Parameters.dry_run:
                continue

            def safe_remove(path, recursive):
                try:
                    yt.remove(path, recursive=recursive)
                    return True
                except Exception as ex:
                    yt_err_accumulator.append("YT error on cluster {} on path {}: {}".format(cluster, path, str(ex)))
                    return False

            def log_removed(path):
                if period == Periods.FAST:
                    logging.info('{}:{} ({}) removed'.format(cluster, path, remove_dt.strftime(DATETIME_FORMAT)))
                else:
                    logging.info('{}:{} removed'.format(cluster, path))

            def safe_empty(directory):
                try:
                    # NOTE: map_node directory uses 1 node resource, so this way we check that no children are present
                    return yt.get_attribute(directory, 'recursive_resource_usage/node_count') == 1
                except:
                    return False

            if not safe_remove(result_path, recursive=True):
                continue

            log_removed(result_path)

            parent_directory = dirname(result_path)
            while safe_empty(parent_directory):
                # safety measure to prevent non-empty directory removal
                if not safe_remove(parent_directory, recursive=False):
                    break

                log_removed(parent_directory)
                parent_directory = dirname(parent_directory)


    def GetYtRemovalInfo(self):
        from us_processes import reactor_tasks, create_sessions
        from us_processes.yt_removal import YTRemovalAccumulator
        from us_processes.time_periods import Periods

        time_periods = []
        if self.Parameters.period_daily:
            time_periods.append(Periods.DAILY)
        if self.Parameters.period_fast:
            time_periods.append(Periods.FAST)
        assert time_periods, "Time period is required"

        token = self.GetReactorToken()
        options = reactor_tasks.Options(
            reactor_server="MOCK",
            time_periods=time_periods,
            is_trouble_mode=False,
            version=None,
            svn_url="FAKE_svn_url",
            token=token,
            quota_project="fake-quota-project",
            graph_owner="robot-make-sessions",
            prod_version_out_file_path="fake_path",
        )

        labels = set()
        for label in self.Parameters.labels.split(','):
            if label:
                labels.add(label.strip())

        logging.info("Labels: {}".format(labels))
        options.set_labels(list(labels))

        YTRemovalAccumulator().TurnOnSafe()
        create_sessions.create_sessions(options)
        YTRemovalAccumulator().TurnOffSafe()
        removal_configs = YTRemovalAccumulator().GetConfigs()

        return removal_configs

    def GetReactorToken(self):
        secret_owner = self.Parameters.reactor_token_secret_owner
        secret_name = self.Parameters.reactor_token_secret_name
        token = sdk2.Vault.data(secret_owner, secret_name)
        return token

    def Remove(self, removal_configs):
        from us_processes.constants import REACTOR_SERVER
        from us_processes.mr_reaction_lib_config import MRTablesLifetime
        from us_reactor.lib.client import ReactorAPIClient
        import yt.wrapper as yt

        yt_err_accumulator = []

        secret_owner = self.Parameters.yt_token_secret_owner
        secret_name = self.Parameters.yt_token_secret_name
        yt_token = sdk2.Vault.data(secret_owner, secret_name)

        yt.config["token"] = yt_token

        self.reactor_client = ReactorAPIClient(REACTOR_SERVER, self.GetReactorToken())
        self.artifact_instance_info_accumulator = RemoveUsYtNodes.ArtifactInstanceInfoAccumulator(self.reactor_client)

        filter_clusters = filter(bool, map(lambda x: x.strip(), self.Parameters.filter_clusters.split(',')))

        for key, configs in removal_configs.iteritems():
            period, _ = key
            logging.info('Collected {} configs for period={}'.format(len(configs), period))
            for config in configs:
                all_clusters = config.GetAllClusters()
                for mr_output_path in config.output_tables:
                    if mr_output_path.lifetime == "inf":
                        continue
                    elif isinstance(mr_output_path.lifetime, tuple):
                        for cluster, lifetime in mr_output_path.lifetime:
                            if lifetime == "inf":
                                continue
                            if cluster not in all_clusters:
                                continue
                            if filter_clusters and cluster not in filter_clusters:
                                continue
                            self.DoRemove(cluster, mr_output_path.path, lifetime, period, yt_err_accumulator)
                    elif isinstance(mr_output_path.lifetime, MRTablesLifetime):
                        for cluster in all_clusters:
                            if filter_clusters and cluster not in filter_clusters:
                                continue
                            self.DoRemove(cluster, mr_output_path.path, mr_output_path.lifetime, period, yt_err_accumulator)
                    else:
                        raise Exception("Strange lifetime {}; type: {}".format(mr_output_path.lifetime, type(mr_output_path.lifetime)))

        ######################
        art_inst_info_accumulator_errors = self.artifact_instance_info_accumulator.err_accumulator

        fail_msg = ""

        for errors_list, errors_list_description in [(art_inst_info_accumulator_errors, "REACTOR ERRORS"), (yt_err_accumulator, "YT_ERRORS")]:
            if len(errors_list) > 0:
                fail_msg += "\n" + errors_list_description + "\n"
            for error in errors_list:
                fail_msg += error + "\n"

        if fail_msg:
            raise Exception("\n" + fail_msg)

    def CreateDatetime(self, period):
        from us_processes.time_util import datetime_round
        from us_processes.time_periods import parse_dt

        now_dt = datetime_round(datetime.datetime.now(), period)

        if self.Parameters.custom_today_datetime_str:
            custom_dt = parse_dt(self.Parameters.custom_today_datetime_str, period=period)
            if custom_dt > now_dt:
                eh.check_failed("You specified custom today-datetime which is bigger than today!")
                return
            return custom_dt

        return now_dt

    def InitializeTodayDatetime(self):
        from us_processes.time_util import period_name
        from us_processes.time_periods import Periods, DATETIME_FORMAT

        context = {}
        for period in Periods.ALL:
            current_dt = self.CreateDatetime(period)
            context[period_name(period)] = current_dt.strftime(DATETIME_FORMAT)

        self.Context.today_dt = context

    def on_execute(self):
        logging.info('RemoveUsYtNodesTask: Start')

        self.arcadia_key_to_reactor_dir = {}
        self.PrepareArcadia(Arcadia.ARCADIA_TRUNK_URL)

        reactor_token_path = self.GetReactorTokenPath()
        cmd = [
           sys.executable,
           "-m", "us_processes.get_production_svn_url",
           "-t", reactor_token_path
        ]

        env = self.GetEnv(self.arcadia_key_to_reactor_dir[Arcadia.ARCADIA_TRUNK_URL])
        result, error = RunProcess(cmd, env, log_prefix="get_production_svn_url")
        result = result.strip()

        # leave possibility for hotfix merged into release branch
        # but not in prod yet
        prod_svn_url = result.split('@')[0]

        self.PrepareArcadia(svn_url=prod_svn_url)
        sys.path.insert(0, self.arcadia_key_to_reactor_dir[prod_svn_url])
        # further we'll use this code through 'import us_processes'

        if self.Context.first_time:
            self.InitializeTodayDatetime()
            self.Context.first_time = False

        removal_configs = self.GetYtRemovalInfo()
        self.Remove(removal_configs)
