import collections
import logging
import six

from .base import truncate_base_prefix
from .general import get_json_md5

from sandbox.projects.yabs.sandbox_task_tracing import trace_calls


logger = logging.getLogger(__name__)


class CSDataError(Exception):
    pass


def cache_one_value(function):
    'Decorator caching return value for one combination of arguments.'
    cache = [object(), None]
    def wrapper(*args, **kwargs):
        key = (tuple(args), tuple(sorted(kwargs.items())))
        cached_key, cached_result = cache
        if cached_key == key:
            return cached_result
        result = function(*args, **kwargs)
        cache[:] = (key, result)
        return result
    wrapper.__name__ = '{}({})'.format(cache_one_value.__name__, function.__name__)
    return wrapper


@trace_calls
def build_importer_graph(dependencies):
    """
    Build directed graph of importers.
    If there is and edge (u, v), it means that importer 'v' depends on importer 'u'

    :param dependencies: Map from importer in its dependencies
    :type dependencies: dict
    :return: Directed graph of importers
    :rtype: nx.DiGraph
    """
    import networkx as nx

    importer_graph = nx.DiGraph()
    for importer_name, importer_dependencies in dependencies.items():
        importer_graph.add_node(importer_name)
        importer_graph.add_edges_from([
            (dependency, importer_name)
            for dependency in importer_dependencies
        ])
    return importer_graph


@trace_calls
def build_importer_graph_from_importers_info(importers_info):
    """
    Build directed graph of importers.
    If there is and edge (u, v), it means that importer 'v' depends on importer 'u'
    Raises CSDataError if some importers_info is missing `'dependencies'` key.

    :param importers_info: importers info, output of 'cs import --print-info --outputs-version=2'
    :type importers_info: dict
    :return: Directed graph of importers
    :rtype: nx.DiGraph
    """
    importer_direct_dependencies = {
        importer_name: importer_info.get('dependencies')
        for importer_name, importer_info in importers_info.items()
    }
    logger.debug('Importers direct dependencies: %s', importer_direct_dependencies)
    for importer_name, importer_dependencies in importer_direct_dependencies.items():
        if not isinstance(importer_dependencies, list):
            raise CSDataError('importers_info[{!r}][{!r}] is missing or not a list'.format(importer_name, 'dependencies'))
    return build_importer_graph(importer_direct_dependencies)


def get_importers_after(importer, importer_graph):
    """Get importers that run after this importer

    :param importer: Importer name
    :type importer: str
    :param importer_graph: Directed graph of importers
    :type importer_graph: nx.DiGraph
    :return: Dependent importers
    :rtype: list
    """
    import networkx as nx

    if importer not in importer_graph.nodes:
        raise CSDataError('No importer_info for importer {}'.format(importer))

    nodes = nx.dfs_postorder_nodes(importer_graph, importer)
    return list(set(nodes) - {importer})


def get_importers_before(importer, importer_graph):
    """Get importers that run before this importer

    :param importer: Importer name
    :type importer: str
    :param importer_graph: Directed graph of importers
    :type importer_graph: nx.DiGraph
    :return: Dependent importers
    :rtype: list
    """
    import networkx as nx

    if importer not in importer_graph.nodes:
        raise CSDataError('No importer_info for importer {}'.format(importer))

    dependencies = set()
    for edge in nx.edge_dfs(importer_graph, importer, orientation='reverse'):
        dependencies.add(edge[0])
    return list(dependencies)


def get_importers_with_dependencies(importers, importers_info):
    """Get list of importers with their dependencies

    :param importers: list of importer names
    :type importers: iterable
    :param importers_info: importers info, output of 'cs import --print-info --outputs-version=2'
    :type importers_info: dict
    :return: list of unique importers
    :rtype: list
    """
    importer_graph = build_importer_graph_from_importers_info(importers_info)
    result = set()
    for importer in importers:
        logger.debug('Importer: %s', importer)

        importers_before = get_importers_before(importer, importer_graph)
        logger.debug('Importers before: %s', importers_before)

        result.update(importers_before, [importer])

    return sorted(result)


def get_bases_importers_with_dependencies(base_tags, importers_info, mkdb_info):
    """Get importers involved in bases generation (all levels)

    :param base_tags: binary base tags
    :type base_tags: iterable
    :param importers_info: importers info, output of 'cs import --print-info --outputs-version=2'
    :type importers_info: dict
    :param mkdb_info: mkdb info, output of 'dbtool mkdb_info'
    :type mkdb_info: dict
    :return: list of unique importers
    :rtype: list
    """
    importers = get_bases_importers(base_tags, importers_info, mkdb_info)
    logger.debug('Direct importers for bases %s: %s', base_tags, importers)
    return get_importers_with_dependencies(importers, importers_info)


