package ru.yandex.chemodan.app.dataapi.core.datasources.ydb.dao;

import com.yandex.ydb.table.description.TableDescription;
import com.yandex.ydb.table.query.Params;
import com.yandex.ydb.table.result.ResultSetReader;
import com.yandex.ydb.table.values.ListType;
import com.yandex.ydb.table.values.PrimitiveType;
import com.yandex.ydb.table.values.PrimitiveValue;
import com.yandex.ydb.table.values.Value;

import ru.yandex.bolts.collection.Cf;
import ru.yandex.bolts.collection.CollectionF;
import ru.yandex.bolts.collection.ListF;
import ru.yandex.bolts.collection.MapF;
import ru.yandex.bolts.collection.Option;
import ru.yandex.bolts.collection.SetF;
import ru.yandex.bolts.collection.Tuple2;
import ru.yandex.bolts.collection.Tuple2List;
import ru.yandex.bolts.collection.Tuple3;
import ru.yandex.chemodan.app.dataapi.DataApiBenderUtils;
import ru.yandex.chemodan.app.dataapi.api.data.filter.condition.CollectionIdCondition;
import ru.yandex.chemodan.app.dataapi.api.data.filter.condition.RecordCondition;
import ru.yandex.chemodan.app.dataapi.api.data.filter.condition.RecordIdCondition;
import ru.yandex.chemodan.app.dataapi.api.data.filter.ordering.RecordOrder;
import ru.yandex.chemodan.app.dataapi.api.data.record.DataRecord;
import ru.yandex.chemodan.app.dataapi.api.data.record.DataRecordId;
import ru.yandex.chemodan.app.dataapi.api.db.handle.DatabaseHandle;
import ru.yandex.chemodan.app.dataapi.api.db.handle.DatabaseHandleRevisions;
import ru.yandex.chemodan.app.dataapi.api.db.handle.DatabaseHandles;
import ru.yandex.chemodan.app.dataapi.api.db.ref.DatabaseRef;
import ru.yandex.chemodan.app.dataapi.api.db.ref.SpecialDatabases;
import ru.yandex.chemodan.app.dataapi.api.user.DataApiUserId;
import ru.yandex.chemodan.ydb.dao.OneTableYdbDao;
import ru.yandex.chemodan.ydb.dao.ThreadLocalYdbTransactionManager;
import ru.yandex.chemodan.ydb.dao.YdbQueryMapper;
import ru.yandex.chemodan.ydb.dao.YdbRowMapper;
import ru.yandex.misc.bender.BenderMapper;
import ru.yandex.misc.db.SqlFunction0;
import ru.yandex.misc.db.q.ConditionUtils;
import ru.yandex.misc.db.q.SimpleCondition;
import ru.yandex.misc.db.q.SqlCondition;
import ru.yandex.misc.db.q.SqlLimits;
import ru.yandex.misc.db.q.SqlOrder;
import ru.yandex.misc.lang.CharsetUtils;

/**
 * @author tolmalev
 */
public class DataRecordsYdbDao extends OneTableYdbDao {
    private static final BenderMapper mapper = DataApiBenderUtils.mapper();

    public static final String TABLE_NAME = "data";
    public static final TableDescription DESCRIPTION = TableDescription
            .newBuilder()
            .addNonnullColumn("user_id", PrimitiveType.string())
            .addNonnullColumn("handle", PrimitiveType.string())
            .addNonnullColumn("dbId", PrimitiveType.string())
            .addNonnullColumn("app", PrimitiveType.string())
            .addNonnullColumn("collection_id", PrimitiveType.string())
            .addNonnullColumn("record_id", PrimitiveType.string())
            .addNonnullColumn("rev", PrimitiveType.int64())
            .addNonnullColumn("content", PrimitiveType.string())
            .setPrimaryKeys("user_id", "handle", "collection_id", "record_id")
            .build();

    public DataRecordsYdbDao(ThreadLocalYdbTransactionManager transactionManager) {
        super(transactionManager, TABLE_NAME, DESCRIPTION);
    }

    public void deleteAllRecordFromDatabases(DataApiUserId uid, ListF<String> handles) {
        String sql = "" +
                "DECLARE $handles as \"List<String>\";" +
                "DELETE from data WHERE handle IN $handles";

        ListType listType = ListType.of(PrimitiveType.string());
        Params params = Params.create()
                .put("$handles", listType.newValue(handles.map(h -> PrimitiveValue.string(h.getBytes()))));

        execute(sql, params);
    }

