import datetime
import logging
from collections import defaultdict
from enum import Enum, auto
from typing import Dict, List, DefaultDict, Tuple, Optional

from google.protobuf.descriptor import EnumValueDescriptor, Descriptor
from library.python import svn_version
import yt.wrapper as ytw
import yt.yson as yson

from sandbox.projects.yabs.export_keywords_from_proto_to_yt.keyword_data_type import process_message_KDT
from sandbox.projects.yabs.export_keywords_from_proto_to_yt.keyword_type import process_message_KT
from sandbox.projects.yabs.export_keywords_from_proto_to_yt.snapshot_verify import verify_snapshot_content
from sandbox.projects.yabs.export_keywords_from_proto_to_yt.prod_compare import print_diff
from yabs.server.proto.keywords import keywords_schema_pb2, keywords_data_pb2


EXTREMELY_DANGEROUS_DISABLE_SNAPSHOT_VERIFICATION = False

SNAPSHOTS_DIRECTORY_NAME = "keyword_tables_snapshots"
SNAPSHOT_VERSION_DELIMITER = "-"

YT_COLUMN_BY_PROTO_FIELD_NAME_KEYWORD = {
    "Name": "Field",
    "ObsoleteType": "Type",
}
YT_COLUMN_BY_PROTO_FIELD_NAME_KEYWORD_TYPE = {
    "Name": "Data",
}


class YtFieldClasses(Enum):
    """Simple fields are copied from the .proto to YT, whereas complex fields require nontrivial export logic INSTEAD"""
    SIMPLE = auto()
    COMPLEX = auto()


class TableNames(str, Enum):
    KEYWORD = "Keyword"
    KEYWORD_TYPE = "KeywordType"
    KEYWORD_DATA_TYPE = "KeywordDataType"


TABLE_BY_FIELD_NAME_TKEYWORD: Dict[str, Tuple[TableNames, YtFieldClasses]] = {
    "BlindKeywordID": (TableNames.KEYWORD_DATA_TYPE, YtFieldClasses.SIMPLE),
    "ExpireDays": (TableNames.KEYWORD_DATA_TYPE, YtFieldClasses.SIMPLE),
    "ValueTypeAsInt": (TableNames.KEYWORD_DATA_TYPE, YtFieldClasses.SIMPLE),
    "MaxRecordCount": (TableNames.KEYWORD_DATA_TYPE, YtFieldClasses.SIMPLE),
    "ValueType": (TableNames.KEYWORD_DATA_TYPE, YtFieldClasses.COMPLEX),
    "DataGroup": (TableNames.KEYWORD_DATA_TYPE, YtFieldClasses.COMPLEX),
    "DataOption": (TableNames.KEYWORD_DATA_TYPE, YtFieldClasses.COMPLEX),
    "Location": (TableNames.KEYWORD_DATA_TYPE, YtFieldClasses.COMPLEX),

    "KeywordID": (TableNames.KEYWORD, YtFieldClasses.SIMPLE),
    "Name": (TableNames.KEYWORD, YtFieldClasses.SIMPLE),
    "Description": (TableNames.KEYWORD, YtFieldClasses.SIMPLE),
    "ObsoleteType": (TableNames.KEYWORD, YtFieldClasses.SIMPLE),

    "KeywordType": (TableNames.KEYWORD_TYPE, YtFieldClasses.COMPLEX),
}

FIELD_CLASS_BY_FIELD_NAME_TKEYWORD_TYPE: Dict[str, YtFieldClasses] = {
    "Type": YtFieldClasses.COMPLEX,
    "Name": YtFieldClasses.SIMPLE,
    "Option": YtFieldClasses.COMPLEX,
}

GLOBAL_SNAPSHOT_CREATION_ERRORS: List[Tuple[str, Exception]] = []
GLOBAL_LINK_CHANGING_ERRORS: List[Tuple[str, Exception]] = []


class YTEnvironmentException(Exception):
    """Something is wrong in one of the YT cluster (missing tables, wrong directory structure, etc.)"""


def select_yt_cluster(cluster):
    ytw.config.__dict__["_driver"] = None  # RPC driver ignores set_proxy due to caching proxy url
    ytw.config.set_proxy(cluster)  # pylint: disable=no-member # noqa


def get_table(message, field_name: str) -> Tuple[TableNames, YtFieldClasses]:
    if type(message).__name__ == keywords_schema_pb2.TKeywordType.__name__:
        target_table, yt_field_type = "KeywordType", FIELD_CLASS_BY_FIELD_NAME_TKEYWORD_TYPE[field_name]
    elif type(message).__name__ == keywords_schema_pb2.TKeyword.__name__:
        target_table, yt_field_type = TABLE_BY_FIELD_NAME_TKEYWORD[field_name]
    else:
        raise ValueError("Unknown field name in .proto; that should not happen ever.")
    return target_table, yt_field_type


