from sandbox.projects.common.news.YtScriptTaskV2 import YtScriptV2
from sandbox.projects.common.nanny import nanny

from sandbox import sdk2
from sandbox.sandboxsdk import environments, process
from sandbox.sandboxsdk.errors import SandboxTaskFailureError
from sandbox.sandboxsdk.channel import channel

from sandbox.projects import resource_types
from sandbox.projects.news import resources
from sandbox.projects.common import apihelpers, solomon
from sandbox.common.types import task as ctt

import os
import shutil
import time
import copy
import datetime
import tempfile
import logging

story_template = "{shard_name_prefix}-{shardno_fmt}-{ts}"
index_template = "arcnews-{shardno_fmt}-{ts}"
shardno_fmt = "{:03d}"
max_restarts_count = 5


class PublishNewsArchiveV2(nanny.ReleaseToNannyTask2, YtScriptV2):
    class Requirements(YtScriptV2.Requirements):
        environments = [
            environments.SvnEnvironment(),
            environments.PipEnvironment('yandex-yt'),
            environments.PipEnvironment('yandex-yt-yson-bindings-skynet', version='0.3.32-0'),
        ]

    class Parameters(YtScriptV2.Parameters):

        solomon_tvm_vault_selector = sdk2.parameters.YavSecret("Solomon TVM token from YAV", required=True)
        shard_map_torrent = sdk2.parameters.String("Shardmap rbtorrent. Use it if you want just publish this shardmap and don't want to get shards from YT", default_value=None)
        shards_count = sdk2.parameters.Integer("Shards count", required=True, default_value=112)
        newsd_shard_mapping_table = sdk2.parameters.String("slave_newsd shards mapping table", required=True, default_value="//home/news-prod/archive/newsd/shards")
        min_year = sdk2.parameters.Integer("Min year to store in archives", default_value=2010)
        fresh_days_gap = sdk2.parameters.Integer("Number of new days to skip", default_value=3)
        restore_shards = sdk2.parameters.Bool("Restore all shards", default_value=False)
        newsd_table = sdk2.parameters.String("Path to newsd table on YT", required=True, default_value="//home/news-prod/archive/newsd/newsd")
        newsd_statework_binary = sdk2.parameters.Resource("Resource with newsd_statework", required=True, resource_type=resource_types.NEWSD_STATEWORK)
        ttl = sdk2.parameters.Integer("TTL of shard resource", required=True, default_value=14)
        shard_name_prefix = sdk2.parameters.String("Prefix for shard names", required=True, default_value="shardmap_arcnews_yp")

    class Context(sdk2.Context):
        timestamp = None

        shardmap_name_yp = None
        # resource_id_yp = None

        status = None
        restarts = {}
        wait_subtasks = False
        subtasks = []
        shards = []

    def _init_context(self):
        self.Context.timestamp = int(time.time())

        filename = 'shardmap_arcnews_yp-%s.map' % self.Context.timestamp
        self.Context.shardmap_name_yp = filename

        self.Context.save()

    def _get_archive_upper_bound(self, days):
        last_arc_date = datetime.date.today() - datetime.timedelta(days=days)
        last_arc_day_str = last_arc_date.strftime("%Y%m%d")
        return int(last_arc_day_str)

    def _get_last_shard_info(self, type, shard_number):
        from sandbox.yasandbox import manager

        last_shard = apihelpers.get_last_resource_with_attribute(
            resource_type=type,
            attribute_name='shard_number',
            attribute_value=str(shard_number),
            status='READY',
        )
        if last_shard is None:
            raise SandboxTaskFailureError("Failed to get shard resource for shard %s" % shard_number)
        manager.resource_manager.touch(last_shard.id)
        shard_name = channel.sandbox.get_resource_attribute(last_shard.id, attribute_name="shard_name")
        return {'name': shard_name, 'skynet_id': last_shard.skynet_id}

    def __add_newsd_task(self, shard_id, since=None, to=None, days=None):
        from sandbox.projects.news.MakeArcNewsdShardV2 import MakeArcNewsdShardV2

        shard_name = story_template.format(
            shard_name_prefix=self.Parameters.shard_name_prefix,
            shardno_fmt=shardno_fmt.format(shard_id),
            ts=self.Context.timestamp,
        )
        sub_task = MakeArcNewsdShardV2(
            self,
            description=shard_name,
            yt_proxy=self.Parameters.yt_proxy,
            yt_token=self.Parameters.yt_token,
            yt_token_field=self.Parameters.yt_token_field,
            yt_pool=self.Parameters.yt_pool,
            script_url=self.Parameters.script_url,
            script_cmdline=self.Parameters.script_cmdline,
            script_patch=None,
            news_python=self.Parameters.news_python,
            newsd_table=self.Parameters.newsd_table,
            shard_name=shard_name,
            shard_number=shard_id,
            newsd_statework_binary=self.Parameters.newsd_statework_binary,
            ttl=self.Parameters.ttl,
            restore=self.Parameters.restore_shards,
        )
        if days is not None:
            sub_task.Parameters.days=" ".join(days)
        else:
            sub_task.Parameters.lower_bound=since,
            sub_task.Parameters.upper_bound=to,

        sub_task.save()
        sub_task.enqueue()
        return sub_task

    def __get_days_to_shards_mapping(self, last_day):
        import yt.wrapper as yt
        token = self.get_token()
        yt_proxy = self.get_yt_proxy()
        client = yt.YtClient(proxy=yt_proxy, token=token)

        shards = {}
        for i in xrange(0, self.Parameters.shards_count):
            shards[str(i)] = []

        days_invalid_shard_id = []
        days_too_new = []
        days_too_old = []
        for l in client.read_table(self.Parameters.newsd_shard_mapping_table):
            day = l['day']
            shard = l['shard_id']
            if int(shard) >= self.Parameters.shards_count:
                days_invalid_shard_id.append((day, shard))
                continue
            int_day = int(day)
            if int_day > last_day:
                days_too_new.append((day, shard))
                continue
            if int_day < self.Parameters.min_year * 10000:
                days_too_old.append((day, shard))
                continue
            shards[shard].append(day)

        if days_invalid_shard_id:
            logging.error("Days with invalid shard ids: %s", days_invalid_shard_id)
        logging.info("Too new days: %s", days_too_new)
        logging.info("Too old days: %s", days_too_old)

        return shards

    def on_execute(self):
        if self.Context.timestamp is None:
            self._init_context()
        logging.info("Just some code to commit into sandbox and continue debug with binary task")

        common_labels = {
            'project': 'news',
            'cluster': 'asp437-test',  # TODO: use from parameter
            'service': 'main',
        }

        sensors = [
            {
                'labels': {
                    'archive': 'publish',
                    'sensor': 'start',
                },
                'ts': int(time.time()),
                'value': int(time.time()),
            }
        ]
        solomon_tvm = self.Parameters.solomon_tvm_vault_selector
        # solomon.push_to_solomon_v2(params=common_labels, sensors=sensors, token=solomon_tvm)

        rbtorrent = self.Parameters.shard_map_torrent
        if rbtorrent:
            temp_dir = tempfile.mkdtemp()
            process.run_process(['sky', 'get', '-wu', rbtorrent], work_dir=temp_dir, log_prefix='sky_get')
            found = False
            for obj in os.listdir(temp_dir):
                path = os.path.join(temp_dir, obj)
                if os.path.isfile(path) and obj.startswith("shardmap_arcnews_"):
                    shutil.copyfile(path, self.Context.shardmap_name_yp)
                    found = True
                    break
            if not found:
                raise SandboxTaskFailureError("Failed to load shardmap")
            return

        if 'WAIT_SUBTASKS' != self.Context.status:
            self.Context.restarts = {}
            subtasks = []
            shards = []
            arc_upper_bound = self._get_archive_upper_bound(self.Parameters.fresh_days_gap)
            logging.info("arc_upper_bound: %s", arc_upper_bound)
            days = self.__get_days_to_shards_mapping(arc_upper_bound)
            for shard_id in xrange(0, self.Parameters.shards_count):
                shards.append(shard_id)
                subtasks.append(
                    self.__add_newsd_task(
                        shard_id,
                        days=days[str(shard_id)],
                    )
                )

            if len(subtasks):
                self.Context.status = 'WAIT_SUBTASKS'
                self.Context.subtasks = [subtask.id for subtask in subtasks]
                self.Context.shards = shards
                self.Context.save()
                raise sdk2.WaitTask(subtasks, ctt.Status.Group.SUCCEED + ctt.Status.Group.SCHEDULER_FAILURE, wait_all=True)
        else:
            # check subtasks status
            need_to_wait = []
            for task_id in self.Context.subtasks:
                task_query = self.find(id=task_id)
                for task in task_query:
                    if task.status not in ctt.Status.Group.SUCCEED:
                        logging.info("Restarting slave_newsd_task for shard %s" % task.Parameters.shard_number)
                        subtask = self.create_subtask("MAKE_ARCHIVE_NEWSD_SHARD", "restarting shard %s" % task.Parameters.shard_number, task.ctx)
                        need_to_wait.append(subtask)
                        break

            if len(need_to_wait):
                self.Context.subtasks = [task_to_wait.id for task_to_wait in need_to_wait]
                self.Context.save()
                raise sdk2.WaitTask(need_to_wait, ctt.Status.Group.SUCCEED + ctt.Status.Group.SCHEDULER_FAILURE, wait_all=True)

            shardmap_filename = self.Context.shardmap_name_yp
            with open(str(self.path(shardmap_filename)), 'w') as shardmap_yp:
                for shard_id in sorted(self.Context.shards):
                    shard_info = self._get_last_shard_info(resource_types.SLAVE_NEWSD_ARCHIVE_SHARD, shard_id)
                    shardmap_yp.write("pod_label:shard_id={}".format(str(int(shard_id) + 1)) + '\t' + shard_info['skynet_id'] + "(local_path=shard)\n")

            resource = resources.ARCNEWS_SHARDMAP_YP(self, shardmap_filename, self.path(shardmap_filename), arch='any')
            resource_data = sdk2.ResourceData(resource)
            resource_data.ready()
