#include "scoring.h"

#include <util/generic/algorithm.h>

namespace NInfra::NYpDnsApi {

////////////////////////////////////////////////////////////////////////////////

TClusterUseReporter TClusterDescriptor::GetUseReporter() const {
    return TClusterUseReporter{Scoring_.Scores_.at(ClusterId_)};
}

////////////////////////////////////////////////////////////////////////////////

TScoringAlgorithmOptions<EScoreAlgorithm::PessimizeWithFails>::TScoringAlgorithmOptions(
    const TPessimizeWithFailsAlgoConfig& config
)
    : AfterFailPessimizeTime(TDuration::Parse(config.GetAfterFailPessimizeTime()))
{
}

////////////////////////////////////////////////////////////////////////////////

TClustersScoring::TClustersScoring(TClustersScoringOptions options)
    : Options_(std::move(options))
    , Scores_(Options_.ClusterIds.size())
{
    for (const TClusterId& clusterId : Options_.ClusterIds) {
        Scores_.emplace(clusterId, MakeAtomicShared<TClusterScore>());
    }
}

TVector<TClusterDescriptor> TClustersScoring::Score(
    TVector<TClusterId> clusters,
    const TScoringAlgorithmConfig& algoConfig
) const {
    switch (algoConfig.GetAlgoCase()) {
        case TScoringAlgorithmConfig::kPessimizeWithFails: {
            return Score(
                std::move(clusters),
                TScoringAlgorithmOptions<EScoreAlgorithm::PessimizeWithFails>(algoConfig.GetPessimizeWithFails())
            );
        }
        case TScoringAlgorithmConfig::ALGO_NOT_SET: {
            return CreateDescriptors(std::move(clusters));
        }
    }
}

TVector<TClusterDescriptor> TClustersScoring::CreateDescriptors(
    TVector<TClusterId>&& clusters
) const {
    TVector<TClusterDescriptor> result(Reserve(clusters.size()));
    for (TClusterId& clusterId : clusters) {
        result.emplace_back(std::move(clusterId), *this);
    }
    return result;
}

void TClustersScoring::PessimizeWithFails(
    TVector<TClusterId>& clusters,
    const TScoringAlgorithmOptions<EScoreAlgorithm::PessimizeWithFails>& options
) const {
    struct TClusterCompareData {
        const TScoringAlgorithmOptions<EScoreAlgorithm::PessimizeWithFails>& ScoringOptions;
        ui64 SequentialFails;
        TDuration TimeSinceLastTry;

        ui64 Group() const {
            if (SequentialFails > 0 && TimeSinceLastTry > ScoringOptions.AfterFailPessimizeTime) {
                return 0;
            } else if (SequentialFails == 0) {
                return 1;
            } else {
                return 2;
            }
        }

        auto CompareKey() const {
            return std::make_tuple(Group(), -static_cast<i64>(TimeSinceLastTry.GetValue()), SequentialFails);
        }

        bool operator<(const TClusterCompareData& other) const noexcept {
            return CompareKey() < other.CompareKey();
        }
    };

    THashMap<TClusterId, TClusterCompareData> compareData(clusters.size());
    for (const TClusterId& cluster : clusters) {
        TClusterScorePtr score = Scores_.at(cluster);
        compareData.emplace(
            cluster,
            TClusterCompareData{
                .ScoringOptions = options,
                .SequentialFails = score->SequentialFails.load(),
                .TimeSinceLastTry = score->TimeSinceLastTry()
            });
    }

    SortBy(clusters, [&compareData](const TClusterId& cluster) {
        return compareData.at(cluster);
    });
}

////////////////////////////////////////////////////////////////////////////////

} // namespace NInfra::NYpDnsApi
