# coding: utf-8
from typing import Type

from .errors import MissingArgumentError, NotSupportedError


class BaseTask(object):
    required_args = []

    def __init__(self, app, transfer_id, uid, task_args, tvm=None):
        self.app = app
        self.transfer_id = transfer_id
        self.uid = uid
        self.task_args = task_args or {}
        self.tvm = tvm
        self.check_args()

    @property
    def config(self):
        return self.app.args

    def get_arg(self, name, default=None):
        return self.task_args.get(name, default)

    def __repr__(self):
        return \
            '{0.__class__.__name__}' \
            '(transfer_id={0.transfer_id},' \
            'uid={0.uid})'.format(self)

    def run(self):
        raise NotImplementedError()

    def check_args(self):
        missing_args = set(self.required_args) - set(self.task_args.keys())
        if missing_args:
            raise MissingArgumentError(
                "Task {0} required: {1} args, found: {2}, missing: {3}".format(
                    self, self.required_args,
                    self.task_args.keys(), missing_args
                )
            )

    @property
    def loaded_shard_id(self):
        pass


def task_handlers():
    return dict(
        (o.name, o) for o in BaseTask.__subclasses__()
        if hasattr(o, 'name')
    )


def get_handler(task) -> Type[BaseTask]:
    handlers = task_handlers()
    if task not in handlers:
        raise NotSupportedError('Task {0} not supported'.format(task))
    return handlers[task]


def get_tvm_ticket(tvm_tool, tvm_id):
    return tvm_id and tvm_tool.get(tvm_id)
