# -*- coding: utf-8 -*-

import json
import logging
from sandbox import sdk2

from sandbox.projects import resource_types
from sandbox.sandboxsdk.channel import channel
from sandbox.sandboxsdk.task import SandboxTask
from sandbox.sandboxsdk import environments
from sandbox.sandboxsdk import parameters
from sandbox.common import errors
import sandbox.common.types.task as ctt

from sandbox.projects.common import error_handlers as eh
from sandbox.projects.common import utils
from sandbox.projects.common import apihelpers
from sandbox.projects.common.search.database import iss_shards
from sandbox.projects.common.search import shards as search_shards
from sandbox.projects.common.search import bugbanner
from sandbox.projects import MakeIdxOpsPool

from sandbox.projects.judtier import shards

OUT_FACTOR_NAMES = MakeIdxOpsPool.OUT_FACTOR_NAMES
OUT_FACTOR_BORDERS = MakeIdxOpsPool.OUT_FACTOR_BORDERS

_main = "Main"
_tmp_table_suffix = "_raw"


class ShardsList(parameters.ResourceSelector):
    # FIXME(mvel): Remove it
    name = 'shards_list_resource_id'
    description = 'Shards list'
    resource_type = [
        resource_types.OTHER_RESOURCE,
    ]
    group = _main
    required = True


class ShardCount(parameters.SandboxIntegerParameter):
    name = 'shard_count'
    description = 'Number of shards to process (all by default)'
    default_value = 0


class UseLastImportedState(parameters.SandboxBoolParameter):
    """
        Использовать Юпитерный стейт, который уже выкачан в Sandbox регулярным процессом
    """
    name = 'use_last_imported_state'
    description = 'Use last imported JudTier'
    default_value = True
    group = _main


class JupiterState(parameters.SandboxStringParameter):
    """
        Юпитерный стейт или алиас, e.g. 20161231-123456 или production_current_state
    """
    name = 'jupiter_state'
    description = 'Jupiter state to work with'
    default_value = 'production_current_state'
    group = _main


UseLastImportedState.sub_fields = {'false': [JupiterState.name]}


class CanonizeUrls(parameters.SandboxBoolParameter):
    name = 'canonize_urls'
    description = 'Canonize URLs'
    default_value = True
    group = _main


_PROCESS_SHARDS = 'process_shards'

_idx_ops_params = MakeIdxOpsPool.MakePoolParams


# build proper params list
def _get_make_pool_main_params():
    return [
        param for param in _idx_ops_params.params
        if param.group == _main and param.name not in [
            _idx_ops_params.SearchDatabaseResource.name,
            _idx_ops_params.RankingUrlsMap.name,
        ]
    ]


def _get_make_pool_other_params():
    return [
        param for param in _idx_ops_params.params
        if param.group != _main
    ]


_make_pool_main_params = _get_make_pool_main_params()
_make_pool_other_params = _get_make_pool_other_params()


