#include "calc_taxi_company_stats.h"

#include <drive/backend/cars/car.h>
#include <drive/backend/cars/car_model.h>
#include <drive/backend/data/leasing/company.h>
#include <drive/backend/data/leasing/leasing.h>
#include <drive/backend/data/scoring/scoring.h>
#include <drive/backend/tags/history.h>
#include <drive/backend/tags/tags_manager.h>

#include <rtline/library/storage/sql/transaction.h>

namespace NDrivematics {
    using TReportStats = ui32;
    enum EReportStats : ui32 {
        Medians = 1 << 0,
        ByObject = 1 << 1,
        Distribution = 1 << 2,
        CurrentAverage = 1 << 3
    };
    static constexpr TReportStats ReportAllStats = Max<TReportStats>();

    static constexpr size_t DefaultScoringDistributionBinsAmount = 11;
    template <typename TEntityScoringTag>
    class TScoringManager {
    private:
        using TScoresByObject = TVector<std::pair<TString, TVector<double>>>;
        using TObjects = TSet<TString>;
        using TScoresByTimestamp = TMap<std::pair<TInstant, TString>, double>;

    public:
        TScoringManager(
            EHistoryScoreType type,
            size_t scoringLimit = DefaultScoringDistributionBinsAmount
        ) : HistoryScoreType(type), ScoringLimit(scoringLimit) {
        }

        bool ScoringProcess(TJsonReport::TGuard& g, const NSQL::TStringContainer& tagNames, const TRange<TInstant> statsRange, TMaybe<ui32> limit, TReportStats reportTraits, const NDrive::IServer& server, NDrive::TEntitySession& tx) {
            const auto driveApi = Yensured(server.GetDriveAPI());
            auto& tagsManager = driveApi->GetEntityTagsManager(Entity);

            auto optionalCompanyDBTags = tagsManager.RestoreTagsRobust({}, tagNames, tx);
            R_ENSURE(optionalCompanyDBTags, {}, "can't restore tags", tx);
            const auto& companyDBTags = *optionalCompanyDBTags;

            TObjects objectIds;
            {
                for (const auto& dbTag : companyDBTags) {
                    objectIds.insert(dbTag.GetObjectId());
                }
            }

            TScoresByTimestamp timestampToRank;
            {
                auto queryOptions = TTagEventsManager::TQueryOptions{};
                queryOptions.SetObjectIds(std::cref(objectIds));
                queryOptions.SetTags({ TEntityScoringTag::TypeName });
                auto events = tagsManager.GetEvents({}, statsRange, tx, queryOptions);
                R_ENSURE(events, {}, "can't restore events", tx);

                for (const auto& ev : *events) {
                    auto tag = ev.GetTagAs<TEntityScoringTag>();
                    R_ENSURE(tag, HTTP_INTERNAL_SERVER_ERROR, "can't cast tag as scoring tag: " << ev.GetTagId(),  tx);
                    if (tag->HasRank()) {
                        timestampToRank[std::make_pair(tag->GetTimestamp(), ev.GetObjectId())] = tag->GetRankRef();
                    }
                }
            }

            if (reportTraits & EReportStats::Medians) {
                g.MutableReport().AddReportElement("median_ranks",
                    GetMedianReport(timestampToRank, limit)
                );
            }

            if (!(reportTraits & ~EReportStats::Medians)) {
                return true;
            }

            TScoresByObject historyScores;
            {
                TMap<TString, TVector<double>> scoresPerObject;
                for (auto&& [key, score] : timestampToRank) {
                    auto&& [_, carId] = key;
                    scoresPerObject[carId].push_back(score);
                }
                for (auto& [_, scores] : scoresPerObject) {
                    std::sort(scores.begin(), scores.end());
                }
                for (auto&& [objectId, scores] : scoresPerObject) {
                    historyScores.push_back({objectId, std::move(scores)});
                }

                auto cmp = [this](const auto& first, const auto& second) -> bool {
                    return
                        GetHistoryScore(first.second)
                        < GetHistoryScore(second.second);
                };
                std::sort(historyScores.begin(), historyScores.end(), cmp);
            }

            if (reportTraits & EReportStats::CurrentAverage) {
                g.MutableReport().AddReportElement("today_score",
                    GetCurrentScoringReport(objectIds, server, tx)
                );
            }
            if (reportTraits & EReportStats::Distribution) {
                g.MutableReport().AddReportElement("distribution",
                    GetDistributionReport(historyScores)
                );
            }
            if (reportTraits & EReportStats::ByObject) {
                g.MutableReport().AddReportElement("by_object",
                    GetPerObjectMediansReport<Entity>(objectIds, historyScores, server, tx)
                );
            }

            return true;
        }

