#include "searchproxycgi.h"
#include "searchproxyserver.h"

#include <saas/searchproxy/common/cgi.h>
#include <saas/searchproxy/configs/searchproxyconfig.h>
#include <saas/searchproxy/experiment/context.h>

#include <search/session/reqenv.h>
#include <search/request/data/reqdata.h>

#include <kernel/reqid/reqid.h>

#include <library/cpp/lua/eval.h>

#include <util/digest/city.h>
#include <util/string/hex.h>
#include <util/random/random.h>

namespace {
    const TString ReqIdClass = "SAAS";

    void FormExtraCgi(const TServiceConfig::TExtraCgi& extraCgi, TCgiParameters& cgi) {
        for (auto&& parameter : extraCgi.Get()) {
            switch (parameter.Policy) {
            case TServiceConfig::TExtraCgi::EPolicy::AddIfEmpty:
                if (cgi.Has(parameter.Name)) {
                    break;
                } else {
                    // fallthrough
                }
            case TServiceConfig::TExtraCgi::EPolicy::Add:
                for (auto&& value : parameter.Values) {
                    cgi.InsertUnescaped(parameter.Name, value);
                }
                break;
            case TServiceConfig::TExtraCgi::EPolicy::Calculate:
            case TServiceConfig::TExtraCgi::EPolicy::CalculateAdditional:
            case TServiceConfig::TExtraCgi::EPolicy::CalculateIfEmpty:
                try {
                    if (TServiceConfig::TExtraCgi::EPolicy::CalculateIfEmpty == parameter.Policy) {
                        if (cgi.Has(parameter.Name)) {
                            break;
                        }
                    }
                    TLuaEval lua;
                    lua.SetVars(cgi);

                    TVector<TString> values;
                    values.reserve(parameter.Values.size());
                    for (auto&& value : parameter.Values) {
                        if (TString calculatedValue = lua.EvalExpression(value)) {
                            values.emplace_back(std::move(calculatedValue));
                        }
                    }

                    if (!values.empty()) {
                        if (parameter.Policy != TServiceConfig::TExtraCgi::EPolicy::CalculateAdditional) {
                            cgi.ReplaceUnescaped(parameter.Name, values.cbegin(), values.cend());
                        } else {
                            for (auto&& value : values) {
                                cgi.InsertUnescaped(parameter.Name, value);
                            }
                        }
                    }
                } catch (const yexception& e) {
                    ERROR_LOG << "cannot calculate Lua: " << e.what() << Endl;
                }
                break;
            case TServiceConfig::TExtraCgi::EPolicy::Remove:
                for (auto&& value : parameter.Values) {
                    auto range = cgi.equal_range(parameter.Name);
                    for (auto i = range.first; i != range.second;) {
                        auto tested = i++;
                        if (tested->second == value) {
                            cgi.erase(tested);
                        }
                    }
                }
                break;
            case TServiceConfig::TExtraCgi::EPolicy::Replace:
                cgi.EraseAll(parameter.Name);
                for (auto&& value : parameter.Values) {
                    if (value != "__remove") {
                        cgi.InsertUnescaped(parameter.Name, value);
                    }
                }
                break;
            default:
                FAIL_LOG("unknown policy");
            }
        }
    }
}

void TBaseCgiCorrector::FormContextCgi(IRequestContext* /*requestContext*/, ISearchContext** searchContext) {
    TReqEnv* re = static_cast<TReqEnv*>((**searchContext).ReqEnv());
    FormCgi(re->CgiParam, re->RequestData);
}

TBaseCgiCorrector::TBaseCgiCorrector(const TServiceConfig& config)
    : Config(config)
    , UsingWizard(config.GetMetaSearchConfig() ? !!config.GetMetaSearchConfig()->WizardConfig : false)
{
    if (UsingWizard) {
        TString cleanedQueryLanguage(Config.GetMetaSearchConfig()->QueryLanguage);
        RemoveAll(cleanedQueryLanguage, ' ');
        ServiceSpecificCgi.insert(TMultiMap<TString, TString>::value_type("restrict_config", cleanedQueryLanguage));
    }

    for (const auto& ruleName: config.GetCgiCorrectorRules()) {
        ICorrectorRule::TPtr rule = ICorrectorRule::TFactory::Construct(ruleName);
        if (!!rule.Get()) {
            Rules.push_back(rule);
        } else {
            ERROR_LOG << "Unknown global cgi corrector rule " << ruleName << Endl;
        }
    }
}

