import os
import logging

import jinja2

from sandbox import sdk2
from sandbox.sandboxsdk import environments

from sandbox.projects.common.yabs.server.util.general import try_get_from_vault
from sandbox.projects.common.yql import run_query
from sandbox.projects.yabs.qa.utils.general import get_task_html_hyperlink, html_hyperlink
from sandbox.projects.common.yabs.server.util.general import check_tasks
from sandbox.projects.yabs.qa.tasks.YabsServerCalculateCSImportDigest import YabsServerCalculateCSImportDigest
from sandbox.projects.yabs.qa.tasks.YabsServerRunEditPageParse import YabsServerRunEditPageParse
from sandbox.projects.yabs.qa.tasks.YqlTaskDiff import YqlTaskDiff
from sandbox.projects.yabs.qa.resource_types import BaseBackupSdk2Resource


logger = logging.getLogger(__name__)

TEMPLATE_DIR = 'templates'
DIFF_REPORT_FILENAME = 'diff_report.html'
DEFAULT_TARGET_TABLES = ['Page', 'PageImp', 'PageDSP', 'BlockSettings']
TABLE_DIFF_YQL_TEMPLATE_NAME = 'table_diff.yql'
YT_DIGEST_DIFF_YQL_TEMPLATE_NAME = 'yt_digest_diff.yql'


class YabsServerEditpageParseDiffReport(BaseBackupSdk2Resource):
    pass


def get_jinja2_env():
    template_path = os.path.normpath(os.path.join(os.path.dirname(__file__), TEMPLATE_DIR))
    return jinja2.Environment(loader=jinja2.FileSystemLoader(template_path))


def get_table_diff_yql(table_name, keys, fields):
    jinja2_env = get_jinja2_env()
    template = jinja2_env.get_template(TABLE_DIFF_YQL_TEMPLATE_NAME)
    return template.render(
        table_name=table_name,
        keys=keys,
        fields=fields
    )


def get_yt_digest_diff_yql(baseline_yt_digest_path, patched_yt_digest_path):
    jinja2_env = get_jinja2_env()
    template = jinja2_env.get_template(YT_DIGEST_DIFF_YQL_TEMPLATE_NAME)
    return template.render(
        baseline_yt_digest_path=baseline_yt_digest_path,
        patched_yt_digest_path=patched_yt_digest_path,
    )