        bool ScoringProcess(TJsonReport::TGuard& g, const NSQL::TStringContainer& tagNames, ui32 limit, TReportStats reportTraits, const NDrive::IServer& server, NDrive::TEntitySession& tx) {
            const auto now = Now();
            TRange<TInstant> statsRange(now - TDuration::Days(limit + 1), now);
            return ScoringProcess(g, tagNames, statsRange, limit, reportTraits, server, tx);
        }

        bool ScoringProcess(TJsonReport::TGuard& g, const NSQL::TStringContainer& tagNames, const TRange<TInstant> statsRange, TReportStats reportTraits, const NDrive::IServer& server, NDrive::TEntitySession& tx) {
            return ScoringProcess(g, tagNames, statsRange, {}, reportTraits, server, tx);
        }

        static TVector<TString> GetUserOrganizationAffiliationTagNames(const TString& taxiCompanyTagName, const NDrive::IServer& server) {
            auto tagDescriptionsPtr = server.GetDriveAPI()->GetTagsManager().GetTagsMeta().GetRegisteredTags(NEntityTagsManager::EEntityType::User, { NDrivematics::TUserOrganizationAffiliationTag::TypeName });
            TVector<TString> result;
            ForEach(tagDescriptionsPtr.begin(), tagDescriptionsPtr.end(), [&result, &taxiCompanyTagName](const auto& elem) {
                auto tagDesc = std::dynamic_pointer_cast<const NDrivematics::TUserOrganizationAffiliationTag::TDescription>(elem.second);
                if (tagDesc && tagDesc->GetOwningCarTagName() == taxiCompanyTagName) {
                    result.push_back(tagDesc->GetName());
                }
            });
            return std::move(result);
        }

    private:
        template<EHistoryScoreType ScoreType>
        double GetHistoryScoreImpl(const TVector<double>& scores);

        template<>
        double GetHistoryScoreImpl<EHistoryScoreType::Average>(const TVector<double>& scores) {
            if (!scores) {
                return 0;
            }

            double avgScore = 0;
            for (auto score : scores) {
                avgScore += score;
            }
            return avgScore / scores.size();
        }

        template<>
        double GetHistoryScoreImpl<EHistoryScoreType::Median>(const TVector<double>& scores) {
            if (scores.empty()) {
                return 0;
            }
            return scores[scores.size() / 2];
        }

        double GetHistoryScore(const TVector<double>& scores) {
            switch (HistoryScoreType) {
                case EHistoryScoreType::Average:
                    return GetHistoryScoreImpl<EHistoryScoreType::Average>(scores);
                case EHistoryScoreType::Median:
                    return GetHistoryScoreImpl<EHistoryScoreType::Median>(scores);
            }
        }

        template <typename TTag>
        static constexpr NEntityTagsManager::EEntityType GetEntity() noexcept {
            return NEntityTagsManager::EEntityType::Undefined;
        }

        template <>
        static constexpr NEntityTagsManager::EEntityType GetEntity<TScoringUserTag>() noexcept {
            return NEntityTagsManager::EEntityType::User;
        }

        template <>
        static constexpr NEntityTagsManager::EEntityType GetEntity<TScoringCarTag>() noexcept {
            return NEntityTagsManager::EEntityType::Car;
        }

