from sqlalchemy.dialects import postgresql
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.sql.expression import Insert, insert
from sqlalchemy.util.langhelpers import public_factory


__all__ = ['Upsert', 'upsert']


class Upsert(Insert):
    def __init__(self, table, values, ignore=False, constraint=None, where=None):
        super(Upsert, self).__init__(table, values)
        self.table = table
        self.values = values
        self.ignore = ignore
        self.constraint = constraint
        self.where = where


@compiles(Upsert, 'postgresql')
def postgres_upsert(upsert_stmt, compiler, **kwargs):
    table = upsert_stmt.table
    index_elements = upsert_stmt.constraint
    ignore = upsert_stmt.ignore

    stmt = postgresql.insert(table, upsert_stmt.values)
    if ignore:
        stmt = stmt.on_conflict_do_nothing()
    else:
        stmt = stmt.on_conflict_do_update(
            index_elements=index_elements,
            set_={
                k: getattr(stmt.excluded, k)
                for k in stmt.parameters[0] if k not in index_elements
            },
            index_where=upsert_stmt.where,
        )
    return compiler.process(stmt, **kwargs)


@compiles(Upsert, 'sqlite')
def sqlite_upsert(upsert_stmt, compiler, **kwargs):
    ignore = upsert_stmt.ignore
    stmt = insert(upsert_stmt.table, upsert_stmt.values).prefix_with('OR IGNORE' if ignore else 'OR REPLACE')
    return compiler.process(stmt, **kwargs)


upsert = public_factory(Upsert, '.expression.upsert')
