package ru.yandex.chemodan.app.ydbtest;

import java.io.IOException;
import java.time.Duration;
import java.util.concurrent.ConcurrentHashMap;

import javax.servlet.ServletException;
import javax.servlet.http.HttpServlet;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;

import com.yandex.ydb.table.Session;
import com.yandex.ydb.table.TableClient;
import com.yandex.ydb.table.query.DataQuery;
import com.yandex.ydb.table.query.DataQueryResult;
import com.yandex.ydb.table.query.Params;
import com.yandex.ydb.table.result.ResultSetReader;
import com.yandex.ydb.table.transaction.TxControl;
import com.yandex.ydb.table.values.PrimitiveValue;
import org.joda.time.Instant;

import ru.yandex.bolts.collection.Cf;
import ru.yandex.bolts.collection.ListF;
import ru.yandex.bolts.collection.Option;
import ru.yandex.misc.ExceptionUtils;
import ru.yandex.misc.io.IoFunction;
import ru.yandex.misc.log.mlf.Logger;
import ru.yandex.misc.log.mlf.LoggerFactory;
import ru.yandex.misc.time.TimeUtils;
import ru.yandex.misc.web.servlet.HttpServletRequestX;
import ru.yandex.misc.web.servlet.HttpServletResponseX;

/**
 * @author tolmalev
 */
public class YdbtestServlet extends HttpServlet {
    private static final Logger logger = LoggerFactory.getLogger(YdbtestServlet.class);

    private static final Duration SESSION_GET_TIMEOUT = Duration.ofSeconds(5);

    private final TableClient tableClient;
    private final ConcurrentHashMap<Session, ConcurrentHashMap<String, DataQuery>>  preparedQueries =
            new ConcurrentHashMap<Session, ConcurrentHashMap<String, DataQuery>>();

    public YdbtestServlet(TableClient tableClient) {
        this.tableClient = tableClient;
    }

    @Override
    protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
        HttpServletRequestX reqX = HttpServletRequestX.wrap(req);
        HttpServletResponseX respX = HttpServletResponseX.wrap(resp);

        String handle = reqX.getParameter("handle");
        Option<String> collectionId = reqX.getParameterO("collection_id");
        Option<String> recordId = reqX.getParameterO("record_id");

        String paramsDeclare = "DECLARE $handle as String;\n";
        String query = "SELECT * FROM data WHERE handle = $handle";
        if (collectionId.isPresent()) {
            query += " AND collection_id = $collection_id?";
            paramsDeclare += "DECLARE $collection_id as String;\n";
        }
        if (recordId.isPresent()) {
            query += " AND record_id = $record_id";
            paramsDeclare += "DECLARE $record_id as String;\n";
        }
        query += ";";

        String finalQuery = paramsDeclare + query;
        executeInSession(session -> {
            DataQuery prepared = prepare(session, finalQuery);

            Params params = prepared.newParams().put("$handle", PrimitiveValue.string(handle.getBytes()));

            if (collectionId.isPresent()) {
                params.put("$collection_id", PrimitiveValue.string(collectionId.get().getBytes()));
            }
            if (recordId.isPresent()) {
                params.put("$record_id", PrimitiveValue.string(collectionId.get().getBytes()));
            }

            try {
                ListF<String> jsons = Cf.arrayList();
                TxControl txControl = TxControl.serializableRw();
                Instant start = Instant.now();
                DataQueryResult result = prepared.execute(txControl, params).get().expect("Can't get result");
                logger.info("execute read took {}", TimeUtils.secondsStringToNow(start));
                start = Instant.now();

                ResultSetReader rs = result.getResultSet(0);
                while (rs.next()) {
                    jsons.add(rs.getColumn("j_content").getJson());
                }
                logger.info("read data took {}", TimeUtils.secondsStringToNow(start));

                start = Instant.now();
                session.release();
                logger.info("session release took {}", TimeUtils.secondsStringToNow(start));

                StringBuilder resultSb = new StringBuilder();

                for (String json : jsons) {
                    resultSb
                            .append(json)
                            .append("\n");
                }

                respX.writeContent(resultSb.toString().getBytes(), "application/json");
            } catch (Exception e) {
                throw ExceptionUtils.translate(e);
            }
            return null;
        });

        respX.setStatus(200);
        respX.flushBuffer();
    }

    @Override
    protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
        HttpServletRequestX reqX = HttpServletRequestX.wrap(req);

        byte[] data = reqX.getInputStreamX().readBytes();
        String handle = reqX.getParameter("handle");
        String collectionId = reqX.getParameter("collection_id");
        String recordId = reqX.getParameter("record_id");

        TxControl txControl = TxControl.serializableRw();

        executeInSession(session -> {
            String query = "" +
                    "DECLARE $handle as String;\n" +
                    "DECLARE $collection_id as String;\n" +
                    "DECLARE $record_id as String;\n" +
                    "DECLARE $j_content as Json;\n" +
                    "DECLARE $rev as Uint64;\n" +
                    "DECLARE $user_id as String;\n" +

                    "REPLACE INTO data " +
                    "(handle, collection_id, record_id, j_content, rev, user_id)" +
                    "VALUES ($handle, $collection_id, $record_id, $j_content, $rev, $user_id)";

            DataQuery prepared = prepare(session, query);

            Params params = prepared.newParams()
                    .put("$handle", PrimitiveValue.string(handle.getBytes()))
                    .put("$collection_id", PrimitiveValue.string(collectionId.getBytes()))
                    .put("$record_id", PrimitiveValue.string(recordId.getBytes()))
                    .put("$j_content", PrimitiveValue.json(new String(data)))
                    .put("$rev", PrimitiveValue.uint64(0))
                    .put("$user_id", PrimitiveValue.string(handle.getBytes()))
                    ;

            try {
                Instant start = Instant.now();
                prepared.execute(txControl, params).get().expect("Failed to insert data");
                logger.info("execute write took {}", TimeUtils.secondsStringToNow(start));
            } catch (Exception e) {
                throw ExceptionUtils.translate(e);
            }
            Instant start = Instant.now();
            session.release();
            logger.info("session release took {}", TimeUtils.secondsStringToNow(start));
            return null;
        });


        resp.setStatus(201);
        resp.flushBuffer();
    }

    DataQuery prepare(Session session, String query) {
        return preparedQueries
                .computeIfAbsent(session, s -> new ConcurrentHashMap<>())
                .computeIfAbsent(query, q -> {
                    try {
                        return session.prepareDataQuery(q).get().expect("Can't prepare query");
                    } catch (Exception e) {
                        throw ExceptionUtils.translate(e);
                    }
                });
    }

    protected <T> T executeInSession(IoFunction<Session, T> fn) {
        try {
            Instant start = Instant.now();
            Session getSession = tableClient.getOrCreateSession(SESSION_GET_TIMEOUT).get().expect("Can't get session");
            logger.info("getOrCreateSession took {}", TimeUtils.secondsStringToNow(start));
            return fn.apply(getSession);
        } catch (Exception e) {
            throw ExceptionUtils.translate(e);
        }
    }
}
