#include "broadcast.h"
#include "searchproxyclient.h"
#include "searchproxyserver.h"

#include <saas/library/authorization/auth.h>
#include <saas/library/searchserver/exception.h>
#include <saas/library/searchserver/replier.h>
#include <saas/library/searchserver/delay.h>
#include <saas/library/searchserver/auth.h>
#include <saas/searchproxy/unistat_signals/factory.h>
#include <saas/searchproxy/unistat_signals/signals.h>
#include <saas/searchproxy/common/cgi.h>
#include <saas/searchproxy/configs/searchproxyconfig.h>
#include <saas/searchproxy/logging/error_log.h>
#include <saas/searchproxy/common/auth.h>
#include <saas/searchproxy/common/nameextractors.h>
#include <saas/searchproxy/tvm/tvm_auth.h>

#include <search/session/compression/report.h>

#include <library/cpp/logger/global/global.h>
#include <library/cpp/http/io/stream.h>
#include <library/cpp/http/misc/httpcodes.h>
#include <library/cpp/json/json_reader.h>
#include <library/cpp/json/json_value.h>
#include <library/cpp/json/json_writer.h>
#include <library/cpp/protobuf/json/proto2json.h>

#include <util/stream/mem.h>
#include <util/generic/guid.h>

using namespace NSearchProxyCgi;

namespace {
    struct TTvmAuth {
        using TResult = NSearchProxy::TTvmAuthResult;

        bool DryRun = false;
        bool Exception = false;
        TMaybe<TResult> Result;

        TString MakeStatusString() const;
        NSaas::TNotAuthorizedInfo MakeNotAuthorizedInfo(const THttpStatusManagerConfig& conf) const;
        void SendUnistatSignalIfNeed(const TString& srv) const;
    };

    TString TTvmAuth::MakeStatusString() const {
        if (Exception) {
            return "exception";
        }
        Y_ASSERT(Result);
        Y_ASSERT(Result->Status != TResult::Initial);
        return ToString(Result->Status) + (DryRun ? ".dryrun" : "");
    }

    NSaas::TNotAuthorizedInfo TTvmAuth::MakeNotAuthorizedInfo(const THttpStatusManagerConfig& conf) const {
        Y_ASSERT(Result);

        NSaas::TNotAuthorizedInfo ret;
        ret.HttpCode = conf.UnauthorizedStatus;

        switch (Result->Status) {
        case TResult::NoServiceTicketFromClient:
            ret.Reason = "No TVM ServiceTicket in user request params";
            break;
        case TResult::NoUserTicketFromClient:
            ret.Reason = "No TVM UserTicket in user request params";
            break;
        case TResult::BadServiceTicket:
            ret.Reason = "Got malformed TVM ServiceTicket from user";
            break;
        case TResult::BadUserTicket:
            ret.Reason = "Got malformed TVM UserTicket from user";
            break;
        case TResult::NotAllowedSrcClientId:
            ret.Reason = TStringBuilder() << "TVM ServiceTicket source service id is not allowed: "
                << ToString(*Result->SrcId);
            ret.HttpCode = conf.PermissionDeniedStatus;
            break;
        case TResult::NotAllowedUid:
            ret.Reason = TStringBuilder() << "TVM UserTicket UID is not allowed: "
                << ToString(*Result->Uid);
            ret.HttpCode = conf.PermissionDeniedStatus;
            break;
        case TResult::Initial:
        case TResult::Authorized:
            Y_FAIL();
        }

        return ret;
    }

    void TTvmAuth::SendUnistatSignalIfNeed(const TString& srv) const {
        if (Exception) {
            TSaasSearchProxySignals::DoUnistatTvmException(srv, DryRun);
            return;
        }
        Y_ASSERT(Result);
        switch (Result->Status) {
        case TResult::Initial:
            Y_FAIL();
            break;
        case TResult::NoServiceTicketFromClient:
        case TResult::NoUserTicketFromClient: // TODO
            TSaasSearchProxySignals::DoUnistatTvmNoTicket(srv, DryRun);
            break;
        case TResult::BadServiceTicket:
        case TResult::BadUserTicket: // TODO
            TSaasSearchProxySignals::DoUnistatTvmBadTicket(srv, DryRun);
            break;
        case TResult::NotAllowedSrcClientId:
        case TResult::NotAllowedUid: // TODO
            TSaasSearchProxySignals::DoUnistatTvmAccessDenied(srv, DryRun);
            break;
        case TResult::Authorized:
            TSaasSearchProxySignals::DoUnistatTvmAccessGranted(srv, DryRun);
            break;
        }
    }

}

