from yt.wrapper import ypath_join
from copy import deepcopy


def get_nodes_in_folder(yt_client, folder):
    return yt_client.list(folder, attributes=['target_path', 'type'])


def get_yt_tables(yt_client, archive_root):
    def recurse_folder(yt_client, folder, tables):
        nodes = get_nodes_in_folder(yt_client, folder)
        for node in nodes:
            node_path = ypath_join(folder, node)
            if node.attributes['type'] == "map_node":
                recurse_folder(yt_client, node_path, tables)
            elif 'target_path' in node.attributes:
                target_path = node.attributes.get('target_path')
                # Merge groups
                group = set()
                if node_path in tables:
                    group = tables[node_path]['group']
                if target_path in tables:
                    if group:
                        for path in tables[target_path]['group']:
                            group.add(path)
                            tables[path]['group'] = group
                    else:
                        group = tables[target_path]['group']
                # Grow group
                group.add(node_path)
                group.add(target_path)
                # Set group
                tables[node_path] = {
                    'type': node.attributes['type'],
                    'target_path': target_path,
                    'group': group
                }
                if target_path not in tables:
                    tables[target_path] = {}
                tables[target_path]['group'] = group
            else:
                if node_path not in tables:
                    tables[node_path] = {}
                    tables[node_path]['group'] = set([node_path])
                tables[node_path]['type'] = node.attributes['type']
        return tables

    tables = {}
    recurse_folder(yt_client, archive_root, tables)
    return tables


def set_sampled_tables(tables, sampled_tables):
    updated_tables = deepcopy(tables)
    for sampled_table in sampled_tables:
        if sampled_table.absolute_input_path in tables:
            updated_tables[sampled_table.absolute_input_path]['sampled_path'] = sampled_table.absolute_output_path
        else:
            raise Exception('Not found table: ' + sampled_table.absolute_input_path)
    return updated_tables


def replace_paths_in_spec(spec, tables, sampled_tables):
    updated_spec = deepcopy(spec)
    for key in spec:
        path = spec[key]
        if path in tables:
            for alias in tables[path]['group']:
                # Replace with first found sampled_path
                if 'sampled_path' in tables[alias]:
                    updated_spec[key] = tables[alias]['sampled_path']
                    break
    return updated_spec


# TODO: Make yt requests in parallel with sampling yql request
def replace_table_paths_with_sampled(yt_client, spec, sampled_tables, archive_root):
    tables = get_yt_tables(yt_client, archive_root)
    updated_tables = set_sampled_tables(tables, sampled_tables)
    return replace_paths_in_spec(spec, updated_tables, sampled_tables)
