# -*- coding: utf-8 -*-
"""
Создаём многотомный бинарный индекс хроносрезов из не-legacy-таблиц.
"""
from __future__ import print_function
import os
import json
import tempfile
import logging
import subprocess
from pipes import quote
import pathlib2

from sandbox import sdk2
import sandbox.sdk2.helpers
import sandbox.common.types.client as ctc
import sandbox.common.types.resource as ctr
from sandbox.common.errors import TaskError, TaskFailure
from sandbox.common.types.task import Semaphores, Status
from sandbox.projects.advq.AdvqGenChronoChunk import AdvqGenChronoChunk
from sandbox.projects.advq.artifacts import MRKIT_LEGACY_READER, CHRONO_DB_GENERATOR
from sandbox.projects.advq.common import get_chrono_resource_class, SHELL_COMMAND_PREFIX
from sandbox.projects.advq.common.parameters import PhitsParameters, convert_ttl
from sandbox.projects.advq.common.yt_utils import get_yt_env_from_parameters, setup_yt_from_parameters
from sandbox.sdk2.resource import ResourceData
from sandbox.sdk2.task import WaitTask
from sandbox.sandboxsdk.environments import PipEnvironment

CHRONO_CHUNK_TIMEOUT = 8 * 60 * 60

SEMAPHORE_GENERATION_NAME_TEMPLATE = 'advq_chrono_db_generation_single_{type}_{chrono_type}'
SEMAPHORE_CAPACITY = 24

CHRONO_TYPE_WEEK = 'week'
CHRONO_TYPE_MONTH = 'month'


