# -*- coding: utf-8 -*-

from __future__ import absolute_import, print_function, division

import os
import datetime
import json
import logging
import random
import time
import contextlib

import requests

from sandbox import common
from sandbox import sdk2

from sandbox.sandboxsdk.environments import PipEnvironment

from sandbox.common.types import misc as ctm

from .supervisor import ProcessSupervisor
from .. import (
    STATBOX_ABT_METRICS_BIN,
    STATBOX_ABT_METRICS_RESULT,
    STATBOX_ABT_BASE_LXC_IMAGE
)

logger = logging.getLogger(__name__)


# JSON keys used to interact with server

REQ_KEYS = type('REQ_KEYS', (), {})()

REQ_KEYS.SANDBOX_ID = 'sandbox_id'
REQ_KEYS.ABT_ID = 'abt_id'
REQ_KEYS.STATUS = 'status'
REQ_KEYS.PROGRESS = 'progress'
REQ_KEYS.INFO = 'info'

RES_KEYS = type('RES_KEYS', (), {})()

RES_KEYS.COMMAND = 'command'

# Commands received from server

CMD = type('CMD', (), {})()

CMD.CONTINUE = 'CONTINUE'
CMD.ABORT = 'ABORT'
CMD.TERMINATE = 'TERMINATE'

# Statuses received from sandbox and sent to server

STATUS = type('STATUS', (), {})()

STATUS.QUEUE = 'QUEUE'
STATUS.EXECUTE = 'EXECUTE'
STATUS.TEMPORARY = 'TEMPORARY'
STATUS.WAIT = 'WAIT'
STATUS.SUCCESS = 'SUCCESS'
STATUS.FAILURE = 'FAILURE'
STATUS.EXCEPTION = 'EXCEPTION'
STATUS.TIMEOUT = 'TIMEOUT'
STATUS.STOPPED = 'STOPPED'


