package ru.yandex.direct.core.entity.moderation.repository.bulk_update;

import java.util.Collection;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
import java.util.function.BiFunction;
import java.util.function.Predicate;

import com.google.common.collect.Iterables;
import org.jooq.Configuration;
import org.jooq.Field;
import org.jooq.Record;
import org.jooq.TableField;

import ru.yandex.direct.jooqmapper.JooqMapperUtils;

import static java.util.stream.Collectors.toList;

public class BulkUpdate<R extends Record, H extends Number> {
    private final TableField<R, H> keyField;
    private final Map<H, RowChanges<R>> rowChangesMap = new HashMap<>();

    private final BiFunction<Object, Object, Object> valueMerger;
    private final Predicate<RowChanges<? extends Record>> validator;
    private final static int CHUNK_SIZE = 500;

    public BulkUpdate(TableField<R, H> keyField,
                      BiFunction<Object, Object, Object> valueMerger,
                      Predicate<RowChanges<? extends Record>> validator) {
        this.keyField = keyField;
        this.valueMerger = valueMerger;
        this.validator = validator;
    }

    public RowChanges<R> forId(H id) {
        return rowChangesMap.computeIfAbsent(id, e -> new RowChanges<>(valueMerger));
    }

    void validateChanges(Collection<RowChanges<R>> changes) {
        for (var row : changes) {
            if (!validator.test(row)) {
                throw new IllegalStateException("Attempt to execute incorrect update " + row);
            }
        }
    }

    public void execute(Configuration configuration) {
        if (rowChangesMap.isEmpty()) {
            return;
        }

        validateChanges(rowChangesMap.values());
        Map<TableField<R, ?>, Field<?>> cases = new HashMap<>();

        var sortedKeys = rowChangesMap.keySet().stream().sorted().collect(toList());
        for (var part : Iterables.partition(sortedKeys, CHUNK_SIZE)) {
            Map<TableField<R, ?>, Map<H, Object>> values = new HashMap<>();

            for (H bid : part) {
                for (var newValues : rowChangesMap.get(bid).getNewColumnValues().entrySet()) {
                    Map<H, Object> update = values.computeIfAbsent(newValues.getKey(), k -> new HashMap<>());
                    update.put(bid, newValues.getValue());
                }
            }

            for (var field : values.keySet()) {
                cases.put(field, JooqMapperUtils.makeCaseStatement(keyField, (TableField<R, Object>) field,
                        values.get(field)));
            }

            configuration.dsl()
                    .update(keyField.getTable())
                    .set(cases)
                    .where(keyField.in(part))
                    .orderBy(keyField)
                    .execute();
        }

        //На всякий случай, чтобы избежать повторного апдейта
        rowChangesMap.clear();
    }

    public String getTableName() {
        return keyField.getTable().getName();
    }

    public static class RowChanges<G extends Record> {
        private final Map<TableField<G, ?>, Object> values = new HashMap<>();
        private final BiFunction<Object, Object, Object> valueMerger;

        public RowChanges(BiFunction<Object, Object, Object> valueMerger) {
            this.valueMerger = valueMerger;
        }

        public Map<TableField<G, ?>, Object> getNewColumnValues() {
            return values;
        }

        @SuppressWarnings("unchecked")
        public <T> RowChanges<G> set(TableField<G, T> field, T value) {
            T currentObject = (T) values.get(field);

            if (currentObject != null && !currentObject.equals(value)) {
                value = (T) valueMerger.apply(currentObject, value);
                //throw new IllegalStateException("Trying to change already set value for field " + field);
            }

            values.put(field, value);

            return this;
        }

        @Override
        public boolean equals(Object o) {
            if (this == o) {
                return true;
            }
            if (o == null || getClass() != o.getClass()) {
                return false;
            }
            RowChanges<?> that = (RowChanges<?>) o;
            return Objects.equals(values, that.values);
        }

        //TODO: cache hashcode
        @Override
        public int hashCode() {
            return Objects.hash(values);
        }

        @Override
        public String toString() {
            return "RowChanges{" +
                    "values=" + values +
                    '}';
        }
    }

}
