# -*- coding: utf-8 -*-

import sandbox.projects.common.news.YtScriptTask as ys
from sandbox.projects.common.nanny import nanny
import sandbox.projects.MakeArcNewsdShard as sh
import sandbox.projects.BuildNewsSearchShard as shard_builder

from sandbox.projects import resource_types
from sandbox.projects.news import resources
from sandbox.projects.common import apihelpers
from sandbox.projects.common import solomon

from sandbox.sandboxsdk import task as st
from sandbox.sandboxsdk import parameters as sp
from sandbox.sandboxsdk import process
from sandbox.sandboxsdk import environments
from sandbox.sandboxsdk.channel import channel
from sandbox.sandboxsdk.errors import SandboxTaskFailureError

import os
import shutil
import time
import copy
import datetime
import tempfile
import logging

story_template = "arcnews_story-{shardno_fmt}-{ts}"
index_template = "arcnews-{shardno_fmt}-{ts}"
shardno_fmt = "{:03d}"
max_restarts_count = 5


class VaultTvmEnv(sp.SandboxStringParameter):
    name = 'vault_solomon_tvm_env'
    description = 'Solomon TVM token from Vault, format: vault_owner:vault_key'
    required = True


class DoNotReportMetricsToSolomon(sp.SandboxBoolParameter):
    name = 'dont_report_metrics_to_solomon'
    description = "Don't push progress metrics to solomon"
    required = False
    default_value = False


class ShardsNumber(sp.SandboxIntegerParameter):
    name = 'shards_number'
    description = 'Number of shards'
    required = False
    default_value = None


class MinYear(sp.SandboxIntegerParameter):
    name = 'min_year'
    description = 'Min year to store in archives'
    required = False
    default_value = 2010


class FreshDaysGap(sp.SandboxIntegerParameter):
    name = 'fresh_days_gap'
    description = 'Number of new days to skip'
    required = False
    default_value = 3


class NewsdShardsMappingTable(sp.SandboxStringParameter):
    name = 'newsd_shards_mapping_table'
    description = "slave_newsd shards mapping table"
    required = False
    default_value = '//home/news-prod/archive/newsd/shards'


class ShardMapTorrent(sp.SandboxStringParameter):
    name = 'shard_map_torrent'
    description = "Shardmap rbtorrent, use it if you want just publish this shardmap and don't want to get shards from YT"
    required = False
    default_value = None


class RestoreShards(sp.SandboxBoolParameter):
    name = 'restore_shards'
    description = "Restore all shards"
    required = False
    default_value = False


class BuildSearchShards(sp.SandboxBoolParameter):
    name = 'build_search_shards'
    description = "Build legacy search shards"
    required = False
    default_value = False