    public ListF<DataRecord> find(
            DataApiUserId uid, DatabaseRef dbRef,
            CollectionIdCondition collectionIdCond, RecordIdCondition recordIdCond, RecordCondition recordCond,
            RecordOrder order, SqlLimits limits)
    {
        SqlCondition cond =
                getBasicRecordsCondition(uid, dbRef, "$handle_by_db", collectionIdCond, recordIdCond)
                        .and(recordCond.getCondition());

        return find(handleByDbQuery(uid, dbRef), uid, collectionIdCond, recordIdCond, recordCond, order, limits, cond);
    }

    public ListF<DataRecord> find(
            DataApiUserId uid, DatabaseHandle databaseHandle,
            CollectionIdCondition collectionIdCond, RecordIdCondition recordIdCond, RecordCondition recordCond,
            RecordOrder order, SqlLimits limits)
    {
        SqlCondition cond =
                getBasicRecordsCondition(uid, databaseHandle, collectionIdCond, recordIdCond)
                        .and(recordCond.getCondition());

        return queryForList("SELECT user_id, app, dbId, handle, rev, collection_id, record_id, content FROM data",
                cond, order.toSqlOrder(), limits, new RecordMapper());
    }

    private ListF<DataRecord> find(String sqlPrefix, DataApiUserId uid, CollectionIdCondition collectionIdCond,
           RecordIdCondition recordIdCond, RecordCondition recordCond, RecordOrder order, SqlLimits limits,
           SqlCondition cond)
    {
        if (collectionIdCond.getCondition().isConstantFalse()) {
            return Cf.list();
        }

        return queryForList(sqlPrefix + "SELECT user_id, app, dbId, handle, rev, collection_id, record_id, content FROM data",
                cond, order.toSqlOrder(), limits, new RecordMapper());
    }

    public ListF<DataRecord> findRecords(DataApiUserId uid, DatabaseHandle dbHandle, ListF<DataRecordId> recordIds) {
        SqlCondition condition = SqlCondition.trueCondition()
                .and(SqlCondition.column("user_id").eq(uid.toString()))
                .and(SqlCondition.column("handle").eq(dbHandle.handle))
                .and(SqlCondition.column("(collection_id, record_id)").inSet(recordIds.map(rid -> new Tuple2<>(rid.collectionId(), rid.recordId()))));

        return queryForList("SELECT * FROM data", condition, new RecordMapper(uid, dbHandle));
    }

    public int count(DataApiUserId uid, DatabaseRef dbRef, CollectionIdCondition collectionIdCond,
                     RecordIdCondition recordIdCond)
    {
        return count(handleByDbQuery(uid, dbRef), uid,
                getBasicRecordsCondition(uid, dbRef, "$handle_by_db", collectionIdCond, recordIdCond));
    }

    public int count(DataApiUserId uid, DatabaseHandle dbHandle, CollectionIdCondition collectionIdCond,
                     RecordIdCondition recordIdCond)
    {
        return count("", uid,
                getBasicRecordsCondition(uid, dbHandle, collectionIdCond, recordIdCond));
    }

    public int count(String sqlPrefix, DataApiUserId uid, SqlCondition cond) {
        return (int) queryForLong(sqlPrefix + "SELECT count(*) as count FROM data", cond);
    }

    private SqlCondition deleteCondition(SetF<DataRecordId> deletedIds) {
        return SqlCondition
                .column("(handle, collection_id, record_id)")
                .inSet(deletedIds.map(id -> Tuple3.tuple(id.handle.handle, id.collectionId(), id.recordId())));
    }

    private SqlCondition getBasicRecordsCondition(DataApiUserId uid, DatabaseHandle databaseHandle,
                                                  CollectionIdCondition collectionIdCond, RecordIdCondition recordIdCond)
    {
        return getBasicRecordsCondition(uid, databaseHandle.dbRef, "'" + databaseHandle.handle + "'",
                collectionIdCond, recordIdCond);
    }

    private SqlCondition getBasicRecordsCondition(DataApiUserId uid, DatabaseRef dbRef, String handleCondition,
                                                  CollectionIdCondition collectionIdCond, RecordIdCondition recordIdCond)
    {
        return SqlCondition.trueCondition()
                .and(Option.when(SpecialDatabases.isDataRecordsPartialIndexed(dbRef), () -> SqlCondition.trueCondition()
                        .and(ConditionUtils.column("app").eq(dbRef.dbAppId()))
                        .and(ConditionUtils.column("dbId").eq(dbRef.databaseId()))))
                .and(new SimpleCondition("handle = " + handleCondition, Cf.list()))
                .and(ConditionUtils.column("user_id").eq(uid.toString()))
                .and(collectionIdCond.getCondition())
                .and(recordIdCond.getCondition());
    }

