import logging
from datetime import timedelta

from sandbox import sdk2
from sandbox.sandboxsdk import environments
from sandbox.common.errors import TaskError, TaskFailure
from sandbox.projects.common.yabs.server.util.general import try_get_from_vault
from sandbox.projects.common.yabs.server.db.yt_bases import prepare_yt_env, YT_POOL, get_jailed_yt_token
from sandbox.projects.common.yql import run_query
from sandbox.projects.yabs.qa.resource_types import MR_UTILS
from sandbox.projects.yabs.qa.utils import yt_utils
from sandbox.projects.yabs.qa.utils.general import get_yt_path_html_hyperlink


logger = logging.getLogger(__name__)

EPS = '0.001'  # for an inaccurate float comparison
NODE_TTL = timedelta(days=3).total_seconds()


class YqlTaskDiff(sdk2.Task):
    '''
    Task to run YQL diff operation between two abstract input tasks
    '''

    class Requirements(sdk2.Task.Requirements):
        environments = [
            environments.PipEnvironment('yql'),
            environments.PipEnvironment('yandex-yt', use_wheel=True),
        ]

    class Context(sdk2.Task.Context):
        has_diff = None

    class Parameters(sdk2.Task.Parameters):
        task_1 = sdk2.parameters.Task('Task 1')
        task_2 = sdk2.parameters.Task('Task 2')
        yql_query = sdk2.parameters.String(
            'YQL query',
            multiline=True,
            description="""YQL query accepting parametrization with task context or parameter fields, in curly braces and prefixed with
                        "task_(1|2)_(parameters|context)_". Query must return at least 1 result, and the last result must contain exactly
                        1 row, 1 column with value convertible to boolean - representing success of diff calculation.
                        Result value of True will result in has_diff=False and vice versa."""
        )
        yt_proxy = sdk2.parameters.String('YT proxy', default='hahn')
        yql_token_vault_name = sdk2.parameters.String('YQL token valut name', default='yql_token')

        check_only_digest = sdk2.parameters.Bool('Check only digests, if some table does not exist', default=False)
        with check_only_digest.value[True]:
            task_context_with_digest = sdk2.parameters.String('Task output parameter with digest', description='In input tasks task_1 and task_2')

        with sdk2.parameters.Group('Make table with diff') as make_table:
            make_table_with_diff = sdk2.parameters.Bool('Make table with diff', default=False)
            with make_table_with_diff.value[True]:
                mr_utils = sdk2.parameters.LastResource(
                    'Tool for diff tables',
                    resource_type=MR_UTILS,
                    attrs={"released": "stable"},
                )
                sort_by = sdk2.parameters.List('Sort by fields')
                exclude_node = sdk2.parameters.String('Exclude node', description='Do not include fields in diff. Learn more in https://wiki.yandex-team.ru/advmachine/kak-najjti-diff-dvux-tablic/')
                limit = sdk2.parameters.Integer('Limit', default=100)
                diff_mode = sdk2.parameters.String('Diff mode', choices=(('inner', 'inner'), ('full', 'full')), default='full')
                destination_path = sdk2.parameters.String('Destination path', description='Task adds current task_id')
                task_context_with_result_path = sdk2.parameters.String('Task output parameter with result path', description='In input tasks task_1 and task_2')

        with sdk2.parameters.Output:
            has_diff = sdk2.parameters.Bool('Has diff', required=True)

    def context_to_dict(self, context, prefix):
        return {'{}{}'.format(prefix, key): value for key, value in context}

    def get_table_size(self, table):
        row_count = len(table.rows)
        if row_count > 0:
            col_count = len(table.rows[0])
        else:
            col_count = 0
        return row_count, col_count

    def run_mr_utils(self, mr_utils_path, old_table, new_table, dst, tmp_cache_path):
        from yt.wrapper import YtClient

        yt_token = get_jailed_yt_token()
        yt_proxy = self.Parameters.yt_proxy

        yt_client = YtClient(proxy=yt_proxy, token=yt_token)

        new_schema = yt_client.get_attribute(new_table, 'schema')
        new_columns = set(col['name'] for col in new_schema)
        old_schema = yt_client.get_attribute(old_table, 'schema')
        old_columns = set(col['name'] for col in old_schema)
        schema_columns_intersection = new_columns & old_columns
        # TODO: Compare schema

        cmd = [
            mr_utils_path,
            'diff',
            '-s', yt_proxy,
            '--old', old_table,
            '--new', new_table,
            '--dst', dst,
            '--limit', str(self.Parameters.limit),
            '--eps', EPS,
            '--diff-mode', self.Parameters.diff_mode,

        ]
        sort_by_fields = []
        for field in self.Parameters.sort_by:
            if not schema_columns_intersection or field in schema_columns_intersection:
                sort_by_fields.append(field)

        for field in sort_by_fields:
            cmd += ['--sort-by', field]
        if self.Parameters.exclude_node:
            cmd += ['--exclude-node', self.Parameters.exclude_node]

        env = prepare_yt_env(yt_token, tmp_cache_path, YT_POOL)
        with sdk2.helpers.ProcessLog(self, logger='mr_utils') as log, sdk2.helpers.ProcessRegistry:
            logger.info('Running mr_utils: %s', cmd)
            try:
                output = sdk2.helpers.subprocess.check_output(cmd, stderr=log.stderr, env=env)
                logger.info('mr_utils result is: %s', output)
            except sdk2.helpers.subprocess.CalledProcessError as error:
                error_message = 'mr_utils exited with code {}'.format(error.returncode)
                raise TaskFailure(error_message)
            self.set_info("See diff of tables {} and {} in {}".format(
                get_yt_path_html_hyperlink(proxy=yt_proxy, path=old_table),
                get_yt_path_html_hyperlink(proxy=yt_proxy, path=new_table),
                get_yt_path_html_hyperlink(proxy=yt_proxy, path=dst)),
                do_escape=False
            )

    def get_certain_diff(self, table_with_hashes):
        from yt.wrapper import ypath_join, YtClient

        yt_client = YtClient(proxy=self.Parameters.yt_proxy, token=get_jailed_yt_token())

        table_with_hashes.fetch_full_data()
        table_index = table_with_hashes.column_names.index('table')
        if not table_index:
            raise TaskError("The second last table of YQL result has no column 'table'")
        tables = [row[table_index] for row in table_with_hashes.rows]
        mr_utils_path = str(sdk2.ResourceData(self.Parameters.mr_utils).path)
        task_1_path = getattr(self.Parameters.task_1.Parameters, self.Parameters.task_context_with_result_path)
        task_2_path = getattr(self.Parameters.task_2.Parameters, self.Parameters.task_context_with_result_path)
        destination_path = ypath_join(self.Parameters.destination_path, str(self.id))
        yt_utils.create_node(
            path=destination_path,
            yt_client=yt_client)
        yt_utils.set_yt_node_ttl(destination_path, NODE_TTL, yt_client)
        tmp_cache_path = yt_utils.create_tmp_node(
            yt_client,
            destination_path,
            ttl=timedelta(hours=3).total_seconds(),
            use_expiration_timeout=True,
        )
        for table in tables:
            self.run_mr_utils(
                mr_utils_path,
                ypath_join(task_1_path, table),
                ypath_join(task_2_path, table),
                ypath_join(destination_path, table + 'Diff'),
                tmp_cache_path,
            )

    def checking_tables_exist(self):
        import yt.wrapper as yt

        yt_proxy = self.Parameters.yt_proxy
        yt.config['proxy']['url'] = yt_proxy
        yt.config['token'] = get_jailed_yt_token()

        task_1_path = getattr(self.Parameters.task_1.Parameters, self.Parameters.task_context_with_result_path)
        task_2_path = getattr(self.Parameters.task_2.Parameters, self.Parameters.task_context_with_result_path)
        logger.info("Tables with results: %s, %s", task_1_path, task_2_path)
        is_table_1_exists = yt.exists(task_1_path)
        is_table_2_exists = yt.exists(task_2_path)
        if not (is_table_1_exists and is_table_2_exists):
            info = "The tables don't exist:\n{}".format(
                ('task_1 table ' + get_yt_path_html_hyperlink(proxy=yt_proxy, path=task_1_path) + '\n' if not is_table_1_exists else '') +
                ('task_2 table ' + get_yt_path_html_hyperlink(proxy=yt_proxy, path=task_2_path) if not is_table_2_exists else ''))
            if not self.Parameters.check_only_digest:
                raise TaskError(info)
            task_1_digest = getattr(self.Parameters.task_1.Parameters, self.Parameters.task_context_with_digest)
            task_2_digest = getattr(self.Parameters.task_2.Parameters, self.Parameters.task_context_with_digest)
            self.Parameters.has_diff = self.Context.has_diff = not (task_1_digest == task_2_digest)
            self.set_info(info, do_escape=False)
            sign = '==' if task_1_digest == task_2_digest else "!="
            self.set_info("Check digests: {} {} {}".format(task_1_digest, sign, task_2_digest))
            return False
        return True

    def on_execute(self):
        if not self.Parameters.yql_query:
            raise TaskError('Empty YQL query')
        if self.checking_tables_exist():
            format_dict = self.context_to_dict(self.Parameters.task_1.Context, 'task_1_context_')
            format_dict.update(self.context_to_dict(self.Parameters.task_2.Context, 'task_2_context_'))
            format_dict.update(self.context_to_dict(self.Parameters.task_1.Parameters, 'task_1_parameters_'))
            format_dict.update(self.context_to_dict(self.Parameters.task_2.Parameters, 'task_2_parameters_'))
            logger.info(format_dict)
            formatted_yql_query = self.Parameters.yql_query.format(**format_dict)
            self.set_info('Running query:\n{}'.format(formatted_yql_query))
            yql_token = try_get_from_vault(self, self.Parameters.yql_token_vault_name)
            request = run_query(
                query_id='YQL task diff query',
                query=formatted_yql_query,
                yql_token=yql_token,
                db=self.Parameters.yt_proxy,
                wait=False,
                syntax_version=1)
            self.set_info('<a href="{}" target="_blank">YQL diff operation share link</a>'.format(request.share_url),
                          do_escape=False
                          )
            tables = []
            table = None
            for table in request.get_results(wait=True):
                tables.append(table)
            # table = tables[-1]
            if not table:
                raise TaskError("No results fetched (either query error or it doesn't return results)")
            table.fetch_full_data()
            row_count, col_count = self.get_table_size(table)
            if (row_count != 1) or (col_count != 1):
                raise TaskError("""Wrong output format of YQL diff query.
                                Expected: rows = 1, cols = 1
                                Actual: rows = {}, cols = {}""".format(row_count, col_count))

            bool_result = bool(table.rows[0][0])
            self.set_info('YQL diff query has returned {}'.format(bool_result))
            self.Parameters.has_diff = self.Context.has_diff = not bool_result

            if self.Parameters.make_table_with_diff and self.Parameters.has_diff:
                self.get_certain_diff(tables[-2])  # Output differing hashes of tables
