# -*- coding: utf-8 -*-

import multiprocessing
import os
import re
import sys
import time
import traceback

from sandbox import sdk2
from sandbox.common.types.client import Tag
from sandbox.sdk2 import yav

from sandbox.sandboxsdk import environments
from sandbox.sandboxsdk.errors import SandboxTaskFailureError
from sandbox.sandboxsdk.parameters import SandboxStringParameter
from sandbox.sandboxsdk.task import SandboxTask

from sandbox.projects.resource_types import TASK_LOGS
from sandbox.projects.common import utils as sb_utils
from sandbox.projects.common.apihelpers import list_task_resources
from sandbox.projects.common.yabs.graphite import Graphite, five_sec_metric, metric_point

from .job_yt_client import JobYtClient, YtQuorumLock
from .lb_client import LbClient
from .utils import get_job_logger
from .config import (
    MAX_LOCK_FAILURES, NO_LOCK_SLEEP,
    YAV_SECRET_ID
)


class YtPathPrefix(SandboxStringParameter):
    name = 'yt_path_prefix'
    description = 'Common prefix for all Yt paths'
    default_value = '//home/yabs/lbyt-reader-testing'


class JobsRegex(SandboxStringParameter):
    name = 'jobs_regex'
    description = 'Regular expression to filter jobs'
    default_value = '.*'


