package ru.yandex.travel.api.services.avia.variants.repositories;

import java.sql.Timestamp;
import java.time.Instant;
import java.util.List;
import java.util.UUID;

import com.google.common.base.Preconditions;
import com.google.common.base.Strings;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.boot.autoconfigure.condition.ConditionalOnBean;
import org.springframework.jdbc.core.JdbcTemplate;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Propagation;
import org.springframework.transaction.annotation.Transactional;

import ru.yandex.travel.api.config.avia.AviaBookingConfiguration;
import ru.yandex.travel.api.services.avia.variants.model.AviaCachedVariantCheck;

@Service
@ConditionalOnBean(AviaBookingConfiguration.class)
@RequiredArgsConstructor
@Slf4j
public class AviaCachedVariantRepository {
    private final JdbcTemplate jdbcTemplate;

    @Transactional(propagation = Propagation.MANDATORY)
    public AviaCachedVariantCheck selectForUpdateNoWait(String partnerId, String variantId) {
        // The attempt to lock a record will either fail immediately with org.springframework.dao.CannotAcquireLockException
        // or immediately return in case of either successful locking or record absence.
        // The lock will be held until the current transaction is either committed or rolled back.
        List<AviaCachedVariantCheck> checks = jdbcTemplate.query("select check_id, expires_at " +
                        "from cached_variant_checks " +
                        "where partner_id = ? and variant_id = ? " +
                        "for update nowait",
                new Object[]{partnerId, variantId},
                (rs, i) -> {
                    String checkId = rs.getString("check_id");
                    Timestamp expiresAt = rs.getTimestamp("expires_at");
                    return AviaCachedVariantCheck.builder()
                            .partnerId(partnerId)
                            .variantId(variantId)
                            .checkId(!Strings.isNullOrEmpty(checkId) ? UUID.fromString(checkId) : null)
                            .expiresAt(expiresAt != null ? expiresAt.toInstant() : null)
                            .build();
                });
        Preconditions.checkState(checks.size() <= 1, "at most one row is expected but got: %s", checks);
        return checks.isEmpty() ? null : checks.get(0);
    }

    @Transactional(propagation = Propagation.MANDATORY)
    public void insert(AviaCachedVariantCheck cacheRecord) {
        Instant expiresAt = cacheRecord.getExpiresAt();
        int updated = jdbcTemplate.update(
                "insert into cached_variant_checks(partner_id, variant_id, check_id, expires_at) values(?, ?, ?, ?)",
                cacheRecord.getPartnerId(), cacheRecord.getVariantId(),
                cacheRecord.getCheckId(), expiresAt != null ? Timestamp.from(expiresAt) : null);
        Preconditions.checkState(updated == 1, "unexpected insert result count: %s", updated);
    }

    @Transactional(propagation = Propagation.MANDATORY)
    public void update(AviaCachedVariantCheck cacheRecord) {
        Instant expiresAt = cacheRecord.getExpiresAt();
        int updated = jdbcTemplate.update("update cached_variant_checks " +
                        "set check_id = ?, expires_at = ? " +
                        "where partner_id = ? and variant_id = ?",
                cacheRecord.getCheckId(), expiresAt != null ? Timestamp.from(expiresAt) : null,
                cacheRecord.getPartnerId(), cacheRecord.getVariantId());
        Preconditions.checkState(updated == 1, "unexpected update result count: %s", updated);
    }
}
