#include "helpers.h"
#include "connect.h"

#include <balancer/kernel/client_request/backend_config.h>
#include <balancer/kernel/helpers/misc.h>
#include <balancer/kernel/log/errorlog.h>

using namespace NSrvKernel;

namespace {
    TErrorOr<TSockAddrInfo> Addr(const TConnDescr& descr, const THostInfo hostInfo, const TBackendConfig& config, TInstant deadline) noexcept {
        std::optional<TSockAddrInfo> result;

        bool needResolve = config.need_resolve();
        if (hostInfo.IsSrcRwr) {
            if (hostInfo.CachedIp) {
                TSockAddr addr;
                Y_PROPAGATE_ERROR(TSockAddr::FromIpPort(hostInfo.CachedIp, hostInfo.Port).AssignTo(addr));
                return TSockAddrInfo{{addr}};
            }
            needResolve = true;
        }

        if (needResolve) {
            Y_REQUIRE(descr.HasProcess(), yexception{} << "can not resolve without resolver");
            Y_TRY(TError, error) {
                return descr.Process().Resolver().Resolve(NDns::TResolveInfo{hostInfo.Host, config.GetFamily(), hostInfo.Port, NDns::TResolveInfo::GENERAL},
                                                  !config.cached_ip().empty() && !hostInfo.IsSrcRwr, deadline).AssignTo(result);
            } Y_CATCH {
                if (const auto* e = error.GetAs<TNetworkResolutionError>()) {
                    if (config.cached_ip() && !hostInfo.IsSrcRwr) {
                        Y_HTTPD_LOG_IMPL(descr.ErrorLog, TLOG_INFO, "proxy", descr, "DNS cache address resolving failed " << e->what() << " " << hostInfo.Host << ':' << hostInfo.Port <<
                                                                                    ", using cached address " << config.cached_ip() << " instead");
                    } else {
                        Y_HTTPD_LOG_IMPL(descr.ErrorLog, TLOG_ERR, "proxy", descr, "address resolving failed " << e->what() << " " << hostInfo.Host << ':' << hostInfo.Port);
                        return error;
                    }
                } else {
                    if (!config.cached_ip() || hostInfo.IsSrcRwr) {
                        Y_HTTPD_LOG_IMPL(descr.ErrorLog, TLOG_ERR, "proxy", descr, "address resolving failed with general error for " << hostInfo.Host << ':' << hostInfo.Port);
                        return error;
                    } else {
                        Y_HTTPD_LOG_IMPL(descr.ErrorLog, TLOG_INFO, "proxy", descr, "address resolving failed for " << hostInfo.Host << ':' << hostInfo.Port << ", cached address is used");
                    }
                }
            }
        }

        return (result ? *result : *config.GetCachedAddr());
    }
}

namespace NModProxy {
    TError WrapIntoBackendError(TError err, bool fastError) noexcept {
        if (err && !err.GetAs<TBackendError>()) {
            return Y_MAKE_ERROR(TBackendError(std::move(err), fastError));
        } else {
            return err;
        }
    }

    TError TKeepAliveData::SslConnect(const THttpsSettings& settings, const TRequest* req, TInstant deadline) noexcept {
        try {
            Y_ASSERT(settings.SslClientContext);
            SslIo_ = MakeHolder<TSslIo>(
                *settings.SslClientContext,
                BackendIo_->Input(),
                BackendIo_->Output()
            );
        } Y_TRY_STORE(TSslError, yexception);

        if (settings.sni_on()) {
            if (!settings.sni_host() && req) {
                SniHost_ = StripString(req->Headers().GetFirstValue("host"));
            }

            if (TStringBuf host = SniHost_ ?: settings.sni_host()) {
                Y_PROPAGATE_ERROR(SslIo_->SetSniServername(host.data()));
            }
        }

        if (req && req->Props().ClientProto) {
            TVector<ui8> proto;
            switch (*req->Props().ClientProto) {
                case EClientProto::CP_HTTP:
                    proto = {8, 'h','t','t','p','/','1','.','1'};
                    break;
                case EClientProto::CP_HTTP2:
                    proto = {2, 'h', '2'};
                    break;
                default:
                    Y_FAIL();
            }
            Y_PROPAGATE_ERROR(SslIo_->SetClientAlpn(proto.data(), proto.size()));
        }

        return SslIo_->Connect(deadline);
    }

    bool MayPass100Continue(const TRequest* request) noexcept {
        if (request->Props().Version >= 1) {
            return true;
        }

        TStringBuf expectHeaderValue = request->Headers().GetFirstValue(TFsm{TExpectFsm::Instance()});
        return expectHeaderValue && Match(T100ContinueFsm::Instance(), expectHeaderValue);
    }

