import os
import time
import datetime
import logging

import sandbox.sdk2 as sdk2

from sandbox.common import errors
import sandbox.sandboxsdk.environments as environments
import sandbox.common.types.task as ctt
import sandbox.common.types.resource as ctr

from sandbox.projects.common import file_utils as fu
from sandbox.projects.common import yt_cleaner
from sandbox.projects.common import solomon

from sandbox.projects.resource_types import MR_GEMINICL
from sandbox.projects.geosearch.YtTransfer import YtTransfer
from sandbox.projects.saas.common.resources import SAAS_UTIL_STANDALONE_INDEXER
from sandbox.projects.SaasStandaloneIndexerYt import SaasStandaloneIndexerYtConfigs
from sandbox.projects.SaasStandaloneIndexerYt.Task import SaasStandaloneIndexerYtTask
from sandbox.projects.CalcZenEmbeddings.GrepBarNavigLog import CalcZenEmbeddingsGrepBarNavigLog
from sandbox.projects.CalcZenEmbeddings.GetTopUrls import CalcZenEmbeddingsGetTopUrls
from sandbox.projects.CalcZenEmbeddings.PrepareEmbeddings import CalcZenEmbeddingsPrepareEmbeddings
from sandbox.projects.CalcZenEmbeddings.ConvertToSaasFormat import CalcZenEmbeddingsConvertToSaasFormat


class CalcZenEmbeddingsConveyorResource(sdk2.Resource):
    pass


