#include "ops_common.h"
#include "math_util.h"

#include <saas/tools/rty_ops/lib/args.pb.h>

#include <kernel/urlid/doc_handle.h>

#include <library/cpp/http/io/stream.h>

#include <search/idl/meta.pb.h>

#include <util/generic/hash.h>
#include <util/network/socket.h>
#include <util/stream/file.h>

using namespace NSaas;

namespace NRTYServer {

//
// TQueryOpBase
//
ui32 TQueryOpBase::DoQuery(const TString& query, TString* resultBuf) {
    return THttpOpHelper::DoQuery(*EndPoint, OpCtx.Host, query, resultBuf);
};

void TQueryOpBase::Init(IRtyOpHost& host) {
    host.InitContext(OpCtx);
    EndPoint = MakeHolder<TNetworkAddress>(OpCtx.Host, OpCtx.Port);
}

void TQueryOpBase::LoadPlan(TPlan& plan, const TRtyOpsConfig& config, ui32 maxItemsToRead) {
    plan.clear();

    TFsPath plainQueries = config.Queries;
    plainQueries.CheckExists();
    TFileInput fi(plainQueries);
    TDolbilkaPlanReader rd;
    rd.Mode = TDolbilkaPlanReader::EPlanFormat::TsvWithIds;
    rd.MaxItemsToRead = maxItemsToRead;
    rd.ReadAll(fi, plan);
}

void TQueryOpBase::PatchQueries(TVector<TDolbilkaPlan::TQueryData>& plan, const TString& extraCgi) {
    for (TDolbilkaPlan::TQueryData& req : plan) {
        TString& query = req.Query;
        if (query.find("&ms=proto") == TString::npos) {
            query = query + "&ms=proto";
        }
        if (query.find("&pc=") != TString::npos) {
            query = query + "&pc=none";
        }
        if (extraCgi) {
            query = query + extraCgi;
        }
    }
}

void TQueryOpBase::DebugRetryAsText(const TDolbilkaPlan::TQueryData& req) {
    WARNING_LOG << "Retrying reqid=" << req.RequestId << Endl;
    TString hrQuery = req.Query + "&ms=proto&hr=da&timeout=10000000&waitall=da";
    TString hrResult;
    const ui32 httpCode = DoQuery(hrQuery, &hrResult);
    if (httpCode != 200) {
        ERROR_LOG << "Retrying failed, httpCode=" << httpCode << Endl;
    }

    Cout << hrResult << Endl;

    WARNING_LOG << "Retrying reqid=" << req.RequestId << " done" << Endl;
}

void TQueryOpBase::WaitForSearcherStart() {
    ui32 delay = 1;
    ui32 prevDelay = 1;

    for (size_t attemptsLeft = 10; attemptsLeft; --attemptsLeft) {
        if (THttpOpHelper::SendCommand(*EndPoint, OpCtx.Host, "/status")) {
            return;
        }

        DEBUG_LOG << "The server is not ready yet, will retry in " << delay << " seconds" << Endl;

        Sleep(TDuration::Seconds(delay));
        ui32 newDelay = delay + prevDelay;
        prevDelay = delay;
        delay = newDelay;
    }

    Y_ENSURE(false, "WaitForSearcherStart timed out");
}

void TQueryOpBase::Prepare() {
    WaitForSearcherStart();
}

//
// TSimplePerfOp
//
class TSimplePerfOp: public TQueryOpBase {
protected:
    NRTYServer::TParamsOpPerf Params;

public:
    struct TSimpleStat {
        bool Valid;
        ui64 ExecUs;
        ui32 NDocs;
    };

    using TBase = TQueryOpBase;

public:
    void Init(IRtyOpHost& host) override {
        TBase::Init(host);
        Params = GetExtraParams<NRTYServer::TParamsOpPerf>(host.GetConfig().VirtualArgs);
        Y_ENSURE(!Params.HasExtraCgi() || Params.GetExtraCgi().StartsWith("&"));
    }

