package ru.yandex.chemodan.app.dataapi.core.datasources.ydb.dao;

import java.time.Instant;

import com.yandex.ydb.core.StatusCode;
import com.yandex.ydb.core.UnexpectedResultException;
import com.yandex.ydb.table.description.TableDescription;
import com.yandex.ydb.table.query.Params;
import com.yandex.ydb.table.values.ListType;
import com.yandex.ydb.table.values.OptionalValue;
import com.yandex.ydb.table.values.PrimitiveType;
import com.yandex.ydb.table.values.PrimitiveValue;
import com.yandex.ydb.table.values.StructType;
import com.yandex.ydb.table.values.Value;
import org.springframework.dao.IncorrectResultSizeDataAccessException;

import ru.yandex.bolts.collection.Cf;
import ru.yandex.bolts.collection.ListF;
import ru.yandex.bolts.collection.Option;
import ru.yandex.bolts.collection.Tuple2;
import ru.yandex.bolts.collection.Tuple2List;
import ru.yandex.chemodan.app.dataapi.api.context.DatabaseContext;
import ru.yandex.chemodan.app.dataapi.api.db.Database;
import ru.yandex.chemodan.app.dataapi.api.db.filter.DatabasesFilter;
import ru.yandex.chemodan.app.dataapi.api.db.handle.DatabaseHandle;
import ru.yandex.chemodan.app.dataapi.api.db.handle.DatabaseHandles;
import ru.yandex.chemodan.app.dataapi.api.db.ref.DatabaseRef;
import ru.yandex.chemodan.app.dataapi.api.user.DataApiUserId;
import ru.yandex.chemodan.app.dataapi.core.dao.data.DatabaseRevisionMismatchException;
import ru.yandex.chemodan.ydb.dao.OneTableYdbDao;
import ru.yandex.chemodan.ydb.dao.ThreadLocalYdbTransactionManager;
import ru.yandex.chemodan.ydb.dao.YdbTypeUtils;
import ru.yandex.misc.db.q.ConditionUtils;
import ru.yandex.misc.db.q.SqlCondition;
import ru.yandex.misc.db.q.SqlLimits;
import ru.yandex.misc.db.q.SqlOrder;
import ru.yandex.misc.lang.CharsetUtils;
import ru.yandex.misc.lang.StringUtils;

/**
 * @author tolmalev
 */
public class DatabasesYdbDao extends OneTableYdbDao {
    public static final TableDescription DESCRIPTION = TableDescription
            .newBuilder()
            .addNonnullColumn("user_id", PrimitiveType.string())
            .addNonnullColumn("app", PrimitiveType.string())
            .addNonnullColumn("dbId", PrimitiveType.string())
            .addNonnullColumn("handle", PrimitiveType.string())
            .addNonnullColumn("rev", PrimitiveType.int64())
            .addNonnullColumn("creation_time", PrimitiveType.timestamp())
            .addNonnullColumn("modification_time", PrimitiveType.timestamp())
            .addNonnullColumn("size", PrimitiveType.int64())
            .addNonnullColumn("records_count", PrimitiveType.int64())
            .addNullableColumn("description", PrimitiveType.string())
            .setPrimaryKeys("user_id", "app", "dbId")
            .build();

    public static final String TABLE_NAME = "databases";

    public DatabasesYdbDao(ThreadLocalYdbTransactionManager transactionManager) {
        super(transactionManager, TABLE_NAME, DESCRIPTION);
    }

    public Option<Database> findByHandle(DataApiUserId uid, DatabaseHandle handle) {
        SqlCondition condition = SqlCondition.trueCondition()
                .and(ConditionUtils.column("user_id").eq(uid.toString()))
                .and(ConditionUtils.column("app").eq(handle.dbAppId()))
                .and(ConditionUtils.column("handle").eq(handle.handle));

        return queryForList("SELECT * FROM databases", condition, DatabaseMapper.FROM_REAL).firstO();
    }

    public Option<Database> find(DataApiUserId uid, DatabaseRef dbRef) {
        return find(uid, dbRef.dbContext(), Cf.list(dbRef.databaseId())).firstO();
    }

