package ru.yandex.chemodan.app.telemost.repository.dao.impl;

import java.io.IOException;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.UUID;

import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.node.ObjectNode;
import lombok.SneakyThrows;
import org.jetbrains.annotations.NotNull;
import org.joda.time.Instant;
import org.springframework.jdbc.core.RowMapper;

import ru.yandex.bolts.collection.Cf;
import ru.yandex.bolts.collection.ListF;
import ru.yandex.bolts.collection.MapF;
import ru.yandex.bolts.collection.Option;
import ru.yandex.bolts.collection.Tuple2;
import ru.yandex.bolts.function.Function;
import ru.yandex.chemodan.app.telemost.exceptions.TelemostRuntimeException;
import ru.yandex.chemodan.app.telemost.repository.dao.UserStateDtoDao;
import ru.yandex.chemodan.app.telemost.repository.model.UserBackData;
import ru.yandex.chemodan.app.telemost.repository.model.UserStateDto;
import ru.yandex.chemodan.app.telemost.services.model.PassportOrYaTeamUid;
import ru.yandex.misc.spring.jdbc.JdbcTemplate3;

import static ru.yandex.bolts.collection.Tuple2.tuple;

public class UserStateDtoPgDaoImpl extends AbstractPgUUIDKeyDao<UserStateDto> implements UserStateDtoDao {
    private static final ListF<String> FIELDS_TO_SELECT =
            Cf.list("id", "created_at", "user_data_id", "version", "state", "passport_data_required");

    private static final String SQL_MERGE = "merge";
    private static final String SQL_FIND_STATES_BY_CONFERENCE_AND_PEER_IDS = "find-states-by-conference-and-peer-ids";
    private static final String SQL_ADD_USER_BACK_DATA = "add-user-back-data";
    private static final String SQL_INCREMENT_VERSION = "increment-version";
    private static final String SQL_SET_PASSPORT_DATA_REQUIRED = "passport-data-required";
    private static final String SQL_SET_PASSPORT_DATA_REQUIRED_BY_UID = "passport-data-required-by-uid";

    private static final Function<String, String>
            jsonFieldMapper = field -> !"state".equals(field) ? field : "state::text";

    private final MapF<String, String> queries = Cf.map(
            SQL_MERGE, "INSERT INTO " + getTableName() + " as t (created_at, user_data_id, version, state) " +
                    "VALUES (:created_at, :user_data_id, :initial_version, :state::jsonb) " +
                    "ON CONFLICT (user_data_id) DO UPDATE SET" +
                    " state = CASE WHEN t.state::jsonb -> 'user_data' IS NULL THEN :state::jsonb - 'user_data' ELSE " +
                    " jsonb_set(:state::jsonb, '{\"user_data\"}', t.state::jsonb -> 'user_data', true) END, " +
                    " version = t.version + 1 " +
                    " where t.state != :state::jsonb " +
                    "RETURNING " + getFieldsToSelect().map(jsonFieldMapper).mkString(", "),
            SQL_FIND_STATES_BY_CONFERENCE_AND_PEER_IDS,
            "select u.user_id as peer_id, " +
                    getFieldsToSelect().map(jsonFieldMapper).map(field -> "s." + field).mkString(", ") +
                    " from " + getTableName() + " s " +
                    " join telemost.user_data u on u.conference_id = :conference_id and u.id = s.user_data_id " +
                    " where u.user_id in ( :peer_ids )",
            SQL_ADD_USER_BACK_DATA,
            "UPDATE " + getTableName() + " SET state = state::jsonb - 'user_data' || :back_state::jsonb, version = version + 1 " +
                    "WHERE user_data_id = :user_data_id " +
                    "RETURNING " + getFieldsToSelect().map(jsonFieldMapper).mkString(", "),
            SQL_INCREMENT_VERSION,
            "UPDATE " + getTableName() + " SET version = version + 1 " +
                    "WHERE user_data_id in ( :user_data_ids )" +
                    "RETURNING " + getFieldsToSelect().map(jsonFieldMapper).mkString(", "))
            .plus1(SQL_SET_PASSPORT_DATA_REQUIRED,
                "UPDATE " + getTableName() + " SET passport_data_required = :required WHERE id = :id")
            .plus1(SQL_SET_PASSPORT_DATA_REQUIRED_BY_UID,
                "UPDATE " + getTableName() + " SET passport_data_required = :required" +
                        " WHERE user_data_id IN (SELECT id FROM telemost.user_data WHERE uid = :uid)");