void TBaseCgiCorrector::FormServiceRelated(TCgiParameters& cgi) {
    if (!!Service && UsingWizard) {
        cgi.InsertUnescaped("wizclient", "saas-" + Service);
    }

    for (TMultiMap<TString, TString>::const_iterator i = ServiceSpecificCgi.begin();
        i != ServiceSpecificCgi.end(); ++i)
    {
        if (!cgi.Get(i->first))
            cgi.InsertUnescaped(i->first.data(), i->second.data());
    }

    FormExtraCgi(Config.GetExtraCgi(), cgi);
}

void TBaseCgiCorrector::FormRequestId(TCgiParameters& cgi, const TBaseServerRequestData* rd) {
    if (!cgi.Has(NSearchProxyCgi::queryid)) {
        TString reqId;
        const TString* externalReqId = rd ? rd->HeaderIn("X-Req-Id") : nullptr;
        if (externalReqId) {
            reqId = *externalReqId + "-" + ReqIdHostSuffix() + "-" + ReqIdClass;
        } else {
            reqId = ReqIdGenerate(ReqIdClass.data());
        }
        reqId += "-" + Service;
        cgi.InsertUnescaped(NSearchProxyCgi::queryid, reqId);
    }

    if (!cgi.Has(NSearchProxyCgi::raid)) {
        ui64 hash = 0;
        for (auto&& p : NSearchProxyCgi::HashParams) {
            hash ^= CityHash64(cgi.Get(p));
        }
        cgi.InsertUnescaped(NSearchProxyCgi::raid, HexEncode(&hash, sizeof(hash)));
    }

    const TString& uuid = cgi.Get(NSearchProxyCgi::uid); // always goes unmodified, has priority
    TString yandexuid = cgi.Get(NSearchProxyCgi::yandexuid);
    if (yandexuid && !yandexuid.StartsWith('y')) { // check for the 'y' prefix as requested by LOGSTAT
        yandexuid = "y" + yandexuid;
    }

    TString finalId;
    if (!finalId) {
        finalId = uuid;
    }
    if (!finalId) {
        finalId = yandexuid;
    }
    if (!finalId) {
        finalId = "saas_fake" + ToString(RandomNumber<ui64>());
        cgi.InsertEscaped(NSearchProxyCgi::fake_uid, ToString(true));
    }
    cgi.ReplaceUnescaped(NSearchProxyCgi::uid, finalId);
    cgi.ReplaceUnescaped(NSearchProxyCgi::yandexuid, finalId);
}

void TBaseCgiCorrector::ApplyCorrectorRules(TCgiParameters& cgi) const {
    for (const auto& rule : Rules) {
        rule->Apply(cgi);
    }
}

void TBaseCgiCorrector::FormMetaServiceRelated(TCgiParameters& cgi) {
    if (!cgi.Has(NSearchProxyCgi::metaservice)) {
        cgi.InsertUnescaped(NSearchProxyCgi::metaservice, Service);
    }
}

void TBaseCgiCorrector::SanitizeRequest(TCgiParameters& cgi) {
    // remove 'format' field from the request
    if (cgi.Has("format")) {
        const TString format = cgi.Get("format");
        cgi.EraseAll("format");
        if (!cgi.Has("report_format"))
            cgi.InsertUnescaped("report_format", format);
    }

    // support standard suggest cgi - SAAS-2297
    if (cgi.Has(NSearchProxyCgi::suggest_part) && !cgi.Has(NSearchProxyCgi::text)) {
        const TString& part = cgi.Get(NSearchProxyCgi::suggest_part);
        cgi.InsertUnescaped(NSearchProxyCgi::text, part);
    }

    cgi.EraseAll("tvm_auth_status");
}