class StatboxAbtJob(sdk2.Task):
    class Requirements(sdk2.Requirements):
        cores = 1
        dns = ctm.DnsType.DNS64
        disk_space = 5 * 1024
        ram = 1 * 1024
        environments = [PipEnvironment('cryptography')]

        class Caches(sdk2.Requirements.Caches):
            pass

    Requirements = Requirements  # type: StatboxAbtJob.Requirements

    class Context(sdk2.Task.Context):
        fail_type = ''
        fail_message = ''
        fail_context = ''

    Context = Context  # type: StatboxAbtJob.Context

    class Parameters(sdk2.Task.Parameters):
        kill_timeout = 48 * 60 * 60
        max_restarts = 10

        # custom parameters
        service = sdk2.parameters.String(
            'Service', required=True
        )  # type: str
        metric_groups = sdk2.parameters.List(
            'Metric groups', required=True
        )  # type: list
        experiments = sdk2.parameters.List(
            'Experiments', required=True
        )  # type: list
        start_date = sdk2.parameters.String(
            'Start date', required=True,
            description='In format of YYYY-MM-DD'
        )  # type: str
        end_date = sdk2.parameters.String(
            'End date', required=True,
            description='In format of YYYY-MM-DD'
        )  # type: str
        scale = sdk2.parameters.String(
            'Scale', required=True, default='daily',
        )  # type: str

        abt_id = sdk2.parameters.String(
            'Abt id', required=True,
            description='Id generated by adminka'
        )  # type: str

        heartbeat_url = sdk2.parameters.Url(
            'Heartbeat url', required=True,
            description='Url used for sending regular heartbeats'
        )  # type: str
        heartbeat_rate = sdk2.parameters.Float(
            'Heartbeat rate', default=5,
            description='Rate of heartbeats in seconds'
        )  # type: float
        heartbeat_sd = sdk2.parameters.Float(
            'Heartbeat standard deviation', default=3,
            description='Standard deviation of a random value added to '
                        'heartbeat rate to avoid load peaks'
        )  # type: float
        max_silent_period = sdk2.parameters.Float(
            'Maximal silent period', default=600,
            description='Task will fail if it can\'t send a heartbeat '
                        'during this time in seconds'
        )  # type: float

        config_url = sdk2.parameters.Url(
            'Config url', required=True,
            description='Url used for receiving base config'
        )  # type: str

        statbox_abt_key_vault_item_owner = sdk2.parameters.String(
            'ABT key vault item owner',
            default='STATBOX_ABT', required=True
        )  # type: str

        statbox_abt_key_vault_item_name = sdk2.parameters.String(
            'ABT key vault item name',
            default='statbox_abt_key', required=True
        )  # type: str

        statbox_abt_metrics_bin = sdk2.parameters.Resource(
            'Metrics binaries',  # required=True,
            resource_type=STATBOX_ABT_METRICS_BIN
        )

        container = sdk2.parameters.Container(
            'LXC Container',
            resource_type=STATBOX_ABT_BASE_LXC_IMAGE,
            required=True,
        )

    Parameters = Parameters  # type: StatboxAbtJob.Parameters

    def on_execute(self):
        self.set_info('Starting task {}'.format(self.Parameters.abt_id))
        logger.info('starting task %r', self.Parameters.abt_id)

        ok, result = self.send_heartbeat(STATUS.EXECUTE, progress=0, delay=15)
        self.process_heartbeat_result(ok, result)

        self.setup_config_file()

        res_path = 'resource.json'

        resource = STATBOX_ABT_METRICS_RESULT(
            self,
            'Calculation result for task {}'.format(self.Parameters.abt_id),
            res_path,
            r_service=self.Parameters.service,
            r_metric_groups=', '.join(self.Parameters.metric_groups),
            r_experiments=', '.join(self.Parameters.experiments),
            r_start_date=self.Parameters.start_date,
            r_end_date=self.Parameters.end_date,
            r_scale=self.Parameters.scale,
            abt_id=self.Parameters.abt_id,
        )

        resource_data = sdk2.ResourceData(resource)

        command_parameters = dict(
            res_path=res_path,
            scale=self.Parameters.scale,
            s=self.Parameters.service,
            m=' '.join(self.Parameters.metric_groups),
            e=' '.join(self.Parameters.experiments),
            start_date=self.Parameters.start_date,
            end_date=self.Parameters.end_date,
            yql_udf_bin='libsabt-yql.so',
            run_python_udf_bin='run_python_udf',
        )

        command_template = (
            './bin run'
            '    --format json_production'
            '    --output {res_path}'
            '    --default-config-file {config_file}'
            '    --reporter machine_readable'
            '    --start-date {start_date}'
            '    --end-date {end_date}'
            '    --scale {scale}'
            '    -s {s} -m {m} -e {e}'
            '    --yt-yql-udf-bin {yql_udf_bin}'
            '    --yt-run-python-udf-bin {run_python_udf_bin} '
        )

        bins = sdk2.ResourceData(self.Parameters.statbox_abt_metrics_bin)
        sdk2.helpers.subprocess.check_call(['tar', '-xf', str(bins.path)])

        with self.setup_config_file() as config_file:
            command_parameters['config_file'] = config_file
            command = command_template.format(**command_parameters)

            process_env = os.environ.copy()
            process_env['STATBOX_LOG_PATH'] = sdk2.paths.get_logs_folder()

            process = ProcessSupervisor(command, process_env)

            try:
                self.do_run(process)
            except BaseException:
                logger.exception('captured error, cleaning up')
                process.terminate_process()
                raise

        resource_data.ready()

        self.set_info('Task {} finished'.format(self.Parameters.abt_id))
        logger.info('task %r finished', self.Parameters.abt_id)

    def do_run(self, process):
        process.run()

        last_heartbeat = datetime.datetime.now()
        max_silent_period = datetime.timedelta(seconds=self.Parameters.max_silent_period)

        while True:
            if not process.is_running():
                return_code = process.get_return_code()
                logger.info('subprocess has finished with code %s', return_code)
                if return_code != 0:
                    context = 'Subprocess exited with code {}.\n\n' \
                              'Captured stderr follows:\n\n{}\n' \
                              .format(return_code, process.get_output())
                    self.raise_fail('An error occurred when executing task.',
                                    context)
                return

            progress = process.get_progress()
            ok, result = self.send_heartbeat(STATUS.EXECUTE, progress=progress)

            if ok:
                last_heartbeat = datetime.datetime.now()
                self.process_heartbeat_result(ok, result)

            if datetime.datetime.now() - last_heartbeat > max_silent_period:
                self.raise_exception('Task was not able to reach server for more than {} seconds.'.format(max_silent_period.total_seconds()))

            sleep_time = self.Parameters.heartbeat_rate
            sleep_time += (random.random() - 0.5) * self.Parameters.heartbeat_sd

            time.sleep(sleep_time)

    @contextlib.contextmanager
    def setup_config_file(self):
        from cryptography.fernet import Fernet

        response = requests.get(self.Parameters.config_url)
        response.raise_for_status()

        key = sdk2.Vault.data(
            self.Parameters.statbox_abt_key_vault_item_owner,
            self.Parameters.statbox_abt_key_vault_item_name
        )

        cipher = Fernet(key)

        try:
            config = cipher.decrypt(response.content)
        except Exception:
            raise RuntimeError('error when decrypting config')

        filename = os.path.expanduser('~/config.yaml')

        logger.info('setting up config %s', filename)

        with open(filename, 'w') as f:
            os.chmod(filename, 0o600)
            f.write(config)

        try:
            yield filename
        finally:
            os.remove(filename)

    def process_heartbeat_result(self, ok, result):
        if not ok:
            return
        command = result[RES_KEYS.COMMAND]
        if command == CMD.CONTINUE:
            pass
        elif command == CMD.ABORT:
            self.raise_stop('Task was aborted by user request.')
        elif command == CMD.TERMINATE:
            self.raise_exception('Task got command to terminate itself.')
        else:
            self.raise_exception('Backend responded with an invalid command {!r}.'.format(command))

    def raise_stop(self, message, context=''):
        logger.warn('stopping: %r -- %r', message, context)
        self.set_info('stopping: {}\n\n{}'.format(message, context))

        self.Context.fail_type = STATUS.STOPPED
        self.Context.fail_message = message
        self.Context.fail_context = context
        self.Context.save()

        raise common.errors.TaskStop(message)

    def raise_fail(self, message, context=''):
        logger.warn('failing: %r -- %r', message, context)
        self.set_info('failing: {}\n\n{}'.format(message, context))

        self.Context.fail_type = STATUS.FAILURE
        self.Context.fail_message = message
        self.Context.fail_context = context
        self.Context.save()

        raise common.errors.TaskFailure(message)

    def raise_exception(self, message, context=''):
        logger.warn('raising an exception: %r -- %r', message, context)
        self.set_info('raising an exception: {}\n\n{}'.format(message, context))

        self.Context.fail_type = STATUS.EXCEPTION
        self.Context.fail_message = message
        self.Context.fail_context = context
        self.Context.save()

        raise common.errors.TaskError(message)

    def format_message(self, status):
        if self.Context.fail_type == status and self.Context.fail_message:
            message = self.Context.fail_message
        elif status == STATUS.SUCCESS:
            message = 'Task succeed'
        elif status == STATUS.STOPPED:
            message = 'Task was hard-stopped.'
        elif status == STATUS.FAILURE:
            message = 'An error occurred when executing task.'
        elif status == STATUS.TIMEOUT:
            message = 'Task was aborted by timeout.'
        else:
            message = 'An internal error occurred when executing this task.'
            status = STATUS.EXCEPTION

        if self.Context.fail_type == status and self.Context.fail_context:
            message += '\n\n{}\n'.format(self.Context.fail_context)

        if status != STATUS.SUCCESS:
            message += '\nCheck logs of the sandbox task {} for more info.'.format(self.id)

        return message

    # Note: we cannot send heartbeats in 'on_enqueue' as it would cause
    # deadlock. That is, backend has locked row for update, sent a request to
    # start task. This request waits till on_enqueue finishes. on_enqueue waits
    # for heartbeat and heartbeat waits to take exclusive lock on row.

    def on_success(self, prev_status):
        message = self.format_message(STATUS.SUCCESS)
        self.send_heartbeat(STATUS.SUCCESS, message=message, progress=1, retry=5, delay=5)

    def on_failure(self, prev_status):
        message = self.format_message(STATUS.FAILURE)
        self.send_heartbeat(STATUS.FAILURE, message=message, progress=1, retry=5, delay=5)

    def on_break(self, prev_status, status):
        message = self.format_message(status)
        self.send_heartbeat(status, message=message, progress=1, retry=5, delay=5)

    def on_wait(self, prev_status, status):
        self.send_heartbeat(status=STATUS.WAIT, retry=5, delay=5)

    def send_heartbeat(self, status, progress=0, message=None, retry=5, delay=1):
        logger.info('sending heartbeat -- %s %s %r', status, progress, message)

        retry_codes = (408, 429, 502, 503, 504, 598, 599)

        for i in range(retry + 1):
            try:
                result = requests.post(
                    self.Parameters.heartbeat_url,
                    json={
                        REQ_KEYS.SANDBOX_ID: self.id,
                        REQ_KEYS.ABT_ID: self.Parameters.abt_id,
                        REQ_KEYS.STATUS: status,
                        REQ_KEYS.PROGRESS: progress,
                        REQ_KEYS.INFO: message,
                    },
                    verify=False
                )
            except requests.ConnectionError:
                result = requests.Response()
                result.status_code = 503

            if result.status_code in retry_codes and i != retry:
                logger.warn('heartbeat failed with code %r, going to retry in %s seconds', result.status_code, delay)
                time.sleep(delay)
            elif result.status_code in retry_codes:
                logger.warn('heartbeat failed with code %r, no more retries left', result.status_code)
                return False, None
            else:
                result.raise_for_status()
                logger.info('heartbeat succeeded with code %r, response is %r', result.status_code, result.text)
                return True, json.loads(result.text)
