from sandbox import sdk2
from sandbox.projects.common.nanny import nanny
from sandbox.sandboxsdk.paths import get_logs_folder

import sandbox.projects.goods.resources as goods_resources
from sandbox.projects.resource_types import OTHER_RESOURCE

import logging
import os
import subprocess
from time import sleep


def get_token(token_path):
    token_info = token_path.split(':')
    assert(len(token_info) == 2)
    return sdk2.Vault.data(token_info[0], token_info[1])


class BuildGoodsIndexShard(sdk2.Task):

    class Parameters(sdk2.task.Parameters):
        push_tasks_resource = True

        binaries_bundle = sdk2.parameters.Resource('Binaries bundle', resource_type=goods_resources.GoodsIndexerBinaries)
        configs_bundle = sdk2.parameters.Resource('Configs bundle', resource_type=goods_resources.GoodsIndexerConfigs)
        working_dir = sdk2.parameters.Resource('Preprocessed working directory', resource_type=OTHER_RESOURCE)
        model_shard = sdk2.parameters.Resource('Model shard', resource_type=OTHER_RESOURCE)
        worker_id = sdk2.parameters.Integer('Worker id')

        yt_proxy = sdk2.parameters.String('Yt cluster', default_value='arnold')
        yt_token = sdk2.parameters.String('Yt token vault: <user>:<vault name>', default_value='aleksko:yt-token')

        debug_mode = sdk2.parameters.Bool('Sleep after indexer process completed or failed', default_value=True)

    class Requirements(sdk2.Task.Requirements):
        cores = 4
        ram = 131072  # 128Gb
        disk_space = 500 * 1024  # 500Gb

        class Caches(sdk2.Requirements.Caches):
            pass

    def _create_shard_resource(self, dist_arch_path, shard_id, model_shard_path):
        dst = './shard_{}'.format(shard_id)
        os.makedirs(dst)
        offer_dst = os.path.join(dst, 'offer')
        model_dst = os.path.join(dst, 'model')
        os.makedirs(offer_dst)
        os.makedirs(model_dst)

        FILES_TO_REMOVE = ["content-offer.tsv", "delivery_shop_ids.mmap"]
        data_parts = [
            (os.path.join(dist_arch_path, 'search-part-base-{}'.format(shard_id), 'search-part-base-{}.tar.gz'.format(shard_id)), offer_dst, ),
            (os.path.join(dist_arch_path, 'search-part-additions-{}'.format(shard_id), 'search-part-additions-{}.tar.gz'.format(shard_id)), offer_dst, ),
            (model_shard_path, model_dst, )
        ]
        for part in data_parts:
            cmd = [
                'tar',
                '-xvzf',
                part[0],
                '-C', part[1]
            ]
            with sdk2.helpers.ProcessLog(self, logger='untar') as pl:
                try:
                    subprocess.check_call(cmd, stdout=pl.stdout, stderr=pl.stderr, env=self.proc_env)
                except subprocess.CalledProcessError:
                    logging.exception('%s command failed' % cmd)
                    raise

        for fname in FILES_TO_REMOVE:
            try:
                os.remove(os.path.join(offer_dst, fname))
                logging.info('file {} removed from shard'.format(fname))
            except OSError:
                pass

        shard_resource = sdk2.Resource[goods_resources.GoodsBasesearchShard]
        current_shard_resource = shard_resource(self,
                                                'Shard for goods base search',
                                                dst)
        current_shard_resource.backup_task = True
        current_shard_resource.shard_id = shard_id
        shard_data = sdk2.ResourceData(current_shard_resource)
        shard_data.ready()

    def on_execute(self):
        self.proc_env = os.environ.copy()
        self.proc_env['YT_PROXY'] = self.Parameters.yt_proxy
        self.yt_token = get_token(self.Parameters.yt_token)
        self.proc_env['YT_TOKEN'] = self.yt_token
        self.proc_env['YQL_TOKEN'] = ''

        binaries_path = str(sdk2.ResourceData(self.Parameters.binaries_bundle).path)
        logging.info('Binaries path: {}'.format(binaries_path))

        configs_path = str(sdk2.ResourceData(self.Parameters.configs_bundle).path)
        logging.info('Configs path: {}'.format(configs_path))

        working_dir_path = str(sdk2.ResourceData(self.Parameters.working_dir).path)

        model_shard_path = str(sdk2.ResourceData(self.Parameters.model_shard).path)

        shard_builder_cmd = [
            os.path.join(binaries_path, 'shard_builder'),
            '--binaries_path', binaries_path,
            '--configs_path', configs_path,
            '--data_path', working_dir_path,
            '--logs_path', get_logs_folder(),
            '--cluster', self.Parameters.yt_proxy,
            '--worker_id', str(self.Parameters.worker_id)
        ]

        with sdk2.helpers.ProcessLog(self, logger='mifd_log') as pl:
            try:
                subprocess.check_call(shard_builder_cmd, stdout=pl.stdout, stderr=pl.stderr, env=self.proc_env)
            except subprocess.CalledProcessError:
                logging.exception('%s command failed' % shard_builder_cmd)
                raise

        dist_arch_path = './workdir/mif/dist'
        dist_arch_path = os.path.join(dist_arch_path, os.listdir(dist_arch_path)[0])
        shards = []
        for part in os.listdir(dist_arch_path):
            if part.startswith('search-part-base-'):
                shards.append(int(part.split('-')[-1]))

        for shard_id in shards:
            logging.info('Create resource for shard {}'.format(shard_id))
            self._create_shard_resource(dist_arch_path, shard_id, model_shard_path)

        if self.Parameters.debug_mode:
            sleep(3600)