class YabsServerRunEditPageParseCmp(sdk2.Task):
    name = 'YABS_SERVER_RUN_EDITPAGE_PARSE_CMP'
    description = 'Run EditPage parse CMP'
    _yt = None

    class Parameters(sdk2.Parameters):
        yql_token_vault_name = sdk2.parameters.String(
            'YQL token valut name', required=True, default='yabs-cs-sb-yql-token')
        yt_token_vault_name = sdk2.parameters.String(
            'YT token valut name', required=True, default='yabs-cs-sb-yt-token')
        yt_proxy = sdk2.parameters.String(
            'YT cluster', required=True, default='hahn')
        baseline_task = sdk2.parameters.Task(
            'Baseline task', required=True, task_type=YabsServerRunEditPageParse.type)
        patched_task = sdk2.parameters.Task(
            'Patched task', required=True, task_type=YabsServerRunEditPageParse.type)
        target_tables = sdk2.parameters.List(
            'List of tables to check', required=True, default=DEFAULT_TARGET_TABLES)

        with sdk2.parameters.Output:
            has_diff = sdk2.parameters.Bool('Has diff', required=True)
            diff_url = sdk2.parameters.String('Diff report link')

    class Requirements(sdk2.Task.Requirements):
        cores = 1
        ram = 1024
        environments = (
            environments.PipEnvironment('yql'),
            environments.PipEnvironment('yandex-yt', use_wheel=True),
        )

        class Caches(sdk2.Requirements.Caches):
            pass

    class Context(sdk2.Task.Context):
        has_diff = False
        diff_url = None
        yt_digest_yql_diff_link = None
        table_diff_tasks_id = {}
        tables_with_diff = []

    @property
    def has_diff(self):
        return self.Context.has_diff

    @has_diff.setter
    def has_diff(self, value):
        self.Context.has_diff = value
        self.Parameters.has_diff = value

    @property
    def diff_url(self):
        return self.Context.diff_url

    @diff_url.setter
    def diff_url(self, value):
        self.Context.diff_url = value
        self.Parameters.diff_url = value

    def _make_yt(self):
        from yt.wrapper import YtClient

        yt_token = sdk2.Vault.data(self.Parameters.yt_token_vault_name)
        return YtClient(proxy=self.Parameters.yt_proxy, token=yt_token)

    @property
    def yt(self):
        if self._yt is None:
            self._yt = self._make_yt()
        return self._yt

    @property
    def baseline_workdir(self):
        return self.Parameters.baseline_task.Parameters.workdir

    @property
    def patched_workdir(self):
        return self.Parameters.patched_task.Parameters.workdir

    def get_target_tables_path(self):
        from yt.wrapper import ypath_join

        workdirs = [
            self.baseline_workdir,
            self.patched_workdir,
        ]
        for workdir in workdirs:
            for table_name in self.Parameters.target_tables:
                table_path = ypath_join(workdir, table_name)
                yield table_path

    def freeze_tables(self):
        paths = []
        for table_path in self.get_target_tables_path():
            tablet_state = self.yt.get(table_path + '/@tablet_state')
            if tablet_state != 'frozen':
                logger.info('Freeze [%s]', table_path)
                self.yt.freeze_table(table_path, sync=True)
                paths.append(table_path)
        return paths

    def unfreeze_tables(self):
        paths = []
        for table_path in self.get_target_tables_path():
            tablet_state = self.yt.get(table_path + '/@tablet_state')
            if tablet_state == 'frozen':
                logger.info('Unfreeze [%s]', table_path)
                self.yt.unfreeze_table(table_path, sync=True)
                paths.append(table_path)
        return paths

    def get_keys_and_fields(self, workdir, table_name):
        from yt.wrapper import ypath_join

        table_path = ypath_join(workdir, table_name)
        schema = self.yt.get(table_path + '/@schema')
        keys = []
        fields = []
        for field in schema:
            name = field['name']
            if name == 'YTHash':
                continue
            if 'sort_order' in field:
                keys.append(name)
            else:
                fields.append(name)
        return keys, fields

    def launch_yt_digest(self, path):
        task = YabsServerCalculateCSImportDigest(
            self,
            description=self.description + ': YtDigest ' + path,
            yt_pool='',
            import_destination_path=path,
            skip_list=[],
            ignore_attributes=True,
        )
        task.enqueue()
        return task

    def launch_yql_diff(self, table_name, task_1, task_2, yql_query):
        task = YqlTaskDiff(
            self,
            description=self.description + ': YQL diff ' + table_name,
            task_1=task_1,
            task_2=task_2,
            yql_query=yql_query,
            yt_proxy=self.Parameters.yt_proxy,
            yql_token_vault_name=self.Parameters.yql_token_vault_name,
        )
        task.enqueue()
        return task

    def run_yt_digest_diff_yql(self):
        yql_token = try_get_from_vault(self, self.Parameters.yql_token_vault_name)
        baseline_task = self.find(id=self.Context.yt_digest_baseline_task_id).first()
        patched_task = self.find(id=self.Context.yt_digest_patched_task_id).first()
        yql_query = get_yt_digest_diff_yql(
            baseline_yt_digest_path=baseline_task.Context.intermediate_hashes_full_path,
            patched_yt_digest_path=patched_task.Context.intermediate_hashes_full_path
        )
        request = run_query(
            query_id='YQL task diff yt digest query',
            query=yql_query,
            yql_token=yql_token,
            db=self.Parameters.yt_proxy,
            wait=False,
            syntax_version=1
        )
        self.set_info(
            html_hyperlink(link=request.share_url, text='YT Digest YQL diff link'),
            do_escape=False
        )
        self.Context.yt_digest_yql_diff_link = request.share_url

        tables_with_diff = []
        for table in request.get_results(wait=True):
            if table.column_names[0] == 'table':
                table.fetch_full_data()
                for row in table.rows:
                    table_name = row[0].lstrip('/')
                    if table_name in self.Parameters.target_tables:
                        tables_with_diff.append(table_name)
                break
        return tables_with_diff

    def check_table(self, table_name):
        keys1, fields1 = self.get_keys_and_fields(self.baseline_workdir, table_name)
        keys2, fields2 = self.get_keys_and_fields(self.patched_workdir, table_name)
        if keys1 != keys2:
            self.set_info('Keys in table [{}] are not equal'.format(table_name))
        if fields1 != fields2:
            self.set_info('Fields in table [{}] are not equal'.format(table_name))

        keys = list(set(keys1).intersection(keys2))
        fields = list(set(fields1).intersection(fields2))
        fields.sort()
        return self.launch_yql_diff(
            table_name,
            self.Parameters.baseline_task,
            self.Parameters.patched_task,
            yql_query=get_table_diff_yql(table_name, keys, fields)
        )

    def on_execute(self):
        with self.memoize_stage.freeze_tables(commit_on_entrance=False):
            self.freeze_tables()

        with self.memoize_stage.yt_digest_baseline(commit_on_entrance=False):
            task = self.launch_yt_digest(self.baseline_workdir)
            self.Context.yt_digest_baseline_task_id = task.id

        with self.memoize_stage.yt_digest_patched(commit_on_entrance=False):
            task = self.launch_yt_digest(self.patched_workdir)
            self.Context.yt_digest_patched_task_id = task.id
        check_tasks(self, [
            self.Context.yt_digest_baseline_task_id,
            self.Context.yt_digest_patched_task_id
        ])

        with self.memoize_stage.yt_digest_diff(commit_on_entrance=False):
            self.Context.tables_with_diff = self.run_yt_digest_diff_yql()
            self.Context.save()

        with self.memoize_stage.check_tables(commit_on_entrance=False):
            for table_name in self.Context.tables_with_diff:
                task = self.check_table(table_name)
                self.Context.table_diff_tasks_id[table_name] = task.id

        check_tasks(self, self.Context.table_diff_tasks_id.values())

        with self.memoize_stage.unfreeze_tables(commit_on_entrance=False):
            self.unfreeze_tables()

        report_html = self._get_report_html(self.Context.table_diff_tasks_id.values())
        self.set_info(report_html, do_escape=False)
        report_resource = self._make_report(report_html)

        self.diff_url = report_resource.http_proxy
        self.has_diff = bool(self.Context.table_diff_tasks_id)

    def _get_report_html(self, failed_tasks_id):
        messages = [
            html_hyperlink(
                link=self.Context.yt_digest_yql_diff_link,
                text='YT Digest YQL diff link'
            )
        ]
        if failed_tasks_id:
            tasks_link = ', '.join(map(get_task_html_hyperlink, failed_tasks_id))
            messages.append('Failed diff tasks: [{}]'.format(tasks_link))
        else:
            messages.append('No diff in target tables')
        return '<hr/>'.join(messages)

    def _make_report(self, report_html):
        resource = YabsServerEditpageParseDiffReport(
            self, 'Report resource', DIFF_REPORT_FILENAME)
        resource_data = sdk2.ResourceData(resource)
        with open(str(resource_data.path), 'w') as report_file:
            report_file.write(report_html)
        resource_data.ready()
        return resource
