from .utils import (
    find_newest_backup,
    is_yp_backup_folder,
    sort_by_date,
)
from .solomon_wrapper import create_solomon_client

from infra.yp.lib.retrier import retry_me

from yp.admin import clone_db

from yt.wrapper import YtClient

from random import shuffle
import argparse
import datetime
import logging
import os
import requests


logging.basicConfig(level=logging.DEBUG)

yt_clients = {}

YP_PATH = "//yp"


def get_yt_client_base(cluster, token):
    if cluster not in yt_clients:
        client = YtClient(proxy=cluster, token=token)
        yt_clients[cluster] = client
        return client
    else:
        return yt_clients[cluster]


def get_yt_client(cluster):
    return get_yt_client_base(cluster, token=os.environ["YT_TOKEN"])


def build_base_backup_path(yt_path, yp_cluster):
    return yt_path + "/" + yp_cluster


def build_backup_path(yt_path, yp_cluster, directory_name):
    return build_base_backup_path(yt_path, yp_cluster) + "/" + directory_name


def copy_table(yp_cluster, yt_cluster, yt_path, directory_name, solomon_client):
    yp_client = get_yt_client(yp_cluster)
    yt_client = get_yt_client(yt_cluster)

    destination_directory = build_backup_path(yt_path, yp_cluster, directory_name)
    logging.debug("Copy table to %s", destination_directory)

    profiling = clone_db(
        yp_client,
        YP_PATH,
        yt_client,
        destination_directory,
        remove_if_exists=True,
        no_init_yp_cluster=True,
    )

    solomon_client.collect_create_snapshot_duration(profiling["create_snapshot_duration"])
    solomon_client.collect_history_events_table_size(profiling["history_events_table_size"])


def get_backup_depth_in_clusters(yt_clusters, yt_path, yp_cluster):
    backup_count = dict()

    for yt_cluster_name in yt_clusters:
        client = get_yt_client(yt_cluster_name)

        try:
            if yp_cluster not in client.list(yt_path):
                client.create("map_node", build_base_backup_path(yt_path, yp_cluster))

            backup_tables = [folder for folder in client.list(build_base_backup_path(yt_path, yp_cluster))
                             if is_yp_backup_folder(folder)]

            backup_count[yt_cluster_name] = len(backup_tables)
        except Exception:
            logging.exception("Generic cluster error for cluster %s, skipping" % yt_cluster_name)

    max_backups = max([cnt for yt_cluster, cnt in backup_count.iteritems()])
    min_backups = min([cnt for yt_cluster, cnt in backup_count.iteritems()])

    logging.debug("Backup count is {}".format(backup_count))
    logging.debug("Max backups = {}, min backups = {}".format(max_backups, min_backups))

    return max_backups, min_backups


def find_cluster_to_backup(yt_clusters, yt_path, yp_cluster):
    cluster_backup_info = dict()

    for yt_cluster_name in yt_clusters:
        client = get_yt_client(yt_cluster_name)

        try:
            if yp_cluster not in client.list(yt_path):
                client.create("map_node", build_base_backup_path(yt_path, yp_cluster))

            backup_tables = [folder for folder in client.list(build_base_backup_path(yt_path, yp_cluster))
                             if is_yp_backup_folder(folder)]

            if len(backup_tables) == 0:
                newest_backup_date = "ypbackup_0"
            else:
                backup_tables_sorted = sorted(backup_tables, cmp=sort_by_date, reverse=True)
                newest_backup_date = backup_tables_sorted[0]
                logging.debug("Cluster {} sorted backup tables = {}".format(yt_cluster_name, backup_tables_sorted))

            cluster_backup_info[yt_cluster_name] = newest_backup_date
        except Exception:
            logging.exception("Generic cluster error for cluster %s, skipping" % yt_cluster_name)

    logging.debug("Cluster backup info = {}".format(cluster_backup_info))

    if len(cluster_backup_info.keys()) == 1:
        return [cluster_backup_info.keys()[0]]

    newest_backup_cluster, newest_backup_date = find_newest_backup(cluster_backup_info)

    del cluster_backup_info[newest_backup_cluster]
    clusters = cluster_backup_info.keys()
    shuffle(clusters)
    clusters.append(newest_backup_cluster)
    return clusters


