package ru.yandex.chemodan.app.djfs.core.album;

import java.sql.PreparedStatement;
import java.sql.SQLException;
import java.util.Comparator;
import java.util.Objects;

import lombok.Data;
import org.bson.types.ObjectId;
import org.postgresql.util.PSQLException;
import org.postgresql.util.ServerErrorMessage;
import org.springframework.dao.DataIntegrityViolationException;
import org.springframework.jdbc.core.BatchPreparedStatementSetter;
import org.springframework.jdbc.core.RowMapper;

import ru.yandex.bolts.collection.Cf;
import ru.yandex.bolts.collection.ListF;
import ru.yandex.bolts.collection.MapF;
import ru.yandex.bolts.collection.Option;
import ru.yandex.bolts.collection.SetF;
import ru.yandex.bolts.collection.Tuple2List;
import ru.yandex.chemodan.app.djfs.core.db.EntityAlreadyExistsException;
import ru.yandex.chemodan.app.djfs.core.db.pg.PgShardedDao;
import ru.yandex.chemodan.app.djfs.core.db.pg.PgShardedDaoContext;
import ru.yandex.chemodan.app.djfs.core.db.pg.ResultSetUtils;
import ru.yandex.chemodan.app.djfs.core.user.DjfsUid;
import ru.yandex.chemodan.app.djfs.core.util.UuidUtils;
import ru.yandex.commune.dynproperties.DynamicProperty;
import ru.yandex.misc.log.mlf.Logger;
import ru.yandex.misc.log.mlf.LoggerFactory;

/**
 * @author eoshch
 */
public class PgAlbumItemDao extends PgShardedDao implements AlbumItemDao {
    private static final Logger logger = LoggerFactory.getLogger(PgAlbumItemDao.class);

    private final static RowMapper<AlbumItem> M = (rs, rowNum) -> AlbumItem.builder()
        .id(new ObjectId(rs.getBytes("id")))
        .uid(DjfsUid.cons(rs.getLong("uid")))
        .albumId(new ObjectId(rs.getBytes("album_id")))
        .objectId(rs.getString("obj_id"))
        .objectType(AlbumItemType.R.fromValue(rs.getString("obj_type")))
        .description(ResultSetUtils.getStringO(rs, "description"))
        .orderIndex(ResultSetUtils.getDoubleO(rs, "order_index"))
        .groupId(ResultSetUtils.getUuidO(rs, "group_id").map(UuidUtils::toHexString))
        .faceInfo(ResultSetUtils.getStringO(rs, "face_info").map(FaceInfo::parseJson))
        .dateCreated(ResultSetUtils.getInstantO(rs, "date_created"))
        .build();

    private final DynamicProperty<Integer> batchSize = new DynamicProperty<>("disk-djfs-albums-insert-album-items-batch-size", 30);

    public PgAlbumItemDao(PgShardedDaoContext context) {
        super(context);
    }

    @Override
    public void deleteAll(DjfsUid uid) {
        String sql = collectStats(uid) + " DELETE FROM disk.album_items WHERE uid = :uid";
        jdbcTemplate(uid).update(sql, Cf.map("uid", uid));
    }

    @Override
    public ListF<AlbumItem> getAllAlbumItems(DjfsUid uid, ListF<ObjectId> albumIds, int limit) {
        if (albumIds.isEmpty()) {
            return Cf.list();
        }

        String sql = collectStats(uid)
                + " SELECT * FROM disk.album_items WHERE uid = :uid AND album_id IN (:album_ids) LIMIT :limit";
        MapF<String, Object> params = Cf.map(
                "uid", uid,
                "album_ids", albumIds.map(ObjectId::toByteArray),
                "limit", limit
        );
        ListF<AlbumItem> items = jdbcTemplate(uid).query(sql, M, params);

        Comparator<AlbumItem> comparator = Comparator
                .comparing(AlbumItem::getAlbumId)
                .thenComparing(x -> x.getOrderIndex().getOrElse(0d));
        return items.sorted(comparator);
    }

    @Override
    public ListF<AlbumItem> getAllAlbumItems(DjfsUid uid, ObjectId albumId) {
        String sql = collectStats(uid)
                + " SELECT * FROM disk.album_items WHERE uid = :uid AND album_id = :album_id";
        MapF<String, Object> params = Cf.map(
                "uid", uid,
                "album_id", albumId.toByteArray()
        );
        return jdbcTemplate(uid).query(sql, M, params);
    }

