# -*- coding: utf-8 -*-
from datetime import timedelta, datetime
import logging

from sandbox import sdk2
from sandbox.sdk2 import yav
import sandbox.common.types.task as ctt
from sandbox.projects.common import task_env

YQL_TOKEN_OWNER = 'STATKEY'
YQL_TOKEN_NAME = 'YQL_TOKEN'

BY_DATE_TEMPLATE = "{date}/{table_name}"
BY_TABLE_TEMPLATE = "{table_name}/{date}"


class MakeHistoryForYtTable(sdk2.Task):
    class Requirements(task_env.TinyRequirements):
        disk_space = 10000
        ram = 4 * 1024

    class Context(sdk2.Context):
        tables_for_history = []

    class Parameters(sdk2.Parameters):
        with sdk2.parameters.RadioGroup("History path date template") as history_path_date_template:
            history_path_date_template.values["by_date"] = history_path_date_template.Value(BY_DATE_TEMPLATE, default=True)
            history_path_date_template.values["by_table"] = history_path_date_template.Value(BY_TABLE_TEMPLATE)
        history_path_same_as_table_path = sdk2.parameters.Bool('History path same as table path?', default=True)
        history_path = sdk2.parameters.String('History path', required=False)
        get_tables_by_yt_list = sdk2.parameters.Bool('Get tables by yt.list', default=False)
        black_tables_list = sdk2.parameters.String('Black list of tables, separate by comma', required=False, default="")
        tables_dir = sdk2.parameters.String('Base path for dirs with tables', required=True)
        tables = sdk2.parameters.String('Tables separate by comma', required=False)
        report_date = sdk2.parameters.String('Report date, example 2021-06-11', required=True)
        yt_cluster = sdk2.parameters.String('Yt cluster', default='hahn')
        max_parallel_operations = sdk2.parameters.String('Max parallel operations in YT', required=True, default="10")
        days_to_store = sdk2.parameters.Integer('Days to store dates foldes', required=True, default=3)
        clean_old_tables = sdk2.parameters.Bool("Clean tables older than 7 days, until TM-1988", default=False)
        YQLSecretID = sdk2.parameters.String(
            'YQL token secret ID',
            required=True
        )
        YQLTokenKey = sdk2.parameters.String(
            'YQL token key',
            required=True
        )

    def get_yt_token(self):
        secret = yav.Secret(self.Parameters.YQLSecretID)
        return secret.data()[self.Parameters.YQLTokenKey]

    def get_yt_client(self):
        import yt.wrapper as yt
        token = self.get_yt_token()
        yt.config.config['token'] = token
        yt.config.set_proxy(self.Parameters.yt_cluster)

        return yt

    def prepare_yql(self, tables):
        path = self.Parameters.tables_dir
        report_date = self.Parameters.report_date
        history_prefix = path if self.Parameters.history_path_same_as_table_path else self.Parameters.history_path
        history_table_prefix = "{path}/history".format(path=history_prefix)

        query = '''
            PRAGMA yt.PublishedAutoMerge = 'economy';
            PRAGMA yt.TemporaryAutoMerge = 'disabled';
            PRAGMA yt.MinPublishedAvgChunkSize = '4G';
            PRAGMA yt.Pool = 'statbox-cooked-logs-batch';
            PRAGMA yt.NightlyCompress;
            PRAGMA yt.ParallelOperationsLimit = '{max_parallel_operations}';
            PRAGMA SimpleColumns;

            $table_names = {tables};

            DEFINE ACTION $has_history($old, $new, $history_table, $fresh_table, $redate) as
                INSERT INTO $history_table with truncate
                select * from (select * from $old($history_table, $redate) union all select * from $new($fresh_table, $redate))
                order by report_date;
            END DEFINE;

            DEFINE ACTION $no_history($new, $history_table, $fresh_table, $report_date) as
                INSERT INTO $history_table with truncate
                select * from (select * from $new($fresh_table, $report_date))
                order by report_date;
            END DEFINE;

            DEFINE SUBQUERY $get_old_table($old, $report_date) AS
                select * from $old where report_date < $report_date
            END DEFINE;

            DEFINE SUBQUERY $get_fresh_table($fresh_table, $report_date) AS
                select $report_date as report_date, s.* from $fresh_table as s
            END DEFINE;


            $report_date = "{report_date}";
            $path_prefix = '{path}' || '/';
            $history_table_prefix = '{history_table_prefix}';


            EVALUATE FOR $table IN $table_names DO BEGIN
                $history_table = $history_table_prefix || '/' || $table;
                $fresh_table = $path_prefix || {yql_template};

                $exist = SELECT ListHas( AGGREGATE_LIST( TableName(Path, "yt")), $table) FROM hahn.FOLDER($history_table_prefix);

                EVALUATE IF $exist == true
                    DO $has_history($get_old_table, $get_fresh_table, $history_table, $fresh_table, $report_date)
                ELSE
                    DO $no_history($get_fresh_table, $history_table, $fresh_table, $report_date)
            END DO;
        '''.format(
            yql_template=self.get_yql_template(),
            history_table_prefix=history_table_prefix,
            tables=tables,
            path=path,
            report_date=report_date,
            max_parallel_operations=self.Parameters.max_parallel_operations
        )

        return query

    def run_query(self, query):
        from yql.api.v1.client import YqlClient

        yql_client = YqlClient(db=self.Parameters.yt_cluster, token=self.get_yt_token())
        request = yql_client.query(query, syntax_version=1)
        request.encoding = 'utf-8'
        request.run()

        if not request.get_results().is_success:
            error_description = '\n'.join([str(err) for err in request.get_results().errors])
            raise RuntimeError(error_description)

        return request.get_results().table

    def prepare_and_run_yql(self, tables):
        query = self.prepare_yql(tables)
        self.run_query(query)

    def get_tables_by_yt(self):
        yt_client = self.get_yt_client()

        if self.Parameters.history_path_date_template == "by_date":
            table_names = yt_client.list(self.Parameters.tables_dir + "/{date}".format(date=self.Parameters.report_date))
        else:
            table_names = yt_client.list(self.Parameters.tables_dir)

        if "history" in table_names:
            table_names.remove("history")

        for table in self.Parameters.black_tables_list.split(","):
            if table in table_names:
                table_names.remove(table)

        for table in table_names:
            if self.Parameters.history_path_date_template == "by_date":
                table_path_template = "{tables_dir}/{date}/{table_name}"
            else:
                table_path_template = "{tables_dir}/{table_name}/{date}"
            table_path = table_path_template.format(
                tables_dir=self.Parameters.tables_dir,
                table_name=table,
                date=self.Parameters.report_date
            )
            if not yt_client.exists(table_path):
                table_names.remove(table)

        return table_names

    def on_create(self):
        self.Requirements.tasks_resource = sdk2.service_resources.SandboxTasksBinary.find(
            attrs={"Name": "MakeHistoryForYtTable", "release": ctt.ReleaseStatus.STABLE},
        ).first()

    def clean_history(self, tables):
        import yt.wrapper as yt

        token = self.get_yt_token()
        yt.config.config['token'] = token
        yt.config.set_proxy(self.Parameters.yt_cluster)

        d = (datetime.strptime(self.Parameters.report_date, '%Y-%m-%d') -
             timedelta(days=self.Parameters.days_to_store)).strftime('%Y-%m-%d')
        logging.info("Date for delete table is {date}".format(date=d))

        if self.Parameters.history_path_date_template == "by_date":
            date_path = "{prefix}/{date}".format(prefix=self.Parameters.tables_dir, date=d)
            if yt.exists(date_path):
                yt.remove(date_path, recursive=True, force=True)
                logging.info("Path '{date_path}' DELETED".format(date_path=date_path))
        else:
            for table in tables:
                table_path = "{prefix}/{table_name}/{date}".format(prefix=self.Parameters.tables_dir, date=d,
                                                                   table_name=table)
                logging.info("Table path to delete is: {table}".format(table=table_path))
                if yt.exists(table_path):
                    yt.remove(table_path)
                    logging.info("Table path '{table}' DELETED".format(table=table_path))

    def get_table_name_date_str(self, table_name, date):
        if self.Parameters.history_path_date_template == "by_date":
            template = BY_DATE_TEMPLATE
        else:
            template = BY_TABLE_TEMPLATE
        logging.info(template.format(table_name=table_name, date=date))
        return template.format(table_name=table_name, date=date)

    def get_yql_template(self):
        if self.Parameters.history_path_date_template == "by_date":
            return "$report_date || '/' || $table"
        else:
            return "$table || '/' || $report_date"

    def create_link_for_last_table(self):
        import yt.wrapper as yt

        token = sdk2.Vault.data(YQL_TOKEN_OWNER, YQL_TOKEN_NAME)
        yt.config.config['token'] = token
        yt.config.set_proxy(self.Parameters.yt_cluster)

        link_path = self.Parameters.tables_dir + "/last"

        if yt.exists(link_path):
            yt.remove(link_path)
        dates = yt.list(self.Parameters.tables_dir)
        if "history" in dates:
            dates.remove("history")
        last_date = max(dates)
        last_date_path = self.Parameters.tables_dir + "/" + last_date

        yt.link(last_date_path, link_path)

    def on_execute(self):
        tables = self.get_tables_by_yt() if self.Parameters.get_tables_by_yt_list else self.Parameters.tables.split(',')
        self.Context.tables_for_history = tables
        self.prepare_and_run_yql(tables)
        if self.Parameters.clean_old_tables:
            self.clean_history(tables)
        self.create_link_for_last_table()
