from __future__ import absolute_import
import datetime
import logging
import os

import retry
import yt.wrapper as yt
import yt.logger as yt_logger

from crypta.lib.python import script_name
import crypta.lib.python.arg_to_list as arg_to_list
from crypta.lib.python.data_size import DataSize
from crypta.lib.python.logging import logging_helpers
import crypta.lib.python.yql.proto_field as yql_proto_field
from crypta.lib.python.yt import path_utils
import crypta.lib.python.yt.schema_utils as schema_utils


STORAGE_MODES = {
    'compact': {
        'compression_codec': 'brotli_8',
    },
    'default': {
        'compression_codec': 'brotli_3',
    },
    'foreign': {
        'compression_codec': 'none',
        'replication_factor': 6,
    }
}

logger = logging.getLogger(__name__)


def wrap_yt_run(func):
    def wrapper(*args, **kwargs):
        # update spec with script name annotations for resource tracking
        spec = kwargs.get('spec', dict())
        annotations = spec.get('annotations', dict())
        annotations['script_name'] = script_name.detect_script_name(skip_locations={"crypta/lib/python/yt"})
        spec['annotations'] = annotations
        kwargs['spec'] = spec
        return func(*args, **kwargs)

    return wrapper


class AnnotatedYtClient(yt.YtClient):
    def __init__(self, *args, **kwargs):
        super(AnnotatedYtClient, self).__init__(*args, **kwargs)

        self.run_map = wrap_yt_run(self.run_map)
        self.run_reduce = wrap_yt_run(self.run_reduce)
        self.run_map_reduce = wrap_yt_run(self.run_map_reduce)
        self.run_sort = wrap_yt_run(self.run_sort)
        self.run_merge = wrap_yt_run(self.run_merge)
        self.run_join_reduce = wrap_yt_run(self.run_join_reduce)


def get_yt_client(yt_proxy, yt_pool=None, yt_token=None, yt_prefix="//",  acl=None, read_parallel=None,
                  remote_temp_tables_directory=None, logging_level=logging.INFO, with_rpc=False):
    yt_token = os.getenv("YT_TOKEN", yt_token)
    assert yt_token is not None, "No yt_token is specified, and can't find environment variable YT_TOKEN"

    # YtError: PREFIX should end with /
    if not yt_prefix.endswith("/"):
        yt_prefix += "/"

    yt_config = {
        "proxy": {"url": yt_proxy},
        "token": yt_token,
        "prefix": yt_prefix
    }

    if yt_pool is not None:
        yt_config["pool"] = yt_pool

    if read_parallel is not None:
        yt_config['read_parallel'] = read_parallel

    if remote_temp_tables_directory is not None:
        yt_config['remote_temp_tables_directory'] = remote_temp_tables_directory

    if acl:
        yt_config['spec_defaults'] = {'acl': acl}

    if with_rpc:
        yt_config['backend'] = 'rpc'

    yt_logger.LOGGER.setLevel(logging_level)

    return AnnotatedYtClient(config=yt_config)


def copy_unprocessed_tables(yt_proxy, source_dir, destination_dir, state_table, max_tables=12 * 24 * 7):
    logger = logging_helpers.register_stdout_logger(__name__)

    yt_client = get_yt_client(yt_proxy)
    table_field = "table"

    logger.info("Read source tables in %s", source_dir)
    source_tables = sorted(yt_client.list(source_dir))
    logger.info("Source tables: %s", source_tables)
    logger.info("Total: %s", len(source_tables))

    logger.info("Read processed tables from %s", state_table)
    processed_tables = [] if not yt_client.exists(state_table) else sorted([row[table_field] for row in yt_client.read_table(state_table)])
    logger.info("Processed tables: %s", processed_tables)
    logger.info("Total: %s", len(processed_tables))

    source_tables_set = set(source_tables)
    processed_tables_set = set(processed_tables)
    not_too_old_tables = set(sorted(source_tables_set | processed_tables_set)[-max_tables:])
    unprocessed_tables = sorted((source_tables_set - processed_tables_set) & not_too_old_tables)
    logger.info("Unprocessed tables: %s", unprocessed_tables)
    logger.info("Total: %s", len(unprocessed_tables))

    if not unprocessed_tables:
        logger.info("There is no unprocessed tables. Nothing to do")
        return

    for table in unprocessed_tables:
        source_table = os.path.join(source_dir, table)
        destination_table = os.path.join(destination_dir, table)
        logger.info("Link %s to %s", source_table, destination_table)
        yt_client.link(source_table, destination_table, recursive=True, ignore_existing=True)

    yt_client.create("table", path=state_table, recursive=True, ignore_existing=True)
    rows = [{table_field: table} for table in sorted(not_too_old_tables)]
    logger.info("Update state table %s", state_table)
    yt_client.write_table(state_table, rows)