    public ListF<Database> find(DataApiUserId uid) {
        SqlCondition condition = SqlCondition.trueCondition()
                .and(SqlCondition.column("user_id").eq(uid.toString()));

        return queryForList("SELECT * FROM databases", condition, DatabaseMapper.FROM_REAL);
    }

    public ListF<Database> find(DataApiUserId uid, DatabaseContext dbContext) {
        SqlCondition condition = SqlCondition.trueCondition()
                .and(SqlCondition.column("user_id").eq(uid.toString()))
                .and(SqlCondition.column("app").eq(dbContext.dbAppId()));

        return queryForList("SELECT * FROM databases", condition, DatabaseMapper.FROM_REAL);
    }

    public ListF<Database> find(DataApiUserId uid, DatabaseContext dbContext, ListF<String> databaseIds) {
//        StructType structType = YdbTypeUtils.structTypeFromFields(Tuple2List.fromPairs(
//                "user_id", PrimitiveType.string(),
//                "app", PrimitiveType.string(),
//                "dbId", PrimitiveType.string()
//        ));
//
//        String sql = "" +
//                "PRAGMA SimpleColumns=\"true\";\n" +
//                "" +
//                "DECLARE $keys AS \"List<Struct<" +
//                "   'user_id': String," +
//                "   'app': String," +
//                "   'dbId': String" +
//                ">>\";\n" +
//                "SELECT t.* FROM AS_TABLE($keys) AS k\n" +
//                "INNER JOIN databases AS t\n" +
//                "ON t.user_id = k.user_id AND t.app = k.app AND t.dbId = k.dbId;";
//
//        Function<DataQuery, Params> paramsBuilder = query -> query.newParams()
//                .put("$keys", ListValue.of(databaseIds.map(dbId -> structType.newInstance(Cf.map(
//                        "user_id", PrimitiveValue.string(uid.toString().getBytes()),
//                        "app", PrimitiveValue.string(dbContext.dbAppId().getBytes()),
//                        "dbId", PrimitiveValue.string(dbId.getBytes())
//                )))));
//
//        return queryForList(sql, paramsBuilder, DatabaseMapper.FROM_REAL);

        SqlCondition condition = SqlCondition.trueCondition()
                .and(ConditionUtils.column("user_id").eq(uid.toString()))
                .and(ConditionUtils.column("app").eq(dbContext.dbAppId()))
                .and(ConditionUtils.column("dbId").inSet(databaseIds));

        return queryForList("SELECT * FROM databases", condition, SqlOrder.unordered(), SqlLimits.all(), DatabaseMapper.FROM_REAL);
    }

    public void insert(Database database) {
        try {
            insertBatch(database.uid, Cf.list(database));
        } catch (UnexpectedResultException e) {
            if (e.getStatusCode() == StatusCode.PRECONDITION_FAILED) {
                throw database.consExistsException();
            } else {
                throw e;
            }
        }
    }

    private void insertBatch(DataApiUserId uid, ListF<Database> list) {
        insertBatch(list.map(db -> {

            OptionalValue descriptionValue = getOptionalStringValue(db.meta.description);

            return Tuple2List.<String, Value>fromPairs(
                    "user_id", PrimitiveValue.string(uid.toString().getBytes()),
                    "app", PrimitiveValue.string(db.dbAppId().getBytes()),
                    "dbId", PrimitiveValue.string(db.dbRef().databaseId().getBytes()),
                    "handle", PrimitiveValue.string(db.handleValue().getBytes()),
                    "rev", PrimitiveValue.int64(db.rev),
                    "creation_time", PrimitiveValue.timestamp(Instant.ofEpochMilli(db.meta.creationTime.getMillis())),
                    "modification_time", PrimitiveValue.timestamp(Instant.ofEpochMilli(db.meta.modificationTime.getMillis())),
                    "size", PrimitiveValue.int64(db.meta.size.toBytes()),
                    "records_count", PrimitiveValue.int64(db.meta.recordsCount),
                    "description", descriptionValue
            ).toMap();
        }));
    }

