# -*- coding: utf-8 -*-

import sandbox.projects.dj.BuildRecommenderShard as build_shard_task
import sandbox.common.types.task as ctt
from sandbox.sandboxsdk.errors import SandboxTaskFailureError

from sandbox import sdk2
import string

'''
    Build and publish recommender index shards.
'''


class PublishRecommenderShards(sdk2.Task):
    class Parameters(sdk2.Task.Parameters):
        yt_proxy = sdk2.parameters.String('YT proxy', required=True, default='banach')
        yt_token = sdk2.parameters.String('YT token vault', required=True)
        shard_count = sdk2.parameters.Integer('Shard count', required=True)
        service_name = sdk2.parameters.String('Service name', required=True)
        index_state_dir = sdk2.parameters.String('Index state dir', required=True)
        shard_resource = sdk2.parameters.String('Shard resource type', required=True)
        shardmap_resource = sdk2.parameters.String('Shardmap resource type', required=True)

    @staticmethod
    def get_index_state_name(state_path):
        return string.rsplit(state_path, '/')[-1]

    def on_enqueue(self):
        self.Context.state_name = self.get_index_state_name(self.Parameters.index_state_dir)

    def on_execute(self):
        shard_name_template = '%s-{shardno_fmt}-{state}' % self.Parameters.service_name
        shardno_fmt = '{:03d}'
        with self.memoize_stage.create_subtasks:
            self.Context.subtasks = []
            self.Context.shards = []
            for shard_number in range(0, self.Parameters.shard_count):
                shard_name = shard_name_template.format(
                    shardno_fmt=shardno_fmt.format(shard_number),
                    state=self.Context.state_name
                )
                self.Context.shards.append((shard_number, shard_name))

                child_task = build_shard_task.BuildRecommenderShard(
                    self,
                    description=shard_name,
                    notifications=self.Parameters.notifications
                )
                child_task.Parameters.yt_token = self.Parameters.yt_token
                child_task.Parameters.yt_proxy = self.Parameters.yt_proxy
                child_task.Parameters.shard_name = shard_name
                child_task.Parameters.shard_number = shard_number
                child_task.Parameters.index_state_dir = self.Parameters.index_state_dir
                child_task.Parameters.shard_resource = self.Parameters.shard_resource
                child_task.save()
                child_task.enqueue()

                self.Context.subtasks.append(child_task.id)
            raise sdk2.WaitTask(self.Context.subtasks,
                    ctt.Status.Group.FINISH | ctt.Status.Group.BREAK, wait_all=True)

        with self.memoize_stage.check_status:
            # check subtasks status
            for task_id in self.Context.subtasks:
                task = sdk2.Task[task_id]
                if task.status not in ctt.Status.Group.SUCCEED:
                    raise SandboxTaskFailureError('Subtask %s failed' % task_id)

            # prepare shardmap
            filename = 'shardmap_%s_iss-%s.map' % (self.Parameters.service_name, self.Context.state_name)
            resource = sdk2.Resource[self.Parameters.shardmap_resource](self, filename, filename)
            data = sdk2.ResourceData(resource)

            with open(str(data.path), 'w') as shardmap:
                for (shard_id, shard_name) in self.Context.shards:
                    mask = shard_name_template.format(shardno_fmt=shardno_fmt.format(shard_id),
                                                      state='00-00-00_00:00')
                    shardmap.write(mask + '\t' + shard_name + '\t' + self.Parameters.service_name + 'Tier0\n')

            data.ready()