    @Override
    public long getAlbumItemCount(Album album) {
        final DjfsUid uid = album.getUid();
        final String sql = collectStats(uid) + " SELECT count(*) FROM disk.album_items WHERE uid = ? AND album_id = ?";
        return jdbcTemplate(uid).queryForLong(sql, uid, album.getId().toByteArray());
    }

    @Override
    public void insert(AlbumItem item) {
        String sql = collectStats(item)
                + " INSERT INTO disk.album_items "
                + " (uid, id, album_id, description, group_id, order_index, obj_id, obj_type, face_info, date_created) "
                + " VALUES (:uid, :id, :album_id, :description, :group_id, :order_index, :obj_id, "
                + " :obj_type::disk.album_item_type, :face_info::jsonb, :date_created)";

        MapF<String, Object> params = Cf.toMap(Tuple2List.fromPairs(
                "uid", item.getUid(),
                "id", item.getId().toByteArray(),
                "album_id", item.getAlbumId().toByteArray(),
                "description", item.getDescription().getOrNull(),
                "group_id", item.getGroupId().map(UuidUtils::fromHex).getOrNull(),
                "order_index", item.getOrderIndex().getOrNull(),
                "obj_id", item.getObjectId(),
                "obj_type", item.getObjectType().value(),
                "face_info", item.getFaceInfo().map(FaceInfo::serializeJson).getOrNull(),
                "date_created", item.getDateCreated()
        ));

        try {
            jdbcTemplate(item.getUid()).update(sql, params);
        } catch (DataIntegrityViolationException e) {
            Throwable cause = e.getCause();
            if (cause instanceof PSQLException) {
                ServerErrorMessage error = ((PSQLException) cause).getServerErrorMessage();
                if (error != null && (Objects.equals(error.getConstraint(), "pk_album_items")
                        || Objects.equals(error.getConstraint(), "uk_album_items_id")))
                {
                    logger.info("PgAlbumItemDao.insert(AlbumItem) handled exception for user "
                            + item.getUid() + " : ", e);
                    throw new EntityAlreadyExistsException(item.getId().toHexString(), e);
                }
            }
            throw e;
        }
    }

    @Override
    public ListF<AlbumItem> findObjectInAlbum(DjfsUid uid, ObjectId albumId, String objectId) {
        String sql = collectStats(uid)
                + " SELECT * FROM disk.album_items WHERE uid = :uid AND album_id = :album_id"
                + " AND obj_id = :obj_id";
        MapF<String, Object> params = Cf.map(
                "uid", uid,
                "album_id", albumId.toByteArray(),
                "obj_id", objectId
        );
        return jdbcTemplate(uid).query(sql, M, params);
    }

    @Override
    public ListF<AlbumItem> findObjectsInAlbum(DjfsUid uid, ObjectId albumId, ListF<String> objectsIds) {
        String sql = collectStats(uid)
                + " SELECT * FROM disk.album_items WHERE uid = :uid AND album_id = :album_id"
                + " AND obj_id IN (:obj_ids)";
        MapF<String, Object> params = Cf.map(
                "uid", uid,
                "album_id", albumId.toByteArray(),
                "obj_ids", objectsIds
        );
        return jdbcTemplate(uid).query(sql, M, params);
    }

    @Override
    public ListF<AlbumItem> find(DjfsUid uid, String objectId) {
        String sql = collectStats(uid)
                + " SELECT * FROM disk.album_items WHERE uid = :uid AND obj_id = :obj_id";
        MapF<String, Object> params = Cf.map(
                "uid", uid,
                "obj_id", objectId
        );
        return jdbcTemplate(uid).query(sql, M, params);
    }

    @Override
    public Option<AlbumItem> findByItemId(DjfsUid uid, String albumItemId) {
        String sql = collectStats(uid)
                + " SELECT * FROM disk.album_items WHERE uid = :uid AND id = :id";
        MapF<String, Object> params = Cf.map(
                "uid", uid,
                "id", new ObjectId(albumItemId).toByteArray()
        );
        return jdbcTemplate(uid).queryForOption(sql, M, params);
    }

    @Override
    public int countObjectsInAlbum(DjfsUid uid, ObjectId albumId) {
        String sql = collectStats(uid)
                + " SELECT COUNT(*) FROM disk.album_items WHERE uid = ? AND album_id = ?";
        return jdbcTemplate(uid).queryForInt(sql, uid, albumId.toByteArray());
    }