class YabsLbYtSandboxTask(SandboxTask):
    MIN_ITERATION_SLEEP_TIME = 10  # must be not less than metric frequency
    NOTHING_TO_DO_SLEEP_TIME = 60
    JOB_NOT_WORKED_CODE = 123
    JOB_FAILED_CODE = 111

    input_parameters = (
        YtPathPrefix,
        JobsRegex,
    )
    environment = (
        environments.PipEnvironment('requests'),
        environments.PipEnvironment('yandex-yt', '0.8.38a1', use_wheel=True),
        environments.PipEnvironment('yandex-yt-yson-bindings-skynet', use_wheel=True),
    )
    client_tags = (Tag.CUSTOM_XURMA | Tag.GENERIC) & Tag.SSD & Tag.LINUX_PRECISE
    required_ram = 120 * 1024

    def on_execute(self):
        logs_resource = sdk2.Resource.find(task_id=self.id, type='TASK_LOGS').first()
        self.logs_url = logs_resource.http_proxy

        processes = []
        jobs_regex = re.compile(sb_utils.get_or_default(self.ctx, JobsRegex) + '$')
        for job in self._lbyt_jobs:
            if not jobs_regex.match(job):
                continue
            proc = multiprocessing.Process(
                name=job,
                target=self.job_worker,
                args=(job,)
            )
            processes.append(proc)
            proc.start()
        for proc in processes:
            proc.join()

        self.descr, failed = self.parse_results(processes)
        if failed:
            raise SandboxTaskFailureError('Some jobs failed, see description for details')

    # Job functions

    def job_worker(self, job):
        self.job = job
        try:
            self.logger = get_job_logger(self.job)
            self.config = self._lbyt_jobs[self.job]
            self.logger.info('Started process %s for job %s', os.getpid(), self.job)

            from yt.wrapper.version import VERSION as wrapper_version
            self.logger.info('Using yt.wrapper version %s', wrapper_version)

            graphite_namespace = 'lbyt'
            if os.path.isfile('/etc/testing'):
                graphite_namespace = 'lbyt-testing'
            # Metric name must be consistent with MIN_ITERATION_SLEEP_TIME (see above)
            self.base_metric = five_sec_metric(graphite_namespace, job)
            self.graphite_metrics = []

            yt_path_prefix = sb_utils.get_or_default(self.ctx, YtPathPrefix).strip()
            self.init_clusters('yt', lambda cluster: JobYtClient(
                job=self.config.get('job', self.job),
                cluster=cluster,
                path_prefix=yt_path_prefix,
                get_token_func=(lambda: yav.Secret(YAV_SECRET_ID).data()["yt"]),
            ))
            self.init_clusters('lb', LbClient)

            if not self.job_init_before_lock():
                sys.exit(self.JOB_FAILED_CODE)

            lock = YtQuorumLock(self.yt_clusters, self.type)
            if lock.have_queue_on_quorum():
                self.logger.warn('Other task for job %s is already waiting for lock', self.job)
                sys.exit(self.JOB_NOT_WORKED_CODE)

            self.logger.info('Waiting for lock for job %s from the previous task', self.job)
            have_successful_iteration = False
            have_essential_iteration = False
            lock_failures = 0
            with lock:
                if not self.job_init_after_lock():
                    sys.exit(self.JOB_FAILED_CODE)
                while True:
                    self.start_timestamp = time.time()
                    lock.check_quorum_state()
                    if lock.have_waiting_task():
                        self.logger.info(
                            'Other task is waiting for lock for job %s, releasing it',
                            self.job,
                        )
                        break
                    if not lock.have_quorum():
                        self.logger.warning(
                            'No lock quorum for job %s (retries: %d of %d)',
                            self.job, lock_failures, MAX_LOCK_FAILURES,
                        )
                        lock_failures += 1
                        if lock_failures > MAX_LOCK_FAILURES:
                            self.logger.error('No lock quorum tries left, failing job')
                            sys.exit(self.JOB_FAILED_CODE)
                        time.sleep(NO_LOCK_SLEEP)
                        continue
                    else:
                        lock_failures = 0
                    if self.need_work():
                        have_essential_iteration = True
                        self.logger.info('Starting iteration for job %s', self.job)
                        success, next_iteration_since = self.do_job_work()
                        have_successful_iteration = (have_successful_iteration or success)
                        self.logger.info('Iteration for job %s is finished', self.job)
                        self.flush_metrics(timeout=1)
                    else:
                        self.logger.warn('Job %s is disabled via job control', self.job)
                        next_iteration_since = self.start_timestamp + self.NOTHING_TO_DO_SLEEP_TIME
                    next_iteration_since = max(
                        next_iteration_since,
                        self.start_timestamp + self.MIN_ITERATION_SLEEP_TIME,
                    )
                    self.logger.info(
                        'Waiting for next iteration at ts >= %d',
                        next_iteration_since,
                    )
                    time.sleep(max(0, next_iteration_since - time.time()))
            self.flush_metrics(timeout=10)
            if not have_essential_iteration:
                sys.exit(self.JOB_NOT_WORKED_CODE)
            if not have_successful_iteration:
                sys.exit(self.JOB_FAILED_CODE)
        except (Exception, KeyboardInterrupt):
            if hasattr(self, 'logger'):
                self.logger.error(
                    'Job %s raised an exception. %s',
                    self.job, traceback.format_exc(),
                )
            if hasattr(self, 'graphite_metrics'):
                self.flush_metrics(timeout=10)
            sys.exit(1)

    def job_init_before_lock(self):
        return True

    def job_init_after_lock(self):
        return True

    def do_job_work(self):
        raise NotImplementedError()

    # Metrics

    def add_metric(self, name, value, **kwargs):
        self.graphite_metrics.append(
            metric_point(self.base_metric(name, **kwargs), value, self.start_timestamp)
        )

    def flush_metrics(self, timeout):
        if not self.graphite_metrics:
            return
        self.add_metric('work_time', time.time() - self.start_timestamp)
        self.logger.info('Sending metrics for job %s to Graphite', self.job)
        try:
            Graphite(timeout=timeout, logger=self.logger).send(self.graphite_metrics)
            self.graphite_metrics = []
        except Exception as err:
            self.logger.warn('Failed to send metrics to Graphite: %s', err)

    # Helpers

    def init_clusters(self, subject, client_creator):
        if subject not in self.config:
            return
        clusters = self.config[subject]
        if isinstance(clusters, basestring):
            client = client_creator(clusters)
            setattr(self, subject + '_cluster', client)
            setattr(self, subject + '_clusters', [client])
        else:
            setattr(self, subject + '_clusters', [client_creator(cluster) for cluster in clusters])

    def parse_results(self, processes):
        done_jobs = []
        failed_jobs = []
        empty_jobs = []

        job_states = {proc.name: proc.exitcode for proc in processes}
        for job in sorted(job_states):
            job_string = '- <a href="{}/{}.log">{}</a>'.format(self.logs_url, job, job)
            if job_states[job] == self.JOB_NOT_WORKED_CODE:
                empty_jobs.append(job_string)
            elif job_states[job] == self.JOB_FAILED_CODE:
                failed_jobs.append(job_string + ' [ERR]')
            elif job_states[job] != 0:
                failed_jobs.append(job_string + ' [EXC]')
            else:
                done_jobs.append(job_string)

        description = []
        if failed_jobs:
            description.append('Err logs\n' + '\n'.join(failed_jobs))
        if done_jobs:
            description.append('Logs\n' + '\n'.join(done_jobs))
        if empty_jobs:
            description.append('Boring logs\n' + '\n'.join(empty_jobs))

        return '\n&\n'.join(description), bool(failed_jobs)
