# -*- coding: utf-8 -*-
from __future__ import absolute_import, division, print_function, unicode_literals

import logging
import os
import subprocess
import tarfile

import requests
import shutil

from sandbox import sdk2
from sandbox.common.errors import TaskFailure
from sandbox.common.types import task as common_task_types
from sandbox.projects.geobase.Geodata6BinStable.resource import GEODATA6BIN_STABLE
from sandbox.projects.rasp.qloud.UpdateResources import RaspQloudUpdateResources
from sandbox.projects.rasp.resource_types import RaspBanditResource
from sandbox.projects.rasp.utils.email_notifications import (
    EmailNotificationMixin,
    use_email_notification_params,
    TRAIN_GROUP,
)

log = logging.getLogger(__name__)


class RaspBanditProcessing(sdk2.Task, EmailNotificationMixin):
    BINARY_DIR = 'cmd/processing/'
    BINARY_PATH = BINARY_DIR + 'processing'
    BANDIT_LOG_NAME = 'bandit.log'
    BINARY_RESOURCE_ATTRS_TO_SEEK = {'resource_name': 'train-bandit-api'}
    DATA_RESOURCE_NAME = 'snapshot.zip'
    SNAPSHOT_RESOURCE_ATTRS_TO_SEEK = {'resource_name': DATA_RESOURCE_NAME}
    CONFIG_PATH_TEMPLATE = 'docker/api/config.{}.yaml'
    DATA_RESOURCE_TTL = 14
    GEODATA_PATH = '/tmp/geodata6.bin'
    LAST_STABLE_GEODATA_URL = 'https://proxy.sandbox.yandex-team.ru/last/GEODATA6BIN_STABLE'
    OUTPUT_PATH = '/tmp/snapshot.zip'

    class Requirements(sdk2.Task.Requirements):
        ram = 3 * 1024

    class Parameters(sdk2.Task.Parameters):
        kill_timeout = 600
        environment = sdk2.parameters.String(
            'Environment',
            default='testing',
        )
        qloud_token_owner = sdk2.parameters.String(
            'Qloud secret vault owner',
            required=True,
        )
        qloud_token_name = sdk2.parameters.String(
            'Qloud secret vault name',
            required=True,
        )
        qloud_environments = sdk2.parameters.List(
            'Qloud environments',
            default=('rasp.train-bandit-api.testing',),
        )
        travel_vault_oauth_name = sdk2.parameters.String(
            'Travel vault oauth name',
            default='travel_vault_oauth_token',
        )
        update_qloud_resources = sdk2.parameters.Bool(
            'Run RaspQloudUpdateResources task with result snapshot',
            default=True
        )
        drop_bandit_state = sdk2.parameters.Bool(
            'WARNING!!! Ignore previous bandit snapshot resource if true',
            default=False,
        )
        initial_snapshot = sdk2.parameters.Resource(
            'Initial bandit snapshot resource. Latest if empty',
            default=None
        )
        bandit_processing = sdk2.parameters.Resource(
            'Bandit tarball resource. Latest if empty',
            default=None
        )
        teach_from = sdk2.parameters.Integer(
            'Override teach from. Unix-timestamp',
            default=0,
        )
        teach_to = sdk2.parameters.Integer(
            'Override teach to. Unix-timestamp',
            default=0,
        )

        _email_notification_params = use_email_notification_params()

    def on_execute(self):
        self._prepare_environments(self.SNAPSHOT_RESOURCE_ATTRS_TO_SEEK)

        with self.memoize_stage.update_resources:
            process_env = os.environ.copy()
            process_env['TRAVEL_VAULT_OAUTH_TOKEN'] = sdk2.Vault.data(
                self.Parameters.travel_vault_oauth_name,
            )
            process_env_open = {}
            process_env_open['YENV_TYPE'] = self.Parameters.environment
            process_env_open['CONFIG_PATH'] = self.CONFIG_PATH_TEMPLATE.format(self.Parameters.environment)
            process_env_open['OUTPUT'] = self.OUTPUT_PATH
            process_env_open['GEOBASE_DATAFILE'] = self.GEODATA_PATH
            process_env_open['LOG_BANDIT_LOG_PATH'] = self.BANDIT_LOG_NAME
            process_env_open['MANAGER_RESOURCE_UPLOAD_ON_START'] = 'false' if self.Parameters.drop_bandit_state else 'true'
            if self.Parameters.teach_from:
                process_env_open['TEACH_FROM'] = repr(self.Parameters.teach_from)
            if self.Parameters.teach_to:
                process_env_open['TEACH_TO'] = repr(self.Parameters.teach_to)

            bandit_resource = (self.Parameters.bandit_processing or
                               self._find_last_resource(RaspBanditResource, self.BINARY_RESOURCE_ATTRS_TO_SEEK))
            self._extract_all_from_resource(resource=bandit_resource)
            if not self.Parameters.drop_bandit_state:
                if self.Parameters.initial_snapshot:
                    process_env_open['VALIDATE_BANDIT_VERSION'] = 'false'
                    snapshot = self.Parameters.initial_snapshot
                else:
                    snapshot = self._find_last_resource(RaspBanditResource, self.SNAPSHOT_RESOURCE_ATTRS_TO_SEEK)
                snapshot_data = sdk2.ResourceData(snapshot)
                snapshot_path = snapshot_data.path.absolute().as_posix()
                process_env_open['MANAGER_RESOURCE_FILENAME'] = snapshot_path

            self._download_data(self.LAST_STABLE_GEODATA_URL, self.GEODATA_PATH)

            log.info(process_env_open)
            process_env.update(process_env_open)

            if os.path.exists(self.OUTPUT_PATH):
                os.remove(self.OUTPUT_PATH)

            self._run_binary('./{}'.format(self.BINARY_PATH), process_env)

            log.info('Written {} bytes to {}'.format(
                os.stat(self.OUTPUT_PATH).st_size, self.OUTPUT_PATH,
            ))

            resources = [
                self._create_bandit_resource(
                    self.OUTPUT_PATH, self.DATA_RESOURCE_NAME, self.DATA_RESOURCE_TTL,
                ),
                self._find_last_resource(GEODATA6BIN_STABLE),
            ]

            if self.Parameters.update_qloud_resources:
                update_tasks = self._enqueue_tasks(self._update_qloud_resources(resources))
                self.Context.update_task_ids = [t.id for t in update_tasks]
                log.info('Waiting')
                raise sdk2.WaitTask(
                    update_tasks,
                    common_task_types.Status.Group.FINISH | common_task_types.Status.Group.BREAK,
                )
            else:
                self.Context.update_task_ids = []
                log.info('Skip update qloud resources')

        with self.memoize_stage.check_result:
            update_task_ids = self.Context.update_task_ids
            for task_id in update_task_ids:
                task = self.find(id=task_id).first()
                if task.status not in common_task_types.Status.Group.SUCCEED:
                    raise TaskFailure(
                        'Error at resources update. '
                        'RaspQloudUpdateResources status: {}, task ID: {}'.format(
                            task.status, task.id,
                        )
                    )
            logging.info('Done')

    def on_save(self):
        super(RaspBanditProcessing, self).on_save()
        self.add_email_notifications(notifications_group=TRAIN_GROUP)

    def _prepare_environments(self, *envs):
        for env in envs:
            env.update({
                'environment': self.Parameters.environment,
            })

    @classmethod
    def _find_last_resource(cls, resource_class, attrs=None):
        count = resource_class.find(attrs=attrs).count
        resource = resource_class.find(attrs=attrs).order(resource_class.id).offset(count - 1).first()
        log.info('Found {} by attrs {}: id {}'.format(
            resource_class.__name__,
            attrs,
            resource.id
        ))
        return resource

    @classmethod
    def _extract_all_from_resource(cls, resource):
        data = sdk2.ResourceData(resource)
        data_path = data.path.absolute().as_posix()
        tar_file = tarfile.open(data_path)
        tar_file.extractall()
        log.info('Extracted all from {}'.format(data_path))

    @classmethod
    def _download_data(cls, url, destination_file):
        dst_dir = os.path.dirname(destination_file)
        if not os.path.exists(dst_dir):
            os.makedirs(dst_dir)
            log.info('Created directory {}'.format(dst_dir))

        response = requests.get(url)
        with open(destination_file, 'wb') as fd:
            for chunk in response.iter_content(chunk_size=128):
                fd.write(chunk)
        log.info('Downloaded from {} to {}'.format(url, destination_file))

    @classmethod
    def _run_binary(cls, binary_path, process_env):
        log.info('Running {path}'.format(**{
            'path': binary_path,
        }))
        proc = subprocess.Popen([
            binary_path,
        ], stdout=subprocess.PIPE, stderr=subprocess.STDOUT, env=process_env)
        stdout, stderr = proc.communicate()

        if stdout:
            log.debug(stdout)
        if stderr:
            log.error(stderr)
        if proc.returncode != 0:
            raise Exception('Error during _run_binary: {}'.format(binary_path))

    def _create_bandit_resource(self, filename, resource_name, ttl):
        log.info('Create bandit resource {} from {}'.format(resource_name, filename))
        resource = RaspBanditResource(
            self,
            description=resource_name,
            path=resource_name,
            ttl=ttl,
        )
        resource.resource_name = resource_name
        resource.environment = self.Parameters.environment
        resource_data = sdk2.ResourceData(resource)
        shutil.copy(
            src=filename,
            dst=resource_data.path.absolute().as_posix(),
        )
        resource_data.ready()

        log.info('Uploaded resource {}'.format(resource.id))
        return resource

    def _update_qloud_resources(self, resources):
        tasks = [
            RaspQloudUpdateResources(
                self,
                owner=self.Parameters.owner,
                priority=self.Parameters.priority,
                description='Update resource',

                token_owner=self.Parameters.qloud_token_owner,
                token_name=self.Parameters.qloud_token_name,

                resources=resources,
                qloud_environment=qloud_environment,

                enable_email_notifications=False
            )
            for qloud_environment in self.Parameters.qloud_environments
        ]
        return tasks

    @classmethod
    def _enqueue_tasks(cls, tasks):
        for task in tasks:
            task.enqueue()
            log.info('Enqueued task {}'.format(task.id))
        return tasks