    static String handleByDbQuery(DataApiUserId uid, DatabaseRef dbRef) {
        SqlCondition condition = SqlCondition.all(
                ConditionUtils.column("user_id").eq(uid.toString()),
                ConditionUtils.column("app").eq(dbRef.dbAppId()),
                ConditionUtils.column("dbId").eq(dbRef.databaseId())
        );
        return "$handle_by_db = (SELECT handle FROM databases WHERE " + condition.toConcreteSql() + "); \n";
    }

    private byte[] serializeJson(DataRecord record) {
        return mapper.serializeJson(record);
    }

    public void bulkInsertDeleteUpdate(DataApiUserId uid, DatabaseRef dbRef, ListF<DataRecord> newRecords, SetF<DataRecordId> deletedIds,
                                       ListF<DataRecord> updatedRecords)
    {
        if (newRecords.isEmpty() && deletedIds.isEmpty() && updatedRecords.isEmpty()) {
            return;
        }

        ListF<Tuple2<String, String>> newIds = newRecords.map(dr -> new Tuple2<>(dr.getCollectionId(), dr.getRecordId()));
        ListF<Tuple2<String, String>> updateIds = updatedRecords.map(dr -> new Tuple2<>(dr.getCollectionId(), dr.getRecordId()));
        ListF<Tuple2<String, String>> deleteIdsList = deletedIds.map(dr -> new Tuple2<>(dr.collectionId(), dr.recordId()));

        SetF<Tuple2<String, String>> newIdsSet = newIds.unique();
        SetF<Tuple2<String, String>> updateIdsSet = updateIds.unique();
        SetF<Tuple2<String, String>> deleteIdsSet = deleteIdsList.unique();

        if (newIds.size() != newIdsSet.size()
                || updateIds.size() != updateIdsSet.size()
                || deletedIds.size() != deleteIdsSet.size())
        {
            throw new IllegalArgumentException("Bad request - some id's not unique");
        }

        if (newIdsSet.size() + updateIdsSet.size() + deletedIds.size() != deleteIdsSet.plus(newIdsSet).plus(updateIdsSet).unique().size()) {
            throw new IllegalArgumentException("Bad request - some id's not unique");
        }


        StringBuilder declareSb = new StringBuilder();
        StringBuilder actionSb = new StringBuilder();

        MapF<String, Value<?>> params = Cf.hashMap();

        if (deletedIds.isNotEmpty()) {
            YdbQueryMapper.YdbCondition ydbCondition = YdbQueryMapper.mapWhereSql(deleteCondition(deletedIds));

            declareSb.append(ydbCondition.declareSql);
            actionSb.append("$to_delete = (SELECT * FROM data " + ydbCondition.whereSql + ");\n");
            params.putAll(ydbCondition.params);
        }

        if (updatedRecords.isNotEmpty() || newRecords.isNotEmpty()) {
            ListType listType = ListType.of(structType);
            declareSb.append("DECLARE $upsert_items AS \"" + listType + "\";\n");
            actionSb.append("$to_upsert = (SELECT * FROM AS_TABLE($upsert_items));\n");

            params.put("$upsert_items", listType.newValue(updatedRecords.plus(newRecords).map(record -> {
                MapF<String, Value> values = Tuple2List.<String, Value>fromPairs(
                        "user_id", PrimitiveValue.string(uid.toString().getBytes()),
                        "handle", PrimitiveValue.string(record.getDatabaseHandle().getBytes()),
                        "dbId", PrimitiveValue.string(dbRef.databaseId().getBytes()),
                        "app", PrimitiveValue.string(dbRef.dbAppId().getBytes()),
                        "collection_id", PrimitiveValue.string(record.getCollectionId().getBytes()),
                        "record_id", PrimitiveValue.string(record.getRecordId().getBytes()),
                        "rev", PrimitiveValue.int64(record.rev),
                        "content", PrimitiveValue.string(serializeJson(record))
                ).toMap();

                return structType.newValue(values);
            })));
        }

        if (deletedIds.isNotEmpty()) {
            actionSb.append("DELETE FROM data ON SELECT * FROM $to_delete;\n");
        }
        if (updatedRecords.isNotEmpty() || newRecords.isNotEmpty()) {
            actionSb.append("UPSERT INTO data SELECT * FROM $to_upsert;\n");
        }

        execute(declareSb.toString() + actionSb.toString(), params);
    }

