from sandbox import sdk2
from sandbox.sdk2 import service_resources
from sandbox.common.share import skynet_get
from sandbox.projects.common.arcadia import sdk
from sandbox.projects.common.nanny import nanny
from sandbox.projects.common import task_env

import sandbox.projects.market.resources.idx as market_resources
import sandbox.projects.goods.resources as goods_resources
import sandbox.common.types.task as ctt

import json
import logging
from collections import OrderedDict
import shutil
import os


OFFER_SHARD_COUNT = 16
MODEL_SHARD_COUNT = 8


class ResourceInfo:
    def __init__(self, key_name, rbtorrent_id, resource_path=None, files_to_remove=[]):
        self.key_name = key_name
        self.rbtorrent_id = rbtorrent_id
        self.resource_path = resource_path
        self.files_to_remove = files_to_remove


def create_resource(parent_task, src_info, resource_type, folder_name, shard_id=None):
    try:
        tmp = './packed'
        os.makedirs(folder_name)
        with sdk.mount_arc_path("arcadia-arc:/#trunk") as arcadia:
            for item in src_info:
                dst = folder_name
                if item.resource_path:
                    dst = os.path.join(dst, item.resource_path)
                    os.mkdir(dst)
                skynet_get(item.rbtorrent_id, tmp)
                arch_path = tmp + '/{0}/{0}.tar.zstd_10'.format(item.key_name)
                sdk.run_tool(arcadia, 'uc', ['--decompress', '--from', arch_path, '--codec', 'zstd_10', '--to', dst, '-x'])

                for fname in item.files_to_remove:
                    try:
                        os.remove(os.path.join(dst, fname))
                        logging.info('file {} removed from shard'.format(fname))
                    except OSError:
                        pass

        shard_resource = sdk2.Resource[resource_type]
        current_shard_resource = shard_resource(parent_task,
                                                'Shard for goods base search',
                                                folder_name)
        current_shard_resource.backup_task = True
        if shard_id:
            current_shard_resource.shard_id = shard_id
        shard_data = sdk2.ResourceData(current_shard_resource)
        shard_data.ready()
    finally:
        shutil.rmtree(tmp)


class MakeGoodsShard(sdk2.Task):
    class Requirements(task_env.TinyRequirements):
        disk_space = 40 * 1024   # 30 Gb

    class Parameters(sdk2.task.Parameters):
        push_tasks_resource = True
        shard_id = sdk2.parameters.Integer('shard id', required=True)
        rbtorrent_id = sdk2.parameters.String('offers data rbtorrent id', required=True)
        model_rbtorrent_id = sdk2.parameters.String('models data rbtorrent id', required=True)

    def on_execute(self):
        FILES_TO_REMOVE = ["content-offer.tsv", "delivery_shop_ids.mmap"]
        src_info = [
            ResourceInfo('search-part-{}'.format(self.Parameters.shard_id), self.Parameters.rbtorrent_id, 'offer', FILES_TO_REMOVE),
            ResourceInfo('model-part-{}'.format(self.Parameters.shard_id % MODEL_SHARD_COUNT), self.Parameters.model_rbtorrent_id, 'model')
        ]
        create_resource(
            self,
            src_info,
            goods_resources.GoodsBasesearchShard,
            './shard',
            self.Parameters.shard_id,
        )