def remove_oldest_backup(yt_cluster_name, yt_path, yp_cluster, max_backups):
    client = get_yt_client(yt_cluster_name)

    backup_tables = client.list(build_base_backup_path(yt_path, yp_cluster))
    logging.debug("Backup tables = {}".format(backup_tables))

    backup_tables_sorted = sorted(backup_tables, cmp=sort_by_date)
    logging.debug("Sorted latest backup = {}".format(backup_tables_sorted))

    tables_to_delete = len(backup_tables_sorted) - max_backups

    if tables_to_delete > 0:
        tables_to_delete = backup_tables_sorted[:tables_to_delete]
        for table in tables_to_delete:
            table_to_delete = build_backup_path(yt_path, yp_cluster, table)
            logging.info("Table to delete = {} in {}".format(table_to_delete, yt_cluster_name))

            client.remove(table_to_delete, recursive=True)


def notify_backup(yp_cluster, status, description=None):
    result = retry_me(lambda: requests.post("http://juggler-push.search.yandex.net/events",
                                            json={
                                                "source": "yp_backup",
                                                "events": [
                                                    {
                                                        "description": "All ok" if status else description,
                                                        "host": "yp-{}.yandex.net".format(yp_cluster),
                                                        "instance": "",
                                                        "service": "yp_backup",
                                                        "status": "OK" if status else "CRIT"
                                                    }
                                                ]
                                            }, timeout=10), 10)
    result.raise_for_status()


def main_impl(args, solomon_client):
    logging.debug("Preparing for the backup")

    backuped = False
    max_backups_in_cluster = 0
    min_backups_in_cluster = 0

    try:
        yt_clusters_to_backup = args.yt_clusters.split(",")
        yt_base_path = args.yt_base_path
        yp_cluster = args.yp_cluster
        max_backups = args.max_backups

        max_backups_in_cluster, min_backups_in_cluster = \
            get_backup_depth_in_clusters(yt_clusters_to_backup, yt_base_path, yp_cluster)

        yt_clusters_to_backup = find_cluster_to_backup(yt_clusters_to_backup, yt_base_path, yp_cluster)
        logging.debug("Will backup to {}".format(yt_clusters_to_backup))

        timestamp = int((datetime.datetime.now() - datetime.datetime(1970, 1, 1)).total_seconds())
        directory_name = "ypbackup_{}".format(timestamp)

        for yt_cluster_to_backup in yt_clusters_to_backup:
            try:
                logging.debug("Backup in {}".format(yt_cluster_to_backup))
                copy_table(
                    yp_cluster,
                    yt_cluster_to_backup,
                    yt_base_path,
                    directory_name,
                    solomon_client,
                )
                remove_oldest_backup(yt_cluster_to_backup, yt_base_path, yp_cluster, max_backups)
                backuped = True
                break
            except Exception, ex:
                logging.exception("Unhandled exception while trying to backup to {}".format(yt_cluster_to_backup))

        if not backuped:
            logging.error("Backup totally failed")
            notify_backup(args.yp_cluster, False, "Backup failed to all clusters {}".format(yt_clusters_to_backup))
            raise Exception("Backup totally failed")

    except Exception, ex:
        logging.exception("Unhandled exception")
        notify_backup(args.yp_cluster, False, ex.message)
        raise

    if max_backups_in_cluster > min_backups_in_cluster * 2:
        notify_backup(args.yp_cluster, False, "Backup count in one cluster two times bigger than in another")
#    elif max_backups_in_cluster < max_backups/2:
#        notify_backup(args.yp_cluster, False, "Too few backups. Less that {}".format(max_backups/2))
    elif backuped:
        notify_backup(args.yp_cluster, True)
    else:
        notify_backup(args.yp_cluster, False, "Unknown reason")


def parse_arguments():
    parser = argparse.ArgumentParser(add_help=True)
    parser.add_argument("--yp_cluster", required=True)
    parser.add_argument("--yt_clusters", required=True)
    parser.add_argument("--yt_base_path", required=True)
    parser.add_argument("--max_backups", required=True, type=int)
    return parser.parse_args()


def main():
    arguments = parse_arguments()
    with create_solomon_client(arguments.yp_cluster) as solomon_client:
        main_impl(arguments, solomon_client)


if __name__ == "__main__":
    main()
