import datetime
import logging
from typing import Dict, List, Tuple, Optional

from sandbox import sdk2
import yt.wrapper as ytw
import yt.yson as yson

from sandbox.projects.ads.emily.tasks.export_lm_scheme_and_config.lib.compare import is_safe_diff
from sandbox.projects.ads.emily.tasks.export_lm_scheme_and_config.lib.tables import TableNames

EXTREMELY_DANGEROUS_DISABLE_SNAPSHOT_VERIFICATION = False

SNAPSHOTS_DIRECTORY_NAME = "lm_snapshots"
SNAPSHOT_VERSION_DELIMITER = "-"

GLOBAL_SNAPSHOT_CREATION_ERRORS: List[Tuple[str, Exception]] = []
GLOBAL_LINK_CHANGING_ERRORS: List[Tuple[str, Exception]] = []
GLOBAL_DUPLICATE_TABLES = True  # will change to False during duplicate check


class YTEnvironmentException(Exception):
    """Something is wrong in one of the YT cluster (missing tables, wrong directory structure, etc.)"""


def to_set_of_tuples(lst: List[Dict]):
    return {tuple(sorted(row.items())) for row in lst}


def log_to_task(task: Optional[sdk2.Task], message):
    logging.info(message)
    if task:
        task.set_info(message)


def select_yt_cluster(cluster):
    ytw.config.__dict__["_driver"] = None  # RPC driver ignores set_proxy due to caching proxy url
    ytw.config.set_proxy(cluster)  # pylint: disable=no-member # noqa


def prepare_yt_dir_structure(yt_dir_path) -> str:
    """Creates necessary YT directories if needed.
    :return path to snapshots directory"""
    if not ytw.exists(yt_dir_path):
        logging.warning(f"Specified path does not exist, trying to create {yt_dir_path}...")
        ytw.create("map_node", yt_dir_path)
        logging.warning(f"{yt_dir_path} created successfully")
    snapshots_path = f"{yt_dir_path}/{SNAPSHOTS_DIRECTORY_NAME}"
    if not ytw.exists(snapshots_path):
        logging.warning(f"Snapshot directory '{SNAPSHOTS_DIRECTORY_NAME}' does not exist, creating...")
        ytw.create("map_node", snapshots_path)
    return snapshots_path


def create_new_snapshot(yt_dir_path: str,
                        tables: List[Tuple[str, yson.YsonList, List[Dict]]],
                        snapshot_id: str,
                        sandbox_task: Optional[sdk2.Task] = None):
    cluster = ytw.config['proxy']['url']  # pylint: disable=no-member # noqa
    with ytw.Transaction():
        logging.info("Preparing YT directory structure...")
        snapshots_path = prepare_yt_dir_structure(yt_dir_path)
    tables_with_path = [
        (table_name, table_schema, table_rows, f"{snapshots_path}/{table_name}-{snapshot_id}")
        for table_name, table_schema, table_rows in tables
    ]
    global GLOBAL_DUPLICATE_TABLES  # Should be fine as global because exported tables (possibly duplicates) are not changed during task excution # noqa
    for table_name, __, table_rows in tables:
        try:
            existing_table = list(ytw.read_table(f"{yt_dir_path}/{table_name}", format="json"))
            for row in existing_table:
                del row["YTHash"]
            if to_set_of_tuples(existing_table) == to_set_of_tuples(table_rows):
                log_to_task(sandbox_task, f"Table {table_name} on cluster {cluster} is duplicate of exported data.")
            else:
                GLOBAL_DUPLICATE_TABLES = False
        except ytw.errors.YtResponseError as error:
            logging.warning("YT error while trying to check for duplicates", exc_info=error)
            GLOBAL_DUPLICATE_TABLES = False
    if GLOBAL_DUPLICATE_TABLES:
        log_to_task(sandbox_task, f"All tables built for {cluster} are duplicates of existing tables, will not export.")
        return False
    logging.info("Creating snapshot tables...")
    with ytw.Transaction():
        for table_name, table_schema, table_rows, table_path in tables_with_path:
            table_schema.attributes = {"strict": True, "unique_keys": True}
            table_attributes = {"schema": table_schema, "dynamic": True, "enable_dynamic_store_read": True}
            ytw.create("table", table_path, attributes=table_attributes)
    logging.info("Trying to set primary medium to SSD...")
    try:  # Try to set primary medium to SSD, which can fail sometimes depending on quotas and accounts
        # First, check that SSD quota is not zero
        with ytw.Transaction():
            test_path = f"{snapshots_path}/TEST_SSD_QUOTAS"
            ytw.create("table", test_path, attributes={"primary_medium": "ssd_blobs"})
            ytw.write_table(test_path, [{"test": 1}])
            ytw.remove(test_path)
        for __, __, __, table_path in tables_with_path:
            ytw.set(f"{table_path}/@primary_medium", "ssd_blobs")
        logging.info("Primary medium set to SSD successfully.")
    except ytw.errors.YtResponseError as error:
        logging.warning("Setting primary medium to SSD failed!", exc_info=error)
    logging.info("Done creating snapshot tables.")
    with ytw.Transaction():
        for table_name, table_schema, table_rows, table_path in tables_with_path:
            logging.info(f"Mounting {table_path}...")
            ytw.mount_table(table_path)
    with ytw.Transaction(type="tablet"):
        for table_name, table_schema, table_rows, table_path in tables_with_path:
            logging.info(f"Inserting {len(table_rows)} {table_name} rows into {table_path}...")
            ytw.insert_rows(table_path, table_rows)
            log_to_task(sandbox_task, f"Table {table_path} on cluster {cluster} filled with rows.")
    with ytw.Transaction():
        for table_name, table_schema, table_rows, table_path in tables_with_path:
            logging.info(f"Remounting {table_path} to flush rows...")
            ytw.unmount_table(table_path, sync=True)
            ytw.mount_table(table_path, sync=True)
    return True


