from jafar.pipelines.blocks import SingleContextBlock


class ConditionalBlock(SingleContextBlock):
    """
    Switches between multiple alternative block
    based on a specified key.
    """

    def __init__(self, key, choices, default=None):
        super(ConditionalBlock, self).__init__()
        self.key = key
        self.choices = choices
        self.default = default

    def apply(self, context, train):
        value = context.pipeline.environment.get(self.key, self.default)
        block = self.choices[value]
        return block.apply(context, train)