    private final ObjectMapper objectMapper;

    public UserStateDtoPgDaoImpl(JdbcTemplate3 jdbcTemplate, ObjectMapper objectMapper) {
        super(jdbcTemplate);
        this.objectMapper = objectMapper;
    }

    @Override
    @SneakyThrows
    public Option<UserStateDto> updatePeerState(UUID userId, JsonNode state) {
        MapF<String, Object> params = Cf.map(
                "user_data_id", userId,
                "created_at", Instant.now(),
                "initial_version", 0,
                "state", objectMapper.writeValueAsString(state)
        );

        return getJdbcTemplate().query(
                queries.getTs(SQL_MERGE),
                mapper(),
                params
        ).firstO();
    }

    @Override
    public MapF<String, UserStateDto> findStates(UUID conferenceId, ListF<String> peerIds) {
        if (peerIds.isEmpty()) {
            return Cf.map();
        }
        return getJdbcTemplate().query(queries.getTs(SQL_FIND_STATES_BY_CONFERENCE_AND_PEER_IDS),
                (rs, rowNum) -> tuple(rs.getString("peer_id"), parseRow(rs)),
                Cf.map("peer_ids", peerIds, "conference_id", conferenceId))
                .toMap(Tuple2::get1, Tuple2::get2);
    }

    @Override
    @SneakyThrows
    public Option<UserStateDto> insertUserBackDataIntoState(UUID userId, UserBackData userData) {
        JsonNode coreNode = (new ObjectNode(objectMapper.getNodeFactory())).set(
                "user_data", objectMapper.valueToTree(userData)
        );
        MapF<String, Object> params = Cf.map(
                "user_data_id", userId,
                "back_state", objectMapper.writeValueAsString(coreNode)
        );

        return getJdbcTemplate().query(
                queries.getTs(SQL_ADD_USER_BACK_DATA),
                mapper(),
                params
        ).firstO();
    }

    @Override
    public void setPassportDataRequired(UUID id, boolean required) {
        if (id == null)
            throw new IllegalArgumentException();
        MapF<String, Object> params = Cf.map(
                "id", id,
                "required", required
        );
        getJdbcTemplate().update(queries.getTs(SQL_SET_PASSPORT_DATA_REQUIRED), params);
    }

    @Override
    public void setPassportDataRequiredByUid(PassportOrYaTeamUid uid, boolean required) {
        if (uid == null)
            throw new IllegalArgumentException();
        MapF<String, Object> params = Cf.map(
                "uid", uid.asString(),
                "required", required
        );
        getJdbcTemplate().update(queries.getTs(SQL_SET_PASSPORT_DATA_REQUIRED_BY_UID), params);
    }

    @Override
    public ListF<UserStateDto> incrementVersion(ListF<UUID> userDataIds) {
        if (userDataIds.isEmpty())
            return Cf.list();
        MapF<String, Object> params = Cf.map(
                "user_data_ids", userDataIds
        );
        return getJdbcTemplate().query(queries.getTs(SQL_INCREMENT_VERSION), mapper(), params);
    }

    @NotNull
    private RowMapper<UserStateDto> mapper() {
        return (rs, rowNum) -> parseRow(rs);
    }

    @Override
    protected String getTableName() {
        return "telemost.user_state";
    }

    @Override
    protected ListF<String> getFieldsToSelect() {
        return FIELDS_TO_SELECT;
    }

    @Override
    protected UserStateDto parseRow(ResultSet rs) {
        try {
            return new UserStateDto(
                    Option.of(UUID.fromString(rs.getString("id"))),
                    UUID.fromString(rs.getString("user_data_id")),
                    new Instant(rs.getTimestamp("created_at")),
                    rs.getLong("version"),
                    objectMapper.readTree(rs.getString("state")),
                    rs.getBoolean("passport_data_required")
            );
        } catch (SQLException | IOException e) {
            throw new TelemostRuntimeException(e);
        }
    }
}
