import logging
import queue
import threading
import time
from typing import List

from nirvana_api import NirvanaApi
from yweb.video.faas.graphs.ott.common import Priority
from yweb.video.faas.outputs.error import ExceptionInfo
from yweb.video.faas.outputs.result import OutputLabel

from ott.drm.library.python.packager_task.clients import PackagerTasksApiClient, PackagerTasksApiException
from ott.drm.library.python.packager_task.models import (
    TaskStatus,
    TaskExecutionState,
    PackagerTask,
    PackagerOutput,
    PackagerOutputStatus,
    PackagerOutputBlockMeta,
    GraphCreatingError,
    TaskErrorType
)
from sandbox.projects.ott.packager_management_system.lib.graph_creator.ott_packager import (
    OttPackager,
    OttPackagerGraphCreationResult,
    OttPackagerRepository,
    OttPackagerService,
    OttPackagerStage
)
from sandbox.projects.ott.packager_management_system.lib.graph_creator.quota_manager import QuotaManager
from yweb.video.faas.proto.common.trace_pb2 import TError

_SERVICE = 'GRAPH_CREATOR'


class GraphCreator:
    def __init__(self, tasks_client: PackagerTasksApiClient, ott_packager_repository: OttPackagerRepository,
                 max_graph_creating_workers, max_not_launched_graphs, nirvana_oauth_token, nirvana_quota,
                 vod_providers: List[str], s3_creds_nirvana_secret_name: str, s3_creds: str,
                 sandbox_task_id=None, retrieve_tasks_interval_secs=15, full_quota_check_interval_secs=15,
                 graph_creation_retries=5, **_):
        self.tasks_client = tasks_client
        self.nirvana = NirvanaApi(nirvana_oauth_token)
        self.sandbox_task_id = sandbox_task_id
        self.max_graph_creating_workers = max_graph_creating_workers
        self.quota_manager = QuotaManager(self.tasks_client, max_not_launched_graphs, nirvana_quota, vod_providers)
        self.workers_queue = queue.Queue(1)
        self.tasks_queue = queue.Queue(max_graph_creating_workers)
        self.retrieve_tasks_interval_secs = retrieve_tasks_interval_secs
        self.full_quota_check_interval_secs = full_quota_check_interval_secs
        self.nirvana_quota = nirvana_quota
        self.vod_providers = vod_providers
        self.graph_creation_retries = graph_creation_retries
        self.ott_packager_service = OttPackagerService(ott_packager_repository)
        self.s3_creds_nirvana_secret_name = s3_creds_nirvana_secret_name
        self.s3_creds = s3_creds

    def run(self):
        self._init_tasks_queue()

        threading.Thread(target=self.tasks_retrieving_worker, daemon=True).start()
        for _ in range(self.max_graph_creating_workers):
            threading.Thread(target=self.task_handling_worker, daemon=True).start()

        while True:
            try:
                task = self._peek_task(timeout=self.retrieve_tasks_interval_secs)
            except queue.Empty:
                if self.workers_queue.unfinished_tasks == 0:
                    logging.info('No tasks left. Finishing...')
                    break
                logging.info('Empty tasks_queue. Retrying...')
                continue

            if self.quota_manager.can_create_graph(task, self.workers_queue.unfinished_tasks):
                removed_task = self.tasks_queue.get_nowait()
                self.workers_queue.put(removed_task, block=True)
            elif self.workers_queue.unfinished_tasks == 0:
                logging.info('Quota limit is reached. No tasks in progress. Finishing...')
                break
            else:
                logging.info(f'No quota for task {task}. Sleeping for {self.full_quota_check_interval_secs} seconds...')
                time.sleep(self.full_quota_check_interval_secs)

        # to ensure all tasks are done
        self.workers_queue.join()

    def _init_tasks_queue(self):
        self._handle_suspended_tasks()

        new_tasks = self.tasks_client.get_tasks(
            TaskStatus.NEW,
            nirvana_quota=self.nirvana_quota,
            vod_providers=self.vod_providers,
            limit=self.tasks_queue.maxsize
        )
        logging.info(f'Retrieve {len(new_tasks)} new tasks')
        logging.debug(f'Tasks: {new_tasks}')

        for task in new_tasks:
            task.status = TaskStatus.ENQUEUED_FOR_CREATE_GRAPH
            self.tasks_client.update(task, TaskExecutionState(self.sandbox_task_id, _SERVICE))
            self.tasks_queue.put(task, block=True)

    def tasks_retrieving_worker(self):
        while True:
            try:
                # trade-off between tasks order relevance and number of requests to backend
                tasks_limit = max(int(self.max_graph_creating_workers / 2), 1)

                new_tasks = self.tasks_client.get_tasks(
                    TaskStatus.NEW,
                    nirvana_quota=self.nirvana_quota,
                    vod_providers=self.vod_providers,
                    limit=tasks_limit
                )
                logging.info(f'Retrieve {len(new_tasks)} new tasks')
                logging.debug(f'Tasks: {new_tasks}')
            except PackagerTasksApiException:
                logging.exception(f'Exception while retrieving tasks. '
                                  f'Sleeping for {self.retrieve_tasks_interval_secs} seconds...')
                time.sleep(self.retrieve_tasks_interval_secs)
                continue

            for task in new_tasks:
                try:
                    task.status = TaskStatus.ENQUEUED_FOR_CREATE_GRAPH
                    self.tasks_client.update(task, TaskExecutionState(self.sandbox_task_id, _SERVICE))
                    self.tasks_queue.put(task, block=True)
                except PackagerTasksApiException:
                    logging.exception(f'Exception while update task: {task}')

            if len(new_tasks) == 0:
                logging.info(f'No NEW tasks. Sleeping for {self.retrieve_tasks_interval_secs} seconds...')
                time.sleep(self.retrieve_tasks_interval_secs)

    def task_handling_worker(self):
        while True:
            task = self.workers_queue.get(block=True)

            logging.info(f'{task.task_id} - start handling enqueued task')
            try:
                self._handle_enqueued_task(task)
                logging.info(f'{task.task_id} - enqueued task handling succeed')
            except Exception:
                logging.exception(f'{task.task_id} - enqueued task handling failed')

            self.workers_queue.task_done()

    def _peek_task(self, timeout):
        if timeout < 0:
            raise ValueError("'timeout' must be a non-negative number")
        with self.tasks_queue.not_empty:
            endtime = time.monotonic() + timeout
            while not len(self.tasks_queue.queue):
                remaining = endtime - time.monotonic()
                if remaining <= 0.0:
                    raise queue.Empty
                self.tasks_queue.not_empty.wait(remaining)
            return self.tasks_queue.queue[0]

    def _handle_suspended_tasks(self):
        graph_creating_tasks = self.tasks_client.get_tasks(
            TaskStatus.GRAPH_CREATING,
            nirvana_quota=self.nirvana_quota,
            vod_providers=self.vod_providers
        )
        logging.debug(f'Suspended GRAPH_CREATING tasks: {graph_creating_tasks}')

        enqueued_for_create_graph_tasks = self.tasks_client.get_tasks(
            TaskStatus.ENQUEUED_FOR_CREATE_GRAPH,
            nirvana_quota=self.nirvana_quota,
            vod_providers=self.vod_providers
        )
        logging.debug(f'Suspended ENQUEUED_FOR_CREATE_GRAPH tasks: {enqueued_for_create_graph_tasks}')

        for task in graph_creating_tasks + enqueued_for_create_graph_tasks:
            task.status = TaskStatus.NEW
            self.tasks_client.update(task, TaskExecutionState(self.sandbox_task_id, _SERVICE))

    def _handle_enqueued_task(self, task: PackagerTask):
        task.status = TaskStatus.GRAPH_CREATING
        self.tasks_client.update(task, TaskExecutionState(self.sandbox_task_id, _SERVICE))

        graph_creation_result = self._create_graph(task)

        task_error = None
        if graph_creation_result.code == 0:
            if task.priority == Priority.MAX:
                # TODO: replace with valhalla option after https://st.yandex-team.ru/VALHALLA-199
                self.nirvana.edit_workflow(
                    workflow_id=graph_creation_result.graph_meta['nirvana_workflow_id'],
                    workflow_instance_id=graph_creation_result.graph_meta['nirvana_workflow_instance_id'],
                    execution_params={'workflowPriority': 'high'}
                )

            task.status = TaskStatus.GRAPH_CREATED
            task.graph_meta = graph_creation_result.graph_meta

            packager_outputs = self.build_packager_outputs(task, graph_creation_result.graph_meta)
            logging.info(f'Packager outputs: {packager_outputs}')
            self.tasks_client.create_packager_outputs(task.task_id, packager_outputs)
        else:
            task.status = TaskStatus.CREATE_GRAPH_FAILED
            error = graph_creation_result.error
            ex_info = ExceptionInfo(error.Exception, error.Message, '', error.Code)
            task_error = GraphCreatingError(TaskErrorType.GRAPH_CREATING, self.sandbox_task_id, ex_info)

        self.tasks_client.update(task, TaskExecutionState(self.sandbox_task_id, _SERVICE), task_error)

        logging.info(f'{task.task_id} - graph creation result: {graph_creation_result}')

    def build_packager_outputs(self, task, graph_meta):
        packager_outputs = []

        blocks = self.nirvana.get_block_meta_data(workflow_instance_id=graph_meta['nirvana_workflow_instance_id'])

        for result_meta in graph_meta['result_metas']:
            block_name = result_meta['block_name']
            block_meta = PackagerOutputBlockMeta(
                block_name,
                self._get_block_guid(blocks, block_name),
                result_meta['output_name']
            )
            packager_output = PackagerOutput(
                task_id=task.task_id,
                label=OutputLabel(result_meta['label']),
                content_group_uuid=str(task.input_params['ott_content_uuid']),
                content_version_id=int(task.input_params['content_version_id']),
                block_meta=block_meta,
                status=PackagerOutputStatus.NOT_READY,
                activate_content_version=result_meta['activate_content_version']
            )
            packager_outputs.append(packager_output)
        return packager_outputs

    def _create_graph(self, task: PackagerTask) -> OttPackagerGraphCreationResult:
        if arc_revision := task.input_params.get('arc_revision'):
            packager = self.ott_packager_service.find_by_arc_revision(arc_revision)
            if not packager:
                return self._build_ott_packager_not_found_result(arc_revision)
        else:
            stage = OttPackagerStage.TESTING if task.vod_provider == 'ott-packager-testing' else OttPackagerStage.STABLE
            packager = self.ott_packager_service.find_by_stage(stage)

        logging.info(f'{task.task_id} - graph will be created by {packager.attrs.release_status.value} ott_packager '
                     f'({packager.attrs.arc_revision} revision)')

        return packager.create_graph(
            str(task.task_id),
            task.nirvana_quota,
            self.nirvana.oauth_token,
            'ott_packager_testing' if task.vod_provider == 'ott-packager-testing' else 'ott_packager',
            self.graph_creation_retries,
            self.s3_creds_nirvana_secret_name,
            self.s3_creds,
            task.input_params
        )

    @staticmethod
    def _build_ott_packager_not_found_result(arc_revision) -> OttPackagerGraphCreationResult:
        error = TError()
        error.Code = TError.EEC_BAD_INPUT
        error.Message = f'Not found ott_packager build by arc_revision={arc_revision}'
        error.Exception = 'GraphCreatorError'
        return OttPackagerGraphCreationResult(1, error, {})

    @staticmethod
    def _get_block_guid(blocks: list, block_name: str) -> str:
        block_guid = None
        for block in blocks:
            if block['blockName'] == block_name:
                if block_guid is not None:
                    raise RuntimeError(f'Ambiguous result block name={block_name}')

                block_guid = block['blockGuid']

        if block_guid is None:
            raise RuntimeError(f'No block with name: {block_name}')

        return block_guid
