#include <maps/wikimap/mapspro/services/mrc/libs/video_frames_reader/include/frames_reader.h>

#include <cstdint>
#include <mutex>

namespace maps::mrc::video {

namespace {

void initializeFFMpeg()
{
    static bool initialized = false;
    static std::mutex mtx;
    std::lock_guard lock{mtx};
    if (!initialized) {
        av_register_all();
        av_log_set_level(AV_LOG_QUIET);
        initialized = true;
    }
}

} // namespace

VideoFramesReader::VideoFramesReader()
{
    initializeFFMpeg();
}

void VideoFramesReader::open(const std::string& uri)
{
    close();
    avFormatContext_.open(uri);
    initVideoStream();
    initSwsContext();

    duration_ = std::chrono::milliseconds(avFormatContext_->duration / (AV_TIME_BASE / 1000));
    fps_ = av_q2d(avVideoStream_->codec->framerate);
    frameIndex_ = 0;
    isEOF_ = false;
}

void VideoFramesReader::close()
{
    duration_ = TimeInterval(0);
    videoStreamIndex_ = -1;
    frameIndex_ = -1;
    fps_ = 0.;
    isEOF_ = true;

    frames_.clear();
    avVideoStream_.close();
    avFormatContext_.close();
}

std::optional<Frame> VideoFramesReader::readFrame()
{
    REQUIRE(avFormatContext_.isOpened(), "Input is not opened");

    if (frames_.empty()) {
        grabFrames();
    }

    if (frames_.empty()) {
        return std::nullopt;
    }
    auto result = frames_.front();
    frames_.pop_front();
    return result;
}

void VideoFramesReader::seek(TimeInterval timeFromStart)
{
    REQUIRE(avFormatContext_.isOpened(), "Input is not opened");
    REQUIRE(timeFromStart >= TimeInterval(0) && timeFromStart < duration_,
            "Seeking beyond video duration");

    auto millis = std::chrono::duration_cast<std::chrono::milliseconds>(timeFromStart).count();
    auto units = av_rescale(millis, avVideoStream_->time_base.den, avVideoStream_->time_base.num) / 1000;

    auto ret = avformat_seek_file(avFormatContext_.get(), videoStreamIndex_, 0, units, units, AVSEEK_FLAG_ANY);
	checkReturnCode(ret, "avformat_seek_frame");

    avcodec_flush_buffers(avVideoStream_->codec);
    frameIndex_ = millis * fps_ / 1000;
    isEOF_ = false;
}

TimeInterval VideoFramesReader::duration() const
{
    REQUIRE(avFormatContext_.isOpened(), "Input is not opened");
    return duration_;
}


void VideoFramesReader::initVideoStream()
{
    ASSERT(avFormatContext_.isOpened());

    int ret = avformat_find_stream_info(avFormatContext_.get(), nullptr);
    checkReturnCode(ret, "avformat_find_stream_info");

    AVCodec* videoCodec = nullptr;
    videoStreamIndex_ = av_find_best_stream(avFormatContext_.get(), AVMEDIA_TYPE_VIDEO, -1, -1, &videoCodec, 0);
    REQUIRE(videoStreamIndex_ >= 0, "Problem in av_find_best_stream");

    avVideoStream_.wrap(avFormatContext_->streams[videoStreamIndex_]);
    avVideoStream_.open(videoCodec);
}

void VideoFramesReader::initSwsContext()
{
    ASSERT(avVideoStream_.get());

    auto* codecCtx = avVideoStream_->codec;
    REQUIRE(codecCtx->pix_fmt != AV_PIX_FMT_NONE, "Invalid pix_fmt: " << codecCtx->pix_fmt);

    swsContext_ = sws_getCachedContext(
            swsContext_.get(), codecCtx->width, codecCtx->height, codecCtx->pix_fmt,
            codecCtx->width, codecCtx->height, AV_PIX_FMT_BGR24, SWS_BICUBIC,
            nullptr, nullptr, nullptr);

    REQUIRE(swsContext_.get(), "Problem in sws_getCachedContext");
}

void VideoFramesReader::grabFrames()
{
    while (frames_.empty() && !isEOF_) {
        AVPacketWrapper packetWrapper = readPacket();
        auto& packet = packetWrapper.get();

        while (!isEOF_) {
            int gotPicture = 0;
            int ret = avcodec_decode_video2(avVideoStream_->codec, avFrame_.get(), &gotPicture, &packet);
            checkReturnCode(ret, "avcodec_decode_video2");

            if (gotPicture) {
                std::chrono::milliseconds timeFromStart(static_cast<int>(frameIndex_ / fps_ * 1000));
                frames_.emplace_back(Frame{avFrameToCvMat(avFrame_.get()), timeFromStart});
            }
            ++frameIndex_;

            packet.data += ret;
            packet.size -= ret;

            if (packet.size <= 0) {
                break;
            }
        }
    }
}

AVPacketWrapper VideoFramesReader::readPacket()
{
    AVPacketWrapper packetWrapper;
    auto& packet = packetWrapper.get();

    packet.data = nullptr;
    packet.size = 0;

    REQUIRE(avFormatContext_.isOpened(), "Input is not opened");

    while (!isEOF_) {
        int ret = av_read_frame(avFormatContext_.get(), &packet);
        if (ret == AVERROR_EOF) {
            isEOF_ = true;
            break;
        }
        checkReturnCode(ret, "av_read_frame");

        if (packet.stream_index == avVideoStream_->index)
        {
            break;
        }
    }
    return packetWrapper;
}

cv::Mat VideoFramesReader::avFrameToCvMat(AVFrame* avFrame)
{
    ASSERT(avFrame);
    auto* codecCtx = avVideoStream_->codec;
    cv::Mat frame(codecCtx->height, codecCtx->width, CV_8UC3);

    uint8_t *const dst[AV_NUM_DATA_POINTERS] = {(uint8_t* const)frame.datastart};
    const int dstStride[AV_NUM_DATA_POINTERS] = {(int)frame.step[0]};

    sws_scale(swsContext_.get(), avFrame->data, avFrame->linesize, 0, avFrame->height, dst, dstStride);
    return frame;
}

} // namespace maps::mrc::video