TSearchProxyRequestFeatures::TSearchProxyRequestFeatures(TSearchProxyServer* server)
    : Server(server)
    , Config(server->GetConfig())
{}

TString TSearchProxyRequestFeatures::CalculateName(const TCgiParameters& cgi) {
    return TServiceNameExtractor::ExtractServiceName(cgi);
}

void TSearchProxyRequestFeatures::DoHandleException(IReplyContext* context) {
    if (context) {
        TSaasSearchProxySignals::DoUnistatExceptionRecord(Service);
        NSearchProxy::NLogging::TErrorLog::Log(context->GetRequestId(), context->GetRequestData(), CurrentExceptionMessage());
    }
}

ISearchReplier::TPtr TSearchProxyRequestFeatures::DoSelectHandlerImpl(IReplyContext::TPtr context) {
    if (!Service) {
        ythrow TSearchException(HTTP_BAD_REQUEST) << "service is not specified";
    }
    if (!Server->HasService(Service)) {
        ythrow TSearchException(HTTP_BAD_REQUEST) << "service " << Service << " is not found";
    }
    if (!Server->IsGoodService(Service)) {
        ythrow TSearchException(503) << "service " << Service << " has incorrect searchmap";
    }
    return Server->BuildReplier(Service, context);
}

ISearchReplier::TPtr TSearchProxyClient::DoSelectHandler(IReplyContext::TPtr context) try {
    return DoSelectHandlerImpl(context);
} catch (...) {
    DoHandleException(context.Get());
    throw;
}

void TSearchProxyClient::AdjustClientRequest(const TString& service) {
    const TServiceConfig& serviceConfig = Server->GetConfig().GetServiceConfig(service);
    Output().EnableCompression(serviceConfig.UseCompression());
}

void TSearchProxyClient::ProcessBroadcast(const TString& srv, const TString& request) {
    TString service = srv;
    if (!service) {
        service = RD.CgiParam.Get(service);
    }
    if (!service) {
        service = CalculateName(RD.CgiParam);
    }

    ProcessTvmAuth(service);

    TString req = request ? request : ("?" + RD.CgiParam.Print());
    TDuration timeout = RD.CgiParam.Has("timeout")
        ? TDuration::MicroSeconds(FromString(RD.CgiParam.Get("timeout")))
        : TDuration::Seconds(1);

    NSearchProxy::TBroadcastContextPtr ctx = Server->GetBroadcaster().Broadcast(service, req, timeout);
    if (!ctx) {
        throw TSearchException(HTTP_INTERNAL_SERVER_ERROR) << "Broadcast error";
    }

    NJson::TJsonValue result;
    for (auto&& reply : ctx->Replies) {
        result.AppendValue(NSearchProxy::ToJson(reply));
    }

    HttpCodes code = HTTP_OK;
    if (ctx->Empty()) {
        code = HTTP_NOT_FOUND;
    } else if (!ctx->AnswerIsComplete()) {
        code = HTTP_SERVICE_UNAVAILABLE;
    }

    THttpReplyContext::MakeSimpleReplyImpl(Output(), result.GetStringRobust(), code);
}

bool TSearchProxyClient::ProcessSpecialRequest() {
    TSearchRequestDelay::Process(RD, "sp");
    TStringBuf script(RD.ScriptName());
    if (script.StartsWith(NSearchProxyCgi::broadcast)) {
        ProcessBroadcast(TString(script.Tail(NSearchProxyCgi::broadcast.size())));
        return true;
    } else if (script.StartsWith(NSearchProxyCgi::global_ping)) {
        ProcessBroadcast(TString(script.Tail(NSearchProxyCgi::global_ping.size())), "ping");
        return true;
    } else {
        return TSearchClient::ProcessSpecialRequest();
    }
}

void TSearchProxyClient::ProcessAuth() {
    ProcessTvmAuth();
}