def get_yt_column_name(message, field_name: str) -> str:
    if type(message).__name__ == keywords_schema_pb2.TKeyword.__name__:
        return YT_COLUMN_BY_PROTO_FIELD_NAME_KEYWORD.get(field_name, None) or field_name
    if type(message).__name__ == keywords_schema_pb2.TKeywordType.__name__:
        return YT_COLUMN_BY_PROTO_FIELD_NAME_KEYWORD_TYPE.get(field_name, None) or field_name
    return field_name


def add_yt_row_by_message(message, current_table_name: TableNames, keyword_id, tables_dict) -> None:
    """Construct exactly zero or one row in one table by protobuf message"""
    message_descriptor: Descriptor = message.DESCRIPTOR
    yt_row = {"KeywordID": keyword_id}  # This column is present in all 3 tables we are building
    is_row_constructed = False  # In the previous YT version the table didn't have rows like (0, "", "". [])

    # Simple fields copying
    for field_name, field_descriptor in message_descriptor.fields_by_name.items():
        field_value = getattr(message, field_name, None)
        target_table, yt_field_type = get_table(message, field_name)
        if target_table != current_table_name or yt_field_type != YtFieldClasses.SIMPLE:
            continue
        yt_column_name = get_yt_column_name(message, field_name)
        yt_row[yt_column_name] = field_value
        if field_value != field_descriptor.default_value:
            is_row_constructed = True

    # Table and field dependant processing
    if current_table_name == TableNames.KEYWORD_DATA_TYPE:
        is_row_constructed |= process_message_KDT(message, yt_row)

    elif current_table_name == TableNames.KEYWORD_TYPE:
        is_row_constructed |= process_message_KT(message, yt_row)

    # Append row if needed
    if is_row_constructed:
        tables_dict[current_table_name].append(yt_row)


def build_tables(keywords_dict: Dict[int, EnumValueDescriptor]) -> DefaultDict[str, List[Dict]]:
    """Build dict containing all tables to write.
    :param keywords_dict: dict containing all keywords
    :return: dict: tableName -> listOfJsonLikeRows
    """
    tables_dict = defaultdict(list)
    enum_value_descriptor: EnumValueDescriptor
    for keyword_id, enum_value_descriptor in keywords_dict.items():
        tKeyword = enum_value_descriptor.GetOptions().ListFields()[0][1]
        add_yt_row_by_message(tKeyword, TableNames.KEYWORD, keyword_id, tables_dict)
        add_yt_row_by_message(tKeyword, TableNames.KEYWORD_DATA_TYPE, keyword_id, tables_dict)
        for keyword_type in tKeyword.KeywordType:
            add_yt_row_by_message(keyword_type, TableNames.KEYWORD_TYPE, keyword_id, tables_dict)

    return tables_dict


def prepare_yt_dir_structure(yt_dir_path) -> str:
    """Creates necessary YT directories if needed.
    :return path to snapshots directory"""
    if not ytw.exists(yt_dir_path):
        logging.warning(f"Specified path does not exist, trying to create {yt_dir_path}...")
        ytw.create("map_node", yt_dir_path)
        logging.warning(f"{yt_dir_path} created successfully")
    snapshots_path = f"{yt_dir_path}/{SNAPSHOTS_DIRECTORY_NAME}"
    if not ytw.exists(snapshots_path):
        logging.warning(f"Snapshot directory '{SNAPSHOTS_DIRECTORY_NAME}' does not exist, creating...")
        ytw.create("map_node", snapshots_path)
    return snapshots_path


