# -*- coding: utf-8 -*-
import logging
from sandbox import sdk2
import time
from sandbox.sandboxsdk import environments
from sandbox import common
import datetime
from collections import defaultdict

PROCESSING_TIME_TABLE = "//home/yabs/stat/ProcessLogUpdateTime"
ATTR_NAME = "last_sync_time"
LOCK_PATH = "//tmp/MarkLock"
ITER_DURATION = 15
LOOP_DURATION = 15 * 60
YT_RETRY_COUNT = 2


def get_lock_count(ytc, path):
    locks = ytc.get_attribute(path, "locks")
    return sum(1 for lk in locks if lk["mode"] == "exclusive")


def format_time(ts):
    dateobj = datetime.datetime.fromtimestamp(time.mktime(time.gmtime(ts)))
    return dateobj.strftime("%Y-%m-%dT%H:%M:%SZ")


class BatchClientWrapper(object):
    def __init__(self, batch_client):
        self.batch_client = batch_client
        self.requests = []

    def __getattr__(self, name):
        def wrapper(*args, **kwargs):
            context = kwargs.pop('context', None)
            err_message = 'Cannot do operation {}(args={}, kwargs={}) on {}'.format(name, args, kwargs, self.cluster)
            self.requests.append((context, getattr(self.batch_client, name)(*args, **kwargs), err_message))

        return wrapper

    @property
    def cluster(self):
        return self.batch_client.config['proxy']['url']

    def commit_batch(self):
        if not self.requests:
            return [], True
        requests = self.requests
        self.requests = []
        try:
            self.batch_client.commit_batch()
        except Exception as ex:
            logging.warning('Exception = {}. Cannot commit batch on {}'.format(ex, self.cluster))
            return [], False
        results = []
        batch_success = True
        for context, result, err_message in requests:
            if result.is_ok():
                results.append((context, result.get_result()))
            else:
                batch_success = False
                logging.warning('Exception = {}. {}'.format(result.get_error(), err_message))
        return results, batch_success