bool TSearchProxyClient::ProcessPreAuth() {
    const NSearchMapParser::TSearchMap::TServiceMap::const_iterator i = Config.GetSearchMap().GetServiceMap().find(Service);
    if (i == Config.GetSearchMap().GetServiceMap().end() || !i->second.RequireAuth)
        return true;

    ui64 prefix;
    TString auth = RD.CgiParam.Get(NSearchProxyCgi::auth);
    if (!!auth && TryFromString<ui64>(RD.CgiParam.Get(kps), prefix) &&
        Server->GetAuthorizer().CheckAuthorization(auth, Service, prefix))
    {
        return true;
    } else {
        THttpReplyContext::MakeSimpleReplyImpl(Output(), TStringBuf("Unauthorized"), HTTP_UNAUTHORIZED);
        return false;
    }
}

void TSearchProxyRequestFeatures::ProcessTvmAuth() {
    ProcessTvmAuth(Service);
}

void TSearchProxyRequestFeatures::ProcessTvmAuth(const TString& service) {
    const auto& tvmConfig = Config.GetTvmParams();
    const auto& serviceTvmConfig = Config.GetServiceConfig(service).GetTvmParams();

    auto& rd = GetRD();

    const auto authMode = NSearchProxy::TTvmTraits::IsAuthForced(rd.CgiParam) && serviceTvmConfig.Mode == TServiceConfig::TTvmParams::DryRun
        ? TServiceConfig::TTvmParams::Enabled
        : serviceTvmConfig.Mode;

    TMaybe<TTvmAuth> tvmAuth;
    const auto& tvmClients = Server->GetTvmClients();
    if (tvmClients.Main && authMode != TServiceConfig::TTvmParams::Disabled) {
        tvmAuth.ConstructInPlace();
        tvmAuth->DryRun = authMode == TServiceConfig::TTvmParams::DryRun;
    }

    TString statusString = "-";
    if (tvmAuth) {
        try {
            auto* abcResolver = Server->GetAbcResolver();
            TVector<ui64> allowedUids;
            if (abcResolver) {
                TVector<ui32> allowedAbcGroups = serviceTvmConfig.AllowedUserAbcGroups;
                for (const ui32 g : tvmConfig.AllowedUserAbcGroups) {
                    allowedAbcGroups.push_back(g);
                }
                allowedUids = abcResolver->Resolve(allowedAbcGroups);
            }
            NSearchProxy::TTvmAuthParams params{
                .Clients = tvmClients,
                .Tickets = NSearchProxy::TTvmTraits::GetTickets(rd),
                .AllowedSourceIds = serviceTvmConfig.AllowedSourceTvmIds,
                .TvmProxyId = tvmConfig.TvmProxyTvmId,
                .AllowedUids = {allowedUids.begin(), allowedUids.end()},
            };
            tvmAuth->Result = NSearchProxy::AuthorizeTvm(params);
        } catch (...) {
            tvmAuth->Exception = true;
            if (tvmAuth->DryRun) {
                NSearchProxy::NLogging::TErrorLog::Log(/* TODO request_id = */0, rd, "Auth failed: " + CurrentExceptionMessage());
            } else {
                tvmAuth->SendUnistatSignalIfNeed(Service);
                throw;
            }
        }
        statusString = tvmAuth->MakeStatusString();
        tvmAuth->SendUnistatSignalIfNeed(Service);
    }

    rd.CgiParam.InsertUnescaped("tvm_auth_status", statusString);

    if (tvmAuth && !tvmAuth->DryRun && tvmAuth->Result && !tvmAuth->Result->IsAuthorized()) {
        const auto info = tvmAuth->MakeNotAuthorizedInfo(
            Config.GetServiceConfig(Service).GetHttpStatusManagerConfig()
        );
        ythrow NSaas::TNotAuthorizedException(info);
    }
}

