from sandbox import sdk2
from sandbox.sdk2 import service_resources
from sandbox.common import errors
from sandbox.common.share import skynet_get
from sandbox.projects.common.arcadia import sdk
from sandbox.projects.common.nanny import const
from sandbox.projects.common.nanny import nanny
from sandbox.projects.common import task_env

import sandbox.common.types.task as ctt
import sandbox.projects.goods.resources as goods_resources
import sandbox.projects.market.resources.idx as market_resources

import json
import logging
import shutil
import os


OFFER_SHARD_COUNT = 16
MODEL_SHARD_COUNT = 8


class ResourceInfo:
    def __init__(self, key_name, rbtorrent_id, resource_path=None, files_to_remove=[]):
        self.key_name = key_name
        self.rbtorrent_id = rbtorrent_id
        self.resource_path = resource_path
        self.files_to_remove = files_to_remove


def create_resource(parent_task, src_info, resource_type, folder_name, shard_id=None, generation=None):
    try:
        tmp = './packed'
        os.makedirs(folder_name)
        with sdk.mount_arc_path("arcadia-arc:/#trunk") as arcadia:
            for item in src_info:
                dst = folder_name
                if item.resource_path:
                    dst = os.path.join(dst, item.resource_path)
                    os.makedirs(dst)
                try:
                    skynet_get(item.rbtorrent_id, tmp)
                except:
                    parent_task.Context.errors_on_download += 1
                    if parent_task.Context.errors_on_download < 3:
                        raise errors.TemporaryError('Error durring resource download')

                arch_path = tmp + '/{0}/{0}.tar.zstd_10'.format(item.key_name)
                sdk.run_tool(arcadia, 'uc', ['--decompress', '--from', arch_path, '--codec', 'zstd_10', '--to', dst, '-x'], timeout=900)

                for fname in item.files_to_remove:
                    try:
                        os.remove(os.path.join(dst, fname))
                        logging.info('file {} removed from shard'.format(fname))
                    except OSError:
                        pass

        shard_resource = sdk2.Resource[resource_type]
        current_shard_resource = shard_resource(parent_task,
                                                'Shard for goods base search',
                                                folder_name)
        current_shard_resource.backup_task = True
        if shard_id:
            current_shard_resource.shard_id = shard_id
        if generation:
            current_shard_resource.generation = generation
        shard_data = sdk2.ResourceData(current_shard_resource)
        shard_data.ready()
    finally:
        shutil.rmtree(tmp)


class MakeGoodsShardGencfg(sdk2.Task):
    class Context(sdk2.Context):
        errors_on_download = 0

    class Requirements(task_env.TinyRequirements):
        disk_space = 80 * 1024

    class Parameters(sdk2.task.Parameters):
        push_tasks_resource = True
        model_shard_id = sdk2.parameters.Integer('model shard id', required=True)
        first_shard_rbtorrent_id = sdk2.parameters.String('first offers shard data rbtorrent id', required=True)
        second_shard_rbtorrent_id = sdk2.parameters.String('second offers shard data rbtorrent id', required=True)
        model_rbtorrent_id = sdk2.parameters.String('models data rbtorrent id', required=True)
        experiment = sdk2.parameters.Bool("Don't use second offer shard", required=False, default=False)

    def on_execute(self):
        first_shard_id = self.Parameters.model_shard_id
        second_shard_id = self.Parameters.model_shard_id + MODEL_SHARD_COUNT
        FILES_TO_REMOVE = ["content-offer.tsv", "delivery_shop_ids.mmap"]
        if self.Parameters.experiment:
            src_info = [
                ResourceInfo('search-part-{}'.format(first_shard_id), self.Parameters.first_shard_rbtorrent_id, 'part-{}'.format(first_shard_id), FILES_TO_REMOVE),
                ResourceInfo('model-part-{}'.format(first_shard_id % MODEL_SHARD_COUNT), self.Parameters.model_rbtorrent_id, 'model/part-{}'.format(first_shard_id % MODEL_SHARD_COUNT))
            ]
            create_resource(
                self,
                src_info,
                goods_resources.GoodsBasesearchShard,
                './shard',
                '{}'.format(first_shard_id)
            )
        else:
            src_info = [
                ResourceInfo('search-part-{}'.format(first_shard_id), self.Parameters.first_shard_rbtorrent_id, 'part-{}'.format(first_shard_id), FILES_TO_REMOVE),
                ResourceInfo('search-part-{}'.format(second_shard_id), self.Parameters.second_shard_rbtorrent_id, 'part-{}'.format(second_shard_id), FILES_TO_REMOVE),
                ResourceInfo('model-part-{}'.format(first_shard_id), self.Parameters.model_rbtorrent_id, 'model/part-{}'.format(first_shard_id))
            ]
            create_resource(
                self,
                src_info,
                goods_resources.GoodsBasesearchShard,
                './shard',
                '{}-{}'.format(first_shard_id, second_shard_id)
            )