    void Run(IRtyOpHost& host) override {
        TVector<TDolbilkaPlan::TQueryData> plan;
        LoadPlan(plan, host.GetConfig(), Params.GetMaxQueries());
        PatchQueries(plan, Params.GetExtraCgi());

        ui32 nRuns = Params.HasRetries() ? Params.GetRetries() : 1;

        for (const auto& req : plan) {
            for (ui32 nRun = 0; nRun < nRuns; ++nRun) {
                TString result;
                const ui32 httpCode = DoQuery(req.Query, &result);
                if (httpCode > 299) {
                    break;
                }
                if (httpCode != 200) {
                    continue;
                }

                NMetaProtocol::TReport report;

                const bool parsed = report.ParseFromString(result) && !report.HasCompressedReport();
                if (!parsed) {
                    DebugRetryAsText(req);
                }
                Y_ENSURE(parsed, "failed to parse answer for reqid=" << req.RequestId);

                PrintStats(req.RequestId, httpCode, report);
            }
        }
    }

    static TSimpleStat ResultToStats(ui32 httpCode, const NMetaProtocol::TReport& parsedReport) {
        Y_ENSURE(!parsedReport.HasCompressedReport(), "an uncompressed report expected");
        TSimpleStat res;

        using namespace NMetaProtocol;

        const TBalancingInfo& bal = parsedReport.HasBalancingInfo() ? parsedReport.GetBalancingInfo() : Default<TBalancingInfo>();
        res.ExecUs = bal.GetElapsed() - bal.GetWaitInQueue();

        const TDebugInfo& di = parsedReport.HasDebugInfo() ? parsedReport.GetDebugInfo() : Default<TDebugInfo>();
        const bool answerIsComplete = httpCode == 200 && (!di.HasAnswerIsComplete() || di.GetAnswerIsComplete());

        res.Valid = answerIsComplete && parsedReport.TotalDocCountSize();

        res.NDocs = res.Valid ? parsedReport.GetTotalDocCount(0) : 0;
        return res;
    }
    void PrintStats(NSaas::TRequestId reqid, ui32 httpCode, const NMetaProtocol::TReport& parsedReport) {
        const TSimpleStat s = ResultToStats(httpCode, parsedReport);
        Cout << reqid << "\t" << httpCode << "\t" << s.ExecUs << "\t" << s.NDocs << Endl;
    }

protected:
private:
    static IRtyOp::TFactory::TRegistrator<TSimplePerfOp> Registrator;
};

IRtyOp::TFactory::TRegistrator<TSimplePerfOp> TSimplePerfOp::Registrator("SimplePerf");

//
// @brief TSdPerfOp  Measure CPU time by sending the request several times
//
class TSdPerfOp final: public TSimplePerfOp {
public:
    using TBase = TSimplePerfOp;

    static constexpr ui32 MaxRuns = 32;

    struct TAvgStat {
    public:
        ui32 Count = 0;
        float Average = 0;
        float Sigma = 0;

    public:
        template <typename TIter>
        static TIter Filter(TIter begin, TIter end) {
            if (!(begin < end))
                return end;

            Sort(begin, end);
            size_t sz = end - begin;
            TIter pMax = begin + sz - 1;
            const float v0 = *begin;
            const float v1 = *end;

            const float threshold = Min(v0 + 15000, v1 - (v1 - v0) * 0.7f);
            size_t maxRemoved = Min(sz - 3, size_t(0.66 * sz));
            while (maxRemoved && *pMax > threshold) {
                maxRemoved--;
                pMax--;
            }

            return ++pMax;
        }

        template <typename TIter>
        static float CalcAverage(TIter begin, TIter end) {
            Y_ASSERT(end > begin);
            const size_t sz = end - begin;
            float res = 0;
            if (!sz) {
                return res;
            }
            for (; begin < end; ++begin) {
                res += *begin;
            }
            res /= sz;
            return res;
        }

        template <typename TIter>
        static float CalcSigma(TIter begin, TIter end, float avg) {
            Y_ASSERT(end > begin);
            const size_t sz = end - begin;
            float res = 0;
            if (sz < 2) {
                return res;
            }
            for (; begin < end; ++begin) {
                const float vl = *begin - avg;
                res += vl * vl;
            }
            res /= sz - 1;
            if (res > 0) {
                res = sqrtf(res);
            } else {
                res = 0;
            }
            return res;
        }
    };

    struct TStat {
        bool IsValid = true;
        ui32 Count = 0;
        ui32 LastNDocs = Max<ui32>();
        float Values[MaxRuns] = {0};

        Y_FORCE_INLINE void Push(float vl, ui32 nDocs) {
            if (IsValid) {
                if (LastNDocs == Max<ui32>()) {
                    LastNDocs = nDocs;
                } else if (Y_UNLIKELY(LastNDocs != nDocs)) {
                    IsValid = false;
                    return;
                }

                Y_ASSERT(Count < MaxRuns);
                Values[Count++] = vl;
            }
        }
    };