class MakeGoodsShardmap(nanny.ReleaseToNannyTask2, sdk2.Task):

    class Parameters(sdk2.task.Parameters):
        push_tasks_resource = True
        market_report_dist = sdk2.parameters.Resource('Market report dist', resource_type=market_resources.MARKET_REPORT_DIST_META)

    class Requirements(task_env.TinyRequirements):
        disk_space = 30 * 1024   # 30 Gb

    def on_create(self):
        self.Requirements.tasks_resource = service_resources.SandboxTasksBinary.find(owner="GOODS-RUNTIME",
                                                                                     attrs={"task_type": "MAKE_GOODS_SHARDMAP", "released": "stable"}).first()

    def _get_shard_rbtorrent(self, task):
        return sdk2.Resource[goods_resources.GoodsBasesearchShard].find(task=task).first().skynet_id

    def _get_shards(self):
        tasks = self.find(task_type=sdk2.Task[MakeGoodsShard.type],
                          status=ctt.Status.Group.SUCCEED)
        logging.info('Child tasks: %s' % [t.id for t in tasks])
        shards = {task.Parameters.shard_id: self._get_shard_rbtorrent(task) for task in tasks}
        return shards

    def _make_yp_shardmap_dict(self, shards):
        shardmap = []
        labels = range(len(shards))
        shardmap.extend(zip(labels, shards))
        return OrderedDict(shardmap)

    def _make_yp_shardmap(self):
        shards = self._get_shards().values()
        shardmap_dict = self._make_yp_shardmap_dict(shards)
        shardmap = []
        for label_id, rbtorrent in shardmap_dict.iteritems():
            shardmap.append(('pod_label:shard_id={label_id}    {rbtorrent}(local_path=shard)\n').format(label_id=label_id,
                                                                                                        rbtorrent=rbtorrent))
        logging.info('Shardmap content: %s' % shardmap)
        shardmap_path = './goods_shardmap'
        with open(shardmap_path, 'w+') as f:
            f.writelines(shardmap)
        logging.info('Shardmap written to %s file' % shardmap_path)
        return shardmap_path

    def on_execute(self):
        with self.memoize_stage.PREPARE_SHARDS(commit_on_entrance=False):
            if self.Context.make_shard_tasks:
                return

            if self.Parameters.market_report_dist:
                dist_path = str(sdk2.ResourceData(self.Parameters.market_report_dist).path)
            else:
                market_dist = sdk2.Resource[market_resources.MARKET_REPORT_DIST_META].find(attrs={'is_turbo': 'True'}).first()
                dist_path = str(sdk2.ResourceData(market_dist).path)
            data = None
            with open(dist_path, 'r') as dist:
                data = json.load(dist)

            report_data = data['search-report-data']
            search_stats = data['search-stats']
            def get_rbtorrent_ids(key_name):
                shards = []
                i = 0
                while True:
                    key = '{0}-{1}'.format(key_name, i)
                    if key not in data:
                        break
                    shards.append(data[key])
                    i += 1
                return shards

            shards = get_rbtorrent_ids('search-part')
            model_shards = get_rbtorrent_ids('model-part')
            assert len(shards) == OFFER_SHARD_COUNT, '{} offer shards expected, {} found'.format(OFFER_SHARD_COUNT, len(shards))
            assert len(model_shards) == MODEL_SHARD_COUNT, '{} model shards expected, {} found'.format(MODEL_SHARD_COUNT, len(model_shards))

            self.Context.make_shard_tasks = []
            for shard_id in range(len(shards)):
                shard = shards[shard_id]
                model_shard = model_shards[shard_id % MODEL_SHARD_COUNT]
                shard_task = MakeGoodsShard(self,
                                            description='Create goods shard #{shard_id} (#{parent_id} subtask)'.format(shard_id=shard_id, parent_id=self.id),
                                            owner=self.owner,
                                            shard_id=shard_id,
                                            rbtorrent_id=shard,
                                            model_rbtorrent_id=model_shard,
                                            push_tasks_resource=True,
                                            create_sub_task=True).enqueue()
                self.Context.make_shard_tasks.append(shard_task.id)

            src_info = [ResourceInfo('search-report-data', report_data), ResourceInfo('search-stats', search_stats)]
            create_resource(self,
                            src_info,
                            goods_resources.GoodsReportData,
                            './report_data')

            raise sdk2.WaitTask(self.Context.make_shard_tasks,
                                ctt.Status.Group.FINISH | ctt.Status.Group.BREAK,
                                wait_all=True)

        with self.memoize_stage.PREPARE_SHARDMAP(commit_on_entrance=False):
            yp_shardmap_path = self._make_yp_shardmap()
            yp_shardmap_resource = sdk2.Resource[goods_resources.GoodsBasesearchShardmap]
            current_yp_shardmap_resource = yp_shardmap_resource(self, 'Goods shardmap', yp_shardmap_path)
            current_yp_shardmap_resource.backup_task = True

    def on_release(self, params):
        super(MakeGoodsShardmap, self).on_release(params)
        self.mark_released_resources(params["release_status"], ttl=30)
