#include "loader.h"
#include "utils.h"

#include <maps/wikimap/mapspro/services/mrc/libs/common/include/algorithm/collection.h>
#include <maps/wikimap/mapspro/services/mrc/libs/db/include/eye/frame_gateway.h>
#include <maps/wikimap/mapspro/services/mrc/libs/db/include/eye/object_gateway.h>
#include <maps/wikimap/mapspro/services/mrc/libs/db/include/eye/recognition_gateway.h>

namespace maps::mrc::signs_export_gen {

namespace tbl = db::eye::table;

db::TIds
loadSignIds(pqxx::transaction_base& txn)
{
    return db::eye::ObjectGateway{txn}.loadIds(
        tbl::Object::type == db::eye::ObjectType::Sign &&
        tbl::Object::deleted.is(false));
}

std::vector<db::eye::Object>
loadObjects(
    pqxx::transaction_base& txn,
    const db::TIds& objectIds)
{
    return db::eye::ObjectGateway{txn}.load(
        tbl::Object::id.in(objectIds) && tbl::Object::deleted.is(false));
}

db::IdTo<db::eye::ObjectLocation>
loadObjectLocations(
    pqxx::transaction_base& txn,
    const db::TIds& objectIds)
{
    auto locations = db::eye::ObjectLocationGateway(txn).load(
         tbl::ObjectLocation::objectId.in(objectIds));

    return toMapByKey(std::move(locations), &db::eye::ObjectLocation::objectId);
}

std::pair<db::eye::Detections, ObjectIdToDetectionIds>
loadObjectDetections(
    pqxx::transaction_base& txn,
    const db::eye::Objects& objects)
{
    db::TIds detectionIds;
    std::unordered_map<db::TId, db::TIds> objectIdToDetectionIds;
    std::unordered_map<db::TId, db::TId> detectionIdToObjectId;

    for (const auto& object : objects) {
        auto primaryDetectionId = object.primaryDetectionId();
        detectionIds.push_back(primaryDetectionId);
        detectionIdToObjectId[primaryDetectionId] = object.id();
    }

    auto relations = db::eye::PrimaryDetectionRelationGateway(txn).load(
        tbl::PrimaryDetectionRelation::primaryDetectionId.in(detectionIds) &&
        tbl::PrimaryDetectionRelation::deleted.is(false));

    for (const auto& relation : relations) {
        auto detectionId = relation.detectionId();
        auto objectId = detectionIdToObjectId[relation.primaryDetectionId()];
        detectionIds.push_back(detectionId);
        detectionIdToObjectId[detectionId] = objectId;
    }

    auto detections = db::eye::DetectionGateway(txn).load(
        tbl::Detection::id.in(detectionIds) && tbl::Detection::deleted.is(false));

    for (const auto& detection : detections) {
        auto objectId = detectionIdToObjectId.at(detection.id());
        objectIdToDetectionIds[objectId].push_back(detection.id());
    }

    return {std::move(detections), std::move(objectIdToDetectionIds)};
}

std::pair<FrameIds, DetectionIdToFrameId>
loadFramesWithDetections(
    pqxx::transaction_base& txn,
    const db::eye::Detections& detections)
{
    std::unordered_set<db::TId> detGroupIds;

    for (const auto& detection : detections) {
        detGroupIds.insert(detection.groupId());
    }

    auto detGroups = db::eye::DetectionGroupGateway(txn).load(
        tbl::DetectionGroup::id.in({detGroupIds.begin(), detGroupIds.end()}));

    auto groupIdToFrameId = makeMap(detGroups,
        &db::eye::DetectionGroup::id,
        &db::eye::DetectionGroup::frameId);

    db::TIds frameIds;
    DetectionIdToFrameId detectionIdToFrameId;
    for (const auto& detection : detections) {
        auto frameIdItr = groupIdToFrameId.find(detection.groupId());
        if (frameIdItr != groupIdToFrameId.end()) {
            auto frameId = frameIdItr->second;
            frameIds.push_back(frameId);
            detectionIdToFrameId.emplace(detection.id(), frameId);
        }
    }

    return {std::move(frameIds), std::move(detectionIdToFrameId)};
}

std::pair<FrameIds, ObjectIdToFrameIds>
loadFramesWithMissingObjects(
    pqxx::transaction_base& txn,
    const db::TIds& objectIds)
{
    db::TIds frameIds;
    ObjectIdToFrameIds objectIdToFrameIds;

    auto missing = db::eye::ObjectMissingOnFrameGateway{txn}.load(
        tbl::ObjectMissingOnFrame::objectId.in(objectIds) &&
        tbl::ObjectMissingOnFrame::deleted.is(false));

    for (const auto& item : missing) {
        frameIds.push_back(item.frameId());
        objectIdToFrameIds[item.objectId()].push_back(item.frameId());
    }

    return {std::move(frameIds), std::move(objectIdToFrameIds)};
}

db::IdTo<db::eye::Frame>
loadFrames(
    pqxx::transaction_base& txn,
    const db::TIds& frameIds)
{
    auto frames = db::eye::FrameGateway(txn).load(
        tbl::Frame::id.in(frameIds) && tbl::Frame::deleted.is(false));
    return common::byId(std::move(frames));
}

db::IdTo<db::eye::FrameLocation>
loadFrameLocations(
    pqxx::transaction_base& txn,
    const db::TIds& frameIds)
{
    auto locations = db::eye::FrameLocationGateway(txn).load(
         tbl::FrameLocation::frameId.in(frameIds));

    return toMapByKey(std::move(locations), &db::eye::FrameLocation::frameId);
}

db::IdTo<std::vector<traffic_signs::TrafficSign>>
loadAdditionalTables(
    pqxx::transaction_base& txn,
    const db::TIds& objectIds)
{
    auto relations = db::eye::ObjectRelationGateway{txn}.load(
        tbl::ObjectRelation::masterObjectId.in(objectIds) &&
        tbl::ObjectRelation::deleted.is(false)
    );

    auto slaveObjectIds = transform(
        relations, &db::eye::ObjectRelation::slaveObjectId);
    auto slaveObjects = loadObjects(txn, slaveObjectIds);
    auto slaveObjectsById = common::byId(std::move(slaveObjects));

    db::IdTo<std::vector<traffic_signs::TrafficSign>> result;
    for (const auto& relation : relations) {
        auto objectItr = slaveObjectsById.find(relation.slaveObjectId());
        if (objectItr != slaveObjectsById.end()) {
            const auto& attrs = objectItr->second.attrs<db::eye::SignAttrs>();
            result[relation.masterObjectId()].push_back(attrs.type);
        }
    }
    return result;
}

db::eye::DetectionGroups
loadDetectionGroups(
    pqxx::transaction_base& txn,
    const db::TIds& frameIds)
{
    return db::eye::DetectionGroupGateway{txn}.load(
        tbl::DetectionGroup::frameId.in(frameIds) &&
        tbl::DetectionGroup::type == db::eye::DetectionType::Sign);
}

db::eye::Detections
loadDetections(
    pqxx::transaction_base& txn,
    const db::TIds& detGroupIds)
{
    return db::eye::DetectionGateway{txn}.load(
        tbl::Detection::groupId.in(detGroupIds) &&
        tbl::Detection::deleted.is(false));
}

db::IdTo<db::TId>
loadDetectionIdToObjectId(
    pqxx::transaction_base& txn,
    const db::TIds& detectionIds)
{
    auto primaryDetectionRelations
        = db::eye::PrimaryDetectionRelationGateway{txn}.load(
            tbl::PrimaryDetectionRelation::detectionId.in(detectionIds) &&
            tbl::PrimaryDetectionRelation::deleted.is(false));

    auto primaryDetectionIds = transform(
        primaryDetectionRelations,
        &db::eye::PrimaryDetectionRelation::primaryDetectionId);

    auto detectionIdToPrimaryId = makeMap(primaryDetectionRelations,
        &db::eye::PrimaryDetectionRelation::detectionId,
        &db::eye::PrimaryDetectionRelation::primaryDetectionId);

    auto objects = db::eye::ObjectGateway{txn}.load(
        tbl::Object::primaryDetectionId.in(
            joinUnique(detectionIds, primaryDetectionIds)) &&
        tbl::Object::deleted.is(false));
    removeAdditionalTables(objects);

    auto primaryDetectionToObjectId = makeMap(objects,
        &db::eye::Object::primaryDetectionId,
        &db::eye::Object::id);

    db::IdTo<db::TId> result;
    for (db::TId detectionId : detectionIds) {
        auto primaryDetectionId = findOptional(detectionIdToPrimaryId, detectionId)
            .value_or(detectionId);

        auto itr = primaryDetectionToObjectId.find(primaryDetectionId);
        if (itr != primaryDetectionToObjectId.end()) {
            result[detectionId] = itr->second;
        }
    }
    return result;
}

void removeAdditionalTables(db::eye::Objects& objects)
{
    std::erase_if(objects, [](const db::eye::Object& object){
        return traffic_signs::isAdditionalTable(
            object.attrs<db::eye::SignAttrs>().type);
    });
}

}  // namespace maps::mrc::signs_export_gen
