#include "make_expert_sample.h"
#include "make_expert_sample_impl.h"

#include "common.h"
#include "limits.h"
#include "sample.h"

#include <maps/libs/log8/include/log8.h>
#include <maps/wikimap/mapspro/libs/common/include/yandex/maps/wiki/common/robot.h>
#include <maps/wikimap/mapspro/libs/assessment/include/gateway.h>

namespace maps::wiki::assessment::sampler {

namespace {

SampleInfo getBasicSample(
    const Console& console,
    const Mode& mode,
    const social::DateTimeCondition& createdAt)
{
    auto filter = SampleFilter()
        .entityDomain(mode.domain)
        .qualification(Qualification::Basic)
        .createdAt(createdAt);

    switch (mode.staffType) {
    case StaffType::Staff:
        filter.nameNotLike(PIECEWORK_PREFIX + "%");
        filter.nameNotLike(OUTSOURCE_PREFIX + "%");
        break;
    case StaffType::Piecework:
        filter.nameLike(PIECEWORK_PREFIX + "%");
        break;
    case StaffType::Outsource:
        filter.nameLike(OUTSOURCE_PREFIX + "%");
        break;
    }

    switch (mode.hypothesesMode) {
    case HypothesesMode::Yes:
        filter.nameLike("%" + HYPOTHESES_PREFIX + "%");
        break;
    case HypothesesMode::No:
        filter.nameNotLike("%" + HYPOTHESES_PREFIX + "%");
        break;
    }

    static constexpr TId before = 0;
    static constexpr TId after = 0;
    static constexpr size_t perPage = 1;

    const auto feed = console.sampleFeed(
        SampleFeedParams(before, after, perPage),
        filter);

    REQUIRE(
        !feed.hasMore(),
        "Too many " << toString(mode.domain) << " "
        "samples since " << chrono::formatIsoDateTime(*createdAt.first()) << " "
        "till " << chrono::formatIsoDateTime(*createdAt.last()));

    REQUIRE(
        !feed.samples().empty(),
        "No " << toString(mode.domain) << " "
        "samples since " << chrono::formatIsoDateTime(*createdAt.first()) << " "
        "till " << chrono::formatIsoDateTime(*createdAt.last()));

    return feed.samples()[0];
}

impl::UnitGroups getUnitGroups(const assessment::Gateway& gateway, TId sampleId)
{
    impl::UnitGroups groups;
    for (const auto& stats : gateway.sampleGradeStats(sampleId)) {
        if (!stats.correct && !stats.incorrect) {
            groups.notGraded.push_back(stats.unitId);
        } else if (stats.correct > 0 && stats.incorrect > 0) {
            groups.inconsistent.push_back(stats.unitId);
        } else {
            groups.consistent.push_back(stats.unitId);
        }
    }
    return groups;
}

std::vector<TId>
shuffleAndPickUnits(std::mt19937& rndGen, std::vector<TId>&& unitIds, size_t maxUnitsToPick)
{
    std::shuffle(unitIds.begin(), unitIds.end(), rndGen);
    unitIds.resize(std::min(unitIds.size(), maxUnitsToPick));
    return unitIds;
}

} // namespace

namespace impl {

size_t UnitGroups::totalUnitCount() const
{
    return inconsistent.size() + consistent.size() + notGraded.size();
}

void UnitGroups::logStats(maps::log8::Level level, const std::string& sampleName) const
{
    MAPS_LOG(level) << "Basic sample '" << sampleName << "' with unit stats:";
    MAPS_LOG(level) << "\tinconsistent: " << inconsistent.size();
    MAPS_LOG(level) << "\tconsistent " << consistent.size();
    MAPS_LOG(level) << "\tnot-graded: " << notGraded.size();
}

std::vector<TId>
pickExpertUnits(
    std::mt19937& rndGen,
    UnitGroups&& unitGroups,
    double unitsRatio,
    double inconsistentUnitsRatio)
{
    std::vector<TId> result;

    const size_t totalUnits = std::ceil(unitGroups.totalUnitCount() * unitsRatio);
    const size_t maxInconsistentUnitsToPick = std::ceil(totalUnits * inconsistentUnitsRatio);
    result = shuffleAndPickUnits(rndGen, std::move(unitGroups.inconsistent), maxInconsistentUnitsToPick);

    const size_t maxConsistentUnitsToPick = totalUnits - result.size();
    const auto consistentUnitIds = shuffleAndPickUnits(rndGen, std::move(unitGroups.consistent), maxConsistentUnitsToPick);
    std::move(consistentUnitIds.begin(), consistentUnitIds.end(), std::back_inserter(result));

    return result;
}

} // namespace impl

void makeExpertSample(
    pqxx::transaction_base& txn,
    Entity::Domain entityDomain,
    chrono::TimePoint timepointMin,
    chrono::TimePoint timepointMax,
    std::mt19937& rndGen,
    StaffType staffType,
    HypothesesMode hypothesesMode,
    const std::string& sampleName)
{
    const Mode mode = {entityDomain, staffType, hypothesesMode};
    const social::DateTimeCondition createdAt(timepointMin, timepointMax);

    assessment::Gateway assessmentGw{txn};
    const auto sampleInfo = getBasicSample(
        assessmentGw.console(common::ROBOT_UID), mode, createdAt
    );
    auto unitGroups = getUnitGroups(assessmentGw, sampleInfo.id);
    unitGroups.logStats(maps::log8::Level::INFO, sampleInfo.name);
    const auto unitIds = pickExpertUnits(rndGen, std::move(unitGroups), EXPERT_UNITS_RATIO, INCONSISTENT_UNITS_RATIO);

    makeSample(assessmentGw, entityDomain, Qualification::Expert, sampleName, unitIds, EXPERT_TASKS_PER_UNIT);
}

} // maps::wiki::assessment::sampler
