
from sandbox.common.types import task
from sandbox.projects.crypta.common.task import CryptaBaseYqlTask
from sandbox.projects.crypta.graph.households.bundles import CryptaHouseholdsBundle


class HHBaseTask(CryptaBaseYqlTask):

    """ Base househodls task """

    BIN_FILE_NAME = "crypta-households"
    YT_POOL = "crypta_households"

    class CryptaOptions(CryptaBaseYqlTask.CryptaOptions):

        bundle_resource_type = CryptaHouseholdsBundle
        use_semaphore = True
        report_status_to_crypta_api = True

    class Parameters(CryptaBaseYqlTask.Parameters):
        tags = ["households"]


class CryptaHhDataImportWatchlog(HHBaseTask):
    TASK = "crypta.graph.households.data_import.watchlog.lib.WatchLogParser"


class CryptaHhDataImportIncDay(HHBaseTask):
    TASK = "crypta.graph.households.data_import.increment_day.lib.IncrementDay"

    def on_enqueue(self):
        self.Requirements.semaphores = task.Semaphores(
            acquires=[
                # self semaphore
                task.Semaphores.Acquire(self._get_semaphore_name(), weight=1, capacity=1),
                # import logs semaphores
                task.Semaphores.Acquire(
                    self.get_semaphore_name_for_task(CryptaHhDataImportWatchlog), weight=1, capacity=1
                ),
            ]
        )

    def get_semaphore_name_for_task(self, task_cls):
        return "{}_{}".format(task_cls.__name__, self.Parameters.environment)


class CryptaHhHhMatch(HHBaseTask):
    TASK = "crypta.graph.households.hh_match.lib.HHMatch"

    def on_enqueue(self):
        self.Requirements.semaphores = task.Semaphores(
            acquires=[
                # self semaphore
                task.Semaphores.Acquire(self.get_semaphore_name_for_task(CryptaHhHhMatch), weight=1, capacity=1),
                # import logs semaphores
                task.Semaphores.Acquire(
                    self.get_semaphore_name_for_task(CryptaHhDataImportIncDay), weight=1, capacity=1
                ),
            ]
        )

    def get_semaphore_name_for_task(self, task_cls):
        return "{}_{}".format(task_cls.__name__, self.Parameters.environment)


class CryptaHhHhMatchPrepare(CryptaHhHhMatch):
    TASK = "crypta.graph.households.hh_match.lib.PrepareHH"

    """ Run hh match last subtask """

    def get_cmd(self):
        cmd = super(CryptaHhHhMatchPrepare, self).get_cmd()
        cmd[1:1] = ["--kind", "prepare"]
        return cmd


class CryptaHhHhMatchFinish(CryptaHhHhMatch):
    TASK = "crypta.graph.households.hh_match.lib.FinishHHx"

    """ Run hh match last subtask """

    def get_cmd(self):
        cmd = super(CryptaHhHhMatchFinish, self).get_cmd()
        cmd[1:1] = ["--kind", "finish"]
        return cmd


class CryptaHhHhComposition(HHBaseTask):
    TASK = "crypta.graph.households.hh_composition.lib.CompositionHH"

    def on_enqueue(self):
        self.Requirements.semaphores = task.Semaphores(
            acquires=[
                # self semaphore
                task.Semaphores.Acquire(self._get_semaphore_name(), weight=1, capacity=1),
                # hh match semaphores
                task.Semaphores.Acquire(self.get_semaphore_name_for_task(CryptaHhHhMatch), weight=1, capacity=1),
            ]
        )

    def get_semaphore_name_for_task(self, task_cls):
        return "{}_{}".format(task_cls.__name__, self.Parameters.environment)