class YabsTablesTimeSync(sdk2.Task):
    """
    Task for https://st.yandex-team.ru/BSDEV-69727
    """

    class Parameters(sdk2.Task.Parameters):
        clusters = sdk2.parameters.List("Clusters to process")
        clusters_blacklist = sdk2.parameters.List("Clusters to skip")
        tables = sdk2.parameters.List("Tables to process")
        yt_token_secret_id = sdk2.parameters.YavSecret(
            label="YT token secret id",
            required=True,
            description='secret should contain keys: YT_TOKEN',
        )

    class Requirements(sdk2.Task.Requirements):
        environments = (
            environments.PipEnvironment('yandex-yt', '0.9.17', use_wheel=True),
            environments.PipEnvironment('yandex-yt-yson-bindings-skynet', use_wheel=True),
        )

        cores = 1
        ram = 4096
        disk_space = 4096

        class Caches(sdk2.Requirements.Caches):
            pass

    connection_cache = {}
    task_failure = False

    def get_connection(self, proxy):
        if proxy not in self.connection_cache:
            import yt.wrapper as yt
            ytc = yt.YtClient(config={"proxy": {"url": proxy, "retries": {"count": YT_RETRY_COUNT}}, "token": self.yt_token})
            bc = BatchClientWrapper(ytc.create_batch_client())
            self.connection_cache[proxy] = ytc, bc
        return self.connection_cache[proxy]

    def commit_batches(self, clusters):
        results = {}
        for cluster in clusters:
            _, bc = self.connection_cache[cluster]
            batch_result, batch_success = bc.commit_batch()
            if batch_result:
                results[cluster] = batch_result
            self.task_failure = self.task_failure or not batch_success
        return results

    def get_max_update_time(self, rows):
        if len(rows) == 0:
            return []
        result = [('_' + row['LogType'], row['Time']) for row in rows]
        result.append(('', max([row['Time'] for row in rows])))
        return result

    def process_all(self, yt_token):
        for cluster in self.Parameters.clusters:
            if cluster in self.Parameters.clusters_blacklist:
                continue
            _, bc = self.get_connection(cluster)
            for table in self.Parameters.tables:
                bc.exists(table, context=table)
        results = self.commit_batches(set(self.Parameters.clusters).difference(self.Parameters.clusters_blacklist))

        for cluster, result in results.iteritems():
            _, bc = self.get_connection(cluster)
            for path, is_exist in result:
                if is_exist:
                    bc.get(path + '/@type', context=path)
                else:
                    logging.info('Table {}.{} doesn\'t exist. Skip'.format(cluster, path))
        results = self.commit_batches(results.keys())

        ready_results = {}
        for cluster, result in results.iteritems():
            _, bc = self.get_connection(cluster)
            ready_results[cluster] = ready_result = []
            for path, node_type in result:
                if node_type == "replicated_table":
                    bc.get(path + '/@replicas')
                else:
                    ready_result.append((cluster, path, 0))
        results = self.commit_batches(results.keys())
        for cluster, result in results.iteritems():
            for _, replicas in result:
                for replica in replicas.values():
                    dest_cluster = replica['cluster_name']
                    dest_path = replica['replica_path']
                    replica_lag = replica['replication_lag_time']
                    ready_results[cluster].append((dest_cluster, dest_path, replica_lag))
                    logging.info('Find replica {}.{} with lag={}'.format(dest_cluster, dest_path, replica_lag))

        select_results = defaultdict(list)
        for cluster, result in ready_results.iteritems():
            if cluster in self.Parameters.clusters_blacklist:
                continue
            from yt.wrapper import ypath_split
            ytc, _ = self.get_connection(cluster)
            table_names = ["'{}'".format(ypath_split(path)[-1]) for _, path, _ in result]  # noqa
            query = "TableName, LogType, max(UpdateTime) as Time from [{}] where TableName in ({}) group by TableName, LogType".format(PROCESSING_TIME_TABLE, ','.join(table_names))
            try:
                rows = list(ytc.select_rows(query))
            except Exception as ex:
                logging.warning('Exception = {}. Cannot select rows on {}'.format(ex, cluster))
                self.task_failure = True
                continue
            by_table_name = defaultdict(list)
            for row in rows:
                by_table_name[row['TableName']].append(row)
            for dest_cluster, dest_path, replication_lag in result:
                table_name = ypath_split(dest_path)[-1]
                select_results[dest_cluster].append((dest_path, replication_lag, by_table_name[table_name]))

        for cluster, result in select_results.iteritems():
            _, bc = self.get_connection(cluster)
            for path, lag, rows in result:
                for log_type, update_time in self.get_max_update_time(rows):
                    update_time = update_time - lag / 1000
                    str_time = format_time(update_time)
                    bc.get(path, attributes=[ATTR_NAME + log_type], context=(path, log_type, str_time))
        results = self.commit_batches(select_results.keys())

        for cluster, result in results.iteritems():
            _, bc = self.get_connection(cluster)
            for (path, log_type, new_time), node in result:
                attr = ATTR_NAME + log_type
                cur_time = node.attributes.get(attr, 0)
                if new_time > cur_time:
                    logging.info('Update last_sync_time on %s.%s for log-type %s to %s. Previos was %s', cluster, path, log_type, new_time, cur_time)
                    bc.set_attribute(path, attr, new_time)
                else:
                    logging.info('Update last_sync_time on %s.%s for log-type %s to %s. No changes', cluster, path, log_type, cur_time)
        self.commit_batches(results.keys())

    def on_execute(self):
        if self.Parameters.clusters_blacklist:
            logging.info('Clusters {} in blacklist and will be ignored'.format(self.Parameters.clusters_blacklist))
        import yt.wrapper as yt
        start_time = time.time()
        logging.info("Started at {}".format(start_time))
        self.yt_token = self.Parameters.yt_token_secret_id.data()["YT_TOKEN"]
        ytc = yt.YtClient(config={"token": self.yt_token, "proxy": {"url": "locke"}})
        ytc.create("int64_node", LOCK_PATH, ignore_existing=True, recursive=True)
        if get_lock_count(ytc, LOCK_PATH) > 1:
            logging.info("Too many tasks already waiting for lock. I quit.")
            return
        with ytc.Transaction():
            logging.info("Waiting for lock")
            ytc.lock(LOCK_PATH, mode="exclusive", waitable=True, wait_for=600000)
            while time.time() - start_time < LOOP_DURATION and get_lock_count(ytc, LOCK_PATH) == 1:
                logging.info("Start iteration")
                iter_start = time.time()
                self.process_all(self.yt_token)
                logging.info("Finish iteration")
                sleep_duration = iter_start + ITER_DURATION - time.time()
                if sleep_duration > 0:
                    time.sleep(sleep_duration)
        if self.task_failure:
            raise common.errors.TaskFailure("Problems with clusters")