def get_expiration_time_from_creation_time(cypress_path, ttl_timedelta, yt_client):
    creation_time_str = get_attribute(cypress_path, "creation_time", yt_client)
    creation_time = datetime.datetime.strptime(creation_time_str, "%Y-%m-%dT%H:%M:%S.%fZ")
    return creation_time + ttl_timedelta


def set_ttl(
    table,
    ttl_timedelta,
    yt_client=yt,
    remove_if_empty=False,
    expiration_time_func=get_expiration_time_from_creation_time,
):
    if remove_if_empty and is_empty(yt_client, table):
        logger.info("Table %s is empty, so removing instead of setting ttl", table)
        yt_client.remove(table)
        return

    logger.info("Setting ttl for %s of %s using %s", table, ttl_timedelta, expiration_time_func.__name__)
    expiration_time = expiration_time_func(table, ttl_timedelta, yt_client)
    yt_client.set(
        os.path.join(table, '@expiration_time'),
        expiration_time.strftime("%Y-%m-%d %H:%M:%S.%f+00:00")
    )
    logger.info("Done")


def set_ttl_by_table_name(cypress_path, ttl_timedelta, yt_client, remove_if_empty=False, name_format="%Y-%m-%d"):
    def get_expiration_time_from_table_name(cypress_path, ttl_timedelta, yt_client):
        name = path_utils.get_basename(cypress_path)
        return datetime.datetime.strptime(name, name_format) + ttl_timedelta

    return set_ttl(cypress_path, ttl_timedelta, yt_client, remove_if_empty=remove_if_empty, expiration_time_func=get_expiration_time_from_table_name)


def make_backup_table(src_table, backup_table, ttl_timedelta, yt_client=yt):
    logger.info("Backing up %s to %s", src_table, backup_table)
    yt_client.copy(src_table, backup_table, recursive=True, force=True)
    logger.info("Done")

    set_ttl(backup_table, ttl_timedelta, yt_client=yt_client)


def backup_local_file(local_path, yt_path, ttl_timedelta=None, yt_client=yt):
    logger.info("Backing up local file %s to YT as %s ...", local_path, yt_path)
    with open(local_path, "rb") as f:
        yt_client.write_file(yt_path, f)
    logger.info("Done")

    if ttl_timedelta is not None:
        set_ttl(yt_path, ttl_timedelta, yt_client=yt_client)


def _get_attr_path(node_path, attribute):
    return yt.ypath_join(node_path, "@{}".format(attribute))


def set_attribute(node_path, attribute, value, client=yt):
    client.set(_get_attr_path(node_path, attribute), value)


def get_attribute(node_path, attribute, client=yt):
    return client.get(_get_attr_path(node_path, attribute))


def has_attribute(node_path, attribute, client=yt):
    return client.exists(_get_attr_path(node_path, attribute))


def get_attributes(node_path, attr_list, client=yt):
    return client.get(node_path, attributes=attr_list).attributes


def wait_for_mounted(client, table, timeout=60):
    def is_mounted():
        assert get_attribute(table, "tablet_state", client) == "mounted"
        statistics = client.get(table + "/@tablets/0/statistics")
        assert statistics["store_count"] > 0 and statistics["preload_pending_store_count"] == 0

    retry.retry_call(is_mounted, tries=timeout, delay=1, exceptions=AssertionError)


def wait_for_frozen(client, table, timeout=60):
    def is_frozen():
        assert get_attribute(table, "tablet_state", client) == "frozen"

    retry.retry_call(is_frozen, tries=timeout, delay=1, exceptions=AssertionError)


def wait_for_unmounted(client, table, timeout=60):
    def is_unmounted():
        assert get_attribute(table, "tablet_state", client) == "unmounted"

    retry.retry_call(is_unmounted, tries=timeout, delay=1, exceptions=AssertionError)


def wait_for_replicas_enabled(master_client, master_table, timeout=60):
    def are_enabled():
        assert all(replica["state"] == "enabled" for replica in get_attribute(master_table, "replicas", master_client).values())

    retry.retry_call(are_enabled, tries=timeout, delay=1, exceptions=AssertionError)


def has_optimize_for(client, path):
    return client.exists(_get_attr_path(path, "optimize_for"))


def get_optimize_for(client, path):
    return get_attribute(path, "optimize_for", client)


def get_cluster_name(yt_client):
    return get_attribute("//sys", "cluster_name", yt_client)