def get_bases_importers(base_tags, importers_info, mkdb_info):
    """Get importers involved in bases generation

    :param base_tags: binary base tags
    :type base_tags: iterable
    :param importers_info: importers info, output of 'cs import --print-info --outputs-version=2'
    :type importers_info: dict
    :param mkdb_info: mkdb info, output of 'dbtool mkdb_info'
    :type mkdb_info: dict
    :return: list of unique importers
    :rtype: list
    """
    bases = set(map(truncate_base_prefix, base_tags))

    result = []
    for importer_name, importer_info in sorted(importers_info.items()):
        if not get_importer_bases(importer_info, mkdb_info).isdisjoint(bases):
            result.append(importer_name)

    return result


@cache_one_value
@trace_calls
def _get_mkdb_queries_bases(mkdb_info):
    result = collections.defaultdict(set)
    for base, base_mkdb_info in sorted(mkdb_info.items()):
        for query in base_mkdb_info['Queries']:
            query_bases = result[query['Name']]
            if base in query_bases:
                raise CSDataError('Seems like base {} invokes query {} multiple times, base_mkdb_info={}'.format(
                    base,
                    query,
                    base_mkdb_info,
                ))
            query_bases.add(base)
    # switch back from `defaultdict(set)` to `dict` to ensure unknown keys raise KeyError
    return dict(result)


def _get_base_group(base):
    # remove all trailing digits and underscores; zeroth shard (e.g. lmng_0) is an exception forming its own group
    return base if base.endswith('_0') else base.rstrip('_0123456789')


def _get_output_bases(output, mkdb_info):
    mkdb_queries_bases = _get_mkdb_queries_bases(mkdb_info)

    candidate_bases = set()
    for query in output['mkdb_queries']:
        candidate_bases.update(mkdb_queries_bases[query])

    base_not_excluded = lambda base: _get_base_group(base) not in output['excluded_base_groups']
    return set(filter(base_not_excluded, candidate_bases))


def get_importer_bases(importer_info, mkdb_info):
    """Get bases that depend on importer

    :param importer_info: importer info, single value from output of 'cs import --print-info --outputs-version=2'
    :type importers_info: dict
    :param mkdb_info: mkdb info, output of 'dbtool mkdb_info'
    :type mkdb_info: dict
    :return: bases set
    :rtype: set
    """
    return set().union(*(_get_output_bases(output, mkdb_info) for output in importer_info['outputs']))


def get_importer_bases_by_tables(tables, importers_info, mkdb_info):
    bases = set()
    for importer, importer_info in six.iteritems(importers_info):
        importer_tables = importer_info["tables"]
        for table in importer_tables:
            full_table_name = table["path"]
            if "table" in table:
                full_table_name += '/' + table["table"]
            if full_table_name in tables:
                bases.update(get_importer_bases(importer_info, mkdb_info))
    return bases


def get_importer_output_tables(importer_info, mkdb_info, base_tags):
    bases = set(map(truncate_base_prefix, base_tags))
    return sorted(
        output['path'] for output in importer_info['outputs']
        if not (output['mkdb_queries'] and _get_output_bases(output, mkdb_info).isdisjoint(bases))
    )


def get_importers_source_tables(importers_info):
    tables = []
    need_mount_tables = set()
    for _, importer_data in importers_info.items():
        importer_tables = importer_data["tables"]
        for table in importer_tables:
            full_table_name = table["path"]
            if "table" in table:
                full_table_name += '/' + table["table"]
            tables.append(full_table_name)
            if table.get('need_mount'):
                need_mount_tables.add(full_table_name.lstrip('//'))
    return tables, need_mount_tables


def get_importer_mkdb_info_version(importer_info, mkdb_info):
    importer_bases = get_importer_bases(importer_info, mkdb_info)
    if not importer_bases:
        return 'N/A'

    mkdb_info_version = {
        base: mkdb_info.get(base)
        for base in importer_bases
    }
    for base, mkdb_info_entry in mkdb_info_version.items():
        if not mkdb_info_entry:
            raise CSDataError('No mkdb_info for base {!r}'.format(base))
    return get_json_md5(mkdb_info_version)


def is_mysql_importer(importer_info):
    return bool(importer_info["queries"])
