#include "make_basic_sample.h"
#include "make_basic_sample_impl.h"

#include "common.h"
#include "limits.h"
#include "uids.h"

#include <maps/wikimap/mapspro/libs/assessment/include/gateway.h>
#include <maps/wikimap/mapspro/libs/social/include/yandex/maps/wiki/social/feedback/gateway_ro.h>
#include <maps/wikimap/mapspro/libs/social/include/yandex/maps/wiki/social/gateway.h>
#include <maps/wikimap/mapspro/services/tasks_social/src/assessment_sampler/lib/load_units/edits.h>
#include <maps/wikimap/mapspro/services/tasks_social/src/assessment_sampler/lib/load_units/feedback.h>
#include <maps/wikimap/mapspro/services/tasks_social/src/assessment_sampler/lib/load_units/moderation.h>
#include <maps/wikimap/mapspro/services/tasks_social/src/assessment_sampler/lib/load_units/tracker.h>


namespace maps::wiki::assessment::sampler {

namespace impl {

GroupNameToUnits
pickBasicUnits(Mode mode, std::mt19937& rndGen, GroupNameToUnits&& groupNameToUnits)
{
    size_t unitsNumber{0};

    // Shuffle units in groups prior to cropping
    for (auto& [_, units]: groupNameToUnits) {
        std::shuffle(units.begin(), units.end(), rndGen);
        unitsNumber += units.size();
    }

    return cropGroups(
        std::move(groupNameToUnits), unitsRatio(mode), maxUnits(mode, unitsNumber)
    );
}


std::vector<TId>
getOrCreateUnits(assessment::Gateway& assessmentGw, const GroupNameToUnits& groupNameToUnits)
{
    std::vector<TId> result;

    for (const auto& [_, units]: groupNameToUnits) {
        for (const auto& unit: units) {
            auto unitId = unit.id;
            if (unitId == UNKNOWN_UNIT_ID) {
                unitId = assessmentGw.getOrCreateUnit(unit.entity, unit.action);
            }
            result.emplace_back(unitId);
        }
    }

    return result;
}

} // namespace impl


void makeBasicSample(
    pqxx::transaction_base& txn,
    Entity::Domain domain,
    chrono::TimePoint timepointMin,
    chrono::TimePoint timepointMax,
    std::mt19937& rndGen,
    StaffType staffType,
    HypothesesMode hypothesesMode,
    const std::string& sampleName,
    const std::optional<std::string>& workerConfigPath)
{
    const Mode mode{domain, staffType, hypothesesMode};

    GroupNameToUnits groupNameToUnits;

    switch (domain) {
    case Entity::Domain::Edits: {
        social::Gateway socialGw{txn};
        groupNameToUnits = loadEditsUnits(
            socialGw,
            timepointMin,
            timepointMax,
            loadAllowedUids(staffType, workerConfigPath));
        break;
    }
    case Entity::Domain::Feedback: {
        social::feedback::GatewayRO fbGw{txn};
        groupNameToUnits = loadFeedbackUnits(
            fbGw,
            timepointMin,
            timepointMax,
            loadAllowedUids(staffType, workerConfigPath),
            hypothesesMode);
        break;
    }
    case Entity::Domain::Moderation: {
        social::Gateway socialGw{txn};
        groupNameToUnits = loadModerationUnits(
            socialGw,
            timepointMin,
            timepointMax,
            loadAllowedUids(staffType, workerConfigPath));
        break;
    }
    case Entity::Domain::Tracker:
        groupNameToUnits = loadTrackerUnits(
            loadStaffLoginToPuidFromYt(),
            timepointMin,
            loadClosedTickets(timepointMin, timepointMax),
            loadOnSupportSideTickets(timepointMin, timepointMax));
        break;
    default:
        throw maps::RuntimeError() << "Entity domain '" << domain << "' is not supported";
    }

    groupNameToUnits = impl::pickBasicUnits(mode, rndGen, std::move(groupNameToUnits));

    assessment::Gateway assessmentGw{txn};
    const auto unitIds = impl::getOrCreateUnits(assessmentGw, groupNameToUnits);
    makeSample(assessmentGw, domain, Qualification::Basic, sampleName, unitIds, tasksPerUnit(mode));
}

} // maps::wiki::assessment::sampler
