package ru.yandex.chemodan.ydb.dao.pojo;

import java.util.function.Function;

import com.yandex.ydb.table.result.ResultSetReader;
import com.yandex.ydb.table.result.ValueReader;
import com.yandex.ydb.table.transaction.TransactionMode;
import com.yandex.ydb.table.values.PrimitiveValue;
import com.yandex.ydb.table.values.Value;
import lombok.AllArgsConstructor;
import lombok.Data;

import ru.yandex.bolts.collection.Cf;
import ru.yandex.bolts.collection.ListF;
import ru.yandex.bolts.collection.MapF;
import ru.yandex.bolts.collection.Option;
import ru.yandex.bolts.collection.SetF;
import ru.yandex.bolts.function.Function0;
import ru.yandex.chemodan.ydb.dao.OneTableYdbDao;
import ru.yandex.chemodan.ydb.dao.ThreadLocalYdbTransactionManager;
import ru.yandex.chemodan.ydb.dao.YdbQueryMapper;
import ru.yandex.chemodan.ydb.dao.YdbRowMapper;
import ru.yandex.chemodan.ydb.dao.YdbUtils;
import ru.yandex.misc.bender.Bender;
import ru.yandex.misc.bender.parse.BenderJsonParser;
import ru.yandex.misc.bender.serialize.BenderJsonSerializer;
import ru.yandex.misc.db.q.SqlCondition;
import ru.yandex.misc.db.q.SqlLimits;
import ru.yandex.misc.db.q.SqlOrder;
import ru.yandex.misc.log.mlf.Logger;
import ru.yandex.misc.log.mlf.LoggerFactory;
import ru.yandex.misc.random.Random2;

/**
 * @author yashunsky
 */
public class OneTablePojoYdbDao<T> extends OneTableYdbDao {
    private static final Logger logger = LoggerFactory.getLogger(OneTablePojoYdbDao.class);

    private final BenderJsonSerializer<T> serializer;
    private final RowMapper<T> rowMapper;

    protected final ThreadLocalYdbTransactionManager transactionManager;

    public OneTablePojoYdbDao(ThreadLocalYdbTransactionManager transactionManager,
            String tableName, Class<T> pojoClass, YdbClassAnalyzer.Description<T> description)
    {
        super(transactionManager, tableName, description.getTableDescription(), description.getHashedColumns(),
                description.getCreateTableSettings());
        this.serializer = Bender.jsonSerializer(pojoClass, description.getBenderConfiguration());
        this.rowMapper = new RowMapper<>(Bender.jsonParser(pojoClass, description.getBenderConfiguration()));
        this.transactionManager = transactionManager;
    }

    public void upsert(ListF<T> objects) {
        upsertBatch(objects.map(this::serialize));
    }

    public void upsert(T object) {
        upsert(Cf.list(object));
    }

    public void insert(ListF<T> objects) {
        insertBatch(objects.map(this::serialize));
    }

    public void insert(T object) {
        insert(Cf.list(object));
    }

    public ListF<T> find(SqlCondition condition, Option<String> index, SqlOrder order, SqlLimits limits) {
        return queryForList("SELECT * FROM " + getTableName(index), condition, order, limits, rowMapper);
    }

    public ListF<T> find(SqlCondition condition) {
        return find(condition, Option.empty(), SqlOrder.unordered(), SqlLimits.all());
    }

    public ListF<T> find(SqlCondition condition, String index) {
        return find(condition, Option.of(index), SqlOrder.unordered(), SqlLimits.all());
    }

    public ListF<T> find(SqlCondition condition, Option<String> index) {
        return find(condition, index, SqlOrder.unordered(), SqlLimits.all());
    }

    public Option<T> findOne(SqlCondition condition, Option<String> index, SqlOrder order, SqlLimits limits) {
        return find(condition, index, order, limits).firstO();
    }

    public Option<T> findOne(SqlCondition condition) {
        return find(condition, Option.empty(), SqlOrder.unordered(), SqlLimits.first(1)).firstO();
    }

    public Option<T> findOne(SqlCondition condition, String index) {
        return find(condition, Option.of(index), SqlOrder.unordered(), SqlLimits.first(1)).firstO();
    }

