package ru.yandex.chemodan.app.dataapi.core.dao.data;

import java.sql.ResultSet;
import java.sql.SQLException;

import lombok.AllArgsConstructor;
import org.springframework.jdbc.core.PreparedStatementCallback;
import org.springframework.jdbc.core.RowMapper;

import ru.yandex.bolts.collection.Cf;
import ru.yandex.bolts.collection.CollectionF;
import ru.yandex.bolts.collection.IteratorF;
import ru.yandex.bolts.collection.ListF;
import ru.yandex.bolts.collection.MapF;
import ru.yandex.bolts.collection.Option;
import ru.yandex.bolts.collection.Tuple2List;
import ru.yandex.bolts.function.Function2;
import ru.yandex.chemodan.app.dataapi.DataApiBenderUtils;
import ru.yandex.chemodan.app.dataapi.api.data.filter.condition.CollectionIdCondition;
import ru.yandex.chemodan.app.dataapi.api.data.filter.condition.DataCondition;
import ru.yandex.chemodan.app.dataapi.api.data.filter.condition.RecordCondition;
import ru.yandex.chemodan.app.dataapi.api.data.filter.condition.RecordIdCondition;
import ru.yandex.chemodan.app.dataapi.api.data.filter.ordering.ByIdRecordOrder;
import ru.yandex.chemodan.app.dataapi.api.data.filter.ordering.RecordOrder;
import ru.yandex.chemodan.app.dataapi.api.data.protobuf.ProtobufDataUtils;
import ru.yandex.chemodan.app.dataapi.api.data.record.DataRecord;
import ru.yandex.chemodan.app.dataapi.api.data.record.DataRecordId;
import ru.yandex.chemodan.app.dataapi.api.db.handle.DatabaseHandle;
import ru.yandex.chemodan.app.dataapi.api.db.handle.DatabaseHandleRevisions;
import ru.yandex.chemodan.app.dataapi.api.db.handle.DatabaseHandles;
import ru.yandex.chemodan.app.dataapi.api.db.ref.DatabaseRef;
import ru.yandex.chemodan.app.dataapi.api.db.ref.SpecialDatabases;
import ru.yandex.chemodan.app.dataapi.api.user.DataApiUserId;
import ru.yandex.chemodan.app.dataapi.core.dao.ShardPartitionDataSource;
import ru.yandex.chemodan.app.dataapi.core.dao.support.DataApiShardPartitionDaoSupport;
import ru.yandex.chemodan.ratelimiter.chunk.ChunkRateLimiter;
import ru.yandex.chemodan.util.postgres.PgSqlQueryUtils;
import ru.yandex.commune.dynproperties.DynamicProperty;
import ru.yandex.commune.test.random.RunWithRandomTest;
import ru.yandex.misc.ExceptionUtils;
import ru.yandex.misc.bender.BenderMapper;
import ru.yandex.misc.db.SqlFunction0;
import ru.yandex.misc.db.q.ConditionUtils;
import ru.yandex.misc.db.q.SimpleCondition;
import ru.yandex.misc.db.q.SqlCondition;
import ru.yandex.misc.db.q.SqlLimits;
import ru.yandex.misc.db.q.SqlQueryUtils;
import ru.yandex.misc.monica.annotation.GroupByDefault;
import ru.yandex.misc.monica.annotation.MonicaContainer;
import ru.yandex.misc.monica.annotation.MonicaMetric;
import ru.yandex.misc.monica.core.blocks.Meter;
import ru.yandex.misc.monica.core.blocks.Statistic;
import ru.yandex.misc.monica.core.name.MetricGroupName;
import ru.yandex.misc.monica.core.name.MetricName;

/**
 * @author tolmalev
 * @author osidorkin
 */
