# -*- coding: utf-8 -*-

from sandbox import sdk2
from sandbox.sdk2.vcs.svn import Arcadia
from sandbox.sandboxsdk import environments

from sandbox.projects.common import error_handlers as eh
from sandbox.projects.logs.common import GetArcadia
import sandbox.projects.common.build.parameters as build_parameters

import os
import sys
import time
import shlex
import random
import string
import logging
from os.path import join as pj

DEFAULT_NIRVANA_QUOTA = 'user-sessions-ci'
DEFAULT_YT_POOL = 'userdata-sessions-build-ci'
DEFAULT_YT_PREFIX = '//home/userdata-sessions-build-ci/user-sessions-processes-ci/2016-12-12'
DEFAULT_REACTOR_PREFIX = '/home/robot-ci-sessions/user_sessions'
DEFAULT_DAYS_TO_STORE_OUTPUT = 3


# NOTE(ngc224): unfortunately we cannot use binary build now,
# so we cannot reuse quality/user_sessions/reactor/us_processes/logging.py code here..
def PatchRootLogger():
    rootLogger = logging.getLogger()
    formatter = logging.Formatter(
        fmt='%(asctime)s +%(relativeCreated)-8d %(levelname)8s %(name)s (%(threadName)s): %(message)s',
        datefmt='%F %T',
    )

    for handler in rootLogger.handlers:
        handler.setFormatter(formatter)