def verify_snapshots(clusters, yt_dir_path, snapshot_id, sandbox_task=None):
    if EXTREMELY_DANGEROUS_DISABLE_SNAPSHOT_VERIFICATION:
        logging.warning("REALLY DANGEROUS! Snapshot verification is disabled! Hope you know what you are doing!")
        return
    summary_diff_safeness = True
    for cluster in clusters:
        try:
            logging.info(f"Checking target snapshot integrity on cluster {cluster}...")
            select_yt_cluster(cluster)
            snapshots_dir_path = f"{yt_dir_path}/{SNAPSHOTS_DIRECTORY_NAME}"
            if not ytw.exists(snapshots_dir_path):
                logging.exception("YT error",
                                  exc_info=YTEnvironmentException(f"Snapshots dir does not exist on cluster {cluster}"))
            child_nodes = set(ytw.list(snapshots_dir_path))
            snapshot_table_names = {f"{table_name}{SNAPSHOT_VERSION_DELIMITER}{snapshot_id}" for table_name in TableNames}  # noqa
            if not snapshot_table_names.issubset(child_nodes):
                raise YTEnvironmentException(
                    f"Snapshot with suffix {snapshot_id} for tables {snapshot_table_names - child_nodes} "
                    f"not found on cluster {cluster}!"
                )
            table_paths = {table_name: f"{snapshots_dir_path}/{table_name}{SNAPSHOT_VERSION_DELIMITER}{snapshot_id}"
                           for table_name in TableNames}
            table_dict = {table_name: list(ytw.read_table(table_paths[table_name], format='json'))
                          for table_name in TableNames}
            try:
                old_tables = {table_name: list(ytw.read_table(f"{yt_dir_path}/{table_name}", format="json"))
                              for table_name in TableNames}
                safe_diff, error_strs = is_safe_diff(old_tables, table_dict)
                if safe_diff:
                    logging.info(f"Confirmed safe diff on {cluster}")
                else:
                    summary_diff_safeness = False
                    log_to_task(sandbox_task, f"Unsafe diff for on cluster {cluster}: {error_strs}")
            except ytw.errors.YtError:
                logging.warning(f"Computing diff failed on cluster {cluster}.")
            if all(len(table) for table in table_dict.values()):
                logging.info("Confirmed all snapshots have rows")
            else:
                raise YTEnvironmentException("One of snapshot tables returns 0 rows when read.")  # That REALLY should not happen # noqa
        except (YTEnvironmentException, ytw.errors.YtResponseError) as error:
            logging.exception("YT error", exc_info=error)
    if not summary_diff_safeness and not EXTREMELY_DANGEROUS_DISABLE_SNAPSHOT_VERIFICATION:
        raise YTEnvironmentException("Some snapshot diff is unsafe, see logs")
    logging.info("Snapshot integrity is verified on target clusters.")


