# coding: utf-8

from django.db import DatabaseError
from django.db.models.sql.compiler import (
    SQLCompiler,
    SQLInsertCompiler as BaseSQLInsertCompiler,
    SQLDeleteCompiler,
    SQLUpdateCompiler,
    SQLAggregateCompiler
)  # NOQA


__all__ = [
    'SQLCompiler', 'SQLInsertCompiler', 'SQLDeleteCompiler',
    'SQLUpdateCompiler', 'SQLAggregateCompiler', 'SQLUpsertCompiler',
]


class SQLInsertCompiler(BaseSQLInsertCompiler):
    upsert_tables = (
        'upravlyator_internalrole',
        'upravlyator_rolelock',
        'groupmembership_inconsistency',
    )

    def as_sql(self):
        statements = super(SQLInsertCompiler, self).as_sql()
        opts = self.query.get_meta()
        if opts.db_table in self.upsert_tables:
            sql, params = statements[0]
            returning_index = sql.find('RETURNING')
            if returning_index == -1:
                sql += ' ON CONFLICT DO NOTHING'
            else:
                sql = sql[:returning_index] + ' ON CONFLICT DO NOTHING ' + sql[returning_index:]
            statements = [(sql, params)]
        return statements


class ConflictAction(object):
    DO_NOTHING = 0
    UPDATE = 1


class SQLUpsertCompiler(BaseSQLInsertCompiler):
    def build_update_params(self, params):
        for_update = [
            (field.attname, value)
            for field, value in zip(self.query.fields_to_insert, params)
            if field in self.query.fields_to_update
        ]
        return list(zip(*for_update))  # unpack zip()

    def get_auto_now_add_fields(self):
        return [field.attname for field in self.query.fields if getattr(field, 'auto_now_add', False)]

    def get_conflict_target(self):
        target = self.query.conflict_target
        if isinstance(target, list):
            return '({})'.format(', '.join(target))  # поля
        raise DatabaseError('conflict_target must be a list of field names')

    def get_index_predicate(self):
        if not self.query.index_predicate:
            return ''
        else:
            return ' WHERE {}'.format(self.query.index_predicate)

    def as_sql(self):
        raw_statements = super(SQLUpsertCompiler, self).as_sql()
        assert len(raw_statements) == 1  # Мы полагаемся на такую структуру
        sql, params = raw_statements[0]

        # Блок ON CONFLICT
        if self.query.conflict_action == ConflictAction.DO_NOTHING:
            conflict_sql = ' ON CONFLICT {}{} DO NOTHING'.format(
                self.get_conflict_target(),
                self.get_index_predicate(),
            )
        elif self.query.conflict_action == ConflictAction.UPDATE:
            update_fields, update_params = self.build_update_params(params)
            conflict_sql = ' ON CONFLICT {}{} DO UPDATE SET {}'.format(
                self.get_conflict_target(),
                self.get_index_predicate(),
                ', '.join('"{}"=%s'.format(field) for field in update_fields)
            )
            params = list(params) + list(update_params)
        else:
            raise DatabaseError('Invalid ON CONFLICT action: {}'.format(self.query.conflict_action))

        # Блок RETURNING
        id_field = self.query.get_meta().pk
        returning_fields = [id_field.attname] + self.get_auto_now_add_fields()
        returning_sql = ' RETURNING {}'.format(', '.join(returning_fields))

        # Объединяем
        returning_index = sql.find('RETURNING')
        if returning_index == -1:
            sql += conflict_sql + returning_sql
        else:
            sql = sql[:returning_index] + conflict_sql + returning_sql

        statement = (sql, tuple(params))
        return statement

    def execute_sql(self):
        # Кроме фиксации возврата значений, мы вычищаем отсюда все проверки,
        # Так как знаем, как мы сюда можем прийти
        with self.connection.cursor() as cursor:
            sql, params = self.as_sql()
            cursor.execute(sql, params)
            return cursor.fetchone()
