package ru.yandex.partner.core.entity.simplemodels;

import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;

import javax.annotation.Nullable;
import javax.annotation.ParametersAreNonnullByDefault;

import org.jetbrains.annotations.NotNull;
import org.jooq.Condition;
import org.jooq.DSLContext;
import org.jooq.Field;
import org.jooq.Record;
import org.jooq.Record2;
import org.jooq.SelectQuery;
import org.jooq.impl.TableImpl;

import ru.yandex.direct.model.Model;
import ru.yandex.direct.model.ModelProperty;
import ru.yandex.direct.multitype.entity.LimitOffset;
import ru.yandex.partner.core.filter.CoreFilterNode;
import ru.yandex.partner.core.filter.container.ModelFilterContainer;
import ru.yandex.partner.core.utils.OrderBy;

import static org.jooq.impl.DSL.count;

@ParametersAreNonnullByDefault
public abstract class AbstractSimpleRepository<M extends Model, R extends Record> implements SimpleRepository<M> {
    private final DSLContext dslContext;
    private final ModelFilterContainer<M> modelFilterContainer;

    public AbstractSimpleRepository(DSLContext dslContext, ModelFilterContainer<M> modelFilterContainer) {
        this.dslContext = dslContext;
        this.modelFilterContainer = modelFilterContainer;
    }

    @Override
    public List<M> getAll(@NotNull CoreFilterNode<? super M> coreFilterNode,
                          @Nullable LimitOffset limitOffset,
                          @Nullable List<OrderBy> orderByList,
                          boolean forUpdate) {
        SelectQuery<Record> query = getBaseSelectQuery();
        applyFilterToQuery(coreFilterNode, query);

        if (orderByList != null) {
            applyOrderToQuery(orderByList, query);
        }
        if (limitOffset != null) {
            query.addOffset(limitOffset.offset());
            query.addLimit(limitOffset.limit());
        }

        query.setForUpdate(forUpdate);

        return getModels(query);
    }

    @Override
    public Long getCountByCondition(CoreFilterNode<M> coreFilterNode) {
        var query = dslContext.selectCount().from(getTable()).getQuery();
        applyFilterToQuery(coreFilterNode, query);
        return Long.valueOf(dslContext.fetchCount(query));
    }

    protected SelectQuery<Record> getBaseSelectQuery() {
        return dslContext.select(getDbFieldsForQuery())
                .from(getTable())
                .getQuery();
    }

    @Override
    public <T1> Map<T1, Integer> getCountGroupedBy(String fieldName, @NotNull CoreFilterNode<M> coreFilterNode) {
        Field<T1> fieldToGroupBy = (Field<T1>) getFieldByName(fieldName);
        var query = dslContext.select(fieldToGroupBy, count().as("count"))
                .from(getTable())
                .groupBy(fieldToGroupBy).getQuery();
        applyFilterToQuery(coreFilterNode, query);
        var data = query.fetch();
        return data.stream().collect(Collectors.toMap(Record2::component1, Record2::component2));
    }

    @Override
    public <T1> Set<T1> existingIds(@NotNull CoreFilterNode<M> coreFilterNode) {
        var query = dslContext.selectDistinct((Field<T1>) getIdField())
                .from(getTable())
                .getQuery();
        applyFilterToQuery(coreFilterNode, query);
        return query.fetchSet((Field<T1>) getIdField());
    }

    protected void applyFilterToQuery(CoreFilterNode<? super M> coreFilterNode, SelectQuery<?> query) {
        applyConditionsToQuery(((CoreFilterNode) coreFilterNode)
                .toCondition(getModelClass(), modelFilterContainer), query);
    }

    protected void applyConditionsToQuery(Condition condition, SelectQuery<?> query) {
        query.addConditions(condition);
    }

    protected void applyOrderToQuery(List<OrderBy> orderByList, SelectQuery<Record> query) {
        for (OrderBy orderBy : orderByList) {
            ModelProperty<? extends Model, ?> modelProperty = orderBy.getModelProperty();
            var field = getFieldByName(modelProperty.name());
            query.addOrderBy(field.sort(orderBy.getDirection().getSortOrder()));
        }
    }

    protected abstract TableImpl<R> getTable();

    protected List<M> getModels(SelectQuery<Record> query) {
        return query.fetch(record -> fetchQueryFunction().apply(record));
    }

    protected abstract Function<Record, M> fetchQueryFunction();

    protected DSLContext getDslContext() {
        return dslContext;
    }

    protected abstract Map<String, Field<?>> getDbFields();

    protected abstract Field<?> getIdField();

    protected Collection<Field<?>> getDbFieldsForQuery() {
        return getDbFields().values();
    }

    protected Field<?> getFieldByName(String name) {
        return getDbFields().get(name);
    }

    protected abstract Class<M> getModelClass();

    public abstract void delete(Collection<M> currentModels);
}