class CalcZenEmbeddingsConveyor(sdk2.Task):
    """Conveyor calculation zen embeddings"""

    class Parameters(sdk2.Task.Parameters):
        date = sdk2.parameters.String(
            'Date for calculate',
        )

        publish_in_saas = sdk2.parameters.Bool('Do you want boil and publish saas shards?')

        with sdk2.parameters.Group('YT input parameters') as yt_parameters:
            yt_vault_token = sdk2.parameters.String(
                'Your yt token name in vault',
                default='yt-token',
                required=True)

            yql_vault_token = sdk2.parameters.String(
                'Your yql token name in vault',
                default='YQL_TOKEN',
                required=True)

            with sdk2.parameters.RadioGroup('Yt host') as yt_host:
                yt_host.values['hahn'] = yt_host.Value(value='Hahn', default=True)
                yt_host.values['banach'] = yt_host.Value(value='Banach')

        with sdk2.parameters.Group('Log extractor parameters') as extractor_parameters:
            bar_navig_log_folder = sdk2.parameters.String('bar-navig-log folder', default='//statbox/bar-navig-log')

            clicks_folder = sdk2.parameters.String(
                'clicks folder', default='//home/geosearch/zhshishkin/iznanka/bar-navig-log-clicks')

            dates_count = sdk2.parameters.Integer('Time interval length in days', default=15)

            zen_hosts_table = sdk2.parameters.String(
                'Zen hosts table', default='//home/geosearch/iznanka/zen_embeddings/zen_hosts')

            with sdk2.parameters.Group('Gemini parameters') as gemini_parameters:
                gemini_resource_id = sdk2.parameters.Resource(
                    'MR Geminicl',
                    resource_type=MR_GEMINICL,
                    state=(ctr.State.READY,),
                    required=True,
                )

                mr_gemini_user = sdk2.parameters.String(
                    'MR gemini user',
                    default='mr-any',
                    required=True,
                )
                mr_gemini_job_count = sdk2.parameters.Integer(
                    'MR gemini job count',
                    default=100,
                    required=True,
                )
                mr_gemini_max_rps = sdk2.parameters.Integer(
                    'MR gemini max rps',
                    default=15000,
                    required=True,
                )
        with sdk2.parameters.Group('Top urls parameters') as top_urls_parameters:
            top_count = sdk2.parameters.Integer(
                'Top urls count',
                default=5 * 10 ** 6,
                required=True,
            )

        with sdk2.parameters.Group('Embeddings parameters') as embeddings_parameters:
            with publish_in_saas.value[False]:
                output_table = sdk2.parameters.String(
                    'Result output table',
                    required=True,
                )
            with publish_in_saas.value[True]:
                output_dir = sdk2.parameters.String(
                    'Result output directory',
                    required=True,
                )

            jupiter_folder = sdk2.parameters.String(
                'Jupiter folder',
                required=True,
                default='//home/jupiter/backup/walrus'
            )

            shards_folder = sdk2.parameters.String(
                'Stable temp folder',
                required=True,
                default='//tmp'
            )

            countries_with_modes = sdk2.parameters.String(
                'Zen countries, space separated',
                required=True,
                default='russia',
            )

            models_folder = sdk2.parameters.String(
                'Models folder',
                required=True,
                default='//home/geosearch/zhshishkin/iznanka/w2v_models',
            )

        with publish_in_saas.value[True]:
            with sdk2.parameters.Group('standalone_indexer parameters') as standalone_indexer_parameters:
                indexer = sdk2.parameters.Resource(
                    'standalone_indexer',
                    resource_type=SAAS_UTIL_STANDALONE_INDEXER,
                    state=ctr.State.READY,
                    default=None
                )
                dst_dir = sdk2.parameters.String(
                    'DST_DIR',
                    description='path to YT dir for indexing results',
                    required=True
                )
                service = sdk2.parameters.String(
                    'SERVICE',
                    description='name of the SaaS service',
                    required=True
                )
                configs = sdk2.parameters.Resource(
                    'Configs',
                    description='query-language rtyserver.conf-common searchmap.json',
                    resource_type=SaasStandaloneIndexerYtConfigs,
                    state=(ctr.State.READY),
                    required=True
                )
                publish_path = sdk2.parameters.String(
                    'PUB_PATH',
                    description='ypath or znode to publish results (default: "DST_DIR/testing/SERVICE" when PUB_MGR="yt")',
                    required=True
                )

        check_results = sdk2.parameters.Bool('Check results', default=False)
        with check_results.value[True]:
            check_percent = sdk2.parameters.Float('Differ in row_count in percent', default=0.05)

            check_max_row_count = sdk2.parameters.Integer('Max mln key count', default=8)

    class Requirements(sdk2.Task.Requirements):
        cores = 1

        environments = [
            environments.PipEnvironment('yandex-yt'),
            environments.PipEnvironment('yandex-yt-yson-bindings-skynet')
        ]

        class Caches(sdk2.Requirements.Caches):
            pass

    def on_execute(self):
        with self.memoize_stage.save_day:
            if self.Parameters.date is None or len(str(self.Parameters.date)) == 0:
                self.Context.day = datetime.date.today().strftime('%Y-%m-%d')
            else:
                self.Context.day = str(self.Parameters.date)

        with self.memoize_stage.grep_bar_navig_log:
            import yt.wrapper as yt

            yt_proxy = '{}.yt.yandex.net'.format(self.Parameters.yt_host)

            yt.config['token'] = sdk2.Vault.data(self.owner, self.Parameters.yt_vault_token)
            yt.config['proxy']['url'] = yt_proxy

            zen_hosts_table = list(
                yt.read_table(
                    yt.TablePath(self.Parameters.zen_hosts_table, columns=['host', 'path', 'countries']),
                )
            )

            zen_hosts_resource = CalcZenEmbeddingsConveyorResource(
                self,
                'zen_hosts',
                'zen_hosts.tsv',
            )

            zen_hosts_data = sdk2.ResourceData(
                zen_hosts_resource
            )

            fu.write_lines(
                str(zen_hosts_data.path),
                (row['host'] for row in zen_hosts_table)
            )

            zen_hosts_data.ready()

            host_to_path_resource = CalcZenEmbeddingsConveyorResource(
                self,
                'host_to_path',
                'host_to_path.tsv',
            )

            host_to_path_data = sdk2.ResourceData(
                host_to_path_resource
            )

            fu.write_lines(
                str(host_to_path_data.path),
                ('{}\t{}'.format(row['host'], row['path']) for row in zen_hosts_table)
            )

            host_to_path_data.ready()

            host_to_country_resource = CalcZenEmbeddingsConveyorResource(
                self,
                'host_to_country',
                'host_to_country.tsv',
            )

            host_to_country_data = sdk2.ResourceData(
                host_to_country_resource
            )

            fu.write_lines(
                str(host_to_country_data.path),
                ('{}\t{}'.format(row['host'], row['countries']) for row in zen_hosts_table)
            )

            host_to_country_data.ready()

            past_dates = filter(lambda date: date < self.Context.day, yt.list(self.Parameters.bar_navig_log_folder))
            good_dates = sorted(past_dates)[-self.Parameters.dates_count:]

            tasks = list()

            tables = list()

            for day in good_dates:
                output_table = os.path.join(self.Parameters.clicks_folder, day)
                tables.append(output_table)
                if yt.exists(output_table):
                    continue
                task = CalcZenEmbeddingsGrepBarNavigLog(
                    self,
                    description='grep bar-navig-log {}'.format(day),
                    notifications=self.Parameters.notifications,
                    create_sub_task=False,
                    hosts_resource=zen_hosts_resource.id,
                    hosts_to_path_resource=host_to_path_resource.id,
                    host_to_country_resource=host_to_country_resource.id,
                    yt_vault_token=self.Parameters.yt_vault_token,
                    yql_vault_token=self.Parameters.yql_vault_token,
                    input_table=os.path.join(self.Parameters.bar_navig_log_folder, day),
                    output_table=output_table,
                    yt_host=self.Parameters.yt_host,
                    gemini_resource_id=self.Parameters.gemini_resource_id.id,
                    mr_gemini_user=self.Parameters.mr_gemini_user,
                    mr_gemini_job_count=self.Parameters.mr_gemini_job_count,
                    mr_gemini_max_rps=self.Parameters.mr_gemini_max_rps,
                )

                tasks.append(task.id)

            self.Context.tasks = tasks
            self.Context.clicks_tables = tables

        with self.memoize_stage.run_tasks(commit_on_entrance=False, commit_on_wait=False):
            if self.Context.tasks:
                task_id = self.Context.tasks.pop(0)
                task = list(
                    self.find(
                        CalcZenEmbeddingsGrepBarNavigLog,
                        status=(ctt.Status.Group.DRAFT),
                        id=task_id
                    ).limit(1))[0]
                task.enqueue()
                raise sdk2.WaitTask(task_id, ctt.Status.Group.FINISH)

        with self.memoize_stage.get_top:
            import yt.wrapper as yt

            yt_proxy = '{}.yt.yandex.net'.format(self.Parameters.yt_host)

            yt.config['token'] = sdk2.Vault.data(self.owner, self.Parameters.yt_vault_token)
            yt.config['proxy']['url'] = yt_proxy

            self.Context.top_url_table = yt.create_temp_table()

            task = CalcZenEmbeddingsGetTopUrls(
                self,
                description='get top urls',
                notifications=self.Parameters.notifications,
                create_sub_task=False,
                yt_vault_token=self.Parameters.yt_vault_token,
                yql_vault_token=self.Parameters.yql_vault_token,
                input_tables=self.Context.clicks_tables,
                output_table=self.Context.top_url_table,
                yt_host=self.Parameters.yt_host,
                top_count=self.Parameters.top_count,
            )
            task.enqueue()
            raise sdk2.WaitTask(task.id, ctt.Status.Group.FINISH)

        arnold_host = "arnold"
        with self.memoize_stage.table_to_arnold:
            task = YtTransfer(
                self,
                description='Transfer table to arnold for task {}'.format(self.id),
                notifications=self.Parameters.notifications,
                create_sub_task=False,
                src_table=self.Context.top_url_table,
                dst_table=self.Context.top_url_table,
                src_cluster=self.Parameters.yt_host,
                dst_cluster=arnold_host,
            )
            task.enqueue()

            raise sdk2.WaitTask([task.id], ctt.Status.Group.SUCCEED, wait_all=True)

        with self.memoize_stage.get_embeddings:
            import yt.wrapper as yt

            yt_proxy = '{}.yt.yandex.net'.format(arnold_host)

            yt.config['token'] = sdk2.Vault.data(self.owner, self.Parameters.yt_vault_token)
            yt.config['proxy']['url'] = yt_proxy

            self.Context.embeddings_table = yt.create_temp_table()

            task = CalcZenEmbeddingsPrepareEmbeddings(
                self,
                description='get embeddings',
                notifications=self.Parameters.notifications,
                kill_timeout=15 * 3600,
                create_sub_task=False,
                yt_vault_token=self.Parameters.yt_vault_token,
                yt_host=arnold_host,
                input_table=self.Context.top_url_table,
                output_table=self.Context.embeddings_table,
                jupiter_folder=self.Parameters.jupiter_folder,
                shards_folder=self.Parameters.shards_folder,
                countries_with_modes=self.Parameters.countries_with_modes,
                models_folder=self.Parameters.models_folder,
                revive_from=0,
            )
            task.enqueue()
            raise sdk2.WaitTask(task.id, ctt.Status.Group.FINISH)

        with self.memoize_stage.table_from_arnold:
            task = YtTransfer(
                self,
                description='Transfer table from arnold for task {}'.format(self.id),
                notifications=self.Parameters.notifications,
                create_sub_task=False,
                src_table=self.Context.embeddings_table,
                dst_table=self.Context.embeddings_table,
                src_cluster=arnold_host,
                dst_cluster=self.Parameters.yt_host,
            )
            task.enqueue()

            raise sdk2.WaitTask([task.id], ctt.Status.Group.SUCCEED, wait_all=True)

        with self.memoize_stage.saas_format:

            if self.Parameters.publish_in_saas:
                self.Context.timestamp = int(time.time())
                output_table = os.path.join(self.Parameters.output_dir, str(self.Context.timestamp))
                self.Context.output_table = output_table
            else:
                output_table = self.Parameters.output_table

            task = CalcZenEmbeddingsConvertToSaasFormat(
                self,
                description='convert to saas format',
                notifications=self.Parameters.notifications,
                create_sub_task=False,
                yt_vault_token=self.Parameters.yt_vault_token,
                yt_host=self.Parameters.yt_host,
                input_table=self.Context.embeddings_table,
                output_table=output_table,
            )
            task.enqueue()
            raise sdk2.WaitTask(task.id, ctt.Status.Group.FINISH)

        if not self.Parameters.publish_in_saas:
            return

        with self.memoize_stage.check_result:
            import yt.wrapper as yt

            yt.config['token'] = sdk2.Vault.data(self.owner, self.Parameters.yt_vault_token)
            yt.config['proxy']['url'] = '{}.yt.yandex.net'.format(self.Parameters.yt_host)

            if self.Parameters.check_percent is not None:

                current_time = int(os.path.basename(self.Context.output_table))

                previous_times = sorted(
                    filter(lambda x: int(x) < current_time, yt.list(self.Parameters.output_dir)),
                    key=int
                )
                if previous_times:
                    previous_row_count = yt.row_count(os.path.join(self.Parameters.output_dir, previous_times[-1]))
                    current_row_count = yt.row_count(self.Context.output_table)
                    if (float(abs(previous_row_count - current_row_count)) / previous_row_count
                            > self.Parameters.check_percent):
                        raise errors.TaskError('Big change in results, please confirm it')
                else:
                    logging.info('No history')

            if self.Parameters.check_max_row_count is not None:
                if yt.row_count(self.Context.output_table) > self.Parameters.check_max_row_count * 10 ** 6:
                    raise errors.TaskError('To much different docids')

        with self.memoize_stage.monitoring_phase:
            commonLabels = {
                'project': 'iznanka_embeddings',
                'cluster': 'sandbox_metrics',
                'service': 'yt_table_sizes',
            }

            sensors = list()

            import yt.wrapper as yt

            yt.config['token'] = sdk2.Vault.data(self.owner, self.Parameters.yt_vault_token)
            yt.config['proxy']['url'] = '{}.yt.yandex.net'.format(self.Parameters.yt_host)

            sensors.append(
                {
                    'labels': {'sensor': 'row_count', },
                    'ts': int(time.time()),
                    'value': int(yt.row_count(self.Context.output_table)),
                }
            )

            sensors.append(
                {
                    'labels': {'sensor': 'data_size', },
                    'ts': int(time.time()),
                    'value': int(yt.get_attribute(self.Context.output_table, 'uncompressed_data_size')),
                }
            )

            solomon.upload_to_solomon(commonLabels, sensors)

        with self.memoize_stage.clean_phase:
            import yt.wrapper as yt

            yt.config['token'] = sdk2.Vault.data(self.owner, self.Parameters.yt_vault_token)
            yt.config['proxy']['url'] = '{}.yt.yandex.net'.format(self.Parameters.yt_host)
            yt_cleaner.clean_history_folder(
                yt,
                self.Parameters.clicks_folder,
                self.Parameters.dates_count * 2,
            )

            yt_cleaner.clean_history_folder(
                yt,
                self.Parameters.output_dir,
            )
            yt_cleaner.clean_history_folder(
                yt,
                os.path.join(self.Parameters.dst_dir, 'data', self.Parameters.service),
            )

        with self.memoize_stage.standalone_indexer_run:
            self.Context.standalone_indexer_id = (
                self.Parameters.indexer.id
                if self.Parameters.indexer is not None else
                SAAS_UTIL_STANDALONE_INDEXER.find(
                    state=ctr.State.READY,
                    attrs=dict(branch='trunk'),
                ).order(-sdk2.Resource.id).first().id
            )

            task = SaasStandaloneIndexerYtTask(
                self,
                description='boil shards for saas_kv',
                notifications=self.Parameters.notifications,
                create_sub_task=False,
                yt_vault_token=self.Parameters.yt_vault_token,
                standalone_indexer=self.Context.standalone_indexer_id,
                proxy=self.Parameters.yt_host,
                src=self.Context.output_table,
                dst_dir=self.Parameters.dst_dir,
                service=self.Parameters.service,
                configs=self.Parameters.configs.id,
                timestamp=self.Context.timestamp,
                verbose=True,
            )

            task.enqueue()
            raise sdk2.WaitTask(task.id, ctt.Status.Group.FINISH)

        with self.memoize_stage.standalone_indexer_publish:
            task = SaasStandaloneIndexerYtTask(
                self,
                description='publish shards in saas_kv',
                notifications=self.Parameters.notifications,
                create_sub_task=False,
                yt_vault_token=self.Parameters.yt_vault_token,
                standalone_indexer=self.Context.standalone_indexer_id,
                proxy=self.Parameters.yt_host,
                src=self.Context.output_table,
                dst_dir=self.Parameters.dst_dir,
                service=self.Parameters.service,
                configs=self.Parameters.configs.id,
                timestamp=self.Context.timestamp,
                verbose=True,
                publish=True,
                resume=True,
                publish_path=self.Parameters.publish_path,
            )

            task.enqueue()
            raise sdk2.WaitTask(task.id, ctt.Status.Group.FINISH)
