#include "exporter.h"
#include "loader.h"
#include "tools.h"
#include "traffic_sign_groups.h"
#include "utils.h"

#include <maps/libs/log8/include/log8.h>
#include <maps/libs/sql_chemistry/include/batch_load.h>

#include <maps/wikimap/mapspro/services/mrc/libs/common/include/algorithm/collection.h>
#include <maps/wikimap/mapspro/services/mrc/libs/common/include/exif.h>
#include <maps/wikimap/mapspro/services/mrc/libs/db/include/eye/frame_gateway.h>
#include <maps/wikimap/mapspro/services/mrc/libs/db/include/eye/object_gateway.h>
#include <maps/wikimap/mapspro/services/mrc/libs/db/include/eye/recognition_gateway.h>
#include <maps/wikimap/mapspro/services/mrc/libs/fb/include/objects_writer.h>
#include <maps/wikimap/mapspro/services/mrc/libs/fb/include/traffic_sign_groups_writer.h>


namespace maps::mrc::signs_export_gen {

namespace tbl = db::eye::table;

namespace {

constexpr auto BATCH_SIZE = 5000;
constexpr auto NUM_THREADS = 6;
constexpr auto MAX_FRAMES_PER_OBJECT = 3;

const auto NO_ADD_TABLES = std::vector<traffic_signs::TrafficSign>{};

void prepareExportDirectory(const std::string& path)
{
    std::filesystem::remove_all(path);
    std::filesystem::create_directories(path);
}

void keepLimitedNumber(
    std::vector<fb::FrameWithOptDetection>& frameDetections,
    size_t maxCount)
{
    // Use the simplest implementation for now.
    // More sophisticated implementation could consider
    // frame date, privacy, quality etc.
    if (frameDetections.size() > maxCount) {
        frameDetections.erase(
            frameDetections.begin() + maxCount,
            frameDetections.end());
    }
}

void writeTrafficSignGroups(
    const SignGroups& signGroups,
    const std::string& version,
    const std::string& datasetDir)
{
    fb::TrafficSignGroupsWriter writer;

    for (const auto& group : signGroups) {
        auto orientation = group.alignment == GroupAlignment::Horizontal
            ? fb::GroupOrientation_Horizontal
            : fb::GroupOrientation_Vertical;

        writer.add(group.groupId, group.objectIds, orientation);
    }
    writer.dump(version, datasetDir);
    INFO() << "Written traffic sign groups: " << signGroups.size();
}

} // namespace

void generateSignsExport(
    pgpool3::Pool& pool,
    const std::string& version,
    const std::string& datasetDir)
{
    fb::ObjectsWriter objectsWriter;
    size_t processedCount = 0;
    std::mutex guard;

    prepareExportDirectory(datasetDir);
    auto groupsInfo = collectTrafficSignGroups(pool);
    writeTrafficSignGroups(groupsInfo.first, version, datasetDir);

    auto allObjectIds = loadSignIds(*pool.slaveTransaction());
    INFO() << "Loaded traffic sign ids: " << allObjectIds.size();

    parallelForEachBatch<NUM_THREADS, BATCH_SIZE>(
        allObjectIds, [&](std::span<const db::TId> batch) {
            auto txn = pool.slaveTransaction();

            auto objects = loadObjects(*txn, {batch.begin(), batch.end()});
            removeAdditionalTables(objects);
            auto objectIds = transform(objects, &db::eye::Object::id);
            auto objectLocations = loadObjectLocations(*txn, objectIds);

            auto [detections, detectionIdsByObjectId] = loadObjectDetections(*txn, objects);

            auto [frameIdsWithDetection, frameIdByDetectionId]
                = loadFramesWithDetections(*txn, detections);
            auto [frameIdsWithoutDetection, frameIdsByMissingObjectId]
                = loadFramesWithMissingObjects(*txn, objectIds);
            auto frameIds = joinUnique(frameIdsWithDetection, frameIdsWithoutDetection);
            auto frameById = loadFrames(*txn, frameIds);
            auto frameLocations = loadFrameLocations(*txn, frameIds);

            auto detectionById = common::byId(std::move(detections));

            auto additionalTablesByObjectId = loadAdditionalTables(*txn, objectIds);

            for (const auto& object : objects) {
                std::vector<fb::FrameWithOptDetection> detectionProofs;
                if (detectionIdsByObjectId.count(object.id())) {
                    for (auto detectionId : detectionIdsByObjectId.at(object.id())) {
                        auto frameId = findOptional(frameIdByDetectionId, detectionId);
                        if (frameId && frameById.count(*frameId)) {
                            detectionProofs.push_back({
                                frameById.at(*frameId),
                                frameLocations.at(*frameId),
                                detectionById.at(detectionId)
                            });
                        }
                    }
                }

                std::vector<fb::FrameWithOptDetection> noDetectionProofs;
                if (frameIdsByMissingObjectId.count(object.id())) {
                    for (auto frameId : frameIdsByMissingObjectId.at(object.id())) {
                        if (frameById.count(frameId)) {
                            noDetectionProofs.push_back({
                                frameById.at(frameId),
                                frameLocations.at(frameId),
                                std::nullopt
                            });
                        }
                    }
                }

                keepLimitedNumber(detectionProofs, MAX_FRAMES_PER_OBJECT);
                keepLimitedNumber(noDetectionProofs, MAX_FRAMES_PER_OBJECT);

                if (!detectionProofs.empty() || !noDetectionProofs.empty()) {
                    auto lock = std::lock_guard{guard};
                    objectsWriter.add(
                        object,
                        objectLocations.at(object.id()),
                        join(detectionProofs, noDetectionProofs),
                        findOptional(additionalTablesByObjectId, object.id()).value_or(NO_ADD_TABLES),
                        findOptional(groupsInfo.second, object.id())
                    );
                }
            }

            auto lock = std::lock_guard{guard};
            processedCount += objects.size();
            INFO() << "Processed objects: " << processedCount;
    });

    objectsWriter.dump(version, datasetDir);
    INFO() << "Done writing objects";
}

}  // namespace maps::mrc::signs_export_gen