    private:
        NJson::TJsonValue GetMedianReport(
            const TScoresByTimestamp& timestampToRank,
            TMaybe<ui32>& limit
        ) {
            TMap<TInstant, TVector<double>> timestampToRanks;
            for (const auto& [id, rank] : timestampToRank) {
                const auto& [timestamp, _] = id;
                timestampToRanks[timestamp].push_back(rank);
            }
            if (limit && timestampToRanks.size() > *limit) {
                auto end = timestampToRanks.begin();
                std::advance(end, timestampToRanks.size() - *limit);
                timestampToRanks.erase(timestampToRanks.begin(), end);
            }

            NJson::TJsonValue medianRanks = NJson::JSON_ARRAY;
            for (auto& [timestamp, ranks] : timestampToRanks) {
                std::sort(ranks.begin(), ranks.end());
                medianRanks.AppendValue(NJson::TMapBuilder
                    ("day", timestamp.Days())
                    ("timestamp", timestamp.Seconds())
                    ("rank",  ranks[ranks.size() / 2])
                );
            }

            return std::move(medianRanks);
        }

    private:
        NJson::TJsonValue GetCurrentScoringReport(
            const TObjects& objectIds,
            const NDrive::IServer& server,
            NDrive::TEntitySession& tx
        ) {
            const auto driveApi = Yensured(server.GetDriveAPI());
            const auto& tagsManager = driveApi->GetEntityTagsManager(Entity);

            TVector<TDBTag> scoringTags;
            R_ENSURE(
                tagsManager.RestoreTags(objectIds, {TEntityScoringTag::TypeName}, scoringTags, tx),
                HTTP_INTERNAL_SERVER_ERROR,
                "failed to restore scoring tags",
                tx
            );

            double currentScoringAvg = 0;
            size_t rankAmount = 0;
            for (const auto& dbTag : scoringTags) {
                auto tag = dbTag.template GetTagAs<TEntityScoringTag>();
                R_ENSURE(tag, HTTP_INTERNAL_SERVER_ERROR, "can't cast tag as scoring tag: " << dbTag.GetTagId(),  tx);

                if (tag->HasRank()) {
                    currentScoringAvg += tag->GetRankRef();
                    ++rankAmount;
                }
            }

            NJson::TJsonValue currentScoringAvgReport = NJson::JSON_DOUBLE;
            if (rankAmount == 0) {
                currentScoringAvgReport = 0;
                return std::move(currentScoringAvgReport);
            }

            currentScoringAvgReport = currentScoringAvg / rankAmount;
            return std::move(currentScoringAvgReport);
        }

    private:
        template <NEntityTagsManager::EEntityType Entity>
        NJson::TJsonValue GetPerObjectMediansReport(
            const TObjects& objectIds,
            const TScoresByObject& historyScores,
            const NDrive::IServer& server,
            NDrive::TEntitySession& tx
        );

        template <>
        NJson::TJsonValue GetPerObjectMediansReport<NEntityTagsManager::EEntityType::User>(
            const TObjects& /* objectIds */,
            const TScoresByObject& historyScores,
            const NDrive::IServer& /* server */,
            NDrive::TEntitySession& /* tx */
        ) {
            NJson::TJsonValue perUserMedians = NJson::JSON_ARRAY;
            for (auto&& [userId, scores] : historyScores) {
                if (scores.empty()) {
                    continue;
                }

                NJson::TJsonValue scoringReport;
                scoringReport["object_id"] = userId;
                scoringReport["rank"] = GetHistoryScore(scores);
                perUserMedians.AppendValue(std::move(scoringReport));
            }
            return std::move(perUserMedians);
        }

