# coding: utf-8

######################################################################
#
# Есть две задачи - TravelRunBinary и TravelRunBinaryTesting
# Чтобы не ломать основную  TravelRunBinary, просьба все эксперименты/изменения
# сначала коммитить в TravelRunBinaryTesting, а потом, после уверенности,
# что всё работает - в TravelRunBinary
#
############################################################################

import os
import stat
import logging
import json
from contextlib import contextmanager
from datetime import datetime, timedelta
import yaml
import re

from sandbox import sdk2, common
from sandbox.sdk2.helpers import subprocess
from sandbox.sdk2 import yav
from sandbox.projects.common import solomon
from sandbox.projects.common.arcadia import sdk as arcadia_sdk
from sandbox.projects.common.ya_deploy import release_integration
import sandbox.common.types.client as ctc
import sandbox.common.types.misc as ctm
import sandbox.common.types.task as ctt

import traceback

SECRET_PATTERN = re.compile("(sec-[0-9a-zA-Z]+)(\\.(ver-[0-9a-zA-Z]+))?(\\.key-([0-9a-zA-Z\\-_]+))?")
RESOURCE_PATTERN = re.compile("(resource-([0-9A-Z_]+))(\\.released_at=(stable|testing))?")
ARCADIA_MOUNT_POINT_PATTERN = re.compile("ARCADIA_MOUNT_POINT")
SANDBOX_TASK_ID_PATTERN = re.compile("SANDBOX_TASK_ID")

SUBPROCESS_NAME = 'travel_binary'


class SolomonProgress(object):

    def __init__(self, solomon_token, params, labels, enabled, update_callback, points):
        self.solomon_token = solomon_token
        self.params = params
        self.labels = labels
        self.enabled = enabled
        self.update_callback = update_callback
        self.points = points

    @staticmethod
    def _timestamp(dt):
        return int(dt.strftime('%s'))

    def _make_sensor(self, ts, sensor_name, sensor_value):
        labels = {
            'sensor': sensor_name,
        }
        labels.update(self.labels)
        return {
            'labels': labels,
            'ts': ts,
            'value': sensor_value
        }

    def _push(self, sensors):
        if not self.enabled:
            return
        logging.info('Going to push sensors to solomon: %s' % sensors)
        try:
            solomon.push_to_solomon_v2(self.solomon_token, self.params, sensors)
        except Exception as e:
            logging.error('Failed to push to solomon: %s', str(e))

    def update(self, dt):
        ts = self._timestamp(dt)
        sensor = self._make_sensor(ts, 'running', 1)
        self._push([sensor])
        self.points.append(ts)

        if self.update_callback is not None:
            self.update_callback(self.points)

    def finish(self, status):
        if not self.enabled:
            return
        logging.info('Sending %s status to Solomon', status)
        sensors = [self._make_sensor(ts, status, 1) for ts in self.points]
        last_point = self._timestamp(datetime.now())
        if self.points and self.points[-1] == last_point:
            last_point += 1
        sensors.append(self._make_sensor(last_point, status, 0))
        sensors.append(self._make_sensor(last_point, 'running', 0))
        self._push(sensors)

    @contextmanager
    def context(self):
        self.update(datetime.now())
        try:
            yield
        except Exception:
            self.finish('failed')
            raise
        self.finish('success')


def all_subclasses(cls):
    return set(cls.__subclasses__()).union([s for c in cls.__subclasses__() for s in all_subclasses(c)])


RESOURCES_DICT = {cls.__name__: cls for cls in all_subclasses(sdk2.Resource)}