class BuildGoodsIndex(nanny.ReleaseToNannyTask2, sdk2.Task):

    class Parameters(sdk2.task.Parameters):
        push_tasks_resource = True

        binaries_bundle = sdk2.parameters.Resource('Binaries bundle', resource_type=goods_resources.GoodsIndexerBinaries)
        configs_bundle = sdk2.parameters.Resource('Configs bundle', resource_type=goods_resources.GoodsIndexerConfigs)
        svn_data = sdk2.parameters.Resource('Indexer svn data', resource_type=goods_resources.GoodsIndexerSvnData)
        input_data = sdk2.parameters.Resource('Indexer input data', resource_type=goods_resources.GoodsIndexerInputData)
        input_mbo_data = sdk2.parameters.Resource('Indexer input mbo_stuff data', resource_type=goods_resources.GoodsIndexerInputMboData)
        bids_data = sdk2.parameters.Resource('Indexer bids data', resource_type=goods_resources.GoodsIndexerBidsData)

        yt_proxy = sdk2.parameters.String('Yt cluster', default_value='arnold')
        yt_token = sdk2.parameters.String('Yt token vault: <user>:<vault name>', default_value='aleksko:yt-token')
        yql_token = sdk2.parameters.String('Yql token vault: <user>:<vault name>', default_value='aleksko:yql-token')
        yt_home_dir = sdk2.parameters.String('Yt home dir', default_value='//tmp/TV_indexer')

        worker_count = sdk2.parameters.Integer('Build shard tasks count', default_value=16, required=True)

        debug_mode = sdk2.parameters.Bool('Sleep after indexer process completed or failed', default_value=True)

    class Requirements(sdk2.Task.Requirements):
        cores = 32
        ram = 131072  # 128Gb
        disk_space = 500 * 1024  # 500Gb

        class Caches(sdk2.Requirements.Caches):
            pass