void TSearchProxyRequestFeatures::ParseRequest() {
    TSearchRequestData& RD = GetRD();

    TStringBuf rawScript(RD.ScriptName());
    TStringBuf script = rawScript.After('/');

    if (script == "yandsearch")
        script.Clear();

    FilteredScript = script;

    RD.CgiParam.EraseAll("tvm_auth_status");

    // Simplify access to ticket in code snippets where we have only cgi
    {
        const auto tickets = NSearchProxy::TTvmTraits::GetTickets(RD);
        if (const TString& ticket = tickets.ServiceTicket) {
            RD.CgiParam.ReplaceUnescaped(NSearchProxyCgi::ya_service_ticket, ticket);
        }
        if (const TString& ticket = tickets.UserTicket) {
            RD.CgiParam.ReplaceUnescaped(NSearchProxyCgi::ya_user_ticket, ticket);
        }
    }

    bool specialRequest = rawScript.StartsWith(NSearchProxyCgi::broadcast) || rawScript.StartsWith(NSearchProxyCgi::global_ping);
    Service = CalculateName(RD.CgiParam);
    if (!Service && FilteredScript && !specialRequest) {
        Service = script;
        RD.CgiParam.InsertUnescaped(service, Service);
    }
    if (Service && FilteredScript && !specialRequest && Service != FilteredScript) {
        ythrow TSearchException(HTTP_BAD_REQUEST) << "mismatched service names: " << FilteredScript << " != " << Service;
    }

    Server->ApplyGlobalCorrections(Service, RD.CgiParam);
    Service = CalculateName(RD.CgiParam);
}

void TSearchProxyClient::OnBeginProcess(IReplyContext::TPtr /*context*/) {
    if (!RequestString.empty()) {
        Server->GetFlowMirror().ProcessRequest(RequestString, RD, Buf);
    }

    ParseRequest();
    AdjustClientRequest(Service);
}


void TSearchProxyClient::OnAccessDenied(const NSaas::TNotAuthorizedInfo& auth, IReplyContext::TPtr context) {
    NSearchProxy::LogAccessDenied(auth, "", *context);
}

TSearchProxyClient::TSearchProxyClient(TSearchProxyServer* server)
    : TSearchProxyRequestFeatures(server)
{}

TSearchProxyNehRequest::TSearchProxyNehRequest(YandexHttpServer* yserver, TSearchProxyServer* server, NNeh::IRequestRef req)
    : TYsNehClientRequest(yserver, req)
    , TSearchProxyRequestFeatures(server)
    , Preprocessed(false)
    , RequestStartTime(TInstant::Now())
{
    TYsNehClientRequest::ScanQuery();
}

void TSearchProxyNehRequest::Process(void* ts) {
    THolder<TSearchProxyNehRequest> this_(this);
    try {
        if (!Preprocessed) {
            ParseRequest();
            StatFrame = NSaas::GetGlobalUnistatFrameFactory().SpawnFrame(Service);
            auto ptr = TSearchProxyRequestFeatures::Server->GetMetaSearcherSafe(Service);
            if (ptr)
                Frame = MakeIntrusive<TSelfFlushSearchEventLogFrame>(ptr->GetEventLog().Get());
            Preprocessed = true;
        }

        IProxy::TPtr proxy = TSearchProxyRequestFeatures::Server->SelectProxy(RequestData().CgiParam, Config.GetServiceConfig(Service).GetDefaultMetaSearch());
        if (!proxy) {
            ythrow TSearchException(HTTP_BAD_REQUEST) << "incorrect proxy name " << RequestData().CgiParam.Print();
        }
        if (!proxy->IsCommonSearch()) {
            ISearchReplier::TPtr replier = DoSelectHandlerImpl(this_.Release());
            replier.Release()->Reply();
        } else {
            Y_UNUSED(this_.Release());
            TYsNehClientRequest::Process(ts);
        }
    } catch (...) {
        if (this_) {
            MakeErrorPage(*this_, HTTP_INTERNAL_SERVER_ERROR, CurrentExceptionMessage());
            DoHandleException(this_.Get());
        }
        throw;
    }
}

void TSearchProxyNehRequest::OnBeforeRequestHandlerRun() {
    try {
        ProcessTvmAuth();
    } catch (const NSaas::TNotAuthorizedException& e) {
        NSearchProxy::LogAccessDenied(e.GetInfo(), "", *this);
        MakeErrorPage(*this, e.GetHttpCode(), e.what());
        throw;
    } catch (...) {
        MakeErrorPage(*this, HTTP_INTERNAL_SERVER_ERROR, CurrentExceptionMessage());
        DoHandleException(this);
        throw;
    }
}