    public void save(Database db, long currentRevision) {
        String sql = "DECLARE $rev_new as Int64;\n"
                + "DECLARE $rev_old as Int64;\n"
                + "DECLARE $modification_time as Timestamp;\n"
                + "DECLARE $description as \"String?\";\n"
                + "DECLARE $records_count as Int64;\n"
                + "DECLARE $size as Int64;\n"
                + "DECLARE $user_id as String;\n"
                + "DECLARE $app as String;\n"
                + "DECLARE $dbId as String;\n"
                + "DECLARE $handle as String;\n"
                + "\n"
                + "SELECT count(*) FROM databases "
                + "WHERE user_id = $user_id AND app = $app AND dbId = $dbId AND handle = $handle AND rev = $rev_old; \n"
                + "UPDATE databases SET "
                + "rev = $rev_new, modification_time = $modification_time, description = $description, records_count = $records_count, size = $size "
                + "WHERE user_id = $user_id AND app = $app AND dbId = $dbId AND handle = $handle AND rev = $rev_old";

        OptionalValue descriptionValue = getOptionalStringValue(db.meta.description);

        Params params = Params.create()
                .put("$rev_new", PrimitiveValue.int64(db.rev))
                .put("$modification_time", PrimitiveValue.timestamp(Instant.ofEpochMilli(db.meta.modificationTime.getMillis())))
                .put("$description", descriptionValue)
                .put("$records_count", PrimitiveValue.int64(db.meta.recordsCount))
                .put("$size", PrimitiveValue.int64(db.meta.size.toBytes()))
                .put("$user_id", PrimitiveValue.string(db.uid.toString().getBytes()))
                .put("$app", PrimitiveValue.string(db.dbAppId().getBytes()))
                .put("$dbId", PrimitiveValue.string(db.databaseId().getBytes()))
                .put("$handle", PrimitiveValue.string(db.handleValue().getBytes()))
                .put("$rev_old", PrimitiveValue.int64(currentRevision))
                ;

        if(queryForLong(sql, params) == 0) {
            throw new DatabaseRevisionMismatchException(StringUtils.format(
                    "Database {} has not {} revision or does not exist", db.databaseId(), currentRevision));
        }
    }

    public void delete(DataApiUserId uid, DatabaseContext dbContext, ListF<String> databaseIds) {
        StructType structType = YdbTypeUtils.structTypeFromFields(Tuple2List.fromPairs(
                "user_id", PrimitiveType.string(),
                "app", PrimitiveType.string(),
                "dbId", PrimitiveType.string()
        ));

        ListType listType = ListType.of(structType);

        // we need this JOIN because of no indexes fo IN $keys
        // YQL-4253
        String sql = "" +
                "PRAGMA SimpleColumns=\"true\";\n" +
                "" +
                "DECLARE $keys as \"List<Struct<" +
                "   'user_id': String," +
                "   'app': String," +
                "   'dbId': String" +
                ">>\";\n" +
                "\n" +
                "$to_delete = (" +
                "   SELECT t.* FROM AS_TABLE($keys) AS k" +
                "   INNER JOIN databases AS t\n" +
                "   ON t.user_id = k.user_id AND t.app = k.app AND t.dbId = k.dbId" +
                ");\n" +
                "SELECT COUNT(*) FROM $to_delete;\n" +
                "DELETE FROM databases ON SELECT * FROM $to_delete";


        Params params = Params.create()
                .put("$keys", listType.newValue(databaseIds.map(dbId -> structType.newValue(Cf.map(
                        "user_id", PrimitiveValue.string(uid.toString().getBytes()),
                        "app", PrimitiveValue.string(dbContext.dbAppId().getBytes()),
                        "dbId", PrimitiveValue.string(dbId.getBytes())
                )))));

        long affected = queryForLong(sql, params);

        if (databaseIds.size() != affected) {
            throw new IncorrectResultSizeDataAccessException(databaseIds.size(), (int) affected);
        }
    }

    public DatabaseHandles findHandles(DataApiUserId uid, DatabasesFilter filter) {
        SqlCondition condition = ConditionUtils.column("user_id").eq(uid.toString())
                .and(filter.toSqlCondition());

        return DatabaseHandles.fromDatabaseIdHandleTuples(filter,
                queryForList("SELECT dbId, handle FROM databases", condition, (rs, rowNum) -> new Tuple2<>(
                        rs.getColumn(0).getString(CharsetUtils.UTF8_CHARSET),
                        rs.getColumn(1).getString(CharsetUtils.UTF8_CHARSET))));
    }
}
