from nile.api.v1 import clusters, local
from nile.api.v1.stream import BatchStream


class Job:
    '''
    A placeholder class. If passed to the `run_function()` function as an
    argument then it is replaced with a `MockCluster().job()`.
    '''
    pass


class Table:
    '''
    A wrapper class for a list of records. If passed to the `run_function()`
    function as an argument then it is replaced with a `local.StreamSource()`.
    '''
    def __init__(self, data, schema=None):
        '''
        Parameters
        ----------
        data
            Data to be wrapped into a stream source.
        schema:
            Use specified schema instead of one derived from data.
        '''
        assert type(data) is list

        self.data = data
        self.schema = schema


def __prepare_args_sources(args, job, sources):
    func_args = list()
    for value in args:
        if type(value) is Table:
            arg_num = len(func_args)
            label = f'nile ut - {arg_num}'
            sources[label] = local.StreamSource(value.data, value.schema)
            func_args.append(job.table(f'//path/to/table/{arg_num}').label(label))
        elif type(value) is Job:
            func_args.append(job)
        else:
            func_args.append(value)

    return func_args


def __prepare_kwargs_sources(kwargs, job, sources):
    func_kwargs = dict()
    for name, value in kwargs.items():
        if type(value) is Table:
            label = f'nile ut - {name}'
            sources[label] = local.StreamSource(value.data, value.schema)
            func_kwargs[name] = job.table(f'//path/to/table/{name}').label(label)
        elif type(value) is Job:
            func_kwargs[name] = job
        else:
            func_kwargs[name] = value

    return func_kwargs


def __prepare_result_sinks(func_result, sinks):
    result = list()

    if type(func_result) is BatchStream:
        for stream in func_result:
            result.append(list())
            label = f'nile ut result: {len(result)}'
            stream.label(label)
            sinks[label] = local.ListSink(result[-1])
    else:
        func_result.label('nile ut result')
        sinks['nile ut result'] = local.ListSink(result)

    return result


def __run_function(job, func, args, kwargs):
    sources = dict()
    func_args = __prepare_args_sources(args, job, sources)
    func_kwargs = __prepare_kwargs_sources(kwargs, job, sources)

    sinks = dict()
    result = __prepare_result_sinks(func(*func_args, **func_kwargs), sinks)

    job.local_run(sources=sources, sinks=sinks)

    return result


def yt_run(func, *args, **kwargs):
    '''
    Locally runs a function suitable to be run on YT cluster that takes streams
    as its input and returns a stream as its output. However, instead of streams
    lists of Records are used.

    Parameters
    ----------
    func
        A function to be executed.
    args, kwargs
        Arguments to be passed to the excuted function. Any argument of types
        `Table` and `Job` is replaced with a `local.StreamSource()` and a
        `MockCluster().job()` correspondingly.

    Result
    ------
    Records collected by local.ListSink()` as a result of local run of the
    func function.

    If func returns exactly one stream then all records are returned in a list.

    If func returns several streams then records are returned in list of lists
    (where each sublist corresponds to a stream).

    Example
    -------
    >>> def swap_columns(table):
    >>>     return table.project(a='b', b='a')
    >>>
    >>> nile_ut.yt_run(
    >>>     swap_columns,
    >>>     table=nile_ut.Table([
    >>>         Record(a=1, b=1),
    >>>         Record(a=2, b=4),
    >>>         Record(a=3, b=9),
    >>>     ])
    >>> )
    [Record(a=1, b=1), Record(a=4, b=2), Record(a=9, b=3)]
    '''
    return __run_function(clusters.MockCluster().job(), func, args, kwargs)


def yql_run(func, *args, **kwargs):
    '''
    Locally runs a function suitable to be run on YQL cluster that takes streams
    as its input and returns a stream as its output. However, instead of streams
    lists of Records are used.

    Parameters and return values are equal to `yt_run()` function.

    `yql/library/python` must be added into `PEERDIR` section of tests `ya.make`
    if this function is used.
    '''
    return __run_function(clusters.MockYQLCluster().job(), func, args, kwargs)