class MakeGoodsShardmapGencfg(nanny.ReleaseToNannyTask2, sdk2.Task):
    class Context(sdk2.Context):
        generation = "unknown"
        errors_on_download = 0

    class Parameters(sdk2.task.Parameters):
        push_tasks_resource = True
        market_report_dist = sdk2.parameters.Resource('Market report dist', resource_type=market_resources.MARKET_REPORT_DIST_META)
        indexer = sdk2.parameters.String(
            'Specific indexer',
            choices=[('turbo.stratocaster', 'turbo.stratocaster'), ('turbo.gibson', 'turbo.gibson'), ('any', 'any')],
            default='any'
        )
        generation = sdk2.parameters.String('Specific generation')
        use_production_check = sdk2.parameters.Bool('Use production check on dist attributes', default=True)
        use_backends_config = sdk2.parameters.Bool('Use backends_config file to get hosts', default=True)

    class Requirements(task_env.TinyRequirements):
        disk_space = 30 * 1024   # 30 Gb

    def on_create(self):
        self.Requirements.tasks_resource = service_resources.SandboxTasksBinary.find(owner="GOODS-RUNTIME",
                                                                                     attrs={"task_type": "MAKE_GOODS_SHARDMAP_GENCFG", "released": "stable"}).first()

    def _get_shard_rbtorrent(self, task):
        return sdk2.Resource[goods_resources.GoodsBasesearchShard].find(task=task).first().skynet_id

    def _get_shards(self, experiment=False):
        tasks = self.find(task_type=sdk2.Task[MakeGoodsShardGencfg.type],
                          status=ctt.Status.Group.SUCCEED,
                          input_parameters={'experiment': experiment})
        logging.info('Child tasks: %s' % [t.id for t in tasks])
        shards = {task.Parameters.model_shard_id: self._get_shard_rbtorrent(task) for task in tasks}
        tasks_cnt = OFFER_SHARD_COUNT if experiment else MODEL_SHARD_COUNT
        if len(shards) != tasks_cnt:
            raise RuntimeError("Success children task number is {} instead of {}".format(len(shards), tasks_cnt))
        return shards

    def _make_gencfg_shardmap_exp(self):
        exp_shards = self._get_shards(experiment=True).values()
        exp_shardmap = []
        for shard_id in range(len(exp_shards)):
            exp_shardmap.append('pod_label:shard_id={shard_id}    {rbtorrent}\n'.format(shard_id=shard_id,
                                                                                        rbtorrent=exp_shards[shard_id]))
        logging.info('Exp shardmap content: %s' % exp_shardmap)
        exp_shardmap_path = './goods_exp_shardmap'
        with open(exp_shardmap_path, 'w+') as f:
            f.writelines(exp_shardmap)
        logging.info('Beta shardmap written to %s file' % exp_shardmap_path)
        return exp_shardmap_path

    def _make_gencfg_shardmap(self):
        shards = self._get_shards(experiment=False).values()
        shardmap_path = None

        if self.Parameters.use_backends_config:
            report_data = sdk2.Resource[goods_resources.GoodsReportData].find(task_id=self.id).first()
            config_path = os.path.join(str(sdk2.ResourceData(report_data).path), "backends", "backends_config.json")
            backend_config = None
            with open(config_path, 'r') as dist:
                backend_config = json.load(dist)

            shardmap = []
            for service in backend_config:
                for host in service['hosts']:
                    if 'goods' in host['rtc_service']:
                        model_shards = host['shards']['model-part']
                        if len(model_shards) > 0:
                            shard_id = model_shards[0]
                            shardmap.append('host_port:{rtc_host}:{rtc_port}    {rbtorrent}\n'.format(rtc_host=host['rtc_host'],
                                                                                                      rtc_port=host['rtc_port'],
                                                                                                      rbtorrent=shards[shard_id]))
            logging.info('Shardmap content: %s' % shardmap)
            shardmap_path = './goods_shardmap'
            with open(shardmap_path, 'w+') as f:
                f.writelines(shardmap)
            logging.info('Shardmap written to %s file' % shardmap_path)

        beta_shardmap = []
        for shard_id in range(len(shards)):
            beta_shardmap.append('pod_label:shard_id={shard_id}    {rbtorrent}\n'.format(shard_id=shard_id,
                                                                                         rbtorrent=shards[shard_id]))
        logging.info('Beta shardmap content: %s' % beta_shardmap)
        beta_shardmap_path = './goods_beta_shardmap'
        with open(beta_shardmap_path, 'w+') as f:
            f.writelines(beta_shardmap)
        logging.info('Beta shardmap written to %s file' % beta_shardmap_path)
        return shardmap_path, beta_shardmap_path

    def on_execute(self):
        with self.memoize_stage.PREPARE_SHARDS(commit_on_entrance=False):
            if self.Context.make_shard_tasks:
                return

            if self.Parameters.market_report_dist:
                market_dist = self.Parameters.market_report_dist
                dist_path = str(sdk2.ResourceData(self.Parameters.market_report_dist).path)
            else:
                attrs={'is_turbo': 'True'}

                if self.Parameters.indexer != 'any':
                    attrs['mi_type'] = self.Parameters.indexer
                if self.Parameters.generation:
                    attrs['generation'] = self.Parameters.generation
                if self.Parameters.use_production_check:
                    attrs['env_type'] = 'production'
                    attrs['is_blue'] = False
                    attrs['is_fresh'] = False
                    attrs['is_half'] = False
                    attrs['is_planeshift'] = False
                    attrs['is_scaled'] = False
                    attrs['not_for_publish'] = False

                logging.info('Attributes for MARKET_REPORT_DIST_META: {}'.format(attrs))
                market_dist = sdk2.Resource[market_resources.MARKET_REPORT_DIST_META].find(attrs=attrs).first()
                if market_dist is None:
                    raise ValueError('No MARKET_REPORT_DIST_META resource with given attributes')

                if self.Parameters.use_production_check:
                    last_released_shardmap = sdk2.Resource[goods_resources.GoodsBasesearchShardmap].find(attrs={'released': 'stable', 'service_type': 'gencfg'}).first()
                    if last_released_shardmap.generation >= market_dist.generation:
                        raise ValueError('Last MARKET_REPORT_DIST_META ${} has {} generation older or equal last released generation {}'.format(market_dist.id,
                                                                                                                                                market_dist.generation,
                                                                                                                                                last_released_shardmap.generation))

                dist_path = str(sdk2.ResourceData(market_dist).path)

            data = None
            with open(dist_path, 'r') as dist:
                data = json.load(dist)

            self.Context.generation = market_dist.generation
            self.set_info(info="Index generation: {}".format(market_dist.generation))

            report_data = data['search-report-data']
            search_stats = data['search-stats']
            def get_rbtorrent_ids(key_name):
                shards = []
                i = 0
                while True:
                    key = '{0}-{1}'.format(key_name, i)
                    if key not in data:
                        break
                    shards.append(data[key])
                    i += 1
                return shards

            shards = get_rbtorrent_ids('search-part')
            model_shards = get_rbtorrent_ids('model-part')
            assert len(shards) == OFFER_SHARD_COUNT, '{} offer shards expected, {} found'.format(OFFER_SHARD_COUNT, len(shards))
            assert len(model_shards) == MODEL_SHARD_COUNT, '{} model shards expected, {} found'.format(MODEL_SHARD_COUNT, len(model_shards))

            self.Context.make_shard_tasks = []

            for shard_id in range(len(model_shards)):
                first_shard = shards[shard_id]
                second_shard = shards[shard_id + MODEL_SHARD_COUNT]
                model_shard = model_shards[shard_id]
                shard_task = MakeGoodsShardGencfg(self,
                                                    description='Create goods shard #{shard_id} (#{parent_id} subtask)'.format(shard_id=shard_id, parent_id=self.id),
                                                    owner=self.owner,
                                                    model_shard_id=shard_id,
                                                    first_shard_rbtorrent_id=first_shard,
                                                    second_shard_rbtorrent_id=second_shard,
                                                    model_rbtorrent_id=model_shard,
                                                    push_tasks_resource=True,
                                                    create_sub_task=True).enqueue()
                self.Context.make_shard_tasks.append(shard_task.id)

            # Build 1 offer 1 model shards per instance
            for shard_id in range(len(shards)):
                offer_shard = shards[shard_id]
                model_shard = model_shards[shard_id % MODEL_SHARD_COUNT]
                shard_task = MakeGoodsShardGencfg(self,
                                                    description='Create exp goods shard #{shard_id} (#{parent_id} subtask)'.format(shard_id=shard_id, parent_id=self.id),
                                                    owner=self.owner,
                                                    model_shard_id=shard_id,
                                                    first_shard_rbtorrent_id=offer_shard,
                                                    second_shard_rbtorrent_id=None,
                                                    model_rbtorrent_id=model_shard,
                                                    push_tasks_resource=True,
                                                    create_sub_task=True,
                                                    experiment=True).enqueue()
                self.Context.make_shard_tasks.append(shard_task.id)

            src_info = [ResourceInfo('search-report-data', report_data), ResourceInfo('search-stats', search_stats)]
            create_resource(self,
                            src_info,
                            goods_resources.GoodsReportData,
                            './report_data',
                            generation=self.Context.generation)

            raise sdk2.WaitTask(self.Context.make_shard_tasks,
                                ctt.Status.Group.FINISH | ctt.Status.Group.BREAK,
                                wait_all=True)

        with self.memoize_stage.PREPARE_SHARDMAP(commit_on_entrance=False):
            shardmap_path, beta_shardmap_path= self._make_gencfg_shardmap()

            if shardmap_path is not None:
                shardmap_resource = sdk2.Resource[goods_resources.GoodsBasesearchShardmap]
                current_shardmap_resource = shardmap_resource(self, 'Goods shardmap', shardmap_path)
                current_shardmap_resource.backup_task = True
                current_shardmap_resource.generation = self.Context.generation
                current_shardmap_resource.service_type = 'gencfg'

            beta_shardmap_resource = sdk2.Resource[goods_resources.GoodsBasesearchBetaShardmap]
            current_beta_shardmap_resource = beta_shardmap_resource(self, 'Goods beta shardmap', beta_shardmap_path)
            current_beta_shardmap_resource.backup_task = True
            current_beta_shardmap_resource.generation = self.Context.generation
            current_beta_shardmap_resource.service_type = 'gencfg'

            exp_shardmap_path = self._make_gencfg_shardmap_exp()
            exp_shardmap_resource = sdk2.Resource[goods_resources.GoodsBasesearchExpShardmap]
            current_exp_shardmap_resource = exp_shardmap_resource(self, 'Goods exp shardmap', exp_shardmap_path)
            current_exp_shardmap_resource.backup_task = True
            current_exp_shardmap_resource.generation = self.Context.generation
            current_exp_shardmap_resource.service_type = 'gencfg'

    def on_release(self, params):
        params[const.RELEASE_SUBJECT_KEY] = "Release goods gencfg shardmap {}".format(self.Context.generation)
        super(MakeGoodsShardmapGencfg, self).on_release(params)
        self.mark_released_resources(params["release_status"], ttl=30)
