# coding: utf8
from __future__ import unicode_literals, absolute_import, division, print_function

import time
from datetime import timedelta  # noqa
from logging import Logger, getLogger  # noqa
from typing import Any, AnyStr, Dict, List  # noqa

from django.conf import settings

from travel.rasp.library.python.api_clients.sandbox import SandboxClient

from common.data_api.sandbox.errors import (
    SandboxFailTaskException, SandboxNotFinishedTaskException, SandboxNotFoundResourceException,
    SandboxNotReadyResourceException
)
from common.data_api.sandbox.sandbox_public_api import sandbox_public_api
from travel.rasp.library.python.common23.date import environment
from common.utils.metrics import task_progress_report


class SandboxTaskRunner(object):
    def __init__(self, api, environment, logger):
        # type: (SandboxClient, Any, Logger) -> None

        self._api = api
        self._environment = environment
        self._logger = logger

    def run(self, task_type, resource_type, description, custom_fields, max_wait_time, check_delay_time):
        # type: (AnyStr, AnyStr, AnyStr, List[Dict], timedelta, timedelta) -> int

        with task_progress_report('sandbox_run_{}'.format((task_type or resource_type).lower())):
            task_id = self._api.create_task_draft(
                task_type,
                owner=settings.SANDBOX_OWNER,
                description=description,
                custom_fields=custom_fields
            )
            self._api.start_task(task_id)
            self._wait_for_task(task_id, max_wait_time, check_delay_time)

            return self._get_task_resource_id(task_id, resource_type) if resource_type else None

    def async_run(self, task_type, description, custom_fields):
        # type: (AnyStr, AnyStr, List[Dict]) -> None

        with task_progress_report('sandbox_async_run_{}'.format(task_type.lower())):
            task_id = self._api.create_task_draft(
                task_type,
                owner=settings.SANDBOX_OWNER,
                description=description,
                custom_fields=custom_fields
            )
            self._api.start_task(task_id)

    def _get_task_resource_id(self, task_id, resource_type):
        # type: (int, AnyStr) -> int
        for resource in self._api.get_task_resources(task_id):
            if resource.get('type') == resource_type:
                if resource.get('state') != 'READY':
                    raise SandboxNotReadyResourceException(task_id, resource_type)
                resource_id = resource['id']
                self._logger.info('Received Task #%d resource [%d].', task_id, resource_id)
                return resource_id

        raise SandboxNotFoundResourceException(task_id, resource_type)

    def _wait_for_task(self, task_id, max_wait_time, task_check_delay):
        # type: (int, timedelta, timedelta) -> None
        self._logger.info('Waiting Task #%d.', task_id)
        start_time = self._environment.now()
        while self._environment.now() - start_time < max_wait_time:
            status = self._api.get_task_status(task_id)
            if status == 'SUCCESS':
                break
            if status == 'FAILURE':
                raise SandboxFailTaskException(task_id)
            self._logger.info('Sleeping %d seconds to wait task #%d status.', task_check_delay.seconds, task_id)
            time.sleep(task_check_delay.seconds)

        if self._api.get_task_status(task_id) != 'SUCCESS':
            raise SandboxNotFinishedTaskException(task_id)

        self._logger.info('Task #%d has finished.', task_id)


sandbox_task_runner = SandboxTaskRunner(
    api=sandbox_public_api,
    environment=environment,
    logger=getLogger(__name__)
)