        template <>
        NJson::TJsonValue GetPerObjectMediansReport<NEntityTagsManager::EEntityType::Car>(
            const TObjects& objectIds,
            const TScoresByObject& historyScores,
            const NDrive::IServer& server,
            NDrive::TEntitySession& tx
        ) {
            const auto driveApi = Yensured(server.GetDriveAPI());
            const auto carsData = driveApi->GetCarsData();

            auto carsInfo = carsData->FetchInfo(objectIds, tx);
            auto modelNameCache = server.GetDriveDatabase().GetModelsDB().GetCached();
            auto modelName = modelNameCache.GetResult();

            NJson::TJsonValue perCarMedians = NJson::JSON_ARRAY;
            for (auto&& [carId, scores] : historyScores) {
                if (scores.empty()) {
                    continue;
                }

                NJson::TJsonValue scoringReport;
                scoringReport["object_id"] = carId;
                scoringReport["rank"] = GetHistoryScore(scores);
                if (auto carInfo = carsInfo.GetResultPtr(carId)) {
                    const auto model = carInfo->GetModel();
                    scoringReport["model_id"] = model;
                    scoringReport["name"] = modelName[model].GetName();
                    scoringReport["number"] = carInfo->GetNumber();
                }
                perCarMedians.AppendValue(std::move(scoringReport));
            }
            return std::move(perCarMedians);
        }

    private:
        NJson::TJsonValue GetDistributionReport(const TScoresByObject& historyScores) {
            TVector<ui64> distribution(ScoringLimit);

            for (auto&& scoreEntry : historyScores) {
                auto&& [_, scores] = scoreEntry;
                if (!scores) {
                    continue;
                }

                const auto score = GetHistoryScore(scores);
                const auto binNum = static_cast<size_t>(score);
                if (binNum >= ScoringLimit || binNum < 0) {
                    WARNING_LOG << "Invalid scoring value " << binNum << ". Skipping" << Endl;
                    continue;
                }
                ++distribution[binNum];
            }

            NJson::TJsonValue report = NJson::JSON_ARRAY;
            for (size_t i = 0; i < ScoringLimit; ++i) {
                report.AppendValue(distribution[i]);
            }

            return std::move(report);
        }

    private:
        static constexpr NEntityTagsManager::EEntityType Entity = GetEntity<TEntityScoringTag>();

    private:
        const EHistoryScoreType HistoryScoreType;
        const size_t ScoringLimit;
    };
}

namespace NDrivematics {
    void TCalcTaxiCompanyStatsProcessor::ProcessServiceRequest(TJsonReport::TGuard& g, TUserPermissions::TPtr permissions, const NJson::TJsonValue& /* requestData */) {
        const auto& cgi = Context->GetCgiParameters();
        auto carId = GetString(cgi, "car_id", false);
        auto userId = GetString(cgi, "user_id", false);
        auto taxiCompanyTagNameExt = GetString(cgi, "taxi_company_tag_name", false);
        auto userTaxiCompanyTagNameExt = GetString(cgi, "user_taxi_company_tag_name", false);
        auto limit = GetValue<ui32>(cgi, "limit", false).GetOrElse(10);

        R_ENSURE(
              !carId.empty()
            + !taxiCompanyTagNameExt.empty()
            + !userId.empty()
            + !userTaxiCompanyTagNameExt.empty() == 1
            , HTTP_INTERNAL_SERVER_ERROR
            , "object params must be only one"
        );

        auto tx = BuildTx<NSQL::ReadOnly>();
        const auto& api = *Yensured(Server->GetDriveAPI());

        const auto CompanyTagName = [&]() -> TString {
            if (taxiCompanyTagNameExt) {
                return taxiCompanyTagNameExt;
            }
            if (userTaxiCompanyTagNameExt) {
                return userTaxiCompanyTagNameExt;
            }
            if (userId) {
                auto optionalTaggedUser = api.GetTagsManager().GetUserTags().RestoreObject(userId, tx);
                R_ENSURE(optionalTaggedUser, {}, "can't restore user tags", tx);
                auto userOrganizationAffiliationDBTag = optionalTaggedUser->GetFirstTagByClass<TUserOrganizationAffiliationTag>();
                R_ENSURE(userOrganizationAffiliationDBTag, HTTP_BAD_REQUEST, "user doesn't have affiliation company tag");
                auto tagDescription = Server->GetDriveAPI()->GetTagsManager().GetTagsMeta().GetDescriptionByName(userOrganizationAffiliationDBTag->GetName());
                auto description = dynamic_cast<const TUserOrganizationAffiliationTag::TDescription*>(tagDescription.Get());
                R_ENSURE(description, HTTP_BAD_REQUEST, "user doesn't have affiliation company tag");
                return description->GetOwningCarTagName();
            }
            auto optionalTaggedCar = api.GetTagsManager().GetDeviceTags().RestoreObject(carId, tx);
            R_ENSURE(optionalTaggedCar, {}, "can't restore car", tx);
            auto taxiCompanyDBTag = optionalTaggedCar->GetFirstTagByClass<TTaxiCompanyTag>();
            R_ENSURE(taxiCompanyDBTag, HTTP_BAD_REQUEST, "car doesn't have taxi company tag");
            return taxiCompanyDBTag->GetName();
        }();

        R_ENSURE(permissions->GetTagNamesByAction(TTagAction::ETagAction::Observe).contains(CompanyTagName), HTTP_FORBIDDEN, "can't observe " << CompanyTagName);
        if (carId || taxiCompanyTagNameExt) {
            TScoringManager<TScoringCarTag>(EHistoryScoreType::Median)
                .ScoringProcess(g, {CompanyTagName}, limit, EReportStats::Medians, *Server, tx);
        }
        if (userId || userTaxiCompanyTagNameExt) {
            TScoringManager<TScoringUserTag> scoringManager(EHistoryScoreType::Median);
            auto tagNames = scoringManager.GetUserOrganizationAffiliationTagNames(CompanyTagName, *Server);
            scoringManager
                .ScoringProcess(g, tagNames, limit, EReportStats::Medians, *Server, tx);
        }
        g.SetCode(HTTP_OK);
    }