def create_new_snapshot(yt_dir_path, schemas_by_table_name, tables, snapshot_id, sandbox_task=None):
    with ytw.Transaction():
        logging.info("Preparing YT directory structure...")
        snapshots_path = prepare_yt_dir_structure(yt_dir_path)
    tables_with_path = [
        (table_name, table_rows, f"{snapshots_path}/{table_name}-{snapshot_id}")
        for table_name, table_rows in tables.items()
    ]
    logging.info("Creating snapshot tables...")
    with ytw.Transaction():
        for table_name, table_rows, table_path in tables_with_path:
            schema = schemas_by_table_name[table_name]
            schema.attributes = {"strict": True, "unique_keys": True}
            table_attributes = {"schema": schema, "dynamic": True, "enable_dynamic_store_read": True}
            ytw.create("table", table_path, attributes=table_attributes)
    logging.info("Trying to set primary medium to SSD...")
    try:  # Try to set primary medium to SSD, which can fail sometimes depending on quotas and accounts
        # First, check that SSD quota is not zero
        with ytw.Transaction():
            test_path = f"{snapshots_path}/TEST_SSD_QUOTAS"
            ytw.create("table", test_path, attributes={"primary_medium": "ssd_blobs"})
            ytw.write_table(test_path, [{"test": 1}])
            ytw.remove(test_path)
        for __, __, table_path in tables_with_path:
            ytw.set(f"{table_path}/@primary_medium", "ssd_blobs")
        logging.info("Primary medium set to SSD successfully.")
    except ytw.errors.YtResponseError as error:
        logging.warning("Setting primary medium to SSD failed!", exc_info=error)
    logging.info("Done creating snapshot tables.")
    with ytw.Transaction():
        for table_name, table_rows, table_path in tables_with_path:
            logging.info(f"Mounting {table_path}...")
            ytw.mount_table(table_path)
    with ytw.Transaction(type="tablet"):
        for table_name, table_rows, table_path in tables_with_path:
            logging.info(f"Inserting {table_name} rows into {table_path}...")
            ytw.insert_rows(table_path, table_rows)
            if sandbox_task:
                sandbox_task.set_info(f"Table {table_path} on cluster {ytw.config['proxy']['url']} filled with rows.")  # pylint: disable=no-member # noqa
    with ytw.Transaction():
        for table_name, table_rows, table_path in tables_with_path:
            logging.info(f"Remounting {table_path} to flush rows...")
            ytw.unmount_table(table_path, sync=True)
            ytw.mount_table(table_path, sync=True)


def verify_snapshots(clusters, yt_dir_path, snapshot_id, sandbox_task=None):
    if EXTREMELY_DANGEROUS_DISABLE_SNAPSHOT_VERIFICATION:
        logging.warning("REALLY DANGEROUS! Snapshot verification is disabled! Hope you know what you are doing!")
        return
    for cluster in clusters:
        try:
            logging.info(f"Checking target snapshot integrity on cluster {cluster}...")
            select_yt_cluster(cluster)
            snapshots_dir_path = f"{yt_dir_path}/{SNAPSHOTS_DIRECTORY_NAME}"
            if not ytw.exists(snapshots_dir_path):
                logging.exception("YT error",
                                  exc_info=YTEnvironmentException(f"Snapshots dir does not exist on cluster {cluster}"))
            child_nodes = set(ytw.list(snapshots_dir_path))
            snapshot_table_names = {f"{table_name}{SNAPSHOT_VERSION_DELIMITER}{snapshot_id}" for table_name in TableNames}  # noqa
            if not snapshot_table_names.issubset(child_nodes):
                raise YTEnvironmentException(
                    f"Snapshot with suffix {snapshot_id} for tables {snapshot_table_names - child_nodes} "
                    f"not found on cluster {cluster}!"
                )
            table_paths = {table_name: f"{snapshots_dir_path}/{table_name}{SNAPSHOT_VERSION_DELIMITER}{snapshot_id}"
                           for table_name in TableNames}
            table_dict = {table_name: list(ytw.read_table(table_paths[table_name])) for table_name in table_paths}
            try:
                diff_list = [f"Diff for cluster {cluster}:"]
                print_diff(lambda s: diff_list.append(str(s)),
                           table_dict["Keyword"],
                           table_dict["KeywordType"],
                           table_dict["KeywordDataType"])
                if sandbox_task:
                    sandbox_task.set_info("\n".join(diff_list))
            except ytw.errors.YtError:
                logging.warning(f"Printing diff failed on cluster {cluster}.")
            if all(len(table) for table in table_dict.values()):
                logging.info("Confirmed all snapshots have rows")
            else:
                raise YTEnvironmentException("One of snapshot tables returns 0 rows when read.")  # That REALLY should not happen # noqa
            verify_snapshot_content(table_dict)
        except (YTEnvironmentException, ytw.errors.YtResponseError) as error:
            logging.exception("YT error", exc_info=error)
    logging.info("Snapshot integrity is verified on target clusters.")