#    def on_create(self):
#        self.Requirements.tasks_resource = service_resources.SandboxTasksBinary.find(owner="GOODS-RUNTIME",
#        attrs={"task_type": "MAKE_GOODS_SHARDMAP", "released": "stable"}).first()

    def _create_report_data_resource(self, dist_arch_path):
        dst = './report_data'
        os.makedirs(dst)

        data_parts = [
            os.path.join(dist_arch_path, 'search-report-data', 'search-report-data.tar.gz'),
            os.path.join(dist_arch_path, 'search-stats', 'search-stats.tar.gz')
        ]
        for part in data_parts:
            cmd = [
                'tar',
                '-xvzf',
                part,
                '-C', dst
            ]
            with sdk2.helpers.ProcessLog(self, logger='untar') as pl:
                try:
                    subprocess.check_call(cmd, stdout=pl.stdout, stderr=pl.stderr, env=self.proc_env)
                except subprocess.CalledProcessError:
                    logging.exception('%s command failed' % cmd)
                    raise

        data_resource = sdk2.Resource[goods_resources.GoodsReportData]
        current_data_resource = data_resource(self,
                                              'Goods metasearch data',
                                              dst)
        current_data_resource.backup_task = True
        report_data = sdk2.ResourceData(current_data_resource)
        report_data.ready()

    def on_execute(self):
        self.proc_env = os.environ.copy()
        self.proc_env['YT_PROXY'] = self.Parameters.yt_proxy
        self.yt_token = get_token(self.Parameters.yt_token)
        self.proc_env['YT_TOKEN'] = self.yt_token
        self.yql_token = get_token(self.Parameters.yql_token)
        self.proc_env['YQL_TOKEN'] = self.yql_token

        working_copy_data_path = os.path.abspath('input_data')
        os.makedirs(working_copy_data_path)

        binaries_bundle_res = self.Parameters.binaries_bundle or sdk2.Resource[goods_resources.GoodsIndexerBinaries].find().first()
        binaries_path = str(sdk2.ResourceData(binaries_bundle_res).path)
        logging.info('Binaries path: {}'.format(binaries_path))

        market_svn_data_res = self.Parameters.svn_data or sdk2.Resource[goods_resources.GoodsIndexerSvnData].find().first()
        market_svn_data_path = str(sdk2.ResourceData(market_svn_data_res).path)
        logging.info('Market svn data path: {}'.format(market_svn_data_path))

        input_data_path = os.path.join(working_copy_data_path, 'indexer_data')
        os.makedirs(input_data_path)
        input_data_arch_res = self.Parameters.input_data or sdk2.Resource[goods_resources.GoodsIndexerInputData].find().first()
        input_data_arch_path = str(sdk2.ResourceData(input_data_arch_res).path)
        for dirname in os.listdir(input_data_arch_path):
            os.symlink(os.path.join(input_data_arch_path, dirname), os.path.join(input_data_path, dirname))

        input_mbo_data_res = self.Parameters.input_mbo_data or sdk2.Resource[goods_resources.GoodsIndexerInputMboData].find().first()
        input_mbo_data_path = str(sdk2.ResourceData(input_mbo_data_res).path)
        os.symlink(os.path.join(input_mbo_data_path, 'mbo_stuff'), os.path.join(input_data_path, 'mbo_stuff'))

        logging.info('Input data path: {}'.format(input_data_path))

        configs_res = self.Parameters.configs_bundle or sdk2.Resource[goods_resources.GoodsIndexerConfigs].find().first()
        configs_path = str(sdk2.ResourceData(configs_res).path)
        logging.info('Configs path: {}'.format(configs_path))

        bids_res = self.Parameters.bids_data or sdk2.Resource[goods_resources.GoodsIndexerBidsData].find().first()
        bids_path = str(sdk2.ResourceData(bids_res).path)
        logging.info('Bids path: {}'.format(bids_path))

        runner_path = os.path.join(binaries_path, 'indexer_runner')
        logging.info('Runner path: {}'.format(runner_path))

        os.symlink(binaries_path, os.path.join(working_copy_data_path, 'binaries'))
        os.symlink(market_svn_data_path, os.path.join(working_copy_data_path, 'package-data'))
        os.symlink(bids_path, os.path.join(working_copy_data_path, 'bids_data'))
        os.symlink(configs_path, os.path.join(working_copy_data_path, 'configs'))

        cmd = [
            runner_path,
            '--data_path', str(working_copy_data_path),
            '--yt_path', os.path.join(self.Parameters.yt_home_dir, str(self.id)),
            '--logs_path', get_logs_folder(),
            '--cluster', self.Parameters.yt_proxy,
            '--workers', str(self.Parameters.worker_count),
            '--sandbox-run', '1'
        ]
        with sdk2.helpers.ProcessLog(self, logger='runner') as pl:
            try:
                subprocess.check_call(cmd, stdout=pl.stdout, stderr=pl.stderr, env=self.proc_env)
            except subprocess.CalledProcessError:
                logging.exception('%s command failed' % cmd)
                raise

        mifd_path = './workdir/mif'
        mifd_workdir_resource = sdk2.Resource[OTHER_RESOURCE]
        current_mifd_workdir_resource = mifd_workdir_resource(self, 'mifd working dir', mifd_path)
        work_data = sdk2.ResourceData(current_mifd_workdir_resource)
        work_data.ready()

        dists_path = './workdir/dists'
        dists_path = os.path.join(dists_path, os.listdir(dists_path)[0])
        model_shards = []
        MODEL_SHARD_COUNT = 8
        model_shard_resource = sdk2.Resource[OTHER_RESOURCE]
        for i in range(MODEL_SHARD_COUNT):
            current_model_shard_resource = model_shard_resource(self, 'model shard #{}'.format(i), os.path.join(dists_path, 'model-part-{}'.format(i), 'model-part-{}.tar.gz'.format(i)))
            model_shard_data = sdk2.ResourceData(current_model_shard_resource)
            model_shard_data.ready()
            model_shards.append(current_model_shard_resource)

        self.Context.make_shard_tasks = []
        for worker_id in range(self.Parameters.worker_count):
            shard_task = BuildGoodsIndexShard(self,
                                              description='Create goods shard #{shard_id} (#{parent_id} subtask)'.format(shard_id=worker_id, parent_id=self.id),
                                              owner=self.owner,
                                              binaries_bundle=binaries_bundle_res,
                                              configs_bundle=configs_res,
                                              working_dir=current_mifd_workdir_resource,
                                              worker_id=worker_id,
                                              yt_proxy=self.Parameters.yt_proxy,
                                              yt_token=self.Parameters.yt_token,
                                              debug_mode=self.Parameters.debug_mode,
                                              model_shard=model_shards[worker_id % MODEL_SHARD_COUNT],
                                              push_tasks_resource=True,
                                              create_sub_task=True).enqueue()
            self.Context.make_shard_tasks.append(shard_task.id)

        self._create_report_data_resource(dists_path)

        # raise sdk2.WaitTask(self.Context.make_shard_tasks,
        #                         ctt.Status.Group.FINISH | ctt.Status.Group.BREAK,
        #                         wait_all=True)

        if self.Parameters.debug_mode:
            sleep(3600)
