# -*- coding: utf-8 -*-

from itertools import (
    chain,
    repeat,
)

from hamcrest import (
    assert_that,
    contains_inanyorder,
)
import mock
from nose.tools import eq_
from passport.backend.core.db.query import DbTransactionContainer
from passport.backend.core.undefined import Undefined
import six
from six import (
    iteritems,
    string_types,
)
from sqlalchemy.dialects import mysql


def _query_split_field_name(dirty_name):
    begin, _, end = dirty_name.rpartition('_')
    if not begin:
        return end, None
    try:
        return begin, int(end if end.isdigit() else end[1:])
    except ValueError:
        return dirty_name, None


def _convert_query_to_params_list(query, sort=True):
    if isinstance(query, string_types):
        return list()
    dict_params = {}

    for field_name, value in iteritems(query.params):
        base_name, index = _query_split_field_name(field_name)
        index_values = dict_params.setdefault(index, {})
        index_values[base_name] = value

    params_list = dict_params.values()
    if sort:
        return sorted(
            params_list,
            key=lambda params: sorted(params.items()) if isinstance(params, dict) else params,
        )
    else:
        return params_list


def _compile_human_readable_query(query, dialect=None):
    if isinstance(query, string_types):
        return query
    compiled = query.compile(dialect=dialect or mysql.dialect())

    parameters = []
    for param_ind in compiled.positiontup:
        value = compiled.binds[param_ind].value
        if isinstance(value, six.binary_type):
            # Отключаем strict-декодинг, так как иногда в значениях
            # могут быть бинарные строки с невалидными utf8-символами.
            value = value.decode('utf-8', errors='replace')
        parameters.append(value)

    sql = compiled.string.replace('\n', '')
    if parameters:
        sql = sql % tuple(parameters)
    return sql


def compile_query_with_dialect(query, dialect=None):
    if isinstance(query, string_types):
        return query
    return query.compile(dialect=dialect or mysql.dialect())


def _get_compiled_queries(queries, human_readable=False):
    output = []
    if human_readable:
        processor = _compile_human_readable_query
    else:
        processor = compile_query_with_dialect

    for query in queries:
        output.append(processor(query))
    return output


def _get_formatted_queries(queries):
    return map(six.text_type, _get_compiled_queries(queries, human_readable=True))


def _get_queries_params(queries):
    return [
        _convert_query_to_params_list(query)
        for query in _get_compiled_queries(queries)
    ]


def _call_hook_on_mocked_result(hook, inserted_keys, row_count):
    # Эмулируем возврат БД полноценного идентификатора новой записи
    # для правильного выставления ID на объекте.
    # TODO: проверять тип запроса и выдавать их только для INSERT
    next_inserted = next(inserted_keys)

    mocked_result = mock.Mock(
        inserted_primary_key=[next_inserted],
        rowcount=next(row_count),
    )
    hook(mocked_result)


def _split_query_and_hook(call):
    hook = None
    if isinstance(call, string_types):
        return call, hook
    try:
        query, hook = call
    except TypeError:
        query = call
    return query, hook


def get_executed_queries(
    executed_calls, is_eav=False, inserted_keys=None, row_count=None
):
    # Создаем список ID вставленных записей для первых N запросов, у которых есть
    # спецобработчик. Все прочие получат Undefined и это знак - возможно,
    # вы забыли указать ID какому-то из тестируемых INSERT'ов.
    inserted_keys = chain(
        inserted_keys or [],
        repeat(Undefined),
    )
    row_count = chain(row_count or [], repeat(Undefined))

    executed = []

    for call in executed_calls:
        query, hook = _split_query_and_hook(call)

        if isinstance(query, DbTransactionContainer):
            transaction_queries = []
            for inner_call in query.get_queries():
                inner_query, hook = _split_query_and_hook(inner_call)
                if hook:
                    _call_hook_on_mocked_result(hook, inserted_keys, row_count)
                transaction_queries.append(
                    inner_query.to_query() if is_eav else inner_query)
            if transaction_queries:
                executed.extend(['BEGIN'] + transaction_queries + ['COMMIT'])
        else:
            if hook:
                _call_hook_on_mocked_result(hook, inserted_keys, row_count)
            executed.append(query.to_query() if is_eav else query)

    return executed


def compare_queries_strict(executed, expected):
    eq_(
        len(executed),
        len(expected),
        'Queries counts don\'t match. \nExecuted: \n\t%s\n Expected: \n\t%s' % (
            '\n\t'.join(_get_formatted_queries(executed)),
            '\n\t'.join(_get_formatted_queries(expected)),
        ),
    )

    for i, (executed_q, expected_q) in enumerate(zip(executed, expected)):
        executed_compiled_q = compile_query_with_dialect(executed_q)
        expected_compiled_q = compile_query_with_dialect(expected_q)

        eq_(
            str(executed_compiled_q),
            str(expected_compiled_q),
            ('SQL queries are not the same. The first difference found on position %s.\n'
             'Found queries: \n\t%s\n'
             'Expected queries: \n\t%s') % (
                i,
                '\n\t'.join(_get_formatted_queries(executed)),
                '\n\t'.join(_get_formatted_queries(expected)),
            ),
        )

        eq_(
            _convert_query_to_params_list(executed_compiled_q),
            _convert_query_to_params_list(expected_compiled_q),
            ('SQL queries are not the same. The first difference found on position %s.\n'
             'Executing SQL query: %s\n'
             'Found params: \n\t%s\n'
             'Expected params: \n\t%s') % (
                i,
                str(expected_compiled_q),
                '\n\t'.join(map(six.text_type, _get_queries_params(executed))),
                '\n\t'.join(map(six.text_type, _get_queries_params(expected))),
            ),
        )


class QueryWrapper(object):
    def __init__(self, query_obj):
        self.query_obj = query_obj
        self.compiled = compile_query_with_dialect(query_obj)
        self.parameters = _convert_query_to_params_list(self.compiled)

    def __eq__(self, other):
        if not isinstance(other, QueryWrapper):
            raise ValueError('Cannot compare QueryWrapper to {:!r}'.format(other))
        return (
            str(other.compiled) == str(self.compiled) and
            other.parameters == self.parameters
        )

    def __repr__(self):
        return '{}\n{}'.format(
            six.text_type(_compile_human_readable_query(self.query_obj)),
            six.text_type(self.parameters),
        )


def compare_queries_any_order(executed, expected):
    assert_that(
        [QueryWrapper(q) for q in executed],
        contains_inanyorder(
            *[QueryWrapper(q) for q in expected]
        ),
    )


# TODO параметр is_eav не используется никак, кроме True,
#  с False поведение странное
def _eq_queries(
    executed_calls, expected, is_eav=False, inserted_keys=None, row_count=None,
    comparator=compare_queries_strict,
):
    executed = get_executed_queries(executed_calls, is_eav, inserted_keys, row_count)
    comparator(executed=executed, expected=expected)


def eq_eav_queries(
    executed_calls, expected, inserted_keys=None, row_count=None, any_order=False,
):
    _eq_queries(
        executed_calls,
        expected,
        is_eav=True,
        inserted_keys=inserted_keys,
        row_count=row_count,
        comparator=compare_queries_any_order if any_order else compare_queries_strict,
    )