def __change_links(snapshot_id, clusters, yt_dir_path, sandbox_task):
    for cluster in clusters:
        logging.info(f"Dangerous operation! Changing links on cluster {cluster}...")
        select_yt_cluster(cluster)
        snapshots_dir_path = f"{yt_dir_path}/{SNAPSHOTS_DIRECTORY_NAME}"
        try:
            with ytw.Transaction():
                for table_name in TableNames:
                    table_path = f"{yt_dir_path}/{table_name}"
                    snapshot_path = f"{snapshots_dir_path}/{table_name}{SNAPSHOT_VERSION_DELIMITER}{snapshot_id}"
                    node_exists = ytw.exists(table_path)
                    if not node_exists:
                        logging.warning(f"Node {table_path} on cluster {cluster} does not exist. Continuing...")
                    elif ytw.get(f"{table_path}&/@type") != "link":
                        logging.warning(f"Node {table_path} on cluster {cluster} isn't a link. Continuing...")
                    else:
                        old_snapshot_path = ytw.get(f"{table_path}&/@target_path")
                        expiration_time = (datetime.datetime.now() + datetime.timedelta(days=3)).isoformat()
                        if ytw.exists(old_snapshot_path):
                            logging.info(f"Setting TTL {expiration_time} for old snapshot {old_snapshot_path}")
                            ytw.set_attribute(old_snapshot_path, "expiration_time", expiration_time)
                        else:
                            logging.warning(f"{table_path} links to {old_snapshot_path}, which does not exist.")
                    if node_exists:
                        ytw.remove(table_path)
                    ytw.link(snapshot_path, table_path)
                    if sandbox_task:
                        sandbox_task.set_info(f"Linked {table_path} to {snapshot_path} on {cluster}")
                    ytw.set_attribute(table_path, "linked_to_prod_datetime", datetime.datetime.now().isoformat())
        except ytw.YtError as exc:
            GLOBAL_LINK_CHANGING_ERRORS.append((cluster, exc))
            if sandbox_task:
                error_text = f"Changing links failed on cluster {cluster}, see logs for details."
                error_html = f'<p style="color: red">{error_text}</p>'
                sandbox_task.set_info(error_html, do_escape=False)
        if GLOBAL_DUPLICATE_TABLES or not GLOBAL_LINK_CHANGING_ERRORS:
            try:
                with ytw.Transaction():
                    for table_name in TableNames:
                        table_path = f"{yt_dir_path}/{table_name}"
                        ytw.set_attribute(table_path, "max_unix_time", int(datetime.datetime.now().timestamp()))
            except ytw.YtError as exc:
                logging.exception(f"Could not set 'max_unix_time' attr for {table_path} on {cluster}", exc_info=exc)
                GLOBAL_LINK_CHANGING_ERRORS.append((cluster, exc))


def create_snapshots(tables: List[Tuple[str, yson.YsonList, List[Dict]]],
                     clusters: List[str],
                     yt_dir_path: str,
                     snapshot_id="",
                     sandbox_task: Optional[sdk2.Task] = None):
    if sandbox_task:
        sandbox_task.set_info("All tables built")
    for cluster in clusters:
        logging.info(f"Creating snapshot on cluster {cluster}...")
        select_yt_cluster(cluster)
        try:
            create_new_snapshot(yt_dir_path, tables, snapshot_id, sandbox_task=sandbox_task)
        except (YTEnvironmentException, ytw.errors.YtError) as e:
            GLOBAL_SNAPSHOT_CREATION_ERRORS.append((cluster, e))
            if sandbox_task:
                error_text = f"Snapshot creation failed on cluster {cluster}, see logs for details."
                error_html = f'<p style="color: red">{error_text}</p>'
                sandbox_task.set_info(error_html, do_escape=False)


def main(tables: List[Tuple[str, yson.YsonList, List[Dict]]],
         clusters: List[str],
         yt_dir_path: str,
         yt_token: Optional[str] = None,
         should_create_new_snapshot=False,
         should_change_links=False,
         snapshot_id="",
         sandbox_task=None):
    if not snapshot_id and should_create_new_snapshot and sandbox_task:
        snapshot_id = str(sandbox_task.id)
    if yt_token:
        ytw.config["token"] = yt_token  # pylint: disable=unsupported-assignment-operation # noqa
    ytw.config["backend"] = "rpc"  # pylint: disable=unsupported-assignment-operation # noqa
    if should_create_new_snapshot:
        create_snapshots(tables, clusters, yt_dir_path, snapshot_id, sandbox_task=sandbox_task)
    if should_create_new_snapshot or should_change_links:
        verify_snapshots(clusters, yt_dir_path, snapshot_id, sandbox_task=sandbox_task)
    if should_change_links:
        __change_links(snapshot_id, clusters, yt_dir_path, sandbox_task=sandbox_task)
    for cluster_name, error in GLOBAL_SNAPSHOT_CREATION_ERRORS:
        logging.exception(f"Creating snapshot failed on cluster {cluster_name}", exc_info=error)
    for cluster_name, error in GLOBAL_LINK_CHANGING_ERRORS:
        logging.exception(f"Changing links failed on cluster {cluster_name}", exc_info=error)
    if (GLOBAL_LINK_CHANGING_ERRORS and not GLOBAL_DUPLICATE_TABLES) or GLOBAL_SNAPSHOT_CREATION_ERRORS:
        raise YTEnvironmentException("Something failed, see info above and logs for details.")
    if GLOBAL_LINK_CHANGING_ERRORS and GLOBAL_DUPLICATE_TABLES:
        log_to_task(sandbox_task, "Link changing failed but it's probably just because of not-created duplicates")