def __change_links(snapshot_id, clusters, yt_dir_path, sandbox_task):
    for cluster in clusters:
        logging.info(f"Dangerous operation! Changing links on cluster {cluster}...")
        select_yt_cluster(cluster)
        snapshots_dir_path = f"{yt_dir_path}/{SNAPSHOTS_DIRECTORY_NAME}"
        try:
            with ytw.Transaction():
                for table_name in TableNames:
                    table_path = f"{yt_dir_path}/{table_name}"
                    snapshot_path = f"{snapshots_dir_path}/{table_name}{SNAPSHOT_VERSION_DELIMITER}{snapshot_id}"
                    node_exists = ytw.exists(table_path)
                    if not node_exists:
                        logging.warning(f"Node {table_path} on cluster {cluster} does not exist. Continuing...")
                    elif ytw.get(f"{table_path}&/@type") != "link":
                        logging.warning(f"Node {table_path} on cluster {cluster} isn't a link. Continuing...")
                    else:
                        old_snapshot_path = ytw.get(f"{table_path}&/@target_path")
                        expiration_time = (datetime.datetime.now() + datetime.timedelta(days=90)).isoformat()
                        if ytw.exists(old_snapshot_path):
                            logging.info(f"Setting TTL {expiration_time} for old snapshot {old_snapshot_path}")
                            ytw.set_attribute(old_snapshot_path, "expiration_time", expiration_time)
                        else:
                            logging.warning(f"{table_path} links to {old_snapshot_path}, which does not exist.")
                    if node_exists:
                        ytw.remove(table_path)
                    ytw.link(snapshot_path, table_path)
                    if sandbox_task:
                        sandbox_task.set_info(f"Linked {table_path} to {snapshot_path} on {cluster}")
                    ytw.set_attribute(table_path, "linked_to_prod_datetime", datetime.datetime.now().isoformat())
        except ytw.YtError as exception:
            GLOBAL_LINK_CHANGING_ERRORS.append((cluster, exception))
            if sandbox_task:
                error_text = f"Changing links failed on cluster {cluster}, see logs for details."
                error_html = f'<p style="color: red">{error_text}</p>'
                sandbox_task.set_info(error_html, do_escape=False)


def create_snapshots(schemas_by_table_name: Dict[str, yson.YsonList], clusters: List[str], yt_dir_path: str,
                     snapshot_id="", sandbox_task=None):
    rev = svn_version.svn_last_revision()
    logging.info(f"Building tables from rev {rev}...")
    if sandbox_task:
        sandbox_task.set_info(f"Building tables from .proto files. The file should be here but "
                              f"(IMPORTANT) it may have been edited locally, without being pushed.\n"
                              f"https://a.yandex-team.ru/arc_vcs/yabs/server/proto/keywords/keywords_data.proto?rev={rev}")  # noqa
    tables = build_tables(keywords_data_pb2.EKeyword.DESCRIPTOR.values_by_number)
    logging.info("Done building tables.")
    if sandbox_task:
        sandbox_task.set_info("All keyword tables built from .proto files")
    for cluster in clusters:
        logging.info(f"Creating snapshot on cluster {cluster}...")
        select_yt_cluster(cluster)
        try:
            create_new_snapshot(yt_dir_path, schemas_by_table_name, tables, snapshot_id, sandbox_task=sandbox_task)
        except (YTEnvironmentException, ytw.errors.YtError) as e:
            GLOBAL_SNAPSHOT_CREATION_ERRORS.append((cluster, e))
            if sandbox_task:
                error_text = f"Snapshot creation failed on cluster {cluster}, see logs for details."
                error_html = f'<p style="color: red">{error_text}</p>'
                sandbox_task.set_info(error_html, do_escape=False)


def main(schemas_by_table_name: Dict[str, yson.YsonList], clusters: List[str], yt_dir_path: str,
         yt_token: Optional[str] = None, should_create_new_snapshot=False, should_change_links=False, snapshot_id="",
         sandbox_task=None):
    if not snapshot_id and should_create_new_snapshot and sandbox_task:
        snapshot_id = str(sandbox_task.id)
    if yt_token:
        ytw.config["token"] = yt_token  # pylint: disable=unsupported-assignment-operation # noqa
    ytw.config["backend"] = "rpc"  # pylint: disable=unsupported-assignment-operation # noqa
    if should_create_new_snapshot:
        create_snapshots(schemas_by_table_name, clusters, yt_dir_path, snapshot_id, sandbox_task=sandbox_task)
    if should_create_new_snapshot or should_change_links:
        verify_snapshots(clusters, yt_dir_path, snapshot_id, sandbox_task=sandbox_task)
    if should_change_links:
        __change_links(snapshot_id, clusters, yt_dir_path, sandbox_task=sandbox_task)
    for cluster_name, error in GLOBAL_SNAPSHOT_CREATION_ERRORS:
        logging.exception(f"Creating snapshot failed on cluster {cluster_name}", exc_info=error)
    for cluster_name, error in GLOBAL_LINK_CHANGING_ERRORS:
        logging.exception(f"Changing links failed on cluster {cluster_name}", exc_info=error)
    if GLOBAL_LINK_CHANGING_ERRORS or GLOBAL_SNAPSHOT_CREATION_ERRORS:
        error_msg = ""
        if GLOBAL_SNAPSHOT_CREATION_ERRORS:
            error_clusters = [cluster for cluster, error in GLOBAL_SNAPSHOT_CREATION_ERRORS]
            error_msg += f"snapshot creation failed on clusters {error_clusters}, "
        if GLOBAL_SNAPSHOT_CREATION_ERRORS:
            error_clusters = [cluster for cluster, error in GLOBAL_SNAPSHOT_CREATION_ERRORS]
            error_msg += f"link changing failed on clusters {error_clusters}, "
        raise YTEnvironmentException(f"{error_msg.capitalize()}see info above or logs for details.")
