# coding: utf-8


from django.db import router
from django.db.models import signals, AutoField
from django.db.models.sql.subqueries import InsertQuery

from idm.framework.backend.compiler import ConflictAction


class UpsertQuery(InsertQuery):
    compiler = 'SQLUpsertCompiler'

    def insert_values(self, fields_to_insert, objs,
                      conflict_target, conflict_action, index_predicate, fields_to_update,
                      raw=False):
        self.fields = self.fields_to_insert = fields_to_insert
        self.objs = objs
        self.raw = raw
        self.conflict_target = conflict_target
        self.conflict_action = conflict_action
        self.index_predicate = index_predicate
        self.fields_to_update = fields_to_update


class UpsertQueryMixin(object):
    def __init__(self, *args, **kwargs):
        super(UpsertQueryMixin, self).__init__(*args, **kwargs)
        self.conflict_target = None
        self.conflict_action = None
        self.index_predicate = None

    def on_conflict(self, fields, action, index_predicate=None):
        new_query_set = self.all()
        new_query_set.conflict_target = fields
        new_query_set.conflict_action = action
        new_query_set.index_predicate = index_predicate
        return new_query_set

    def insert_object(self, obj, send_save_signals, updated_fields=None):
        self._for_write = True

        using = router.db_for_write(obj.__class__, instance=obj)
        origin = obj.__class__
        fields_to_insert = obj._meta.local_concrete_fields
        if not obj.pk:
            fields_to_insert = [f for f in fields_to_insert if not isinstance(f, AutoField)]
        if updated_fields:
            fields_to_update = [
                f for f in fields_to_insert
                if f.attname in updated_fields
                or getattr(f, 'auto_now', False)
            ]
        else:
            fields_to_update = fields_to_insert
        fieldnames = [field.attname for field in fields_to_insert]

        if send_save_signals:
            signals.pre_save.send(sender=origin, instance=obj, raw=True, using=using, update_fields=fieldnames)

        if self.conflict_target:
            query = UpsertQuery(self.model)
            query.insert_values(
                fields_to_insert,
                [obj],
                self.conflict_target,
                self.conflict_action,
                self.index_predicate,
                fields_to_update,
                raw=False,
            )
            compiler = query.get_compiler(using=using)
            id_field = query.get_meta().pk
            returning_fields = [id_field.attname] + compiler.get_auto_now_add_fields()
            result = compiler.execute_sql()
            # При ON CONFLICT DO NOTHING execute_sql может вернуть None при конфликте.
            # В таком случае игнорируем цикл
            if not (result is None and self.conflict_action == ConflictAction.DO_NOTHING):
                # Проставляем pk и поля с auto_now_add, которые заполнились в компайлере и теперь должны быть сброшены
                for field, value in zip(returning_fields, result):
                    setattr(obj, field, value)
        else:
            query = InsertQuery(self.model)
            query.insert_values(fields_to_insert, [obj], raw=False)
            obj.pk = query.get_compiler(using=using).execute_sql(return_id=True)

        if send_save_signals:
            signals.post_save.send(sender=origin, instance=obj, created=True,  # Нельзя определить, был ли объект создан
                                   update_fields=fieldnames, raw=True, using=using)

        return obj

    def insert_raw(self, send_save_signals, **fields):
        obj = self.model(**fields)
        return self.insert_object(self, obj, send_save_signals, updated_fields=list(fields.keys()))
