# -*- coding: utf-8 -*-

import logging
import re

from sqlalchemy.ext.compiler import compiles
from sqlalchemy.sql.expression import (
    Delete,
    Insert,
)


log = logging.getLogger(__name__)


INSERT_VALUES_REGEX = re.compile('(?<=VALUES \().+(?=\))')


@compiles(Insert)
def passport_sqlalchemy_insert_ext(insert, compiler, **kw):
    insert.include_insert_from_select_defaults = True
    s = compiler.visit_insert(insert, **kw)

    if insert.kwargs.get('mysql_binary_value_ids') and compiler.dialect.name == 'mysql':
        # Добавляем указанным аргументам из VALUES префикс _binary
        # Нужно, чтобы обойти багофичу, возникшую в mysql-server>=5.5.46
        # (https://bugs.mysql.com/bug.php?id=79317)
        # TODO: сделать поддержку склеенных запросов
        values = INSERT_VALUES_REGEX.search(s).group().split(', ')
        for value_id in set(insert.kwargs['mysql_binary_value_ids']):
            values[value_id] = '_binary %s' % values[value_id]

        s = INSERT_VALUES_REGEX.sub(', '.join(values), s)

    if insert.kwargs.get('mysql_on_duplicate_update_keys') and not insert.kwargs.get('mysql_if_equals'):
        keys = insert.kwargs['mysql_on_duplicate_update_keys']
        if compiler.dialect.name == 'mysql':
            statement = _MySqlOnDuplicateUpdateKeys(insert.table)
            for key in keys:
                statement.add_values_key(key)
            s += ' ' + statement.compile()
        elif compiler.dialect.name == 'sqlite':
            s = s.replace('INSERT', 'INSERT OR REPLACE', 1)

    elif insert.kwargs.get('mysql_on_duplicate_append_keys'):
        keys = insert.kwargs['mysql_on_duplicate_append_keys']
        if compiler.dialect.name == 'mysql':
            statement = _MySqlOnDuplicateUpdateKeys(insert.table)
            for key in keys:
                statement.add_concat_values_key(key)
            s += ' ' + statement.compile()
        elif compiler.dialect.name == 'sqlite':
            s = s.replace('INSERT', 'INSERT OR REPLACE', 1)  # это очень плохо, но иначе никак

    elif insert.kwargs.get('mysql_on_duplicate_update_keys') and insert.kwargs.get('mysql_if_equals'):
        if compiler.dialect.name == 'mysql':
            statement = _MySqlOnDuplicateUpdateKeys(insert.table)
            condition_key, condition_value = insert.kwargs['mysql_if_equals']
            else_null = bool(insert.kwargs.get('mysql_else_null'))
            for key in insert.kwargs['mysql_on_duplicate_update_keys']:
                statement.add_if_condition_values_key(
                    key,
                    condition_key,
                    condition_value,
                    else_null,
                )
            s += ' ' + statement.compile()
        elif compiler.dialect.name == 'sqlite':
            s = s.replace('INSERT', 'INSERT OR REPLACE', 1)

    elif insert.kwargs.get('mysql_on_duplicate_increment_key'):
        if compiler.dialect.name == 'mysql':
            statement = _MySqlOnDuplicateUpdateKeys(insert.table)
            statement.add_increment_key(insert.kwargs['mysql_on_duplicate_increment_key'])
            s += ' ' + statement.compile()
        elif compiler.dialect.name == 'sqlite':
            # в случае отсутствия гонок, этот запрос делает то же самое
            s = s.replace('INSERT', 'INSERT OR REPLACE', 1)
    return s


Insert.argument_for('mysql', 'binary_value_ids', None)
Insert.argument_for('mysql', 'on_duplicate_update_keys', None)
Insert.argument_for('mysql', 'if_equals', None)
Insert.argument_for('mysql', 'else_null', False)
Insert.argument_for('mysql', 'on_duplicate_append_keys', None)
Insert.argument_for('mysql', 'on_duplicate_increment_key', None)


@compiles(Delete)
def passport_sqlalchemy_delete_ext(delete, compiler, **kw):
    s = compiler.visit_delete(delete, **kw)

    if delete.kwargs.get('mysql_delete_limit'):
        limit = delete.kwargs['mysql_delete_limit']
        s += ' LIMIT %d' % limit
    return s


Delete.argument_for('mysql', 'delete_limit', None)


class _MySqlOnDuplicateUpdateKeys(object):
    def __init__(self, table):
        self._bits = dict()
        self._table = table

    def compile(self):
        bits = self._build_bits_with_last_insert_id_key_fix(self._bits)
        bits = [bits[k] for k in sorted(bits)]
        return 'ON DUPLICATE KEY UPDATE %s' % ', '.join(bits)

    def add_values_key(self, key):
        bit = '%s = VALUES(%s)' % (key, key)
        self._bits[key] = bit

    def add_concat_values_key(self, key):
        bit = "%s = CONCAT(%s, ';', VALUES(%s))" % (key, key, key)
        self._bits[key] = bit

    def add_if_condition_values_key(self, key, condition_key, condition_value,
                                    else_null=False):
        args = {
            'condition_key': condition_key,
            'condition_value': condition_value,
            'key': key,
            'else_value': 'NULL' if else_null else key,
        }
        bit = '%(key)s = IF(%(condition_key)s = "%(condition_value)s", VALUES(%(key)s), %(else_value)s)' % args
        self._bits[key] = bit

    def add_increment_key(self, key):
        bit = '%s = %s + 1' % (key, key)
        self._bits[key] = bit

    def _build_bits_with_last_insert_id_key_fix(self, bits):
        # INSERT ON DUPLICATE KEY UPDATE не меняет значение функции
        # LAST_INSERT_ID, когда в таблице уже есть полностью совпадающая
        # строка.
        # В таких случаях эта подпрограмма меняет значение LAST_INSERT_ID на
        # правильное.
        pk_list = list()
        for c in self._table.c:
            if c.primary_key:
                pk_list.append(c)

        if len(pk_list) == 1 and pk_list[0].autoincrement:
            pk = pk_list[0].name
        elif not pk_list or all(not c.autoincrement for c in pk_list):
            return bits
        else:
            # Не могу себе представить зачем может быть нужна таблица с
            # составным первичным автоинкрементным ключом. И не знаю как
            # сделать, чтобы фикс заработал с такой таблицей.
            log.error(
                'Unable to guarantee that inserted_primary_key is valid after '
                'insert into table %s. Also you may see this message, because '
                'forgot to add "autoincrement=False" to the primary keys of '
                'the table.' % self._table.name
            )  # pragma: no cover
            return bits  # pragma: no cover

        if pk and pk not in bits:
            bits = dict(bits)
            bits[pk] = '%s = LAST_INSERT_ID(%s)' % (pk, pk)
        return bits
