#include "module.h"
#include "headers.h"

#include <util/generic/string.h>
#include <util/string/subst.h>
#include <balancer/kernel/http/parser/common_headers.h>
#include <balancer/kernel/http/parser/httpencoder.h>
#include <balancer/kernel/http/parser/request_builder.h>
#include <balancer/kernel/http/parser/response_builder.h>
#include <balancer/kernel/log/errorlog.h>
#include <balancer/kernel/module/module.h>
#include <balancer/kernel/requester/requester.h>

using namespace NSrvKernel;

namespace {

enum class EAuthResult {
    OK,
    UNAUTHORIZED,
    FORBIDDEN,
    ERROR
};

struct TAuth {
    EAuthResult AuthResult;
    TMaybe<TString> AppId;
    TMaybe<TString> CsrfState;
    TMaybe<TString> CsrfToken;
};

TStringBuf GetYandexDomainFromHost(const TConnDescr& descr) {
    TStringBuf hostHeader = descr.Request->Headers().GetFirstValue("host");

    if (hostHeader.EndsWith("yandex.ru")) {
        return "yandex.ru";
    } else if (hostHeader.EndsWith("yandex-team.ru")) {
        return "yandex-team.ru";
    }

    return "";
}

}

MODULE(webauth) {
public:
    TModule(const TModuleParams& mp)
        : TModuleBase(mp)
    {
        Config->ForEach(this);

        if (!Submodule_) {
            ythrow TConfigParseError() << "no submodule configured";
        }

        if (!OnForbidden_) {
            ythrow TConfigParseError() << "no on_forbidden module configured";
        }

        if (!AuthCookie_ && RedirectUrl_ || AuthCookie_ && !RedirectUrl_) {
            ythrow TConfigParseError() << "set cookie and redirect url must be configured together";
        }

        if (!Checker_) {
            ythrow TConfigParseError() << "no auth module configured";
        }

        AuthRequest_ = BuildRequest().Version11().Method(EMethod::GET).Path(AuthPath_);
    }

private:
    START_PARSE {
        ON_KEY("unauthorized_redirect", RedirectUrl_) {
            return;
        }

        ON_KEY("unauthorized_set_cookie", AuthCookie_) {
            return;
        }

        ON_KEY("auth_path", AuthPath_) {
            if (!AuthPath_.StartsWith('/')) {
                ythrow TConfigParseError{} << "auth request path must start with /";
            }
            return;
        }

        ON_KEY("role", Role_) {
            return;
        }

        ON_KEY("header_name_redirect_bypass", HeaderNameRedirectByPass_) {
            return;
        }

        ON_KEY("allow_options_passthrough", AllowOptionsPassThrough_) {
            return;
        }

        if (key == "checker") {
            if (Checker_) {
                ythrow TConfigParseError{} << "duplicate auth module found";
            }

            TSubLoader(Copy(value->AsSubConfig())).Swap(Checker_);
            return;
        }

        if (key == "on_forbidden") {
            if (OnForbidden_) {
                ythrow TConfigParseError{} << "duplicate on_forbidden module found";
            }

            TSubLoader(Copy(value->AsSubConfig())).Swap(OnForbidden_);
            return;
        }

        if (key == "on_error") {
            if (OnError_) {
                ythrow TConfigParseError{} << "duplicate on_error module found";
            }

            TSubLoader(Copy(value->AsSubConfig())).Swap(OnError_);
            return;
        }

        {
            if (Submodule_) {
                ythrow TConfigParseError{} << "duplicate child module found";
            }

            Submodule_.Reset(Loader->MustLoad(key, Copy(value->AsSubConfig())).Release());
            return;
        }
    } END_PARSE

private:
    TAuth Authorize(const TConnDescr& descr) const noexcept {
        TRequest authRequest = AuthRequest_;
        authRequest.Headers() = descr.Request->Headers();

        authRequest.Headers().Add("Webauth-Retpath", MakeReturnPath(descr));

        if (Role_) {
            authRequest.Headers().Add("Webauth-Idm-Role", Role_);
        }

        TRequester requester{*Checker_, descr};
        TResponse response;

        Y_TRY(TError, error) {
            return requester.Request(std::move(authRequest), response);
        } Y_CATCH {
            if (const auto& e = error.GetAs<yexception>()) {
                LOG_ERROR(TLOG_INFO, descr, "auth error:" << e->what());
            }
            return { EAuthResult::ERROR, Nothing(), Nothing(), Nothing() };
        }

        return ParseAuthResponse(response);
    }

    TString MakeReturnPath(const TConnDescr& descr) const noexcept {
        const TRequest* const request = descr.Request;
        TStringBuf hostHeader = request->Headers().GetFirstValue("host");

        if (!hostHeader) {
            return "undefined";
        }

        TStringBuilder location;

        if (descr.Properties->UserConnIsSsl) {
            location << "https://";
        } else {
            location << "http://";
        }

        location << hostHeader;

        auto url = request->RequestLine().GetURL();
        if (url.Empty() || url[0] != '/') {
            location << "/";
        }

        location << url;

        return location;
    }

    EAuthResult ParseAuthStatus(const TResponse& response) const noexcept {
        const auto status = response.ResponseLine().StatusCode;
        switch (status) {
        case 200:
            return EAuthResult::OK;
        case 401:
            return EAuthResult::UNAUTHORIZED;
        case 403:
            return EAuthResult::FORBIDDEN;
        default:
            return EAuthResult::ERROR;
        }
    }

    TAuth ParseAuthResponse(const TResponse& response) const noexcept {
        NWebAuth::THeaderMatcher matcher;
        matcher.Match(response.Headers());
        return {
            .AuthResult = ParseAuthStatus(response),
            .AppId = matcher.AppId() ? TMaybe<TString>(*matcher.AppId()) : Nothing(),
            .CsrfState = matcher.State() ? TMaybe<TString>(*matcher.State()) : Nothing(),
            .CsrfToken = matcher.Token() ? TMaybe<TString>(*matcher.Token()) : Nothing()
        };
    }

    TError ResponseUnauthorized(const TConnDescr& descr, const TAuth& auth) const noexcept {
        TResponse response = BuildResponse().Code(302).Version11();
        response.Headers().Add("Location", FillRedirectUrlParams(descr, RedirectUrl_, auth));
        response.Headers().Add("Set-Cookie", FillAuthParams(AuthCookie_, auth));
        Y_PROPAGATE_ERROR(descr.Output->SendHead(std::move(response), false, TInstant::Max()));
        return descr.Output->SendEof(TInstant::Max());
    }

    TError DoRun(const TConnDescr& descr) const noexcept override {
        if (AllowOptionsPassThrough_ && descr.Request->RequestLine().Method == EMethod::OPTIONS) {
            descr.ExtraAccessLog << " passthrough";
            return Submodule_->Run(descr);
        }

        const TAuth& auth = Authorize(descr);

        switch (auth.AuthResult) {
        case EAuthResult::OK:
            descr.ExtraAccessLog << " ok";
            return Submodule_->Run(descr);

        case EAuthResult::UNAUTHORIZED: {
            descr.ExtraAccessLog << " unauthorized";
            Y_DEFER {
                descr.ExtraAccessLog.SetSummary(GetHandle()->Name(), "unauthorized");
            };
            if (descr.Request->Headers().FindValues(HeaderNameRedirectByPass_) != descr.Request->Headers().end()
                || !RedirectUrl_ || !AuthCookie_) {
                return OnForbidden_->Run(descr);
            }
            return ResponseUnauthorized(descr, auth);
        }

        case EAuthResult::FORBIDDEN: {
            descr.ExtraAccessLog << " forbidden";
            Y_DEFER {
                descr.ExtraAccessLog.SetSummary(GetHandle()->Name(), "forbidden");
            };
            return OnForbidden_->Run(descr);
        }

        case EAuthResult::ERROR: {
            descr.ExtraAccessLog << " error";
            Y_DEFER {
                descr.ExtraAccessLog.SetSummary(GetHandle()->Name(), "auth error");
            };
            if (OnError_) {
                return OnError_->Run(descr);
            }
            return OnForbidden_->Run(descr);
        }
        };
    }

    TString FillAuthParams(TString data, const TAuth& auth) const noexcept {
        if (auth.CsrfToken) {
            SubstGlobal(data, "{csrf_token}", *auth.CsrfToken);
        }
        if (auth.CsrfState) {
            SubstGlobal(data, "{csrf_state}", *auth.CsrfState);
        }
        if (auth.AppId) {
            SubstGlobal(data, "{app_id}", *auth.AppId);
        }
        return data;
    }

    TString FillRedirectUrlParams(const TConnDescr& descr, TString data, const TAuth& auth) const noexcept {
        data = FillAuthParams(data, auth);

        TStringBuf yandexDomain = GetYandexDomainFromHost(descr);
        if (yandexDomain) {
            SubstGlobal(data, "{yandex_domain}", GetYandexDomainFromHost(descr));
        }

        // Should be replaced at the end, because it contains unfiltered user data.
        SubstGlobal(data, "{retpath}", MakeReturnPath(descr));
        return data;
    }

    bool DoExtraAccessLog() const noexcept override {
        return true;
    }

private:
    THolder<IModule> Checker_;
    THolder<IModule> Submodule_;
    THolder<IModule> OnForbidden_;
    THolder<IModule> OnError_;

    TString RedirectUrl_;
    TString AuthCookie_;
    TRequest AuthRequest_;
    TString Role_;
    TString AuthPath_ = "/";
    TString HeaderNameRedirectByPass_ = "Webauth-Authorization";

    bool AllowOptionsPassThrough_ = false;
};

IModuleHandle* NModWebAuth::Handle() {
    return TModule::Handle();
}