class AdvqGenChronoIndex(sdk2.Task):
    class Requirements(sdk2.Task.Requirements):
        client_tags = ctc.Tag.LINUX_TRUSTY & ctc.Tag.IPV6
        # это максимальные требования по памяти и диску к normal_rus, 1 чанк.
        # Для других вариантов можно в on_enqueue указать поменьше.
        disk_space = 128 * 1024
        ram = 8 * 1024
        cores = 32
        environments = (
            PipEnvironment("yandex-yt"),
            PipEnvironment("yandex-yt-yson-bindings-skynet")
        )

        class Caches(sdk2.Requirements.Caches):
            pass

    class Parameters(PhitsParameters):
        with sdk2.parameters.RadioGroup("Chrono period", required=True) as advq_chrono_type:
            advq_chrono_type.values[CHRONO_TYPE_MONTH] = advq_chrono_type.Value(value=CHRONO_TYPE_MONTH)
            advq_chrono_type.values[CHRONO_TYPE_WEEK] = advq_chrono_type.Value(value=CHRONO_TYPE_WEEK)
        advq_db = sdk2.parameters.String("db (rus, tur-robots, etc)", required=True)
        date = sdk2.parameters.String("Database week number or month number (YYYYNN)", required=True)
        epoch = sdk2.parameters.Integer("Epoch", required=True, default=0)
        epodate = sdk2.parameters.String(
            "Epodate string (bare date or {date}-{epoch} or {date}-delta{epoch})",
            required=True
        )
        is_delta = sdk2.parameters.Bool("Table is delta", required=True)
        child_parallel_kill_timeout = sdk2.parameters.Integer("Child parallel kill timeout for db chunk, in seconds")
        input_table = sdk2.parameters.String("Input table", required=True)
        chunk_size = sdk2.parameters.Float("Floating-point chunk size, GB")
        disk_space_limit = sdk2.parameters.Float("Generation disk limit, GB", required=True)
        measure_period = sdk2.parameters.Integer("Number of lines to measure data size", default=1000000)
        generate_parallel = sdk2.parameters.Bool("Use parallel generation", required=True, default=False)
        chunk_num_week = sdk2.parameters.Integer("Number of chunks to generate for week", required=True)
        chunk_num_month = sdk2.parameters.Integer("Number of chunks to generate for month", required=True)
        releaseTo = sdk2.parameters.String("Release attribute value", required=False)
        ttl = sdk2.parameters.Integer("TTL for released chunks (days, always; 0 for inf)", default=720, required=True)
        retries_count = sdk2.parameters.Integer("Number of retries for chunk generation subtasks", default=3, required=False)

    def on_enqueue(self):
        # Устанавливаем семафор, имя которого зависит от phits_type и chrono_type.
        self.Requirements.semaphores = Semaphores(
            acquires=[
                Semaphores.Acquire(
                    name=SEMAPHORE_GENERATION_NAME_TEMPLATE.format(
                        type=self.Parameters.advq_phits_type, chrono_type=self.Parameters.advq_chrono_type),
                    capacity=SEMAPHORE_CAPACITY)
            ],
        )
        self.Requirements.disk_space = 1024 * self.Parameters.disk_space_limit

        self.Requirements.ram = 1024 * 300

        super(AdvqGenChronoIndex, self).on_enqueue()

    def on_execute(self):
        import yt.wrapper as yt
        import yt.logger as yt_logger
        yt_logger.LOGGER.setLevel(logging.DEBUG)

        if not all([self.Parameters.advq_db,
                    self.Parameters.advq_chrono_type,
                    self.Parameters.date,
                    self.Parameters.epodate,
                    self.Parameters.input_table]):
            # При создании таска другим таском не проверяется, что все обязательные параметры
            # заданы. Приходится это делать самим.
            raise TaskFailure("Some required parameters are empty")
        if bool(self.Parameters.is_delta) != bool('delta' in self.Parameters.epodate):
            raise TaskFailure("is_delta flag is set, but epodate doensn't contain 'delta'")

        env = get_yt_env_from_parameters(self.Parameters)
        setup_yt_from_parameters(self.Parameters)

        if self.Parameters.generate_parallel:
            self._generate_parallel(yt, env)
        else:
            self._generate_non_parallel(yt, env)

    def _calc_total_chunks(self):
        if self.Parameters.is_delta or self.Parameters.advq_db == 'tur':
            return 1
        elif self.Parameters.advq_chrono_type == 'week':
            return self.Parameters.chunk_num_week
        elif self.Parameters.advq_chrono_type == 'month':
            return self.Parameters.chunk_num_month
        else:
            raise TaskFailure("Unknown task type")

    def _get_attrs(self):
        return {
            'advq_input_table': self.Parameters.input_table,
            'advq_is_delta': self.Parameters.is_delta,
            'advq_db': self.Parameters.advq_db,
            'advq_chrono_type': self.Parameters.advq_chrono_type,
            'advq_phits_type': self.Parameters.advq_phits_type,
            'advq_epodate': self.Parameters.epodate,
            'advq_total_chunks': self._calc_total_chunks(),
            'advq_epoch': self.Parameters.epoch,
            'advq_date': self.Parameters.date,
        }

    def _find_chunks(self):
        resources = sdk2.Resource.find(
            resource_type=get_chrono_resource_class(self.Parameters.advq_phits_type),
            state=ctr.State.READY,
            attrs=self._get_attrs(),
        ).limit(100)
        chunks_found = []
        for resource in resources:
            res = sdk2.Resource[resource.id]
            chunks_found.append(res)
        return chunks_found

    def _generate_parallel(self, yt, env):
        if not self.Context.retries:
            self.Context.retries = 1
        else:
            self.Context.retries += 1
        if self.Context.retries > self.Parameters.retries_count:
            raise TaskFailure("Generation failed")

        chunks_total = self._calc_total_chunks()
        row_count = yt.row_count(self.Parameters.input_table)
        batch_len = (row_count + chunks_total - 1) // chunks_total
        task_ids = []
        chunks_found = [resource.advq_chunk for resource in self._find_chunks()]
        logging.info("Retry number %d", self.Context.retries)
        for i in range(1, chunks_total + 1):
            if i in chunks_found:
                continue
            logging.info("Chunk %d not found", i)
            descr = ("Generate chrono index for {}_{}_{} from {!r}, chunk {}".format(
                self.Parameters.advq_chrono_type,
                self.Parameters.advq_db,
                self.Parameters.epodate,
                self.Parameters.input_table,
                i,
            ))
            task = AdvqGenChronoChunk(
                self,
                description=descr,
                kill_timeout=self.Parameters.child_parallel_kill_timeout,
                yt_proxy=self.Parameters.yt_proxy,
                yt_token_vault_user=self.Parameters.yt_token_vault_user,
                yt_token_vault_name=self.Parameters.yt_token_vault_name,
                advq_phits_type=self.Parameters.advq_phits_type,
                advq_build_binaries=self.Parameters.advq_build_binaries,
                advq_chrono_type=self.Parameters.advq_chrono_type,
                advq_db=self.Parameters.advq_db,
                date=self.Parameters.date,
                epoch=self.Parameters.epoch,
                epodate=self.Parameters.epodate,
                is_delta=self.Parameters.is_delta,
                input_table=self.Parameters.input_table,
                chunk_number=i,
                chunks_total=chunks_total,
                start_index=(i - 1) * batch_len,
                end_index=min(i * batch_len, row_count),
                ttl=self.Parameters.ttl,
            )
            logging.info("Running %s in %s", descr, task.id)
            task_ids.append(task.id)
            task.enqueue()
        self.Context.tasks_to_wait = task_ids
        if self.Context.tasks_to_wait:
            raise WaitTask(self.Context.tasks_to_wait, statuses=(Status.Group.FINISH + Status.Group.BREAK),
                           wait_all=True)
        # this code is reached whenever a WaitTask() is not risen, thus all chunks are generated
        for resource in self._find_chunks():
            logging.info("Releasing %s", resource.id)
            res = sdk2.Resource[resource.id]
            if self.Parameters.releaseTo:
                res.released = self.Parameters.releaseTo

    def _generate_non_parallel(self, yt, env):
        res_class = get_chrono_resource_class(self.Parameters.advq_phits_type)

        chrono_db_prefix = '{chrono_type}lyhits_{type}_{db}_{epodate}'.format(
            type=self.Parameters.advq_phits_type,
            chrono_type=self.Parameters.advq_chrono_type,
            db=self.Parameters.advq_db,
            epodate=self.Parameters.epodate,
        )

        binaries = ResourceData(self.Parameters.advq_build_binaries)

        output_dir = pathlib2.Path('output')
        output_dir.mkdir()

        output_file_base = output_dir.joinpath(chrono_db_prefix)

        mrkit_legacy_reader = str(binaries.path.joinpath(MRKIT_LEGACY_READER))
        chrono_db_generator = str(binaries.path.joinpath(CHRONO_DB_GENERATOR))

        REGION_HITS_FIELDS = (
            'RegionHits',
            'PhoneRegionHits',
            'TabletRegionHits',
        )
        CHANNELS = 'all,phone,tablet'
        FIELDS = ('OrigSanitized',) + REGION_HITS_FIELDS

        if not yt.exists(self.Parameters.input_table):
            raise TaskError("Input table {!r} doesn't exist".format(self.Parameters.input_table))

        with sandbox.sdk2.helpers.ProcessLog(self, logger=logging.getLogger("advq-chrono-db-generate")) as pl:
            # Сохраняем totals в файл.
            totals_attr = self.Parameters.input_table + '/@advq_totals'
            with tempfile.NamedTemporaryFile(prefix=chrono_db_prefix + '.total.') as totals_file:
                if yt.exists(totals_attr):
                    totals = yt.get(totals_attr)
                    totals_fields = [totals[field] for field in REGION_HITS_FIELDS]
                    print('\t'.join(totals_fields), file=totals_file.file)
                else:
                    raise TaskError(
                        "Failed to find totals for chrono table {!r}".format(self.Parameters.input_table))
                totals_file.file.flush()

                extra_args = []
                if self.Parameters.chunk_size:
                    extra_args.extend(['--mem-limit', str(self.Parameters.chunk_size)])
                extra_args.extend(['--measure-period', str(self.Parameters.measure_period)])
                # Запускаем генерацию базы
                try:
                    chunks_json = subprocess.check_output(SHELL_COMMAND_PREFIX + [
                        ("{mrkit_legacy_reader} {tbl} {fields} | "
                         "{chrono_db_generator} {advq_db} {period} {date} --total-file {inp_totals_file}"
                         " --sorted --channels {channels} --output {out_file_path} {extra_args}"
                         ).format(
                            mrkit_legacy_reader=quote(mrkit_legacy_reader),
                            tbl=quote(self.Parameters.input_table),
                            fields=' '.join(quote(fld) for fld in FIELDS),
                            chrono_db_generator=quote(chrono_db_generator),
                            advq_db=quote(self.Parameters.advq_db),
                            period=quote(self.Parameters.advq_chrono_type),
                            date=quote(self.Parameters.epodate),
                            inp_totals_file=quote(totals_file.name),
                            channels=quote(CHANNELS),
                            out_file_path=quote(str(output_file_base)),
                            extra_args=' '.join(quote(arg) for arg in extra_args),
                        )],
                        stderr=pl.stdout,
                        env=env,
                    )
                except subprocess.CalledProcessError as ex:
                    self.set_info(ex.message)
                    raise
                logging.info("Got chunks: %r", chunks_json)
                chunks = json.loads(chunks_json)
                total_chunks = len(chunks)

        for chunk_id, chunk_path in enumerate(chunks, 1):
            chrono_chunk_filename = '{}.{}.{}.db'.format(chrono_db_prefix, chunk_id, total_chunks)

            res = res_class(
                task=self,
                path=chrono_chunk_filename,
                advq_phits_type=self.Parameters.advq_phits_type,
                advq_chrono_type=self.Parameters.advq_chrono_type,
                advq_db=self.Parameters.advq_db,
                advq_date=self.Parameters.date,
                advq_epoch=self.Parameters.epoch,
                advq_epodate=self.Parameters.epodate,
                advq_is_delta=self.Parameters.is_delta,
                description=("{}: {!r} from {!r}".format(
                    self.Parameters.description,
                    chrono_chunk_filename,
                    self.Parameters.input_table)),
                advq_chunk=chunk_id,
                advq_total_chunks=total_chunks,
                advq_input_table=self.Parameters.input_table,
                ttl=convert_ttl(self.Parameters.ttl),
            )
            if self.Parameters.releaseTo:
                res.released = self.Parameters.releaseTo

            os.rename(chunk_path, str(res.path))

            ResourceData(res).ready()