    void TFleetScoringStatsProcessor::ProcessServiceRequest(TJsonReport::TGuard& g, TUserPermissions::TPtr permissions, const NJson::TJsonValue& /* requestData */) {
        const auto& cgi = Context->GetCgiParameters();
        auto companyTagNames = MakeSet(GetStrings(cgi, "company_tag_names", false));
        {
            auto observableTagsNames = permissions->GetTagNamesByAction(TTagAction::ETagAction::Observe);
            if (companyTagNames) {
                for (const auto& name : companyTagNames) {
                    R_ENSURE(
                        observableTagsNames.contains(name),
                        HTTP_FORBIDDEN,
                        "can't observe " << name
                    );
                }
            } else {
                const auto& api = *Yensured(Server->GetDriveAPI());
                auto allCompanyTagsNames = api.GetTagsManager().GetTagsMeta().GetRegisteredTagNames({NDrivematics::TTaxiCompanyTag::TypeName});
                for (const auto& name : allCompanyTagsNames) {
                    if (observableTagsNames.contains(name)) {
                        companyTagNames.insert(name);
                    }
                }
            }

            R_ENSURE(
                !companyTagNames.empty(),
                HTTP_BAD_REQUEST,
                "Company tag names are empty. Check your permissions"
            );
        }

        TRange<TInstant> statsRange;
        statsRange.To = GetTimestamp(cgi, "until", false).GetOrElse(TInstant::Now());
        statsRange.From = GetTimestamp(cgi, "since", false).GetOrElse(*statsRange.To - TDuration::Days(10));
        const auto reportTraits = GetValue<TReportStats>(cgi, "report_type", false)
            .GetOrElse(ReportAllStats);
        const auto scoringLimit = GetValue<size_t>(cgi, "scoring_limit", false)
            .GetOrElse(DefaultScoringDistributionBinsAmount);
        const auto historyScoringType = GetValue<NDrivematics::EHistoryScoreType>(cgi, "history_scoring_type", false)
            .GetOrElse(NDrivematics::EHistoryScoreType::Median);

        auto tx = BuildTx<NSQL::ReadOnly>();
        R_ENSURE(
            TScoringManager<TScoringCarTag>(historyScoringType, scoringLimit)
                .ScoringProcess(g, companyTagNames, statsRange, reportTraits, *Yensured(Server), tx),
            HTTP_INTERNAL_SERVER_ERROR,
            "failed to create stats report",
            tx
        );

        g.SetCode(HTTP_OK);
    }
}