class MakeEstimatedPool(bugbanner.BugBannerTask):
    """
        **Описание**

            Варит оценённый пул из оценённого тира JudTier.

            * Добывает шарды оценённого тира
            * Для каждого шарда запускает MakeIdxOpsPool
    """

    type = 'MAKE_ESTIMATED_POOL'

    environment = (environments.PipEnvironment('yandex-yt', use_wheel=True),)

    input_parameters = [
        # FIXME(mvel): Remove it
        # ShardsList,
        ShardCount,
    ] + _make_pool_main_params + [
        UseLastImportedState,
        JupiterState,
        CanonizeUrls,
    ] + _make_pool_other_params
    execution_space = 1000  # 1 Gb

    db_shard_timeout = 18000

    def on_enqueue(self):
        SandboxTask.on_enqueue(self)
        resource_names = self.create_resource(
            self.descr + ': factor names', 'factor_names.txt', resource_types.FACTOR_NAMES_TXT
        )
        self.ctx[OUT_FACTOR_NAMES] = resource_names.id
        resource_borders = self.create_resource(
            self.descr + ': factor borders', 'factor_borders.txt', resource_types.FACTOR_BORDERS_TXT
        )
        self.ctx[OUT_FACTOR_BORDERS] = resource_borders.id

    def on_execute(self):
        import yt.wrapper as yt
        yt.config['token'] = self.get_vault_data(utils.get_or_default(self.ctx, _idx_ops_params.VaultOwner), utils.get_or_default(self.ctx, _idx_ops_params.VaultName))

        if _PROCESS_SHARDS not in self.ctx:
            self.ctx[_PROCESS_SHARDS] = self._process_shards(yt)

        self._check_results(yt)

    def _split_table_name(self):
        full_table_name = self.ctx[_idx_ops_params.OutputTable.name]
        chunks = full_table_name.split(':')
        proxy = ':'.join(chunks[0:2])
        table_name = ':'.join(chunks[2:])
        if not table_name.startswith('//'):
            table_name = '//' + table_name

        return proxy, table_name

    def _resolve_state(self, used_state, yt):
        yt_client = yt.YtClient(proxy="arnold", token=yt.config["token"])
        jupiter_metadata = yt_client.get("//home/jupiter/@jupiter_meta")
        logging.info("Jupiter metadata:\n%s", json.dumps(jupiter_metadata, indent=4))
        used_state = utils.get_or_default(self.ctx, JupiterState)
        if used_state in jupiter_metadata:
            self.set_info("Jupiter state {} resolved to {}".format(used_state, jupiter_metadata[used_state]))
            used_state = jupiter_metadata[used_state]
        return used_state

    def _get_last_imported_state(self):
        import_task = sdk2.Task["JUD_TIER_IMPORT_SHARD_SET"].find(status=(ctt.Status.SUCCESS,)).order(-sdk2.Task.id).first()
        return import_task.Context.state, import_task.Context.shards

    def _find_imported_shard_names(self, used_state):
        for status in ('READY', 'DELETED'):
            dups_resource = apihelpers.get_last_resource_with_attribute(
                resource_types.JUDTIER_DUPS_TABLE,
                'jupiter_state', used_state,
                status=status
            )
            if dups_resource:
                dups_task = sdk2.Task[dups_resource.task_id]
                # 'primus-JudTier-0-10-timestamp' is listed before 'primus-JudTier-0-2-timestamp', but who cares?
                return sorted(dups_task.Context.shard_resource_ids.keys())
        return None

    def _process_shards(self, yt):
        if utils.get_or_default(self.ctx, UseLastImportedState):
            used_state, shard_names = self._get_last_imported_state()
            self.set_info("Using Jupiter state {}".format(used_state))
        else:
            used_state = utils.get_or_default(self.ctx, JupiterState)
            used_state = self._resolve_state(used_state, yt)
            production_timestamp = shards.state_to_timestamp(used_state)
            shard_names = self._find_imported_shard_names(used_state)
            if shard_names is None:  # not yet imported
                shards_count = shards.detect_shards_count(self, used_state, yt_wrapper=yt)
                shard_names = []
                for shard_num in range(0, shards_count):
                    shard_name = "primus-JudTier-{}-{}-{}".format(0, shard_num, production_timestamp)
                    logging.debug("Shard name: %s", shard_name)
                    shard_names.append(shard_name)

        # create output table in YT
        proxy, table_name = self._split_table_name()
        forshards_table_name = table_name + _tmp_table_suffix
        yt_client = yt.YtClient(proxy=proxy, token=yt.config["token"])
        yt_client.create("table", forshards_table_name, recursive=True)

        judtier_dups = None
        if utils.get_or_default(self.ctx, CanonizeUrls):
            judtier_dups = self._get_or_create_judtier_dups(used_state, shard_names)

        results = {}  # map shard_name -> tasks info

        count = 0
        max_shard_count = int(utils.get_or_default(self.ctx, ShardCount))
        for shard_name in shard_names:
            count += 1
            name_fields = shard_name.split('-')
            logging.debug("Name fields: %s", name_fields)
            timestamp = int(name_fields[-1])
            ratings = self.ctx[_idx_ops_params.Ratings.name]
            canonized = judtier_dups[shard_name] if judtier_dups else None
            results[shard_name] = self._create_tasks_for_shard(shard_name, timestamp, ratings, canonized)

            if max_shard_count and count >= max_shard_count:
                break

        return results

    def _get_judtier_dups(self, shard_names):
        resources = {}
        for shard_name in shard_names:
            for status in ('READY', 'NOT_READY'):
                judtier_dups_resource = apihelpers.get_last_resource_with_attribute(
                    resource_types.JUDTIER_DUPS_TABLE,
                    'shard_name', shard_name,
                    status=status)
                if judtier_dups_resource:
                    break
            else:
                return None
            logging.info('Using existing snapshot of JudTier dups table for shard {}: {}'.format(shard_name, judtier_dups_resource.id))
            resources[shard_name] = judtier_dups_resource.id
        return resources

    def _get_or_create_judtier_dups(self, used_state, shard_names):
        result = self._get_judtier_dups(shard_names)
        if result:
            return result
        task = sdk2.Task["JUD_TIER_GET_DUPS_TABLE"]
        task(
            None,
            description=self.descr+", fetch JudTier dups",
            owner=self.owner,
            priority=ctt.Priority(ctt.Priority.Class.SERVICE, ctt.Priority.Subclass.NORMAL),
            jupiter_state=used_state,
            split_shards=True,
            precreate_count=len(shard_names)
        ).enqueue()
        result = self._get_judtier_dups(shard_names)
        if not result:
            raise errors.TaskError("Failed to fetch JudTier dups table")
        return result

    def _create_tasks_for_shard(self, shard_name, timestamp, ratings, canonized):
        result = {
            'shard_name': shard_name,
        }

        shard = search_shards.ShardDownloadTask(shard_name)
        shard.timestamp = timestamp
        # example: primus-JudTier-1-1-1475521237
        shard_name_split = shard_name.split('-')
        tier = shard_name_split[1]
        eh.ensure('Tier' in tier, "ERROR: Strange tier detected, please check shard name parsing")

        result[_idx_ops_params.TierParameter.name] = tier
        logging.info("Scheduled shard %s download tier %s, splitted name %", shard_name, tier, shard_name_split)

        shard_resource = search_shards.get_shard_resource_by_name(shard_name)
        if not shard_resource:
            shard_torrent = iss_shards.get_shard_name(shard_name)
            shard_resource = search_shards.get_database_shard_resource(
                tier, shard,
                kill_timeout=self.db_shard_timeout,
                db_type='basesearch',
                check_spam=False,
                database_path=shard_torrent,
            )

        result['shard_resource_id'] = shard_resource.id
        result['shard_resource_task_id'] = shard_resource.task_id

        make_pool_task = self._create_make_pool_task(shard_resource.id, tier, shard_name, ratings, canonized)
        result['make_pool_task_id'] = make_pool_task.id
        return result

    def _create_make_pool_task(self, shard_resource_id, tier, shard_name, ratings, canonized):
        """
            Create task to make pool
            @return: task object
        """
        subtask_context = {
            'kill_timeout': 14400,
            'notify_via': '',
        }
        # clone input parameters
        for param in _make_pool_main_params + _make_pool_other_params:
            subtask_context[param.name] = utils.get_or_default(self.ctx, param)
        subtask_context[_idx_ops_params.OutputTable.name] += _tmp_table_suffix
        subtask_context[_idx_ops_params.SearchDatabaseResource.name] = shard_resource_id
        subtask_context[_idx_ops_params.TierParameter.name] = tier
        subtask_context[MakeIdxOpsPool.SHARD_NAME] = shard_name
        subtask_context[_idx_ops_params.Ratings.name] = ratings
        if canonized:
            subtask_context[_idx_ops_params.RankingUrlsMap.name] = canonized

        return self.create_subtask(
            task_type='MAKE_IDX_OPS_POOL',
            description='Make micropool for shard {}'.format(
                shard_name
            ),
            input_parameters=subtask_context,
            arch='linux',
            execution_space=50*1024,
        )

    def _check_results(self, yt):
        utils.restart_broken_subtasks()
        logging.debug("Checking results...")
        results = self.ctx.get(_PROCESS_SHARDS, {})
        factor_infos = None
        for shard_name, shard_result in results.iteritems():
            logging.debug("Shard name: %s", shard_name)
            cur_resources = []
            for resource_type in (resource_types.FACTOR_BORDERS_TXT, resource_types.FACTOR_NAMES_TXT):
                resource_id = apihelpers.get_task_resource_id(shard_result['make_pool_task_id'], resource_type)
                resource_path = self.sync_resource(resource_id)
                with open(resource_path, 'r') as f:
                    cur_resources.append(f.read())
            if factor_infos is None:
                factor_infos = cur_resources
            elif factor_infos != cur_resources:  # sanity check
                raise errors.TaskError('Mismatch in factors info')

        for factor_info, ctx_name in zip(factor_infos, [OUT_FACTOR_BORDERS, OUT_FACTOR_NAMES]):
            resource = channel.sandbox.get_resource(self.ctx[ctx_name])
            with open(resource.path, 'w') as f:
                f.write(factor_info)

        proxy, table_name = self._split_table_name()
        forshards_table_name = table_name + _tmp_table_suffix
        # first of all, sort table by key
        yt_client = yt.YtClient(proxy=proxy, token=yt.config["token"])
        yt_client.run_sort(forshards_table_name, sort_by=["key"])

        # second, filter (reduce) table
        # FIXME(mvel): MatrixNet factor index???
        bash_command = (
            "'./pool_converter -i - -o - mr-proto final | "
            "./pool_merge_filter -U --no-shard-infos  --input - -o - | "
            "./pool_converter -f -z 379 -i - -o - final mr-proto'"
        )

        features_table_name = table_name
        self.set_info("Writing output to table {}".format(features_table_name))
        merge_filter_binary = utils.sync_last_stable_resource(
            resource_type=resource_types.POOL_MERGE_FILTER_EXECUTABLE,
        )
        pool_converter_binary = utils.sync_last_stable_resource(
            resource_type=resource_types.POOL_CONVERTER_EXECUTABLE,
        )

        yt_client.run_reduce(
            "bash -o pipefail -c {}".format(bash_command),
            source_table=forshards_table_name,
            destination_table=yt.ypath.TablePath(features_table_name, sorted_by=["key", "subkey"]),
            local_files=[merge_filter_binary, pool_converter_binary],
            format=yt.format.YamrFormat(has_subkey=True, lenval=True),
            table_writer={"max_row_weight": 64*1024*1024},
            memory_limit=10*1024*1024*1024,
            reduce_by=["key"]
        )

        yt_client.remove(forshards_table_name)


__Task__ = MakeEstimatedPool