class UserSessionsReactorTest(sdk2.Task):
    class Requirements(sdk2.Requirements):
        environments = [
            environments.PipEnvironment('yandex-yt', version='0.10.8'),
            environments.PipEnvironment('requests'),
            environments.PipEnvironment('networkx', version='2.2', use_wheel=True),
        ]

    class Parameters(sdk2.Parameters):
        arcadia_url = sdk2.parameters.ArcadiaUrl('Arcadia URL', required=True, default_value=Arcadia.ARCADIA_TRUNK_URL)
        arcadia_patch = sdk2.parameters.String(
            build_parameters.ArcadiaPatch.description,
            default=build_parameters.ArcadiaPatch.default,
            multiline=True
        )

        with sdk2.parameters.Group("YT params") as yt_group:
            yt_cluster = sdk2.parameters.String('YT cluster', default='hahn')
            yt_token_secret_owner = sdk2.parameters.String("Owner of sb-vault-secret with yt token", required=True, default="USERSESSIONSTOOLS")
            yt_token_secret_name = sdk2.parameters.String("Name of sb-vault-secret with yt token", required=True, default='userdata-sessions-build-ci-token')
            yt_pool = sdk2.parameters.String('YT pool', default=DEFAULT_YT_POOL)
            yt_root_path = sdk2.parameters.String('YT working path', default=DEFAULT_YT_PREFIX)
            days_to_store = sdk2.parameters.Integer('Days to store output', default=DEFAULT_DAYS_TO_STORE_OUTPUT)

        with sdk2.parameters.Group("Reactor params") as reactor_group:
            reactor_token_secret_owner = sdk2.parameters.String("Owner of sb-vault-secret with reactor token", required=True, default="USERSESSIONSTOOLS")
            reactor_token_secret_name = sdk2.parameters.String("Name of sb-vault-secret with reactor token", required=True, default='ci_reactor_token')
            reactor_root_path = sdk2.parameters.String("Reactor working path", required=True, default=DEFAULT_REACTOR_PREFIX)
            reactor_queue_parallelism = sdk2.parameters.Integer('Queue parallelism', default=3)
            reactor_retries_count = sdk2.parameters.Integer('Reactions retries count', default=3)
            reactor_retries_delay = sdk2.parameters.Integer('Reactions retries delay, minutes', default=2)

        with sdk2.parameters.Group("Nirvana params") as nirvana_group:
            nirvana_token_secret_owner = sdk2.parameters.String("Owner of sb-vault-secret with nirvana token", required=True, default="USERSESSIONSTOOLS")
            nirvana_token_secret_name = sdk2.parameters.String("Name of sb-vault-secret with nirvana token", required=True, default='ci_nirvana_token')
            nirvana_yt_token_secret_name = sdk2.parameters.String("Name of nirvana-secret with YT token", required=True, default='robot-ci-sessions-yt-token')
            nirvana_yql_token_secret_name = sdk2.parameters.String("Name of nirvana-secret with YQL token", required=True, default='robot-ci-sessions-yql-token')
            nirvana_quota_name = sdk2.parameters.String("Quota name on nirvana", required=True, default=DEFAULT_NIRVANA_QUOTA)
            nirvana_graph_owner = sdk2.parameters.String("Launched graphs owner", required=True, default='robot-ci-sessions')

        with sdk2.parameters.Group("Build binaries") as build_group:
            bin_use_prod = sdk2.parameters.Bool("Use production binaries", required=True, default=False)

        with sdk2.parameters.Group("Fast sessions parameters") as fast_group:
            period_fast = sdk2.parameters.Bool("Enable time period: 30min", default=True)
            with period_fast.value[True]:
                fast_datetime = sdk2.parameters.String("Datetime in format %Y-%m-%d:%H:%M", required=True, default='2016-12-12:21:00')
                fast_labels = sdk2.parameters.String("Labels, comma-separated")

        with sdk2.parameters.Group("Daily sessions parameters") as daily_group:
            period_daily = sdk2.parameters.Bool("Enable time period: 1d", default=True)
            merge_from_fast = sdk2.parameters.Bool("Create 1d sessions from merged 30min sessions", default=False)
            with period_daily.value[True]:
                daily_datetime = sdk2.parameters.String("Datetime in format %Y-%m-%d", required=True, default='2016-12-12')
                daily_labels = sdk2.parameters.String("Labels, comma-separated")

        with sdk2.parameters.Group("Debug") as debug:
            fail_on_first = sdk2.parameters.Bool("Fail on first error", default=True)
            dry_run = sdk2.parameters.Bool("Dry run", default=False)

        recalc_mode = sdk2.parameters.Bool("Set recalc mode", default=False)
        with recalc_mode.value[True]:
            kill_timeout = 10 * 3600

    class Context(sdk2.Context):
        pass

    def GetYTToken(self):
        secret_owner = self.Parameters.yt_token_secret_owner
        secret_name = self.Parameters.yt_token_secret_name
        token = sdk2.Vault.data(secret_owner, secret_name)
        return token

    def GetNirvanaTokenPath(self):
        secret_owner = self.Parameters.nirvana_token_secret_owner
        secret_name = self.Parameters.nirvana_token_secret_name
        secret_content = sdk2.Vault.data(secret_owner, secret_name)
        nirvana_token_path = pj(str(self.path()), 'nirvana_token_file')
        os.system('echo {} > {}'.format(secret_content, nirvana_token_path))
        return nirvana_token_path

    def GetReactorToken(self):
        secret_owner = self.Parameters.reactor_token_secret_owner
        secret_name = self.Parameters.reactor_token_secret_name
        token = sdk2.Vault.data(secret_owner, secret_name)
        return token

    def GetReactorTokenPath(self):
        secret_content = self.GetReactorToken()
        reactor_token_path = pj(str(self.path()), 'reactor_token_file')
        os.system('echo {} > {}'.format(secret_content, reactor_token_path))
        return reactor_token_path

    def ConfigureEnv(self):
        sys.path.insert(0, self.Context.reactor_dir)
        os.environ['PYTHONPATH'] = ':'.join(filter(bool, sys.path + [os.environ.get('PYTHONPATH')]))

        #  generate paths
        self.Context.salt = ''.join(random.SystemRandom(time.time()).choice(
            string.ascii_lowercase + string.ascii_uppercase + string.digits
        ) for _ in range(5))

        yt_root_path = '//' + self.Parameters.yt_root_path.strip('/')

        if self.Parameters.arcadia_patch:
            working_path = 'r{}_{}'.format(self.Context.version, self.Context.salt)
        else:
            working_path = 'r{}'.format(self.Context.version)

        self.Context.yt_prefix = pj(yt_root_path, 'work', working_path)
        self.Context.reactor_prefix = '/' + pj(self.Parameters.reactor_root_path.strip('/'), working_path)
        self.Context.reactor_token_path = self.GetReactorTokenPath()

        #  prepare YT env
        import yt.wrapper as yt

        yt.config["proxy"]["url"] = self.Parameters.yt_cluster
        yt.config["token"] = self.GetYTToken()

        attributes = {}
        if not self.Parameters.recalc_mode:
            attributes['expiration_time'] = int(time.time() + self.Parameters.days_to_store * 24 * 60 * 60) * 1000

        yt.create('map_node', self.Context.yt_prefix, recursive=True, force=True, attributes=attributes)

        #  TODO hacks for simultaneous run create_sessions/create_scarab_sessions
        yt.create('map_node', pj(self.Context.yt_prefix, 'build', 'logs'),
                  recursive=True, force=True)

        raw_logs_path = pj(yt_root_path, 'raw_logs')
        for name in yt.list(raw_logs_path):
            yt.link(pj(raw_logs_path, name), pj(self.Context.yt_prefix, name))

        os.environ.update(dict(
            YT_PREFIX=self.Context.yt_prefix,
            US_YT_ROOT=self.Context.yt_prefix,
            US_YT_CLUSTER=self.Parameters.yt_cluster,
            US_YT_POOL=self.Parameters.yt_pool or self.author,
            US_NIRVANA_VAULT_YT_TOKEN=self.Parameters.nirvana_yt_token_secret_name,
            US_NIRVANA_VAULT_YQL_TOKEN=self.Parameters.nirvana_yql_token_secret_name,
        ))

    def CreateBuildSubtasksIfNeeded(self):
        if self.Parameters.bin_use_prod:
            return

        raise NotImplementedError()

    def CreateReactorGraph(self):
        logging.info('CreateReactorGraph start')

        from us_reactor.lib import model
        import us_processes.sessions_processes_config as config

        from us_processes.constants import REACTOR_SERVER
        from us_processes.time_periods import Periods, parse_dt
        from us_processes.reactor_tasks import Options
        from us_processes.create_sessions_ci import create_sessions_ci

        time_periods = []
        if self.Parameters.period_daily:
            time_periods.append(Periods.DAILY)
        if self.Parameters.period_fast:
            time_periods.append(Periods.FAST)
        assert time_periods, "Time period is required"

        self.dt_daily = None
        self.dt_fast = None
        merge_intervals = None

        if self.Parameters.period_fast:
            self.dt_fast = [parse_dt(fast_dt, Periods.FAST) for fast_dt in self.Parameters.fast_datetime.split(',')]

        if self.Parameters.period_daily:
            self.dt_daily = [parse_dt(day, Periods.DAILY) for day in self.Parameters.daily_datetime.split(',')]
            if self.Parameters.merge_from_fast:
                assert self.Parameters.period_fast, "Fast period is not enabled"
                assert self.dt_fast, "No fast datetimes parsed"
                merge_intervals = sorted(list(set(map(lambda dt: dt.hour * 60 + dt.minute, self.dt_fast))))

        config.getRetriesSpec = lambda: model.NirvanaReactionAutoRetries(retry_number=self.Parameters.reactor_retries_count,
                                                                         minutes_delay=self.Parameters.reactor_retries_delay)

        options = self.sessions_options = Options(
            reactor_server=REACTOR_SERVER,
            time_periods=time_periods,
            token=self.Context.reactor_token_path,
            is_trouble_mode=False,
            is_ci_mode=True,
            test_reactor_prefix=self.Context.reactor_prefix,
            version=self.Context.version,
            svn_url=self.Context.arcadia_url,
            quota_project=self.Parameters.nirvana_quota_name,
            graph_owner=self.Parameters.nirvana_graph_owner,
            parallelism=self.Parameters.reactor_queue_parallelism,
            recalc_mode=self.Parameters.recalc_mode,
            merge_intervals=merge_intervals,
        )

        options.set_verbose_mode(True)
        options.set_debug_mode(True)

        labels = set()
        for label in self.Parameters.daily_labels.split(',') + self.Parameters.fast_labels.split(','):
            if label:
                labels.add(label.strip())
        logging.info("Labels: {}".format(labels))
        options.set_labels(list(labels))

        logging.info('create_sessions start')
        self.sessions_tasks = create_sessions_ci(options)

        logging.info('create_sessions end')

    def LaunchReactorGraph(self):
        logging.info('LaunchTasks start')
        from us_processes.time_periods import Periods

        if not self.Parameters.recalc_mode:
            logging.info('mandatory delay before artifacts instantiation due to reactions activation process')
            time.sleep(180)

        self.sessions_tasks.LaunchTasks(
            self.sessions_options,
            dt_daily=self.dt_daily,
            dt_fast=self.dt_fast,
        )
        logging.info('LaunchTasks end')

        return {Periods.FAST: self.dt_fast, Periods.DAILY: self.dt_daily}

    def MonitorProgress(self, cluster, period, dt):
        logging.info('Monitor progress for cluster={}, period={}, dt={}'.format(cluster, period, dt))

        from us_processes import check_progress, time_util
        from us_processes.time_periods import DATETIME_FORMAT
        from us_processes.reactions_graph_visitors import getFinalReactionName

        final_artifact = getFinalReactionName(period, cluster)
        logging.info("Checking final artifact: {}".format(final_artifact))

        cmd_line = "-tp {prefix} -t '{token}' -c {cluster} -l {reac_inst_limit} -vh {ver} -ns {ns} -dt {dt} -p {period_str}".format(
            prefix=self.Context.reactor_prefix,
            token=self.Context.reactor_token_path,
            cluster=cluster, reac_inst_limit=5,
            ver=self.Context.version,
            ns=final_artifact,
            dt=dt.strftime(DATETIME_FORMAT),
            period_str=time_util.period_name(period),
        )

        if self.Parameters.fail_on_first:
            cmd_line += " --fail-on-first"

        cmd_line += " --always-show-duration"
        cmd_line += " --verbose"

        logging.debug("python -m us_processes.check_progress " + cmd_line)

        errors = []
        result = check_progress.main(shlex.split(cmd_line), errors_accumulator=errors)

        return result, errors

    def WaitReactor(self, cluster, periods):
        logging.info('WaitReactor - check_progress start')

        from us_processes.check_progress import CODE_STATUS_OK, CODE_STATUS_IN_PROGRESS, CODE_STATUS_FAILED

        while True:
            ok = []
            for period, dts in periods.items():
                if not dts:
                    continue

                for dt in dts:
                    status, errors = self.MonitorProgress(cluster, period, dt)

                    if status == CODE_STATUS_FAILED:
                        eh.ensure(False, 'Reactor tasks FAILED:\n{}\n{}'.format(
                            'cluster={} period={} dt={}'.format(cluster, period, dt), '\n'.join(errors[-20:])
                        ))
                        return
                    elif status == CODE_STATUS_IN_PROGRESS:
                        ok.append(False)
                    elif status == CODE_STATUS_OK:
                        ok.append(True)
                    else:
                        raise Exception("Unknown check_progress status code: {}".format(status))

            if all(ok):
                assert len(ok)
                self.Context.progress_status = 'OK'
                return

            logging.debug('... sleeping')
            time.sleep(60)

    def CheckSessions(self, cluster, periods):
        self.Context.progress_status = 'TODO check non-empty optput tables'
        return False

    def on_execute(self):
        PatchRootLogger()

        #  make RM retries idempotent
        if not getattr(self.Context, 'timestamp'):
            self.Context.timestamp = int(time.time())

        # speedup arcadia checkout
        path = 'quality/user_sessions/reactor' if self.Parameters.bin_use_prod else None

        with GetArcadia(self.Parameters.arcadia_url, path=path, fuse=False) as arcadia_dir:
            self.Context.revision = Arcadia.get_revision(arcadia_dir)
            self.Context.version = '{}.{}'.format(self.Context.revision, self.Context.timestamp)

            if self.Parameters.arcadia_patch:
                Arcadia.apply_patch(arcadia_dir, self.Parameters.arcadia_patch, self.path())

            self.Context.reactor_dir = pj(arcadia_dir, 'quality/user_sessions/reactor')
            self.Context.arcadia_url = '{}@{}'.format(self.Parameters.arcadia_url.split('@')[0], self.Context.revision)

            self.ConfigureEnv()
            self.CreateBuildSubtasksIfNeeded()

            self.CreateReactorGraph()

            if self.Parameters.dry_run:
                logging.info("DRY RUN. No launch was initiated")
            else:
                periods = self.LaunchReactorGraph()
                time.sleep(60) # wait until all reactor instances instantiated
                self.WaitReactor(self.Parameters.yt_cluster, periods)
                #  self.CheckSessions(self.Parameters.yt_cluster, periods)
