from maps.wikimap.stat.libs.nile_utils import (
    get_tables_with_names,
    put_tables_by_prefixes,
)
from nile.api.v1 import (
    clusters,
    Record,
)
import pytest


def get_cluster(use_yql):
    return clusters.MockYQLCluster() if use_yql else clusters.MockCluster()


def get_nile_str_factory(use_yql):
    if use_yql:
        return str
    return lambda value: bytes(value, encoding='utf8')


class DebugRun(object):
    def __init__(self, cluster, init_tables):
        self.job = cluster.job()
        self._debug_cluster = cluster.debug
        for path, records in init_tables.items():
            self._debug_cluster.write(path, records)

    def __call__(self, result_table_paths):
        self.job.debug.run()
        return [self._debug_cluster.read(path) for path in result_table_paths]


@pytest.mark.parametrize('use_yql', (True, False))
def test_get_tables_with_names(use_yql):
    debug_run = DebugRun(
        get_cluster(use_yql),
        {
            '//path/to/input_dir/table1': [
                Record(test_column=11),
                Record(test_column=12),
            ],
            '//path/to/input_dir/table2': [
                Record(test_column=21),
            ],
        }
    )

    get_tables_with_names(
        debug_run.job,
        '//path/to/input_dir',
        'name_column',
        ['table1', 'table2']
    ).put(
        '//path/to/result_table'
    )

    result, = debug_run(['//path/to/result_table'])
    assert sorted(result) == sorted([
        Record(name_column='table1', test_column=11),
        Record(name_column='table1', test_column=12),
        Record(name_column='table2', test_column=21),
    ])


@pytest.mark.parametrize('use_yql', (True, False))
def test_should_put_tables_by_prefixes(use_yql):
    nile_str = get_nile_str_factory(use_yql)
    debug_run = DebugRun(
        get_cluster(use_yql),
        {
            '//path/to/input_table': [
                Record(value=nile_str('abc1'), test_column=11),
                Record(value=nile_str('bcd1'), test_column=21),
                Record(value=nile_str('bcd2'), test_column=22),
                Record(value=nile_str('cde1'), test_column=31),
                Record(value=nile_str('def1'), test_column=41),
            ],
        }
    )

    put_tables_by_prefixes(
        debug_run.job.table('//path/to/input_table'),
        '//path/to/result_dir',
        'value', ['abc', 'bcd', 'cde'],
    )

    abc_result, bcd_result, cde_result = debug_run([
        '//path/to/result_dir/abc',
        '//path/to/result_dir/bcd',
        '//path/to/result_dir/cde',
    ])

    assert sorted(abc_result) == sorted([
        Record(value=nile_str('abc1'), test_column=11),
    ])
    assert sorted(bcd_result) == sorted([
        Record(value=nile_str('bcd1'), test_column=21),
        Record(value=nile_str('bcd2'), test_column=22),
    ])
    assert sorted(cde_result) == sorted([
        Record(value=nile_str('cde1'), test_column=31)
    ])
