# -*- coding: utf-8 -*-
from functools import wraps
from io import StringIO
from intranet.yandex_directory.src.yandex_directory.directory_logging.logger import log
from intranet.yandex_directory.src.yandex_directory.common.utils import grouper
from intranet.yandex_directory.src.yandex_directory.common.db import (
    get_shard_numbers,
    get_main_connection,
    get_meta_connection,
)


def by_shards(meta_for_write=True, main_for_write=True, start_transaction=True):
    def inner(func):
        """
        Принимает функцию, которая в качестве двух первых аргументов ожидает
        meta и main коннекты, и вызывает её столько раз, скольо у нас есть шардов,
        каждый раз передавая новый main_connection.
        """

        @wraps(func)
        def wrapper(*args, **kwargs):
            shards = get_shard_numbers()
            with get_meta_connection(for_write=meta_for_write, no_transaction=not start_transaction) as meta_connection:
                for shard in shards:
                    with log.fields(shard=shard), \
                         get_main_connection(shard, for_write=main_for_write, no_transaction=not start_transaction) as main_connection:
                        func(meta_connection, main_connection, *args, **kwargs)
        return wrapper
    return inner


def batch_insert(connection, table, data, batch_size=1000):
    """Вставляет данные в любую произвольную таблицу, имя которой указано в параметре table.
    """
    def prepare(value):
        if isinstance(value, (list, tuple)):
            return '{%s}' % ','.join(map(prepare, value))
        else:
            return str(value)

    batches = grouper(batch_size, data)
    for batch in batches:
        lines = ['\t'.join([prepare(n) for n in list(x.values())])
                      for x in batch]
        output = StringIO('\n'.join(lines))
        cursor = connection.connection.cursor()
        cursor.copy_from(
            output,
            table,
            columns=list(data[0].keys())
        )