    public ListF<DataRecord> findByDatabaseHandles(DataApiUserId uid, DatabaseHandles handles) {
        CollectionF<String> databaseHandles = handles.toList().map(dbHandle -> dbHandle.handle);
        if (databaseHandles.isEmpty()) {
            return Cf.list();
        }
        SqlCondition condition = SqlCondition.column("handle").inSet(databaseHandles);

        return queryForList("SELECT * FROM data", condition, new RecordMapper());
    }

    public ListF<DataRecord> findByHandlesAndMinRevisionsOrderedByRev(
            DataApiUserId uid, DatabaseHandleRevisions minRevisions, SqlLimits limits)
    {
        Tuple2List<String, Long> handlesMinRevisions = minRevisions.toHandleRevisionTuples();

        if (handlesMinRevisions.size() > 1) {
            String sql = "" +
                    "DECLARE $handles as \"List<String>\";\n" +
                    "DECLARE $rev_by_handle as \"Dict<String, Int64>\";\n" +
                    "\n" +
                    "SELECT * FROM data\n" +
                    "WHERE handle IN $handles AND rev > $rev_by_handle{handle}\n" +
                    "ORDER BY rev " + limits.toMysqlLimits() + ";";

            Params params = Params.create()
                    .put("$handles", YdbQueryMapper.rawArgToValue(handlesMinRevisions.get1())._2)
                    .put("$rev_by_handle", YdbQueryMapper.rawArgToValue(handlesMinRevisions.toMap())._2);

            return queryForList(sql, params, new RecordMapper());
        } else if (handlesMinRevisions.size() == 1) {
            Tuple2<String, Long> single = handlesMinRevisions.single();

            SqlCondition condition = SqlCondition.trueCondition()
                    .and(SqlCondition.column("handle").eq(single._1))
                    .and(SqlCondition.column("rev").gt(single._2))
                    ;

            return queryForList("SELECT * FROM data", condition, SqlOrder.orderByColumn("rev"), limits, new RecordMapper());
        } else {
            return Cf.list();
        }
    }

    private static class RecordMapper implements YdbRowMapper<DataRecord> {
        private final Option<DataApiUserId> uid;
        private final MapF<String, DatabaseHandle> handles;

        public RecordMapper() {
            this.uid = Option.empty();
            this.handles = Cf.map();
        }

        public RecordMapper(DataApiUserId uid, DatabaseHandle handle) {
            this(uid, Cf.list(handle));
        }

        public RecordMapper(DataApiUserId uid, CollectionF<DatabaseHandle> handles) {
            this.uid = Option.of(uid);
            this.handles = handles.toMapMappingToKey(DatabaseHandle::handleValue);
        }

        @Override
        public DataRecord mapRow(ResultSetReader rs, int rowNum) {
            DatabaseHandle handle = getHandle(rs);

            DataApiUserId uid = this.uid.getOrElse(
                    (SqlFunction0<DataApiUserId>) () -> DataApiUserId.parse(rs.getColumn("user_id").getString(CharsetUtils.UTF8_CHARSET))
            );

            byte[] jsonData = rs.getColumn("content").getString();

            long rev = rs.getColumn("rev").getInt64();
            String collectionId = rs.getColumn("collection_id").getString(CharsetUtils.UTF8_CHARSET);
            String recordId = rs.getColumn("record_id").getString(CharsetUtils.UTF8_CHARSET);

            DataRecordId id = new DataRecordId(handle, collectionId, recordId);
            return mapper.parseJson(DataRecord.class, jsonData)
                        .withNewUid(uid)
                        .withNewRev(rev)
                        .withNewId(id);
        }

        private DatabaseHandle getHandle(ResultSetReader rs) {
            if (handles.isEmpty()) {
                return new DatabaseHandle(
                        rs.getColumn("app").getString(CharsetUtils.UTF8_CHARSET),
                        rs.getColumn("dbId").getString(CharsetUtils.UTF8_CHARSET),
                        rs.getColumn("handle").getString(CharsetUtils.UTF8_CHARSET)
                );
            }
            if (handles.size() == 1) {
                return handles.values().iterator().next();
            }

            return handles.getTs(rs.getColumn("handle").getString(CharsetUtils.UTF8_CHARSET));
        }
    }
}