class TravelRunBinaryTesting(release_integration.ReleaseToNannyAndYaDeployTask2, sdk2.Task):
    subtask_binary_path = None

    class Requirements(sdk2.Task.Requirements):
        # https://wiki.yandex-team.ru/sandbox/clients/#client-tags-multislot
        # Requirements to fit to multislot
        cores = 1  # exactly 1 core
        ram = 8192  # 8GiB or less
        client_tags = ctc.Tag.Group.LINUX

        class Caches(sdk2.Requirements.Caches):
            pass  # means that task do not use any shared caches

    class Parameters(sdk2.Task.Parameters):
        """
        """

        with sdk2.parameters.Group("Binary resource") as resource_bin:
            binary_resource_id = sdk2.parameters.String("Resource ID", required=False)
            binary_resource_type = sdk2.parameters.String("Resource Type", required=False)
            with sdk2.parameters.String("Released at", required=False, default_value="testing", hint=True) as released_at:
                released_at.values.testing = released_at.Value('testing')
                released_at.values.stable = released_at.Value('stable')
            do_extract = sdk2.parameters.Bool("Perform extract from archive", required=False, default=False)
            use_shell = sdk2.parameters.Bool("Use Shell", required=False, default=True)
            requires_dns64 = sdk2.parameters.Bool("Requires DNS64", default=False)
            use_utf8_locale = sdk2.parameters.Bool("Use UTF-8 Locale", default=True)

        with sdk2.parameters.Group("Arguments") as params:
            command = sdk2.parameters.String("Command (autodetect, if empty)")
            args = sdk2.parameters.List("Arguments (secrets as sec-xxxx[.ver-yyyy][.key-zzzz])", value_type=sdk2.parameters.String)

        with sdk2.parameters.Group("Solomon") as solomon:
            enable_solomon = sdk2.parameters.Bool("Push status to solomon")
            with enable_solomon.value[True]:
                solomon_token = sdk2.parameters.String('Token', default_value="sec-01d1z7bzb38cvzk7zsxakgvrbf", required=True)
                solomon_project = sdk2.parameters.String('Project', default_value="travel")
                solomon_cluster = sdk2.parameters.String('Cluster (empty -> auto)', default_value="")
                solomon_service = sdk2.parameters.String('Service', default_value="sandbox")
                solomon_common_labels = sdk2.parameters.Dict('Common labels', default={})

        with sdk2.parameters.Group("Concurrency") as concurrency:
            semaphore_name = sdk2.parameters.String("Semaphore Name", default="")
            semaphore_capacity = sdk2.parameters.Integer("Max concurrent runs, (0 = inf)", default=0)

        with sdk2.parameters.Group("Output Resources") as output_resources:
            output_resources_file = sdk2.parameters.String("Output Resources File", default="")

        with sdk2.parameters.Group("Custom notifications") as custom_notifications:
            notify_email = sdk2.parameters.String("EMail Address", default="")
            notify_caption = sdk2.parameters.String("Caption", default="")
            notify_on_success = sdk2.parameters.Bool("Notify on success")

        with sdk2.parameters.Group("Arcadia parameters") as arcadia_parameters:
            mount_arcadia = sdk2.parameters.Bool("Mount arcadia")
            with mount_arcadia.value[True]:
                arcadia_url = sdk2.parameters.String("Arcadia URL", default="")

        with sdk2.parameters.Group("Internal parameters") as internal_parameters:
            plan_item_hash = sdk2.parameters.String("Plan Item Hash")

    class Context(sdk2.Context):
        solomon_points = '[]'

    def on_enqueue(self):
        if self.Parameters.semaphore_capacity > 0:
            semaphore_name = self.Parameters.semaphore_name
            if not semaphore_name:
                semaphore_name = "travel_sem_%s_%s" % (self.Parameters.released_at,
                                                       self.Parameters.binary_resource_type or self.Parameters.binary_resource_id)
            logging.info("Will use semaphore(s) %s of capacity %s" % (semaphore_name, self.Parameters.semaphore_capacity))
            self.Requirements.semaphores = ctt.Semaphores(
                acquires=[ctt.Semaphores.Acquire(name=name,
                                                 weight=1, capacity=self.Parameters.semaphore_capacity)
                          for name in semaphore_name.split(',')],
            )
        if self.Parameters.requires_dns64:
            self.Requirements.dns = ctm.DnsType.DNS64

    def do_work(self, arcadia_src_dir=None):
        if self.Parameters.released_at == 'testing':
            env = 'testing'
        elif self.Parameters.released_at == 'stable':
            env = 'prod'
        else:
            raise Exception("Unknown value for released_at: ")

        resource_path = self.download_main_resource()

        command = self.determine_command(resource_path)
        params, environ = self.resolve_parameters(arcadia_src_dir)
        args = [command] + params

        tags = [
            'ENV:' + env,
            'TYPE:' + self.Context.actual_resource_type
        ]
        for tag in self.Parameters.tags:
            if not (tag.startswith('ENV:') or tag.startswith('TYPE:')):
                tags.append(tag)
        self.Parameters.tags = tags

        solomon_progress = self.new_solomon_progress()

        with sdk2.helpers.ProcessLog(self, logger=SUBPROCESS_NAME) as process_log:
            with solomon_progress.context():
                self.run_binary(args, environ, process_log, solomon_progress)
                self.create_resources()
        if self.Parameters.notify_on_success:
            self.send_custom_notification("Success")

    def do_work_wrapped(self):
        try:
            if self.Parameters.mount_arcadia:
                logging.info("Mounting arcadia by url %s" % self.Parameters.arcadia_url)
                with arcadia_sdk.mount_arc_path(self.Parameters.arcadia_url, use_arc_instead_of_aapi=True) as arcadia_src_dir:
                    logging.info("Arcadia mounted to %s" % arcadia_src_dir)
                    self.do_work(arcadia_src_dir=arcadia_src_dir)
            else:
                self.do_work()
        except subprocess.CalledProcessError as e:
            self.handle_subprocess_exception(e)
            # Explicitly put to failure state
            raise common.errors.TaskFailure("Subprocess failed: %s" % str(e))
        except Exception:
            self.handle_other_exception()
            raise

    def on_execute(self):
        self.do_work_wrapped()

    def on_success(self, prev_status):
        if not self.Parameters.output_resources_file:
            return
        resources = yaml.safe_load(open(self.Parameters.output_resources_file, 'rt').read())
        release_to = resources.get('release')
        if release_to is not None:
            self._self_release(release_to["env"], release_to["title"], release_to.get("comments"))

    def get_logs_tails(self):
        lines = []

        def process_log(suffix):
            filename = '%s.%s.log' % (SUBPROCESS_NAME, suffix)
            full_filename = str(self.log_path(filename))
            if not os.path.exists(full_filename):
                return
            tail_lines = subprocess.check_output(['tail', '-20', full_filename]).split('\n')
            while tail_lines and not tail_lines[-1]:
                tail_lines.pop()
            lines.append("")
            lines.append("------------------ %s tail     ------------------" % filename)
            lines.extend(tail_lines)
            lines.append("------------------ End %s tail ------------------" % filename)
            rel_path = os.path.relpath(full_filename, os.getcwd())
            lines.append("Full log here: https://proxy.sandbox.yandex-team.ru/task/%s/%s" % (self.id, rel_path))

        process_log('err')
        process_log('out')
        return lines

    def handle_subprocess_exception(self, e):
        lines = [(str(e))]
        lines.extend(self.get_logs_tails())
        self.send_custom_notification("Failed with retcode %s" % e.returncode, lines)

    def handle_other_exception(self):
        lines = ['Exception:']
        lines.extend(traceback.format_exc().split('\n'))
        self.send_custom_notification("Unexpected exception", lines)

    def on_timeout(self, prev_status):
        logging.info('Triggered on_timeout handler')
        solomon_progress = self.new_solomon_progress()
        logging.debug('solomon_progress.points = %r', solomon_progress.points)
        solomon_progress.finish('failed')
        self.send_custom_notification("Timeout", add_lines=self.get_logs_tails())
        super(TravelRunBinaryTesting, self).on_timeout(prev_status)

    def on_solomon_update(self, points):
        logging.debug('Triggered on_solomon_update handler')
        self.Context.solomon_points = json.dumps(points)
        self.Context.save()

    def new_solomon_progress(self):
        solomon_token = None
        if self.Parameters.enable_solomon:
            m = SECRET_PATTERN.match(self.Parameters.solomon_token)
            if m is None:
                raise RuntimeError('Solomon token should be passed as secret!')
            solomon_token = self.resolve_secret(m.group(0), m.group(1), m.group(3), m.group(5))
        cluster = self.Parameters.solomon_cluster
        if not cluster:
            cluster = 'push_prod' if self.Parameters.released_at == 'stable' else 'push_testing'
        solomon_params = {
            'project': self.Parameters.solomon_project,
            'cluster': cluster,
            'service': self.Parameters.solomon_service,
        }
        return SolomonProgress(
            solomon_token=solomon_token,
            params=solomon_params,
            labels=self.Parameters.solomon_common_labels,
            enabled=self.Parameters.enable_solomon,
            update_callback=self.on_solomon_update,
            points=json.loads(self.Context.solomon_points),
        )

    def create_resources(self):
        if not self.Parameters.output_resources_file:
            return
        resources = yaml.safe_load(open(self.Parameters.output_resources_file, 'rt').read())
        for r in resources['resources']:
            attrs = r.get('attrs', dict())
            ResourceClass = RESOURCES_DICT[r['type']]
            created_res = ResourceClass(self, r.get('description', "No descr"), r['path'], **attrs)
            sdk2.ResourceData(created_res).ready()

    def run_binary(self, args, environ, process_log, solomon_progress):
        step = timedelta(minutes=1)
        next_point = datetime.now()

        joint_args = ' '.join(args).replace('\n', ' ')

        full_environ = dict()
        if self.Parameters.use_utf8_locale:
            full_environ['LC_ALL'] = 'en_US.UTF-8'
            full_environ['LANG'] = 'en_US.UTF-8'
            full_environ['LANGUAGE'] = 'en_US.UTF-8'
        full_environ.update(environ)

        process = subprocess.Popen(
            joint_args if self.Parameters.use_shell else args,
            shell=self.Parameters.use_shell,
            stdout=process_log.stdout,
            stderr=process_log.stderr,
            env=full_environ,
        )

        result = None
        while True:
            if result is not None:
                if result:
                    raise subprocess.CalledProcessError(result, joint_args)
                break
            solomon_progress.update(datetime.now())
            next_point += step
            try:
                result = process.wait(timeout=(next_point - datetime.now()).total_seconds())
            except subprocess.TimeoutExpired:
                pass

    def download_main_resource(self):
        if self.Parameters.binary_resource_id:
            resource = sdk2.Resource[self.Parameters.binary_resource_id]
        elif self.Parameters.binary_resource_type:
            ResourceClass = RESOURCES_DICT[self.Parameters.binary_resource_type]
            resource = self._get_last_released_resource(ResourceClass, self.Parameters.released_at)
        else:
            raise RuntimeError("Required resource id or type")
        self.Context.actual_resource_id = resource.id
        self.Context.actual_resource_type = str(resource.type)
        self.Context.actual_resource_released_at = str(resource.released)
        resource_path = sdk2.ResourceData(resource).path.as_posix()
        return resource_path

    def download_resource(self, resource_type, resource_release_status):
        ResourceClass = RESOURCES_DICT[resource_type]
        if resource_release_status:
            resource = self._get_last_released_resource(ResourceClass, resource_release_status)
        else:
            logging.info("Looking for resource %s" % ResourceClass)
            resource = ResourceClass.find(state='READY').order(-sdk2.Resource.id).first()
            if not resource:
                raise RuntimeError("Resource not found by type %s" % ResourceClass)
            logging.info("Got resource by type %s: %s" % (ResourceClass, resource))
        resource_path = sdk2.ResourceData(resource).path.as_posix()
        return resource_path

    def _get_last_released_resource(self, resource_type, resource_release_status):
        """
        Returns last released resource with given type and with
        release status that is less or equals to the given one.

        Release statuses have the hierarchy:
        stable < testing.
        """
        logging.info("Looking for resource %s with release status %s", resource_type, resource_release_status)

        last_released_resources = []

        for release_status in ["stable", "testing"]:
            attrs = dict(released=release_status)
            resource = resource_type.find(state='READY', attrs=attrs).order(-sdk2.Resource.id).first()
            if resource:
                last_released_resources.append(resource)
            if release_status == resource_release_status:
                break

        logging.info("Found released resources: %s", last_released_resources)
        if not last_released_resources:
            raise RuntimeError("Failed to find resource of type %s with status %s", resource_type, resource_release_status)
        last_released_resource = max(last_released_resources, key=lambda x: x.id)
        logging.info("Last released to %s resource of type %s: %s" % (resource_release_status, resource_type, last_released_resource))
        return last_released_resource

    def determine_command(self, resource_path):
        if os.path.isdir(resource_path):
            files = os.listdir(resource_path)
            logging.info("Resource is dir %s, contents: %s" % (resource_path, files))
            if self.Parameters.do_extract:
                raise Exception("Cannot perform extraction on directory")
            if not self.Parameters.command:
                if len(files) == 1:
                    command = os.path.join(resource_path, files[0])
                else:
                    raise RuntimeError("Resource is a directory with %s files, do not know which to run" % files)
            else:
                command = os.path.join(resource_path, self.Parameters.command)
        else:
            logging.info("Resource is a file: %s" % resource_path)
            if self.Parameters.do_extract:
                logging.info("Extracting...")
                work_dir = os.getcwd()
                subprocess.check_call(['tar', '-xzf', resource_path], shell=False, cwd=work_dir)
                logging.info("Extracted. Current : %s" % os.listdir(work_dir))
                if not self.Parameters.command:
                    raise RuntimeError("Resource is archive (%s), and command is not specified - it is wrong" % resource_path)
                command = os.path.join(work_dir, self.Parameters.command)
            else:
                if self.Parameters.command:
                    command = self.Parameters.command
                else:
                    command = resource_path
        logging.info("Resulting command is %s" % command)
        os.chmod(command, os.stat(command).st_mode | stat.S_IEXEC)
        return command

    def resolve_secret(self, sec_arg, secret_id, version, key):
        secret_data = yav.Secret(secret_id, version=version, default_key=key).data()
        # See https://wiki.yandex-team.ru/sandbox/yav/#permissions
        if key is None:
            if len(secret_data) == 1:
                value = list(secret_data.values())[0]
            else:
                raise RuntimeError("Secret %s has multiple keys" % sec_arg)
        else:
            value = secret_data[key]
        vault_filter = common.log.VaultFilter.filter_from_logger(logging.getLogger())
        if vault_filter:
            vault_filter.add_record(sec_arg, value)
        else:
            raise RuntimeError("Failed to setup VaultFilter")
        return value

    def resolve_parameters(self, arcadia_src_dir):
        params = []
        replaces = dict()
        for arg_string in self.Parameters.args:
            params.append(arg_string)
            for m in SECRET_PATTERN.finditer(arg_string):
                replaces[m.group(0)] = self.resolve_secret(m.group(0), m.group(1), m.group(3), m.group(5))
            for m in RESOURCE_PATTERN.finditer(arg_string):
                replaces[m.group(0)] = self.download_resource(m.group(2), m.group(4))
            for m in ARCADIA_MOUNT_POINT_PATTERN.finditer(arg_string):
                replaces[m.group(0)] = arcadia_src_dir
            for m in SANDBOX_TASK_ID_PATTERN.finditer(arg_string):
                replaces[m.group(0)] = str(self.id)
        environ = {
            'ARG_REPLACES': json.dumps(replaces, indent=None)
        }
        return params, environ

    def send_custom_notification(self, status, add_lines=None):
        if not self.Parameters.notify_email:
            return
        subject = '%s: %s' % (self.Parameters.notify_caption, status)
        lines = []
        lines.append('Task: https://sandbox.yandex-team.ru/task/%s/view : %s' % (self.id, status))
        lines.append('')
        lines.append("Tags:")
        for tag in sorted(self.Parameters.tags):
            lines.append('  ' + tag)
        if add_lines is not None:
            lines.append('')
            lines.extend(add_lines)
        logging.info("Sending notification to %s, subject '%s'" % (self.Parameters.notify_email, subject))
        self.server.notification(
            subject=subject,
            body='\n'.join(lines),
            recipients=[self.Parameters.notify_email],
            transport=common.types.notification.Transport.EMAIL
        )

    def _get_release_environment(self, env):
        mapping = {
            'testing': ctt.ReleaseStatus.TESTING,
            'stable': ctt.ReleaseStatus.STABLE,
            'prod': ctt.ReleaseStatus.STABLE,
            'unstable': ctt.ReleaseStatus.UNSTABLE
        }
        return mapping.get(env, ctt.ReleaseStatus.TESTING)

    def _self_release(self, env, title, comments):
        self.on_release(dict(
            releaser=self.author,
            release_status=self._get_release_environment(env),
            release_subject=title,
            email_notifications=dict(to=[], cc=[]),
            release_comments=comments,
        ))
