# -*- coding: utf-8 -*-
import os
from shutil import move
from sandbox import sdk2, common
from sandbox.sandboxsdk import environments
import sandbox.common.types.task as ctt
from sandbox.projects.irt.common import create_irt_data, IrtData
from sandbox.projects.irt.IrtBuild import IrtBuild


class BenderParameters(sdk2.Parameters):
    count_shards = sdk2.parameters.Integer(
        'Number of shards',
        default=4
    )
    banners_extended_table = sdk2.parameters.String(
        'Banners_extended table',
        default='//home/catalogia/banners_extended_test'
    )
    bender_tokens_template = sdk2.parameters.String(
        'Bender_tokens table template',
        default='//home/catalogia/bender/bender_tokens_{}_test'
    )
    bender_index_template = sdk2.parameters.String(
        'Bender_index table template',
        default='//home/catalogia/bender/bender_index_{}_test'
    )
    yt_proxy = sdk2.parameters.String(
        'YT proxy',
        default='hahn'
    )
    yt_pool = sdk2.parameters.String(
        'YT pool',
        default='catalogia'
    )
    yt_token_vault_name = sdk2.parameters.String(
        'YT token vault name',
        default='yql_robot_bm_admin'
    )
    delete_tables = sdk2.parameters.Bool(
        'Remove yt-tables after generation',
        default=True
    )


class BmbenderGenShardYT(sdk2.Task):
    """
    Catmedia bender data shard generator YT step
    """

    class Requirements(sdk2.Task.Requirements):
        environments = [
            environments.PipEnvironment('yandex-yt'),
            environments.PipEnvironment('yandex-yt-yson-bindings-skynet'),
        ]

    class Parameters(BenderParameters):
        shard_index = sdk2.parameters.Integer(
            'Shard index'
        )

    def on_execute(self):
        import yt.wrapper as yt
        binary_resource_data = sdk2.ResourceData(sdk2.Resource.find(IrtData, task=self.parent, attrs={'sub_type': 'bmbender_binary_yt'}).first())
        yt_prepare_bender_data_exe = str(binary_resource_data.path.joinpath('yt_prepare_bender_data'))
        yt_token_value = sdk2.Vault.data(self.Parameters.yt_token_vault_name)
        shard_index = str(self.Parameters.shard_index)
        env = os.environ.copy()
        env.update({
            'YT_PROXY': self.Parameters.yt_proxy,
            'YT_TOKEN': yt_token_value,
            'YT_POOL': self.Parameters.yt_pool
        })
        bender_index_table = self.Parameters.bender_index_template.format(shard_index)
        bender_tokens_table = self.Parameters.bender_tokens_template.format(shard_index)
        with sdk2.helpers.ProcessLog(self, logger='yt_prepare_bender_data') as pl:
            sdk2.helpers.subprocess.check_call(
                [
                    yt_prepare_bender_data_exe,
                    '--banners_extended_table={}'.format(self.Parameters.banners_extended_table),
                    '--bender_tokens={}'.format(bender_tokens_table),
                    '--bender_index={}'.format(bender_index_table),
                    '--count_shards={}'.format(str(self.Parameters.count_shards)),
                    '--shard_index={}'.format(shard_index)
                ],
                stdout=pl.stdout,
                stderr=pl.stderr,
                env=env
            )
        yt.config['token'] = yt_token_value
        yt.config['spec_defaults'] = {
            'pool': self.Parameters.yt_pool,
        }
        yt.config['proxy']['url'] = self.Parameters.yt_proxy
        yt.config['read_parallel']['enable'] = True
        yt.config['read_parallel']['max_thread_count'] = 32
        table_resource = create_irt_data(self, 'bmbender_shard_data_yt', 'Intermediate resource of bmbender, YT tables, one shard', [
            'encoded_file.tb', 'index_file.tb'
        ], auto_backup=True, ttl=3, bender_shard=shard_index)
        encoded_file = table_resource.filenames[0]
        index_file = table_resource.filenames[1]
        with open(index_file, 'wb') as f:
            for r in yt.read_table(
                bender_index_table,
                raw=True,
                format='<columns=[token;indexes]>schemaful_dsv'
            ):
                f.write(r)
        with open(encoded_file, 'wb') as f:
            for r in yt.read_table(
                bender_tokens_table,
                raw=True,
                format='<columns=[bid]>schemaful_dsv'
            ):
                f.write(r)
        if self.Parameters.delete_tables:
            yt.remove(bender_index_table)
            yt.remove(bender_tokens_table)


