from __future__ import print_function
import json
import logging

from sandbox.sdk2 import (
    Vault,
    parameters,
    Requirements,
)
from sandbox.projects.yabs.base_bin_task import BaseBinTask


class YabsStaticToDynamicSorter(BaseBinTask):
    class Requirements(Requirements):
        cores = 1
        ram = 2048
        disk_space = 2048

        class Caches(Requirements.Caches):
            pass

    class Parameters(BaseBinTask.Parameters):
        with BaseBinTask.Parameters.version_and_task_resource() as version_and_task_resource:
            resource_attrs = parameters.Dict("Filter resource by", default={"name": "YabsStaticToDynamic"})

        with parameters.Group('Sort parameters') as sort_params:
            yt_token_vault = parameters.String('Vault with yt token', required=True)
            clusters = parameters.List('Source yt clusters', required=True)
            meta_cluster = parameters.String('YT cluster with meta info', required=True)
            meta_path = parameters.String('Meta path', required=True)
            src_path = parameters.String('Source path', required=True)
            dst_path = parameters.String('Destination path', required=True)
            key_columns = parameters.List('Key columns', required=True)

    def get_last_timestamp(self):
        from yt.wrapper import YtClient

        ytc = YtClient(proxy=self.Parameters.meta_cluster, token=self.yt_token)
        if ytc.exists(self.Parameters.meta_path):
            return json.loads(ytc.read_file(self.Parameters.meta_path).read())['timestamp']
        else:
            return ''

    def get_best_cluster(self, last_timestamp):
        from yt.wrapper import YtClient

        best_cluster = None
        for cluster in self.Parameters.clusters:
            try:
                ytc = YtClient(proxy=cluster, token=self.yt_token)
                cluster_time = ytc.get_attribute(self.Parameters.src_path, 'creation_time')
                logging.info('Timestamp for cluster %s is %s', cluster, cluster_time)
                if cluster_time > last_timestamp:
                    last_timestamp = cluster_time
                    best_cluster = cluster

            except Exception as e:
                logging.warning('Could not get creation_time on cluster %s: %s', cluster, e)

        return best_cluster

    def do_sort(self, cluster):
        from yt.wrapper import YtClient

        logging.info('Starting sort operation on cluster %s', cluster)
        ytc = YtClient(proxy=cluster, token=self.yt_token)

        with ytc.Transaction():
            lock_id = ytc.lock(self.Parameters.src_path, mode='snapshot')
            snapshot = '#' + ytc.get_attribute('#' + lock_id, 'node_id')
            timestamp = ytc.get_attribute(snapshot, 'creation_time')

            schema = ytc.get_attribute(snapshot, 'schema')
            schema.attributes['unique_keys'] = True
            for col in schema:
                if col['name'] in self.Parameters.key_columns:
                    col['sort_order'] = 'ascending'

            schema.sort(key=lambda col: self.Parameters.key_columns.index(col['name']) if col['name'] in self.Parameters.key_columns else 100000)
            dst_table = '{}/{}'.format(self.Parameters.dst_path, timestamp)
            ytc.create('table', dst_table, attributes={'schema': schema}, force=True)
            ytc.run_sort(snapshot, dst_table, sort_by=self.Parameters.key_columns)

            return dst_table, timestamp

    def do_cleanup(self, cluster):
        from yt.wrapper import YtClient

        logging.info('Removing old tables on cluster %s', cluster)
        ytc = YtClient(proxy=cluster, token=self.yt_token)

        tables = sorted(list(ytc.list(self.Parameters.dst_path, absolute=True)))
        for table in tables[:-5]:
            logging.info('Removing %s', table)
            ytc.remove(table)

    def on_execute(self):
        from yt.wrapper import YtClient

        yt_token = Vault.data(self.Parameters.yt_token_vault)
        self.yt_token = yt_token

        last_timestamp = self.get_last_timestamp()
        best_cluster = self.get_best_cluster(last_timestamp)

        if best_cluster is None:
            logging.info('No new generation found, stopping')
            return

        result_table, timestamp = self.do_sort(best_cluster)

        new_metadata = {
            'timestamp': timestamp,
            'cluster': best_cluster,
            'path': result_table,
        }
        ytc = YtClient(proxy=self.Parameters.meta_cluster, token=self.yt_token)
        ytc.write_file(self.Parameters.meta_path, json.dumps(new_metadata))

        self.do_cleanup(best_cluster)