class PublishNewsArchive(nanny.ReleaseToNannyTask, st.SandboxTask):
    '''
    Download, register and publish all archive shards
    '''

    type = 'PUBLISH_NEWS_ARCHIVE'

    input_parameters = [ShardMapTorrent,
                        ShardsNumber,
                        NewsdShardsMappingTable,
                        VaultTvmEnv,
                        DoNotReportMetricsToSolomon,
                        MinYear,
                        FreshDaysGap,
                        RestoreShards
                        ] + ys.get_base_params() + [sh.NewsdTable,
                                                    sh.Binary,
                                                    sh.Ttl,
                                                    BuildSearchShards,
                                                    shard_builder.Builder,
                                                    shard_builder.CypressPath,
                                                    shard_builder.GrattrConf,
                                                    ]

    environment = (
        environments.SvnEnvironment(),
        environments.PipEnvironment('yandex-yt'),
        environments.PipEnvironment('yandex-yt-yson-bindings-skynet', version="0.3.32-0"),
    )

    cores = 1
    required_ram = 1024  # 1GB

    def on_enqueue(self):
        self.ctx['timestamp'] = int(time.time())

        filename = 'shardmap_arcnews_iss-%s.map' % self.ctx['timestamp']
        resource = self.create_resource(description=filename,
                                        resource_path=self.path(filename),
                                        resource_type=resource_types.ARCNEWS_SHARDMAP,
                                        arch='any')
        self.ctx['shardmap_name'] = filename
        self.ctx['resource_id'] = resource.id

        filename_yp = 'shardmap_arcnews_yp-%s.map' % self.ctx['timestamp']
        resource_yp = self.create_resource(description=filename_yp,
                                           resource_path=self.path(filename_yp),
                                           resource_type=resources.ARCNEWS_SHARDMAP_YP,
                                           arch='any')
        self.ctx['shardmap_name_yp'] = filename_yp
        self.ctx['resource_id_yp'] = resource_yp.id

    def _get_archive_upper_bound(self, days):
        last_arc_date = datetime.date.today() - datetime.timedelta(days=days)
        last_arc_day_str = last_arc_date.strftime("%Y%m%d")
        return int(last_arc_day_str)

    def _get_last_shard_info(self, type, shard_number):
        from sandbox.yasandbox.api.xmlrpc.resource import touch_resource

        last_shard = apihelpers.get_last_resource_with_attribute(
            resource_type=type,
            attribute_name='shard_number',
            attribute_value=str(shard_number),
            status='READY',
        )
        if last_shard is None:
            raise SandboxTaskFailureError("Failed to get shard resource for shard %s" % shard_number)
        touch_resource(last_shard.id)
        shard_name = channel.sandbox.get_resource_attribute(last_shard.id, attribute_name="shard_name")
        return {'name': shard_name, 'skynet_id': last_shard.skynet_id}

    def __add_newsd_task(self, shard_id, since=None, to=None, days=None):
        shard_name = story_template.format(
            shardno_fmt=shardno_fmt.format(shard_id),
            ts=self.ctx['timestamp']
        )
        sub_ctx = copy.deepcopy(self.ctx)
        sub_ctx.update({
            sh.ShardName.name: shard_name,
            sh.ShardNumber.name: shard_id,
            sh.Restore.name: self.ctx.get(RestoreShards.name)
        })
        if days is not None:
            sub_ctx.update({
                sh.Days.name: " ".join(days),
            })
        else:
            sub_ctx.update({
                sh.LowerBound.name: since,
                sh.UpperBound.name: to,
            })
        task = self.create_subtask('MAKE_ARCHIVE_NEWSD_SHARD', shard_name, sub_ctx, inherit_notifications=False)
        return task.id

    def __add_search_shard(self, shard_id):
        shard_name = index_template.format(
            shardno_fmt=shardno_fmt.format(shard_id),
            ts=self.ctx['timestamp']
        )
        sub_ctx = copy.deepcopy(self.ctx)
        sub_ctx.update({
            shard_builder.ShardName.name: shard_name,
            shard_builder.ShardNumber.name: shard_id,
            shard_builder.Restore.name: self.ctx.get(RestoreShards.name)
        })
        task = self.create_subtask('BUILD_NEWS_SEARCH_SHARD', shard_name, sub_ctx)
        return task.id

    def __get_days_to_shards_mapping(self, last_day):
        import yt.wrapper as yt

        vault_env = self.ctx[ys.VaultEnv.name]
        (owner, key) = vault_env.split(':')
        token = self.get_vault_data(owner, key)

        yt_proxy = self.ctx[ys.YtProxy.name]

        client = yt.YtClient(proxy=yt_proxy, token=token)

        shards = {}
        for i in xrange(0, self.ctx[ShardsNumber.name]):
            shards[str(i)] = []

        days_invalid_shard_id = []
        days_too_new = []
        days_too_old = []
        for l in client.read_table(self.ctx[NewsdShardsMappingTable.name]):
            day = l['day']
            shard = l['shard_id']
            if int(shard) >= self.ctx[ShardsNumber.name]:
                days_invalid_shard_id.append((day, shard))
                continue
            int_day = int(day)
            if int_day > last_day:
                days_too_new.append((day, shard))
                continue
            if int_day < self.ctx[MinYear.name] * 10000:
                days_too_old.append((day, shard))
                continue
            shards[shard].append(day)

        if days_invalid_shard_id:
            logging.error("Days with invalid shard ids: %s", days_invalid_shard_id)
        logging.info("Too new days: %s", days_too_new)
        logging.info("Too old days: %s", days_too_old)

        return shards

    def on_execute(self):
        commonLabels = {
            'project': 'news',
            'cluster': 'main',
            'service': 'main',
        }

        if not self.ctx.get(DoNotReportMetricsToSolomon.name):
            sensors = [
                {
                    'labels': {
                        'archive': 'publish',
                        'sensor': 'start',
                    },
                    'ts': int(time.time()),
                    'value': int(time.time()),
                }
            ]
            (owner, key) = self.ctx.get(VaultTvmEnv.name).split(':')
            solomon_tvm = self.get_vault_data(owner, key)
            solomon.push_to_solomon_v2(params=commonLabels, sensors=sensors, token=solomon_tvm)

        rbtorrent = self.ctx.get(ShardMapTorrent.name)
        if rbtorrent:
            temp_dir = tempfile.mkdtemp()
            process.run_process(['sky', 'get', '-wu', rbtorrent], work_dir=temp_dir, log_prefix='sky_get')
            found = False
            for obj in os.listdir(temp_dir):
                path = os.path.join(temp_dir, obj)
                if os.path.isfile(path) and obj.startswith("shardmap_arcnews_"):
                    shutil.copyfile(path, self.ctx['shardmap_name'])
                    found = True
                    break
            if not found:
                raise SandboxTaskFailureError('Failed to load shardmap')
            return

        if 'WAIT_SUBTASKS' not in self.ctx:
            self.ctx['restarts'] = {}
            # prepare shards
            subtasks = []
            shards = []
            arc_upper_bound = self._get_archive_upper_bound(self.ctx[FreshDaysGap.name])
            logging.info("arc_upper_bound: %s", arc_upper_bound)
            days = self.__get_days_to_shards_mapping(arc_upper_bound)
            for shard_id in xrange(0, self.ctx[ShardsNumber.name]):
                shards.append(shard_id)
                subtasks.append(
                    self.__add_newsd_task(
                        shard_id,
                        days=days[str(shard_id)],
                    )
                )

                if self.ctx.get(BuildSearchShards.name):
                    subtasks.append(
                        self.__add_search_shard(shard_id)
                    )

            if len(subtasks):
                self.ctx['WAIT_SUBTASKS'] = True
                self.ctx['subtasks'] = subtasks
                self.ctx['shards'] = shards
                self.wait_tasks(subtasks, self.Status.Group.SUCCEED + self.Status.Group.SCHEDULER_FAILURE,
                                wait_all=True, state="Waiting for shards preparing")

        else:
            # check subtasks status
            need_to_wait = []
            for task_id in self.ctx['subtasks']:
                task = channel.sandbox.get_task(task_id)
                if not task.is_finished():
                    if not self.ctx['restarts'].get(task_id) or self.ctx['restarts'][task_id] < max_restarts_count:
                        if task.ctx.get('days'):
                            logging.info("Restarting slave_newsd task for shard %s" % task.ctx['shard_number'])
                            subtask = self.create_subtask('MAKE_ARCHIVE_NEWSD_SHARD',
                                                          "restarting shard %s" % task.ctx['shard_number'], task.ctx)
                            need_to_wait.append(subtask.id)
                        else:
                            logging.info("Restarting search task for shard %s" % task.ctx['shard_number'])
                            subtask = self.create_subtask('BUILD_NEWS_SEARCH_SHARD', task.ctx['shard_name'], task.ctx)
                            need_to_wait.append(subtask.id)
                        self.ctx['restarts'][task_id] = self.ctx['restarts'].get(task_id, 0) + 1
                    else:
                        raise SandboxTaskFailureError("Subtask %s failed" % task_id)

            if len(need_to_wait) > 0:
                self.ctx['subtasks'] = need_to_wait
                self.wait_tasks(need_to_wait, self.Status.Group.SUCCEED + self.Status.Group.SCHEDULER_FAILURE,
                                wait_all=True, state="Waiting for failed shards preparing")

            # prepare shardmap
            with open(self.path(self.ctx['shardmap_name']), 'w') as shardmap, open(self.path(self.ctx['shardmap_name_yp']), 'w') as shardmap_yp:
                for shard_id in sorted(self.ctx['shards']):
                    # index
                    if self.ctx.get(BuildSearchShards.name):
                        mask = index_template.format(shardno_fmt=shardno_fmt.format(shard_id), ts="0000000000")
                        search_shard_info = self._get_last_shard_info(resource_types.NEWS_SEARCH_SHARD, shard_id)
                        shardmap.write(mask + "\t" + search_shard_info['name'] + "\tArcnewsRusTier0\n")
                    # newsd
                    shard_info = self._get_last_shard_info(resource_types.SLAVE_NEWSD_ARCHIVE_SHARD, shard_id)
                    mask = story_template.format(shardno_fmt=shardno_fmt.format(shard_id), ts="0000000000")
                    shardmap.write(mask + "\t" + shard_info['name'] + "\tArcnewsRusTier0\n")
                    shardmap_yp.write("pod_label:shard_id={}".format(str(int(shard_id) + 1)) + "\t" + shard_info['skynet_id'] + "(local_path=shard)\n")

            self.mark_resource_ready(self.ctx['resource_id'])
            self.mark_resource_ready(self.ctx['resource_id_yp'])

        if not self.ctx.get(DoNotReportMetricsToSolomon.name):
            sensors = [
                {
                    'labels': {
                        'archive': 'publish',
                        'sensor': 'finish',
                    },
                    'ts': int(time.time()),
                    'value': int(time.time()),
                }
            ]
            solomon.push_to_solomon_v2(params=commonLabels, sensors=sensors, token=solomon_tvm)


__Task__ = PublishNewsArchive