class BmbenderGenShard(sdk2.Task):
    """
    Catmedia bender data shard generator second step
    """

    class Parameters(sdk2.Parameters):
        shard_index = sdk2.parameters.Integer(
            'Shard index'
        )
        yt_shard_data = sdk2.parameters.Resource(
            'bmbender_shard_data_yt resorce'
        )
        resource_banners_bin_path_template = sdk2.parameters.String(
            'path to banners_bin_<shard_index> in IRT_DATA resource',
            default='rt-research/broadmatching/work/bender/banners_bin_{}'
        )

    def on_execute(self):
        binary_resource_data = sdk2.ResourceData(sdk2.Resource.find(IrtData, task=self.parent, attrs={'sub_type': 'bmbender_binary_bender'}).first())
        bender_exe = str(binary_resource_data.path.joinpath('bender'))
        shard_index = str(self.Parameters.shard_index)
        table_resource_data = sdk2.ResourceData(self.Parameters.yt_shard_data)
        encoded_file = str(table_resource_data.path.joinpath('encoded_file.tb'))
        index_file = str(table_resource_data.path.joinpath('index_file.tb'))
        bmbender_shard_data = create_irt_data(self, 'bmbender_shard_data_{}'.format(shard_index), 'Intermediate resource bmbender, one shard', [
            self.Parameters.resource_banners_bin_path_template.format(shard_index)
        ], auto_backup=True, ttl=3, bender_shard=shard_index)
        with sdk2.helpers.ProcessLog(self, logger='bender') as bl:
            sdk2.helpers.subprocess.check_call(
                [
                    bender_exe,
                    '--binary-file={}'.format(bmbender_shard_data.filenames[0]),
                    '--encoded-file={}'.format(encoded_file),
                    '--index-file={}'.format(index_file)
                ],
                stdout=bl.stdout,
                stderr=bl.stderr
            )


def str2mb(s):
    suffix = s[-2:]
    res = float(s[:-2]) if suffix == 'MB' else float(s[:-2]) * 1024 if suffix == 'GB' else float(s[:-2]) * 1024 * 1024 if suffix == 'TB' else float(s[:-2]) / 1024 if suffix == 'KB' else float(s)
    return round(res, 4)