    using TStats = THashMap<TRequestId, TStat>;

public:
    void Init(IRtyOpHost& host) override {
        TBase::Init(host);
    }

    void Run(IRtyOpHost& host) override {
        TPlan plan;
        LoadPlan(plan, host.GetConfig(), Params.GetMaxQueries());
        PatchQueries(plan, Params.GetExtraCgi());

        TStats data;
        data.reserve(plan.size());

        ui32 nRuns = Params.HasRetries() ? Params.GetRetries() : 12;
        nRuns = Min(nRuns, MaxRuns);

        for (ui32 nRun = 0; nRun < nRuns; ++nRun) {
            for (const auto& req : plan) {
                TString result;
                const ui32 httpCode = DoQuery(req.Query, &result);

                NMetaProtocol::TReport report;
                const bool parsed = report.ParseFromString(result) && !report.HasCompressedReport();
                Y_ENSURE(parsed, "failed to parse answer for reqid=" << req.RequestId);

                const TSimpleStat s = ResultToStats(httpCode, report);
                Cout << "Req\t" << req.RequestId << "\t" << httpCode << "\t" << s.ExecUs << "\t" << s.NDocs << Endl;

                if (s.Valid) {
                    TStat& item = data[req.RequestId];
                    item.Push(s.ExecUs, s.NDocs);
                }
            }
        }

        Cout << Endl;

        ui64 totalCnt = 0;
        float totalSum = 0;
        float totalSd = 0;

        constexpr float extraSdFromWallClock = 1e6; //assume +-1ms error on each measurement when making totals

        for (const auto& req : plan) {
            TStat& item = data[req.RequestId];
            auto nDocs = item.LastNDocs;
            TAvgStat avg;
            if (item.IsValid) {
                float* begin = item.Values;
                float* end = begin + item.Count;
                end = TAvgStat::Filter(begin, end); // sorts item.Values
                avg.Count = end - begin;
                avg.Average = TAvgStat::CalcAverage(begin, end);
                avg.Sigma = TAvgStat::CalcSigma(begin, end, avg.Average);
            }

            auto scaleDivision = TCoarser::GetScaleDivision((i64)avg.Average, (i64)(3*avg.Sigma), 1000u);

            Cout << "Query\t" << req.RequestId
                 << "\t" << avg.Count
                 << "\t" << nDocs
                 << "\t" << i64(avg.Average - 3 * avg.Sigma)
                 << "\t" << i64(avg.Average + 3 * avg.Sigma)
                 << "\t" << scaleDivision.first << "(" << scaleDivision.second << ")"
                 << Endl;

            totalCnt += 1;
            totalSum += avg.Average;
            totalSd += avg.Sigma * avg.Sigma + extraSdFromWallClock;
        }

        if (totalSd > 0) {
            totalSd = sqrtf(totalSd);
        }

        Cout << "TOTALS"
             << "\t" << totalCnt
             << "\t" << (ui64)totalSum
             << "\t" << (ui64)totalSd
             << Endl;
    }

private:
    static IRtyOp::TFactory::TRegistrator<TSdPerfOp> Registrator;
};

IRtyOp::TFactory::TRegistrator<TSdPerfOp> TSdPerfOp::Registrator("SdPerf");

//
// @brief TDocidsPool  Dump search results to control zerodiffs (factors are not dumped)
//
class TDocidsPoolOp final: public TSimplePerfOp {
public:
    using TBase = TSimplePerfOp;

public:
    void Init(IRtyOpHost& host) override {
        TBase::Init(host);
    }

