package ru.yandex.direct.core.entity.vcard.repository.internal;

import java.nio.charset.StandardCharsets;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;

import one.util.streamex.StreamEx;
import org.jooq.Condition;
import org.jooq.DSLContext;
import org.jooq.Field;
import org.jooq.Record;
import org.jooq.Result;
import org.jooq.SelectConditionStep;
import org.jooq.SelectJoinStep;
import org.jooq.impl.DSL;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Repository;

import ru.yandex.direct.core.entity.vcard.model.PointOnMap;
import ru.yandex.direct.dbschema.ppc.tables.records.MapsRecord;
import ru.yandex.direct.dbutil.sharding.ShardHelper;
import ru.yandex.direct.dbutil.wrapper.DslContextProvider;
import ru.yandex.direct.jooqmapper.JooqMapperWithSupplier;
import ru.yandex.direct.jooqmapper.JooqMapperWithSupplierBuilder;
import ru.yandex.direct.jooqmapperhelper.InsertHelper;

import static java.math.BigDecimal.ROUND_CEILING;
import static java.util.Arrays.asList;
import static ru.yandex.direct.dbschema.ppc.Tables.MAPS;
import static ru.yandex.direct.jooqmapper.ReaderWriterBuilders.property;
import static ru.yandex.direct.utils.FunctionalUtils.mapList;
import static ru.yandex.direct.utils.HashingUtils.getMd5HashAsHexString;

@Repository
public class MapsRepository {

    private static final String SEPARATOR = "~";

    private static final int SCALE = 6;

    private final DslContextProvider dslContextProvider;
    private final ShardHelper shardHelper;

    private final JooqMapperWithSupplier<PointOnMap> pointMapper;
    private final Field<String> md5Field;
    private final Collection<Field<?>> fieldsToRead;

    @Autowired
    public MapsRepository(DslContextProvider dslContextProvider,
                          ShardHelper shardHelper) {
        this.dslContextProvider = dslContextProvider;
        this.shardHelper = shardHelper;

        pointMapper = createPointMapper();
        fieldsToRead = pointMapper.getFieldsToRead();
        md5Field = DSL.md5(DSL.concat(
                DSL.concat(DSL.cast(MAPS.X, String.class), SEPARATOR),
                DSL.concat(DSL.cast(MAPS.Y, String.class), SEPARATOR),
                DSL.concat(DSL.cast(MAPS.X1, String.class), SEPARATOR),
                DSL.concat(DSL.cast(MAPS.Y1, String.class), SEPARATOR),
                DSL.concat(DSL.cast(MAPS.X2, String.class), SEPARATOR),
                DSL.concat(DSL.cast(MAPS.Y2, String.class), SEPARATOR)));
    }

    public Map<Long, PointOnMap> getPoints(int shard, Collection<Long> pointsIds) {
        if (pointsIds.isEmpty()) {
            return Collections.emptyMap();
        }
        return dslContextProvider.ppc(shard)
                .select(fieldsToRead)
                .from(MAPS)
                .where(MAPS.MID.in(pointsIds))
                .fetchMap(MAPS.MID, pointMapper::fromDb);
    }

    public List<Long> getOrCreatePointOnMap(int shard, List<PointOnMap> points) {
        List<String> pointsHashes = mapList(points, this::calcMd5Hash);

        Map<String, Long> existingPoints = getExistingPoints(shard, points);

        int idCount = (int) StreamEx.of(pointsHashes).distinct().count() - existingPoints.size();
        Iterator<Long> ids = shardHelper.generateMapsIds(idCount).iterator();

        DSLContext dslContext = dslContextProvider.ppc(shard);
        InsertHelper<MapsRecord> insertHelper = new InsertHelper<>(dslContext, MAPS);

        List<Long> pointsIds = StreamEx.zip(points, pointsHashes, (point, pointHash) -> {

            Long existingPointId = existingPoints.get(pointHash);
            if (existingPointId != null) {
                return existingPointId;
            }

            Long newPointId = ids.next();
            point.setId(newPointId);

            // добавляем сгенерированный id в мапу существующих точек,
            // чтобы не создавать одинаковые точки дважды
            existingPoints.put(pointHash, newPointId);

            insertHelper.add(pointMapper, point).newRecord();

            return newPointId;
        }).toList();

        insertHelper.executeIfRecordsAdded();

        return pointsIds;
    }

    private Map<String, Long> getExistingPoints(int shard, List<PointOnMap> points) {
        if (points.isEmpty()) {
            return new HashMap<>();
        }

        Field<Long> idAlias = DSL.field("maps_id", Long.class);
        Field<String> md5Alias = DSL.field("maps_md5", String.class);

        SelectJoinStep<Record> selectJoinStep = dslContextProvider.ppc(shard)
                .select(asList(MAPS.MID.as(idAlias.getName()), md5Field.as(md5Alias.getName())))
                .from(MAPS);
        SelectConditionStep<Record> selectConditionStep = null;

        for (PointOnMap point : points) {
            Condition condition = MAPS.X.eq(point.getX())
                    .and(MAPS.Y.eq(point.getY()))
                    .and(MAPS.X1.eq(point.getX1()))
                    .and(MAPS.Y1.eq(point.getY1()))
                    .and(MAPS.X2.eq(point.getX2()))
                    .and(MAPS.Y2.eq(point.getY2()));

            if (selectConditionStep == null) {
                selectConditionStep = selectJoinStep.where(condition);
            } else {
                selectConditionStep = selectConditionStep.or(condition);
            }
        }

        //noinspection ConstantConditions
        Result<Record> res = selectConditionStep.fetch();

        return StreamEx.of(res)
                .distinct(rec -> rec.getValue(md5Alias))
                .mapToEntry(rec -> rec.getValue(idAlias))
                .mapKeys(rec -> rec.getValue(md5Alias))
                .toMap();
    }

    private String calcMd5Hash(PointOnMap point) {
        String pointStr = point.getX().setScale(SCALE, ROUND_CEILING) + SEPARATOR +
                point.getY().setScale(SCALE, ROUND_CEILING) + SEPARATOR +
                point.getX1().setScale(SCALE, ROUND_CEILING) + SEPARATOR +
                point.getY1().setScale(SCALE, ROUND_CEILING) + SEPARATOR +
                point.getX2().setScale(SCALE, ROUND_CEILING) + SEPARATOR +
                point.getY2().setScale(SCALE, ROUND_CEILING) + SEPARATOR;
        return getMd5HashAsHexString(pointStr.getBytes(StandardCharsets.UTF_8));
    }

    private JooqMapperWithSupplier<PointOnMap> createPointMapper() {
        return JooqMapperWithSupplierBuilder.builder(PointOnMap::new)
                .map(property(PointOnMap.ID, MAPS.MID))
                .map(property(PointOnMap.X, MAPS.X))
                .map(property(PointOnMap.Y, MAPS.Y))
                .map(property(PointOnMap.X1, MAPS.X1))
                .map(property(PointOnMap.Y1, MAPS.Y1))
                .map(property(PointOnMap.X2, MAPS.X2))
                .map(property(PointOnMap.Y2, MAPS.Y2))
                .build();
    }
}