def get_table_attributes(yt_client, node):
    attributes = yt_client.get('{node}/@'.format(node=node))
    return attributes


def create_folder(yt_client, folder):
    if not yt_client.exists(folder):
        yt_client.mkdir(folder, recursive=True)


def write_pandas_dataframe(yt_client, path, df):
    create_folder(yt_client, os.path.dirname(path))
    yt_client.write_table(path, (row.to_dict() for _, row in df.iterrows()))


def create_empty_table(yt_client, path, compression='default', schema=None, additional_attributes=None,
                       erasure=True, force=True):
    if yt_client.exists(path):
        if force:
            yt_client.remove(path)
        else:
            raise ValueError('Table already exists. If you want to overwrite it set force=True')

    attributes = {}
    attributes.update(STORAGE_MODES[compression])
    if schema:
        if isinstance(schema, dict):
            schema = schema_utils.yt_schema_from_dict(schema)
        attributes['schema'] = schema
    if additional_attributes:
        attributes.update(additional_attributes)
    if erasure:
        attributes['erasure_codec'] = 'lrc_12_2_2'
    yt_client.create('table', path, recursive=True, attributes=attributes)


def make_sample_with_rate(yt_client, source_table, destination_table, rate):
    yt_client.run_merge(
        source_table,
        destination_table,
        spec={
            'job_io': {'table_reader': {'sampling_rate': rate}},
            "force_transform": True,
        },
    )


def make_sample_with_size(yt_client, source_table, destination_table, size):
    src_size = int(yt_client.get_attribute(source_table, 'row_count'))
    rate = min(1.0 * size / src_size, 1.0)
    make_sample_with_rate(yt_client, source_table, destination_table, rate)


def is_sorted(yt_client, table, sorted_by=None):
    if not yt_client.is_sorted(table):
        return False
    if sorted_by is not None:
        sorted_by_columns = yt_client.get_attribute(table, 'sorted_by')
        sorted_by = arg_to_list.arg_to_list(sorted_by)
        return list(sorted_by) == sorted_by_columns
    return True


def sort_if_needed(yt_client, table, sort_by):
    if not is_sorted(yt_client, table, sort_by):
        yt_client.run_sort(table, sort_by=sort_by)


def is_empty(yt_client, table):
    return int(get_attribute(table, "row_count", yt_client)) == 0


def set_yql_proto_field(table, field, message_type, yt_client=yt):
    attr_key, attr_value = yql_proto_field.get_attr(field, message_type)
    set_attribute(table, attr_key, attr_value, yt_client)


def set_yql_proto_fields(table, message_type, yt_client=yt):
    for attr_key, attr_value in yql_proto_field.get_attrs(message_type).items():
        set_attribute(table, attr_key, attr_value, yt_client)


def get_yt_client_from_nv_parameters(nv_parameters, tmp_directory='//tmp'):
    return get_yt_client(
        yt_proxy=nv_parameters['mr-default-cluster'],
        yt_pool=nv_parameters['yt-pool'],
        yt_token=nv_parameters['yt-token'],
        remote_temp_tables_directory=tmp_directory,
    )


def set_directory_auto_compression(
    map_node,
    min_table_age=datetime.timedelta(),
    min_table_size=DataSize(b=0),
    erasure_codec="lrc_12_2_2",
    compression_codec="brotli_8",
    pool="crypta_all",
    yt_client=yt,
):
    yt_client.set_attribute(map_node, "nightly_compression_settings", {
        "enabled": True,
        "erasure_codec": erasure_codec,
        "compression_codec": compression_codec,
        "min_table_age": min_table_age.total_seconds(),
        "min_table_size": min_table_size.total_bytes(),
        "pool": pool,
    })


def get_yt_schema_dict_from_table(yt_client, table):
    schema = {}
    for field_desc in yt_client.get_attribute(table, 'schema'):
        schema[field_desc['name']] = field_desc['type']
    return schema


def get_yt_secure_vault_env_var_for(variable):
    return "YT_SECURE_VAULT_{}".format(variable)


def write_stats_to_yt(yt_client, table_path, data_to_write, schema=None, fielddate='fielddate', date=str(datetime.date.today())):
    if not isinstance(data_to_write, list):
        data_to_write = [data_to_write]

    for row in data_to_write:
        if fielddate not in row or row[fielddate] is None:
            row[fielddate] = date

    if not yt_client.exists(table_path):
        assert schema is not None
        if fielddate not in schema:
            schema[fielddate] = 'string'
        create_empty_table(
            yt_client=yt_client,
            path=table_path,
            schema=schema,
        )

    yt_client.write_table(
        yt_client.TablePath(table_path, append=True),
        data_to_write,
    )