    TError SslConnect(TKeepAliveData& keepAlive, const TConnDescr& descr, const THttpsSettings& settings, TInstant deadline) {
        if (!descr.AttemptsHolder) {
            if (!settings.use_by_default()) {
                return {};
            }
        } else if (!descr.AttemptsHolder->UseEncryption()) {
            return {};
        }

        deadline = Min(settings.tls_connect_timeout().ToDeadLine(), deadline);
        auto err = keepAlive.SslConnect(settings, descr.Request, deadline);

        if (err) {
            Y_HTTPD_LOG_IMPL(descr.ErrorLog, TLOG_ERR, "proxy", descr, "SSL connect to " << keepAlive.AddrStr() << " failed " << GetErrorMessage(err));
            return Y_MAKE_ERROR(TConnectError(std::move(err), true) << "connect to " << keepAlive.AddrStr() << " ");
        }

        return {};
    }

    TErrorOr<THolder<TKeepAliveData>> EstablishConnection(
            const TConnDescr& descr,
            const THostInfo& hostInfo,
            const TBackendConfig& config,
            EPollMode pollMode,
            const TInstant start,
            TDuration& connectDuration,
            TInstant& connectDeadline,
            TInstant& effectiveSessionDeadline
    ) noexcept
    {

        auto backendSocket = MakeHolder<TSocketHolder>();

        std::optional<TSockAddrInfo> addr;
        Y_PROPAGATE_ERROR(Addr(descr, hostInfo, config, start + config.resolve_timeout()).AssignTo(addr));
        auto addrFiltered = FilterAddr(*addr, config.GetFamily());

        Y_REQUIRE(!addrFiltered.empty(),
                  TBackendError(Y_MAKE_ERROR(yexception{} << "got zero addresses after filtering: "
                                                          << hostInfo.Host << ':' << hostInfo.Port<< "; family == " << config.GetFamily()), true));

        const TSockAddr* connRes = nullptr;

        const TInstant connectStart = Now();
        const TInstant connectRetryDeadLine = connectStart + config.connect_retry_timeout();

        TDuration connectTimeout = config.connect_timeout();
        if (descr.AttemptsHolder && descr.AttemptsHolder->GetConnectTimeout()) {
            connectTimeout = descr.AttemptsHolder->GetConnectTimeout();
        }
        TDuration tcpConnectTimeout = Min(connectTimeout, config.tcp_connect_timeout());

        while (true) {
            Y_TRY(TError, err) {
                return Connect(
                        RunningCont()->Executor(),
                        *backendSocket,
                        addrFiltered,
                        SOCK_STREAM,
                        IPPROTO_TCP,
                        tcpConnectTimeout.ToDeadLine(),
                        descr.Properties->ConnStats
                ).AssignTo(connRes);
            } Y_CATCH {
                if (Now() + config.connect_retry_delay() < connectRetryDeadLine) {
                    Y_PROPAGATE_ERROR(CheckedSleepT(descr.Process().Executor().Running(), config.connect_retry_delay()));
                    continue;
                } else {
                    descr.ExtraAccessLog << ' ' << TDuration::Zero() << '/' << Now() - start << " system_error " << GetShortErrorMessage(err);
                    Y_HTTPD_LOG_IMPL(descr.ErrorLog, TLOG_ERR, "proxy", descr, "Could not connect to " << hostInfo.Host << ':' << hostInfo.Port << " " << GetErrorMessage(err));
                    return Y_MAKE_ERROR(TConnectError(std::move(err), true));
                }
            };

            Y_PROPAGATE_ERROR(WrapIntoBackendError(EnableNoDelay(*backendSocket), true));
            Y_PROPAGATE_ERROR(WrapIntoBackendError(EnableKeepalive(*backendSocket, config.GetTcpKeepalive()), true));
            Y_PROPAGATE_ERROR(WrapIntoBackendError(SetSockBufSize(*backendSocket, config.GetSockBufSize()), true));
            break;
        }

        connectDuration = Now() - connectStart;
        effectiveSessionDeadline = Now() + config.backend_timeout();

        auto backendIo = MakeHolder<TBackendIo>(std::move(backendSocket), RunningCont()->Executor(), pollMode);

        Y_ASSERT(connRes);

        auto keepAlive = MakeHolder<TKeepAliveData>(std::move(backendIo), std::move(*addr), *connRes);

        connectDeadline = connectRetryDeadLine + connectTimeout;
        if (config.GetHttpsSettings()) {
            if (TError err = SslConnect(*keepAlive, descr, *config.GetHttpsSettings(), connectDeadline)) {
                descr.ExtraAccessLog << ' ' << TDuration::Zero() << '/' << Now() - start << " ssl_connect_error";
                return err;
            }
        }

        return keepAlive;
    }
}