class BmbenderGen(IrtBuild):
    """
    Catmedia bender data generator
    """

    class Parameters(BenderParameters):
        shard_YT_RAM = sdk2.parameters.String(
            'Execution RAM of shard subtask (YT step)',
            default='120 GB'
        )
        shard_YT_HDD = sdk2.parameters.String(
            'Execution space of shard subtask (YT step)',
            default='850 GB'
        )
        shard_RAM = sdk2.parameters.String(
            'Execution RAM of shard subtask (second step)',
            default='250 GB'
        )
        shard_HDD = sdk2.parameters.String(
            'Execution space of shard subtask (second step)',
            default='850 GB'
        )
        resource_sub_type = sdk2.parameters.String(
            'sub_type for IRT_DATA resource',
            default='banners_bender_data_test'
        )
        resource_banners_bin_path_template = sdk2.parameters.String(
            'path to banners_bin_<shard_index> in IRT_DATA resource',
            default='rt-research/broadmatching/work/bender/banners_bin_{}'
        )

    def on_execute(self):
        count_shards = self.Parameters.count_shards
        with self.memoize_stage.shards:
            (yt_prepare_bender_data_exe, bender_exe), build_revision = self.build(
                ['rt-research/broadmatching/scripts/bender/yt_prepare_bender_data', 'rt-research/broadmatching/scripts/cpp-source/bender/bin'],
                'build'
            )
            binary_yt_resource = create_irt_data(self, 'bmbender_binary_yt', 'Intermediate resource of bmbender-gen, yt binary', ['yt_prepare_bender_data'], auto_backup=True, ttl=2)
            binary_bender_resource = create_irt_data(self, 'bmbender_binary_bender', 'Intermediate resource of bmbender-gen, bender binary', ['bender'], auto_backup=True, ttl=3)
            move('build/bin/yt_prepare_bender_data', binary_yt_resource.filenames[0])
            move('build/bin/bender', binary_bender_resource.filenames[0])
            binary_yt_resource.ready()
            binary_bender_resource.ready()
            children = []
            for shard_index in range(1, count_shards+1):
                child = BmbenderGenShardYT(
                    self,
                    owner=self.owner,
                    description='Catmedia bender data generator {} shard YT step subtask'.format(str(shard_index)),
                    count_shards=count_shards,
                    shard_index=shard_index,
                    banners_extended_table=self.Parameters.banners_extended_table,
                    bender_tokens_template=self.Parameters.bender_tokens_template,
                    bender_index_template=self.Parameters.bender_index_template,
                    yt_proxy=self.Parameters.yt_proxy,
                    yt_pool=self.Parameters.yt_pool,
                    yt_token_vault_name=self.Parameters.yt_token_vault_name,
                    delete_tables=self.Parameters.delete_tables,
                )
                child.Requirements.disk_space = str2mb(self.Parameters.shard_YT_HDD)
                child.Requirements.ram = str2mb(self.Parameters.shard_YT_RAM)
                child.Parameters.kill_timeout = self.Parameters.kill_timeout
                child.save().enqueue()
                children.append(child.id)
            raise sdk2.WaitTask(children, ctt.Status.Group.FINISH + ctt.Status.Group.BREAK, wait_all=False)

        yt_data_used = {task.Parameters.shard_index for task in self.find(BmbenderGenShard)}
        if len(yt_data_used) == count_shards:
            shard_resources = []
            for task in self.find(BmbenderGenShard, status=ctt.Status.SUCCESS):
                for resource in sdk2.Resource.find(IrtData, task=task).limit(1):
                    shard_resources.append(resource)
            if len(shard_resources) != count_shards:
                raise common.errors.TaskFailure('Only {} of {} shards successfully generated'.format(len(shard_resources), count_shards))
            result_resource = create_irt_data(self, self.Parameters.resource_sub_type, 'Catmedia Bender data (all shards)', [
                self.Parameters.resource_banners_bin_path_template.format(i) for i in range(1, count_shards+1)
            ], auto_backup=True)
            for resource in shard_resources:
                resource_data = sdk2.ResourceData(resource)  # synchronizing resource data on disk
                shard_index = int(resource.bender_shard)
                sdk2.paths.copy_path(
                    str(resource_data.path.joinpath(self.Parameters.resource_banners_bin_path_template.format(shard_index))),
                    result_resource.filenames[shard_index-1]
                )
            # finish
        else:
            children = []
            for task in self.find(BmbenderGenShardYT):
                if task.Parameters.shard_index in yt_data_used:
                    continue
                if task.status == ctt.Status.SUCCESS:
                    child = BmbenderGenShard(
                        self,
                        owner=self.owner,
                        description='Catmedia bender data generator {} shard second step subtask'.format(str(task.Parameters.shard_index)),
                        count_shards=count_shards,
                        shard_index=task.Parameters.shard_index,
                        yt_shard_data=sdk2.Resource.find(IrtData, task=task).first(),
                        resource_banners_bin_path_template=self.Parameters.resource_banners_bin_path_template
                    )
                    child.Requirements.ram = str2mb(self.Parameters.shard_RAM)
                    child.Requirements.disk_space = str2mb(self.Parameters.shard_HDD)
                    child.Parameters.kill_timeout = self.Parameters.kill_timeout
                    child.save().enqueue()
                    raise sdk2.WaitTask([child.id], ctt.Status.Group.FINISH + ctt.Status.Group.BREAK)
                if task.status in ctt.Status.Group.FINISH + ctt.Status.Group.BREAK:
                    raise common.errors.TaskFailure('YT step of generation for shard {} finished unsuccessfully'.format(str(task.Parameters.shard_index)))
                children.append(task.id)
            if not children:
                raise common.errors.TaskError('Impossible, BMBENDER_GEN_SHARD_YT tasks less then shards ({})'.format(count_shards))
            raise sdk2.WaitTask(children, ctt.Status.Group.FINISH + ctt.Status.Group.BREAK, wait_all=False)