    @Override
    public boolean removeFromAlbum(DjfsUid uid, ObjectId albumId, String objectId) {
        String sql = collectStats(uid)
                + " DELETE FROM disk.album_items WHERE uid = :uid AND album_id = :album_id"
                + " AND obj_id = :obj_id";
        MapF<String, Object> params = Cf.map(
                "uid", uid,
                "album_id", albumId.toByteArray(),
                "obj_id", objectId
        );
        return jdbcTemplate(uid).update(sql, params) > 0;
    }

    @Override
    public boolean removeFromAlbum(DjfsUid uid, ObjectId albumId, SetF<String> objectIds) {
        if (objectIds.isEmpty()) {
            return false;
        }

        String sql = collectStats(uid)
                + " DELETE FROM disk.album_items WHERE uid = :uid AND album_id = :album_id"
                + " AND obj_id IN (:obj_ids)";
        MapF<String, Object> params = Cf.map(
                "uid", uid,
                "album_id", albumId.toByteArray(),
                "obj_ids", objectIds
        );
        return jdbcTemplate(uid).update(sql, params) > 0;

    }

    @Override
    public int removeAllItemsFromAlbum(DjfsUid uid, ObjectId albumId) {
        String sql = collectStats(uid)
                + " DELETE FROM disk.album_items WHERE uid = :uid AND album_id = :album_id";
        MapF<String, Object> params = Cf.map(
                "uid", uid,
                "album_id", albumId.toByteArray()
        );
        return jdbcTemplate(uid).update(sql, params);
    }

    @Override
    public ListF<int[]> batchInsert(DjfsUid uid, ListF<AlbumItem> albumItems) {
        String query = collectStats(uid) + " INSERT INTO disk.album_items (uid, id, album_id, description, group_id, order_index, obj_id, obj_type) "
                + " VALUES (?, ?, ?, ?, ?, ?, ?, ?::disk.album_item_type)";
        int maxBatchSize = batchSize.get();
        ListF<AlbumItem> batch = albumItems.take(maxBatchSize);
        ListF<int[]> result = Cf.arrayList();
        while (batch.isNotEmpty()) {
            result.add(jdbcTemplate(uid).batchUpdate(query, new AlbumItemBatchHandler(batch, batch.size())));
            albumItems = albumItems.drop(batch.size());
            batch = albumItems.take(maxBatchSize);
        }
        return result;
    }

    @Override
    public int changeAlbumId(DjfsUid uid, ObjectId srcAlbumId, ObjectId dstAlbumId) {
        String query = collectStats(uid)
                + " UPDATE disk.album_items SET album_id = :dst_album_id WHERE uid = :uid AND album_id = :src_album_id";
        return jdbcTemplate(uid).update(query,
                Cf.map("uid", uid, "src_album_id", srcAlbumId.toByteArray(), "dst_album_id", dstAlbumId.toByteArray()));
    }

    @Override
    public ListF<AlbumItem> getExistingAlbumItems(DjfsUid uid, ObjectId albumId) {
        String sql = collectStats(uid)
                + " SELECT * " +
                " FROM disk.album_items a LEFT JOIN disk.files f ON a.uid = f.uid and '\\x' || a.obj_id = CAST(f.id as text) " +
                " WHERE a.uid = :uid and album_id=:album_id AND fid is not NULL";
        MapF<String, Object> params = Cf.map(
                "uid", uid,
                "album_id", albumId.toByteArray()
        );
        return jdbcTemplate(uid).query(sql, M, params);
    }

    @Data
    private static class AlbumItemBatchHandler implements BatchPreparedStatementSetter {

        private final ListF<AlbumItem> albumItems;

        private final int maxBatchSize;

        @Override
        public void setValues(PreparedStatement ps, int i) throws SQLException {
            AlbumItem item = albumItems.get(i);
            ps.setLong(1, item.getUid().asLong());
            ps.setBytes(2, item.getId().toByteArray());
            ps.setBytes(3, item.getAlbumId().toByteArray());
            ps.setString(4, item.getDescription().getOrNull());
            ps.setObject(5, item.getGroupId().map(UuidUtils::fromHex).getOrNull());
            ps.setDouble(6, item.getOrderIndex().getOrNull());
            ps.setString(7, item.getObjectId());
            ps.setString(8, item.getObjectType().value());
        }

        @Override
        public int getBatchSize() {
            return Math.min(maxBatchSize, albumItems.size());
        }
    }
}