bool TSearchProxyNehRequest::CreateSearcher(const char* /*name*/) {
    Y_ASSERT(!Searcher);
    if (Searcher) {
        return true;
    }

    Searcher = TSearchProxyRequestFeatures::Server->GetMetaSearcherSafe(Service);
    if (Searcher) {
        Searcher->IncreaseSearch();
    }

    return Searcher != nullptr;
}

NThreading::TFuture<void> TAppHostNehRequest::GetFuture() const {
    return Promise.GetFuture();
}

void TAppHostNehRequest::SendError(TResponseError err, const TString& details /*= TString()*/) {
    Promise.SetException("Error " + ToString((int)err) + ": " + details);
}

void TAppHostNehRequest::SendReply(NNeh::TData& data) {
    NMetaProtocol::TReport report;
    TMemoryInput mi(data.data(), data.size());
    if (!report.ParseFromArcadiaStream(&mi)) {
        Promise.SetException(TString("data is not proto: ") + data.data());
        return;
    }
    NMetaProtocol::Decompress(report);
    if (report.GroupingSize() == 0 || report.GetGrouping(0).GroupSize() == 0 || report.GetGrouping(0).GetGroup(0).DocumentSize() == 0) {
        AppHostContext->AddFlag("empty_answer");
    }
    if (ProtobufResult) {
        AppHostContext->AddProtobufItem(std::move(report), "SAAS_proto");
    } else {
        NJson::TJsonValue ans;
        NProtobufJson::Proto2Json(report, ans);
        AppHostContext->AddItem(std::move(ans), "SAAS");
    }
    Promise.SetValue();
}

bool TAppHostNehRequest::Canceled() const {
    return false;
}

TStringBuf TAppHostNehRequest::RequestId() const {
    return RequestIdStr;
}

TStringBuf TAppHostNehRequest::Data() const {
    return RequestStr;
}

TStringBuf TAppHostNehRequest::Service() const {
    return TStringBuf();
}

TString TAppHostNehRequest::RemoteHost() const {
    return TString();
}

TStringBuf TAppHostNehRequest::Scheme() const {
    return TStringBuf();
}

TAppHostNehRequest::TAppHostNehRequest(NAppHost::TServiceContextPtr appHostContext)
    : Promise(NThreading::NewPromise<void>())
    , AppHostContext(appHostContext)
    , RequestIdStr(GetGuidAsString(appHostContext->GetRequestID()))
{
    TCgiParameters cgi;
    const NJson::TJsonValue& saasInput = AppHostContext->GetOnlyItem("saas_input");

    // log the service name for setrace: SETRACE-447, SAAS-4920
    auto addService = [&appHostContext](const NJson::TJsonValue& value) {
        appHostContext->AddLogLine(TString::Join("Service: ", value.GetStringRobust()));
    };

    for (auto&& [key, value]: saasInput.GetMap()) {
        if (value.IsArray()) {
            if (IsIn({"relev", "rearr", "pron"}, key)) {
                TStringBuilder catValue;
                bool first = true;
                for (const auto& elem: value.GetArray()) {
                    if (!first) {
                        catValue << ';';
                    }
                    catValue << elem.GetStringRobust();
                    first = false;
                }
                cgi.InsertUnescaped(key, catValue);
            } else {
                for (const auto& elem: value.GetArray()) {
                    if (key == "service") {
                        addService(elem);
                    }
                    cgi.InsertUnescaped(key, elem.GetStringRobust());
                }
            }
        } else {
            if (key == "service") {
                addService(value);
            }
            cgi.InsertUnescaped(key, value.GetStringRobust());
        }
    }

    const auto wizards = AppHostContext->GetItemRefs("wizard");
    for (const auto& wizard : wizards) {
        if (wizard.Has("qtree")) {
            cgi.ReplaceUnescaped("qtree", wizard["qtree"].GetStringRobust());
        }
    }

    // add reqid as a cgi parameter: SETRACE-845
    const auto apphostParams = AppHostContext->GetItemRefs("app_host_params");
    for (const auto& params : apphostParams) {
        if (params.Has("reqid")) {
            cgi.InsertUnescaped("reqid", params["reqid"].GetStringRobust());
        }
    }

    ProtobufResult = (cgi.Get("ms") == "proto");
    cgi.EraseAll("hr");
    cgi.EraseAll("format");
    cgi.ReplaceUnescaped("ms", "proto");
    RequestStr = cgi.Print();
}
