package ru.yandex.partner.core.entity.simplemodels;

import java.util.Collection;
import java.util.List;
import java.util.stream.Collectors;

import one.util.streamex.StreamEx;
import org.apache.commons.lang3.tuple.Pair;
import org.jetbrains.annotations.NotNull;
import org.jooq.Condition;
import org.jooq.DSLContext;
import org.jooq.Field;
import org.jooq.Record;
import org.jooq.impl.DSL;

import ru.yandex.direct.jooqmapper.JooqMapper;
import ru.yandex.direct.jooqmapperhelper.InsertHelper;
import ru.yandex.direct.model.Model;
import ru.yandex.partner.core.entity.utils.DSLUtils;
import ru.yandex.partner.core.filter.container.ModelFilterContainer;

public abstract class AbstractSimpleRepositoryWithCommonMethods<M extends Model, R extends Record>
        extends AbstractSimpleRepository<M, R> {

    private final JooqMapper<M> mapper;
    private final List<Field<?>> keyFields;
    private final List<Field<?>> fieldsToUpdate;

    public AbstractSimpleRepositoryWithCommonMethods(DSLContext dslContext,
                                                     ModelFilterContainer<M> modelFilterContainer,
                                                     JooqMapper<M> mapper,
                                                     List<Field<?>> keyFields) {
        super(dslContext, modelFilterContainer);
        this.mapper = mapper;
        this.keyFields = keyFields;
        fieldsToUpdate = prepareFieldsToUpdate(mapper, keyFields);
    }

    protected JooqMapper<M> getMapper() {
        return mapper;
    }

    private static List<Field<?>> prepareFieldsToUpdate(JooqMapper<?> mapper, List<Field<?>> keyFields) {
        List<Field<?>> keyField = keyFields.stream().toList();
        return  mapper.getFieldsToRead().stream()
                .filter(f -> !keyField.contains(f))
                .toList();
    }

    @Override
    public List<M> createOrUpdate(List<M> entities) {
        if (entities.isEmpty()) {
            return List.of();
        }

        InsertHelper<R> insertHelper = new InsertHelper<>(getDslContext(), getTable());
        insertHelper.addAll(mapper, entities);

        insertHelper.onDuplicateKeyUpdate();
        for (var f : fieldsToUpdate) {
            insertHelper.set(f, (Field) DSLUtils.getValuesStatement(f));
        }

        insertHelper.execute();

        var selectQuery = getBaseSelectQuery();

        selectQuery.addConditions(commonCondition(entities));

        return getModels(selectQuery);
    }

    @SuppressWarnings("rawtypes")
    @NotNull
    private Condition commonCondition(List<M> entities) {
        if (entities.isEmpty()) {
            return DSL.falseCondition();
        }

        if (keyFields.size() == 2) {
            Field key1 = keyFields.get(0);
            Field key2 = keyFields.get(1);
            return StreamEx.of(entities)
                    .map(entity -> {
                        var values = mapper.getDbFieldValues(entity, getTable());
                        return Pair.of(values.get(key1), values.get(key2));
                    })
                    .groupingBy(Pair::getKey, Collectors.mapping(Pair::getValue, Collectors.toSet()))
                    .entrySet().stream()
                    .map(entry -> key1.eq(entry.getKey()).and(key2.in(entry.getValue())))
                    .reduce(Condition::or)
                    .orElse(DSL.falseCondition());
        }

        return StreamEx.of(entities)
                .map(entity -> {
                            var values = mapper.getDbFieldValues(entity, getTable());
                            return keyFields.stream().map(k -> ((Field) k).eq(values.get(k)))
                                    .reduce(Condition::and)
                                    .orElse(DSL.falseCondition());
                        }
                )
                .reduce(
                        Condition::or
                ).orElse(DSL.falseCondition());
    }

    @Override
    public int insert(List<M> entities) {
        if (entities.isEmpty()) {
            return 0;
        }

        InsertHelper<R> insertHelper = new InsertHelper<>(getDslContext(), getTable());
        entities.forEach(entity -> {
            insertHelper
                    .add(mapper, entity)
                    .newRecord();
        });
        return insertHelper.execute();
    }

    @Override
    public int update(List<M> entities) {
        return createOrUpdate(entities).size();
    }

    public List<Field<?>> getKeyFields() {
        return keyFields;
    }

    public List<Field<?>> getFieldsToUpdate() {
        return fieldsToUpdate;
    }

    @Override
    public void delete(Collection<M> currentModels) {
        getDslContext()
                .delete(getTable())
                .where(
                        StreamEx.of(currentModels).map(
                                condition -> {
                                    var values = mapper.getDbFieldValues(condition, getTable());
                                    return keyFields.stream().map(k -> ((Field) k).eq(values.get(k)))
                                            .reduce(Condition::and)
                                            .orElse(DSL.falseCondition());
                                }
                        ).reduce(
                                Condition::or
                        ).orElse(DSL.falseCondition())
                ).execute();
    }
}