    void Run(IRtyOpHost& host) override {
        TPlan plan;
        LoadPlan(plan, host.GetConfig(), Params.GetMaxQueries());
        PatchQueries(plan, Params.GetExtraCgi());

        for (const auto& req : plan) {
            TString result;
            const ui32 httpCode = DoQuery(req.Query, &result);

            Y_ENSURE(httpCode == 200, "Request " << req.RequestId << " failed, httpCode=" << httpCode);

            NMetaProtocol::TReport report;
            const bool parsed = report.ParseFromString(result) && !report.HasCompressedReport();
            Y_ENSURE(parsed, "failed to parse answer for reqid=" << req.RequestId);

            for (const auto& grouping : report.get_arr_grouping()) {
                for (const auto& group : grouping.get_arr_group()) {
                    for (const auto& doc : group.get_arr_document()) {
                        TString docId;
                        if (doc.HasDocId()) {
                            docId = doc.GetDocId();
                        } else {
                            const TDocHandle dh(doc.GetDocHash(), doc.GetRoute());
                            docId = dh.ToString();
                        }
                        const i64 SRelevance = doc.GetSRelevance();

                        Cout << "ReqAns\t" << req.RequestId << "\t" << docId << "\t" << SRelevance << Endl;
                    }
                }
            }
        }
    }

private:
    static IRtyOp::TFactory::TRegistrator<TDocidsPoolOp> Registrator;
};

IRtyOp::TFactory::TRegistrator<TDocidsPoolOp> TDocidsPoolOp::Registrator("DocidsPool");

//
// @brief TSimpleQueryaskOp   Base class for QueryAsk-like activities
//
class TSimpleQueryaskOp final: public TQueryOpBase {
public:
    using TBase = TQueryOpBase;

protected:
    NRTYServer::TParamsOpViewer Params;

public:
    void Init(IRtyOpHost& host) override {
        TBase::Init(host);
        Params = GetExtraParams<NRTYServer::TParamsOpViewer>(host.GetConfig().VirtualArgs);
    }

    void Run(IRtyOpHost& host) override {
        TPlan plan0;
        LoadPlan(plan0, host.GetConfig(), 0);

        TRequestId queryId = TRequestId();
        if (Params.HasQueryId()) {
            queryId = Params.GetQueryId();
        } else if (!plan0.empty()) {
            queryId = plan0.front().RequestId;
        }

        auto iter = std::find_if(plan0.begin(), plan0.end(), [queryId](const TDolbilkaPlan::TQueryData& q) {
            return q.RequestId == queryId;
        });

        Y_ENSURE(iter != plan0.end(), "RequestId not found");

        TPlan plan;
        plan.push_back(*iter);
        PatchQueries(plan, Params.GetExtraCgi());
        Y_ASSERT(plan.size() == 1);

        const auto& req = plan.front();
        TString hrQuery = req.Query + "&ms=proto&hr=da&timeout=10000000&waitall=da";
        TString hrResult;

        INFO_LOG << "DoQuery: " << hrQuery << Endl;

        const ui32 httpCode = DoQuery(hrQuery, &hrResult);
        if (httpCode != 200) {
            ERROR_LOG << "Request failed, httpCode=" << httpCode << Endl;
        }

        Cout << Endl;
        Cout << hrResult << Endl;
    }

private:
    static IRtyOp::TFactory::TRegistrator<TSimpleQueryaskOp> Registrator;
};

IRtyOp::TFactory::TRegistrator<TSimpleQueryaskOp> TSimpleQueryaskOp::Registrator("Queryask");

//
// THrDump: Queryask performed multiple times in a row
//
class THrDumpOp: public TQueryOpBase {
protected:
    NRTYServer::TParamsOpPerf Params;

public:
    using TBase = TQueryOpBase;

public:
    void Init(IRtyOpHost& host) override {
        TBase::Init(host);
        Params = GetExtraParams<NRTYServer::TParamsOpPerf>(host.GetConfig().VirtualArgs);
        Y_ENSURE(!Params.HasExtraCgi() || Params.GetExtraCgi().StartsWith("&"));
    }

    void Run(IRtyOpHost& host) override {
        TVector<TDolbilkaPlan::TQueryData> plan;
        LoadPlan(plan, host.GetConfig(), Params.GetMaxQueries());
        PatchQueries(plan, Params.GetExtraCgi());

        for (const auto& req : plan) {
            TString hrQuery = req.Query + "&ms=proto&hr=da&timeout=10000000&waitall=da";
            TString hrResult;

            INFO_LOG << "DoQuery: " << hrQuery << Endl;

            const ui32 httpCode = DoQuery(hrQuery, &hrResult);
            if (httpCode > 299) {
                break;
            }
            if (httpCode != 200) {
                Cout << req.RequestId << "\tHTTP error: " << httpCode << Endl;
                continue;
            }

            Cout << Endl;
            Cout << req.RequestId << " HrDump:" << Endl;
            Cout << hrResult << Endl;
        }
    }

private:
    static IRtyOp::TFactory::TRegistrator<THrDumpOp> Registrator;
};

IRtyOp::TFactory::TRegistrator<THrDumpOp> THrDumpOp::Registrator("HrDump");

}

namespace NRTYOpsApp {
    void RegisterModulePerf() {
        // this call ensures that TRegistrators in the classes here are called
    }
}