void TBaseCgiCorrector::GetAllDocInfos(TCgiParameters& cgi) {
    if (!cgi.Has(NSearchProxyCgi::gta))
        cgi.InsertUnescaped(NSearchProxyCgi::gta, DP_ALLDOCINFOS);
}

void TBaseCgiCorrector::FormKeyPrefixField(TCgiParameters& cgi) {
    if (cgi.Has(NSearchProxyCgi::kps))
        return;

    const TString& complexKps = Service + "_" + NSearchProxyCgi::kps;
    TCgiParameters::const_iterator i = cgi.Find(complexKps);
    if (i != cgi.end()) {
        const TString& value = i->second;
        cgi.InsertUnescaped(NSearchProxyCgi::kps, value);
        cgi.EraseAll(complexKps);
    }
}

void TServiceSearchCgiCorrector::FormCgi(TCgiParameters& cgi, const TBaseServerRequestData* rd) {
    ApplyCorrectorRules(cgi);
    GetAllDocInfos(cgi);
    FormKeyPrefixField(cgi);
    FormServiceRelated(cgi);
    FormRequestId(cgi, rd);
    SanitizeRequest(cgi);
    EraseAuthParams(cgi);
}

void TServiceSearchCgiCorrector::EraseAuthParams(TCgiParameters& cgi) {
    cgi.EraseAll(NSearchProxyCgi::ya_service_ticket);
    cgi.EraseAll(NSearchProxyCgi::ya_user_ticket);
    cgi.EraseAll(NSearchProxyCgi::force_tvm_auth);
}

void TServiceSearchCgiCorrector::SetClientCgi(const TCgiParameters& requestCgi, TCgiParameters& clientCgi, const TRequestParams& /*rp*/) {
    auto kps = requestCgi.Find(NSearchProxyCgi::kps);
    if (kps != requestCgi.end()) {
        clientCgi.InsertUnescaped(NSearchProxyCgi::kps, kps->second);
    }
}

void TServiceGlobalCgiCorrector::FormCgi(TCgiParameters& cgi, const TBaseServerRequestData* /*rd*/) {
    FormExtraCgi(Config.GetGlobalExtraCgi(), cgi);
}

void TUpperSearchCgiCorrector::FormCgi(TCgiParameters& cgi, const TBaseServerRequestData* rd) {
    FormRequestId(cgi, rd);
    SanitizeRequest(cgi);
}

void TMetaSearchCgiCorrector::FormCgi(TCgiParameters& cgi, const TBaseServerRequestData* rd) {
    ApplyCorrectorRules(cgi);
    FormServiceRelated(cgi);
    FormMetaServiceRelated(cgi);
    FormRequestId(cgi, rd);
    SanitizeRequest(cgi);
}

void TMetaSearchCgiCorrector::SetClientCgi(const TCgiParameters& requestCgi, TCgiParameters& clientCgi, const TRequestParams& rp) {
    clientCgi.EraseAll(NSearchProxyCgi::service);
    if (!rp.IsRequestTreeExternal()) {
        clientCgi.EraseAll("qtree");
    }
    for (const auto& parameter : requestCgi) {
        const TString& name = parameter.first;
        if (name.find(NSearchProxyCgi::kps) != TString::npos)
            clientCgi.InsertUnescaped(name, parameter.second);
    }
    ProxyAuthParams(requestCgi, clientCgi);
}

void TMetaSearchCgiCorrector::ProxyAuthParams(const TCgiParameters& requestCgi, TCgiParameters& clientCgi) {
    if (const TString ticket = requestCgi.Get(NSearchProxyCgi::ya_service_ticket)) {
        clientCgi.InsertUnescaped(NSearchProxyCgi::ya_service_ticket, ticket);
    }
    if (const TString ticket = requestCgi.Get(NSearchProxyCgi::ya_user_ticket)) {
        clientCgi.InsertUnescaped(NSearchProxyCgi::ya_user_ticket, ticket);
    }
    if (NSearchProxy::TTvmTraits::IsAuthForced(requestCgi)) {
        clientCgi.InsertUnescaped(NSearchProxyCgi::force_tvm_auth, "1");
    }
}