public class DataRecordsJdbcDaoImpl extends DataApiShardPartitionDaoSupport
        implements MonicaContainer, DataRecordsJdbcDao
{
    private static final BenderMapper mapper = DataApiBenderUtils.mapper();

    private static final int DML_BATCH_SIZE = 100;

    public final DynamicProperty<Boolean> useJsonbSerialization =
            new DynamicProperty<>("dataapi-use-jsonb-serialization", true);

    @MonicaMetric
    @GroupByDefault
    private final Statistic serializedSize = new Statistic();
    @MonicaMetric
    @GroupByDefault
    private final Statistic parsedSize = new Statistic();
    @MonicaMetric
    @GroupByDefault
    private final Statistic snapshotSize = new Statistic();
    @MonicaMetric
    @GroupByDefault
    private final Statistic collectionSnapshotSize = new Statistic();
    @MonicaMetric
    @GroupByDefault
    private final Meter created = new Meter();
    @MonicaMetric
    @GroupByDefault
    private final Meter deleted = new Meter();
    @MonicaMetric
    @GroupByDefault
    private final Meter updated = new Meter();

    public DataRecordsJdbcDaoImpl(ShardPartitionDataSource dataSource) {
        super(dataSource);
    }

    @Override
    @RunWithRandomTest
    public ListF<DataRecord> find(DataApiUserId uid, final DatabaseHandle databaseHandle) {
        return find(uid, databaseHandle,
                CollectionIdCondition.all(), RecordIdCondition.all(), DataCondition.all(),
                RecordOrder.defaultOrder(), SqlLimits.all());
    }

    @Override
    @RunWithRandomTest
    public ListF<DataRecord> find(
            DataApiUserId uid, DatabaseRef dbRef,
            CollectionIdCondition collectionIdCond, RecordIdCondition recordIdCond, RecordCondition recordCond,
            RecordOrder order, SqlLimits limits)
    {
        SqlCondition cond =
                getBasicRecordsCondition(uid, dbRef, handleByDbQuery(uid, dbRef), collectionIdCond, recordIdCond)
                        .and(recordCond.getCondition());

        return find(uid, collectionIdCond, recordIdCond, recordCond, order, limits, cond);
    }

    @Override
    @RunWithRandomTest
    public ListF<DataRecord> find(
            DataApiUserId uid, DatabaseHandle databaseHandle,
            CollectionIdCondition collectionIdCond, RecordIdCondition recordIdCond, RecordCondition recordCond,
            RecordOrder order, SqlLimits limits)
    {
        SqlCondition cond =
                getBasicRecordsCondition(uid, databaseHandle, collectionIdCond, recordIdCond)
                .and(recordCond.getCondition());

        return find(uid, collectionIdCond, recordIdCond, recordCond, order, limits, cond);
    }

    @Override
    @RunWithRandomTest
    public ListF<DataRecord> findNext(
            DataApiUserId uid, DatabaseHandle databaseHandle,
            CollectionIdCondition collectionIdCond, RecordIdCondition recordIdCond, RecordCondition recordCond,
            Option<DataRecordId> prev, int limit, boolean forceCollateC)
    {
        SqlCondition basicCond = getBasicRecordsCondition(uid, databaseHandle, collectionIdCond, recordIdCond);

        SqlCondition cond = prev.<SqlCondition>map(id -> new AfterRecordCondition(uid, id, forceCollateC))
                .getOrElse(SqlCondition::trueCondition)
                .and(basicCond);

        RecordOrder recordOrder = forceCollateC
                ? ByIdRecordOrder.COLLECTION_ID_ASC_RECORD_ID_ASC_COLLATE_C
                : ByIdRecordOrder.COLLECTION_ID_ASC_RECORD_ID_ASC;

        return find(uid, collectionIdCond, recordIdCond, recordCond, recordOrder, SqlLimits.first(limit), cond);
    }

    private SqlCondition getBasicRecordsCondition(DataApiUserId uid, DatabaseHandle databaseHandle,
            CollectionIdCondition collectionIdCond, RecordIdCondition recordIdCond)
    {
        return getBasicRecordsCondition(uid, databaseHandle.dbRef, plainHandle(databaseHandle.handle),
                collectionIdCond, recordIdCond);
    }

    private ListF<DataRecord> find(DataApiUserId uid, CollectionIdCondition collectionIdCond,
            RecordIdCondition recordIdCond, RecordCondition recordCond, RecordOrder order, SqlLimits limits,
            SqlCondition cond)
    {
        if (collectionIdCond.getCondition().isConstantFalse()) {
            return Cf.list();
        }

        ListF<DataRecord> result = getReadJdbcTemplate(uid).query(""
                + "SELECT user_id, app, dbId, handle, rev, collection_id, record_id, content, jcontent from p_data_%"
                + cond.whereSql()
                + " " + order.toSql()
                + " " + limits.toMysqlLimits(),
                new RecordMapper(parsedSize), cond.args());

        if (recordIdCond.isAll() && recordCond.isAll()) {
            (collectionIdCond.isAll() ? snapshotSize : collectionSnapshotSize).update(result.size());
        }
        return result;
    }

    static SqlCondition handleByDbQuery(DataApiUserId uid, DatabaseRef dbRef) {
        SqlCondition condition = SqlCondition.all(
                ConditionUtils.column("user_id").eq(uid.toString()),
                ConditionUtils.column("app").eq(dbRef.dbAppId()),
                ConditionUtils.column("dbId").eq(dbRef.databaseId())
        );
        return new SimpleCondition(
                "handle = (SELECT handle FROM databases_% WHERE " + condition.sql() + ")", condition.args());
    }

    static SqlCondition plainHandle(String handle) {
        return ConditionUtils.column("handle").eq(handle);
    }

    static SqlCondition getBasicRecordsCondition(DataApiUserId uid, DatabaseRef dbRef, SqlCondition handleCondition,
            CollectionIdCondition collectionIdCond, RecordIdCondition recordIdCond)
    {
        return SqlCondition.trueCondition()
                .and(Option.when(SpecialDatabases.isDataRecordsPartialIndexed(dbRef), () -> SqlCondition.trueCondition()
                        .and(ConditionUtils.column("app").eq(dbRef.dbAppId()))
                        .and(ConditionUtils.column("dbId").eq(dbRef.databaseId()))))
                .and(handleCondition)
                .and(ConditionUtils.column("user_id").eq(uid.toString()))
                .and(collectionIdCond.getCondition())
                .and(recordIdCond.getCondition());
    }

    @Override
    @RunWithRandomTest
    public ListF<DataRecord> findByDatabaseHandles(DataApiUserId uid, DatabaseHandles handles) {
        CollectionF<String> databaseHandles = handles.toList().map(dbHandle -> dbHandle.handle);
        if (databaseHandles.isEmpty()) {
            return Cf.list();
        }
        String q = "SELECT * from p_data_% WHERE handle IN ("
                + Cf.repeat("?", databaseHandles.size()).mkString(",")
                + ")";

        return getReadJdbcTemplate(uid).query(q, new RecordMapper(uid, handles.toList(), parsedSize), databaseHandles);
    }

    @Override
    public int count(DataApiUserId uid, DatabaseRef dbRef, CollectionIdCondition collectionIdCond,
            RecordIdCondition recordIdCond)
    {
        return count(uid,
                getBasicRecordsCondition(uid, dbRef, handleByDbQuery(uid, dbRef), collectionIdCond, recordIdCond));
    }

    @Override
    public int count(DataApiUserId uid, DatabaseHandle dbHandle, CollectionIdCondition collectionIdCond,
            RecordIdCondition recordIdCond, RecordCondition recordCond)
    {
        return count(uid,
                getBasicRecordsCondition(uid, dbHandle, collectionIdCond, recordIdCond)
                        .and(recordCond.getCondition()));
    }

    @Override
    public int count(DataApiUserId uid, SqlCondition cond) {
        return getReadJdbcTemplate(uid)
                .queryForInt("SELECT count(*) from p_data_%" + cond.whereSql(), cond.args());
    }

    @Override
    public ListF<DataRecord> findByHandlesAndMinRevisionsOrderedByRev(
            DataApiUserId uid, DatabaseHandleRevisions minRevisions, SqlLimits limits)
    {
        Tuple2List<String, Long> handlesMinRevisions = minRevisions.toHandleRevisionTuples();

        Function2<String, String, String> singleQueryF = (handle, rev) -> ""
                + "SELECT * FROM p_data_%"
                + " WHERE handle = " + handle + " AND rev > " + rev
                + " ORDER BY rev " + limits.toMysqlLimits();

        if (handlesMinRevisions.size() > 1) {
            String array = "UNNEST(ARRAY[" + SqlQueryUtils.qms(handlesMinRevisions.size()) + "])";

            return getReadJdbcTemplate(uid).query(""
                    + "SELECT rec.*"
                    + " FROM (SELECT " + array + " h, " + array + " r) hr,"
                    + " LATERAL (" + singleQueryF.apply("hr.h", "hr.r") + ") rec"
                    + " ORDER BY rec.rev " + limits.toMysqlLimits(),
                    new RecordMapper(uid, minRevisions.handles(), parsedSize),
                    handlesMinRevisions.get1(),
                    handlesMinRevisions.get2());

        } else if (handlesMinRevisions.isNotEmpty()) {
            return getReadJdbcTemplate(uid).query(
                    singleQueryF.apply("?", "?"),
                    new RecordMapper(uid, minRevisions.handles(), parsedSize), handlesMinRevisions.single());
        } else {
            return Cf.list();
        }
    }

    @Override
    public void insertBatched(DataApiUserId uid, DatabaseRef dbRef, CollectionF<DataRecord> records) {
        if (records.isEmpty()) {
            return;
        }
        getJdbcTemplate(uid).execute("INSERT INTO p_data_% "
                + "(user_id, handle, dbId, app, collection_id, record_id, rev, content, jcontent) "
                + "VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?::jsonb)",
                (PreparedStatementCallback<Object>) ps -> {
                    for (IteratorF<ListF<DataRecord>> it = records.iterator().paginate(DML_BATCH_SIZE);
                         it.hasNext(); )
                    {
                        for (DataRecord record : it.next()) {
                            ps.setString(1, uid.toString());
                            ps.setString(2, record.getDatabaseHandle());
                            ps.setString(3, dbRef.databaseId());
                            ps.setString(4, dbRef.dbAppId());
                            ps.setString(5, record.getCollectionId());
                            ps.setString(6, record.getRecordId());
                            ps.setLong(7, record.rev);
                            ps.setBytes(8, serializeProtobuf(record));
                            ps.setString(9, serializeJsonb(record));
                            ps.addBatch();
                        }
                        ps.executeBatch();
                    }
                    return null;
                });
        created.inc(records.size());
    }

    private byte[] serializeProtobuf(DataRecord record) {
        if (useJsonbSerialization.get()) {
            return null;
        }
        return serialize(record);
    }

    private String serializeJsonb(DataRecord record) {
        if (!useJsonbSerialization.get()) {
            return null;
        }
        return escapeJsonForPg(new String(serialize(record)));
    }

    private String escapeJsonForPg(String json) {
        return PgSqlQueryUtils.escapeJson(json);
    }

    private byte[] serialize(DataRecord record) {
        byte[] serialize = useJsonbSerialization.get()
                ? mapper.serializeJson(record)
                : ProtobufDataUtils.serialize(record);
        serializedSize.update(serialize.length);
        return serialize;
    }

    @Override
    public void updateContentsWithRevisionCheck(DataApiUserId uid, DatabaseRef dbRef, DataRecord record,
            boolean migrateToJsonb)
    {
        getJdbcTemplate(uid).update("UPDATE p_data_% SET "
                + "dbId = ?, app = ?, content = ?, jcontent = ?::jsonb WHERE "
                + "handle = ? AND "
                + "collection_id = ? AND "
                + "record_id = ? AND "
                + "rev = ? AND "
                +  (migrateToJsonb ? "jcontent" : "content") + " is null",
                dbRef.databaseId(), dbRef.dbAppId(),
                migrateToJsonb ? null : ProtobufDataUtils.serialize(record),
                migrateToJsonb ? escapeJsonForPg(new String(mapper.serializeJson(record))) : null,
                record.getDatabaseHandle(), record.getCollectionId(), record.getRecordId(), record.rev);
    }

    @Override
    public void updateBatched(DataApiUserId uid, CollectionF<DataRecord> records) {
        if (records.isEmpty()) {
            return;
        }
        getJdbcTemplate(uid).execute("UPDATE p_data_% SET "
                + "rev = ?, content = ?, jcontent = ?::jsonb WHERE "
                + "handle = ? AND "
                + "collection_id = ? AND "
                + "record_id = ?",
                (PreparedStatementCallback<Object>) ps -> {
                    for (IteratorF<ListF<DataRecord>> it = records.iterator().paginate(DML_BATCH_SIZE);
                         it.hasNext(); )
                    {
                        for (DataRecord record : it.next()) {
                            ps.setLong(1, record.rev);
                            ps.setBytes(2, serializeProtobuf(record));
                            ps.setString(3, serializeJsonb(record));
                            ps.setString(4, record.getDatabaseHandle());
                            ps.setString(5, record.getCollectionId());
                            ps.setString(6, record.getRecordId());
                            ps.addBatch();
                        }
                        ps.executeBatch();
                    }
                    return null;
                });
        updated.inc(records.size());
    }

    @Override
    public void deleteRecordsBatched(DataApiUserId uid, CollectionF<DataRecordId> recordIds) {
        if (recordIds.isEmpty()) {
            return;
        }
        getJdbcTemplate(uid).execute("DELETE FROM p_data_% WHERE "
                + "handle = ? AND "
                + "collection_id = ? AND "
                + "record_id = ?",
                (PreparedStatementCallback<Object>) ps -> {
                    for (IteratorF<ListF<DataRecordId>> it = recordIds.iterator().paginate(DML_BATCH_SIZE);
                            it.hasNext(); )
                    {
                        for (DataRecordId recordId : it.next()) {
                            ps.setString(1, recordId.handleValue());
                            ps.setString(2, recordId.collectionId());
                            ps.setString(3, recordId.recordId());
                            ps.addBatch();
                        }
                        ps.executeBatch();
                    }
                    return null;
                });
        deleted.inc(recordIds.size());
    }

    @Override
    @RunWithRandomTest
    public ListF<DataRecord> findRecords(DataApiUserId uid, DatabaseHandle handle,
            CollectionF<DataRecordId> recordIds)
    {
        if (recordIds.isEmpty()) {
            return Cf.list();
        }
        String q = "SELECT * from p_data_% WHERE handle = ? AND (collection_id, record_id) IN ("
                + Cf.repeat("(?, ?)", recordIds.size()).mkString(", ")
                + ")";

        return getReadJdbcTemplate(uid).query(q,
                new RecordMapper(uid, handle, parsedSize),
                handle.handle, recordIds.flatMap(r -> Cf.<Object>list(r.collectionId(), r.recordId()))
        );
    }

    @Override
    @RunWithRandomTest
    public void deleteAllRecordFromDatabases(DataApiUserId uid, ListF<String> handles) {
        SqlCondition condition = ConditionUtils.column("handle").inSet(handles);
        getJdbcTemplate(uid).update("DELETE from p_data_%" + condition.whereSql(), condition.args());
    }

    @Override
    @RunWithRandomTest
    public void deleteAllRecordFromDatabases(DataApiUserId uid, ListF<String> handles, ChunkRateLimiter limiter) {
        SqlCondition condition = ConditionUtils.column("handle").inSet(handles);
        deleteAllByChunks(uid, "p_data_%", Cf.list("handle", "collection_id", "record_id"), condition, limiter);
    }

    @Override
    public MetricGroupName groupName(String s) {
        return new MetricGroupName(
                "dataapi",
                new MetricName("dataapi", "dao", "record-proto"),
                "Records jdbc dao"
        );
    }

    private static class RecordMapper implements RowMapper<DataRecord> {
        private final Option<DataApiUserId> uid;
        private final MapF<String, DatabaseHandle> handles;
        private final Statistic parsedSize;

        public RecordMapper(Statistic parsedSize) {
            this.uid = Option.empty();
            this.handles = Cf.map();
            this.parsedSize = parsedSize;
        }

        public RecordMapper(DataApiUserId uid, DatabaseHandle handle, Statistic parsedSize) {
            this(uid, Cf.list(handle), parsedSize);
        }

        public RecordMapper(DataApiUserId uid, CollectionF<DatabaseHandle> handles, Statistic parsedSize) {
            this.uid = Option.of(uid);
            this.handles = handles.toMapMappingToKey(DatabaseHandle::handleValue);
            this.parsedSize = parsedSize;
        }

        @Override
        public DataRecord mapRow(ResultSet rs, int rowNum) throws SQLException {
            DatabaseHandle handle = getHandle(rs);

            DataApiUserId uid = this.uid.getOrElse(
                    (SqlFunction0<DataApiUserId>) () -> DataApiUserId.parse(rs.getString("user_id"))
            );

            byte[] jsonData = rs.getBytes("jcontent");
            if (jsonData != null && jsonData.length > 0) {
                parsedSize.update(jsonData.length);
                long rev = rs.getLong("rev");
                String collectionId = rs.getString("collection_id");
                String recordId = rs.getString("record_id");
                DataRecordId id = new DataRecordId(handle, collectionId, recordId);
                return mapper.parseJson(DataRecord.class, jsonData)
                        .withNewUid(uid)
                        .withNewRev(rev)
                        .withNewId(id);
            } else {
                byte[] protobufData = rs.getBytes("content");
                parsedSize.update(protobufData.length);
                return ProtobufDataUtils.parse(uid, handle, protobufData);
            }
        }

        private DatabaseHandle getHandle(ResultSet rs) {
            if (handles.isEmpty()) {
                try {
                    return new DatabaseHandle(
                            rs.getString("app"),
                            rs.getString("dbId"),
                            rs.getString("handle")
                    );
                } catch (SQLException e) {
                    throw ExceptionUtils.translate(e);
                }
            }
            if (handles.size() == 1) {
                return handles.values().iterator().next();
            }

            try {
                return handles.getTs(rs.getString("handle"));
            } catch (SQLException e) {
                throw ExceptionUtils.translate(e);
            }
        }
    }

    @AllArgsConstructor
    private static class AfterRecordCondition extends SqlCondition {
        private final DataApiUserId uid;
        private final DataRecordId recordId;
        private final boolean forceCollateC;

        @Override
        public ListF<Object> args() {
            return Cf.list(uid.toString(), recordId.handleValue(), recordId.collectionId(), recordId.recordId());
        }

        @Override
        public String sql(Option<String> tableName) {
            String prefix = tableName.map(s -> s + ".").getOrElse("");
            // XXX we have wrong collate at apidb and smcdb first shard CHEMODAN-53027
            return Cf.list("user_id", "handle", "collection_id", "record_id")
                    .map(name -> prefix + name + (forceCollateC ? " COLLATE \"C\"" : ""))
                    .mkString("(", ", ", ") > (?, ?, ?, ?)");
        }
    }
}