    public Option<T> findOne(SqlCondition condition, Option<String> index) {
        return find(condition, index, SqlOrder.unordered(), SqlLimits.first(1)).firstO();
    }

    public void delete(SqlCondition condition) {
        delete(condition, Option.empty());
    }

    public void delete(SqlCondition condition, String index) {
        delete(condition, Option.of(index));
    }

    public void delete(SqlCondition condition, Option<String> index) {
        YdbQueryMapper.YdbCondition ydbCondition = YdbQueryMapper.mapWhereSql(condition);
        String sql;
        if (index.isPresent()) {
            String columnsString = Cf.x(getTableDescription().getPrimaryKeys()).mkString(", ");
            String selectSql = "$toDelete = (\n" +
                    "    SELECT " + columnsString + " \n" +
                    "    FROM " + getTableName(index) + " \n" +
                    "   " + ydbCondition.whereSql + "\n" +
                    ");\n";

            sql = ydbCondition.declareSql + selectSql + "DELETE FROM " + getTableName() + " ON SELECT * FROM $toDelete;";
        } else {
            sql = ydbCondition.declareSql + "DELETE FROM " + getTableName(index) + " " + ydbCondition.whereSql;
        }
        execute(sql, ydbCondition.params);
    }

    protected MapF<String, Value> serialize(T object) {
        YdbObjectWriter ydbObjectWriter = new YdbObjectWriter(hashedColumns);
        serializer.serializeJson(object, ydbObjectWriter);
        MapF<String, Value> result = ydbObjectWriter.getFinalResult();
        MapF<String, Value> randomHashes = hashedColumns.filterNot(result::containsKeyTs)
                .toMap(YdbUtils::getHashName, column -> PrimitiveValue.uint32(Random2.R.nextInt()));
        return result.plus(randomHashes);
    }

    protected void saveOrUpdate(YdbQueryMapper.YdbCondition condition,
            Function0<T> ifNew, Function<T, T> ifExists, SetF<String> fieldsToUpdate)
    {
        String findSql = condition.declareSql + "SELECT * FROM " + getTableName() + " " + condition.whereSql;

        transactionManager.executeInTx(() -> {
            logger.debug("Starting select in tx");
            Option<T> current = queryForList(findSql, toParams(condition.params), rowMapper).firstO();

            logger.debug("Processing select result");
            T toUpdate = current.map(ifExists::apply).getOrElse(ifNew);
            MapF<String, Value> values = serialize(toUpdate);

            if (current.isPresent()) {
                logger.debug("Starting update in tx");
                update(condition, values.filterKeys(fieldsToUpdate::containsTs).mapValues(v -> v));
            } else {
                logger.debug("Starting upsert in tx");
                upsert(toUpdate);
            }
            return null;
        }, TransactionMode.SERIALIZABLE_READ_WRITE);
    }

    protected SqlCondition getColumnCondition(String column, Object value) {
        return (hashedColumns.containsTs(column)
                ? SqlCondition.column(YdbUtils.getHashName(column)).eq(PrimitiveValue.uint32(YdbUtils.getHashValue(value)))
                : SqlCondition.trueCondition()).and(SqlCondition.column(column).eq(value));
    }

    protected void update(YdbQueryMapper.YdbCondition condition, T toUpdate, SetF<String> fieldsToSet) {
        update(condition, serialize(toUpdate).filterKeys(fieldsToSet::containsTs).mapValues(v -> v));
    }

    @AllArgsConstructor
    private class RowMapper<V> implements YdbRowMapper<V> {
        private final BenderJsonParser<V> parser;

        @Override
        public V mapRow(ResultSetReader rs, int rowNum) {
            return mapRow(rs);
        }

        public V mapRow(ResultSetReader rs) {
            MapF<String, ValueReader> values = Cf.range(0, rs.getColumnCount()).toMap(rs::getColumnName, rs::getColumn);
            return parser.parseJson(new YdbObjectNode(values));
        }
    }

    protected TimestampIterationConfig initTimestampIterationConfig() {
        return new TimestampIterationConfig(100, 10);
    }

    @AllArgsConstructor
    @Data
    protected static class TimestampIterationConfig {
        private final int chunkSize;
        private final int threads;
    }
}
