#include "module.h"

#include <balancer/kernel/coro/cleanable_coro_storage.h>
#include <balancer/kernel/custom_io/queue.h>
#include <balancer/kernel/custom_io/rewind.h>
#include <balancer/kernel/helpers/default_instance.h>
#include <balancer/kernel/http/parser/response_builder.h>
#include <balancer/kernel/io/iobase.h>
#include <balancer/kernel/log/errorlog.h>
#include <balancer/kernel/module/iface.h>
#include <balancer/kernel/module/module.h>
#include <balancer/kernel/requester/requester.h>

#include <balancer/kernel/regexp/regexp_pire.h>

#include <util/generic/ptr.h>

#include <utility>


using namespace NConfig;
using namespace NSrvKernel;

namespace {
#define HEADER_FSM(TName, value)                                            \
    struct TName final : public TFsm, public TWithDefaultInstance<TName> {  \
        TName()                                                             \
            : TFsm(value, TFsm::TOptions().SetCaseInsensitive(true))        \
        {}                                                                  \
    }

HEADER_FSM(TCacheControlFsm, "Cache-Control");
HEADER_FSM(TPragmaFsm, "Pragma");
HEADER_FSM(TAuthorizationFsm, "Authorization");
HEADER_FSM(TVaryFsm, "Vary");
HEADER_FSM(TExpiresFsm, "Expires");

#undef HEADER_FSM

#define HEADER_VALUE_FSM(TName, value)                                                  \
    struct TName final : public TFsm, public TWithDefaultInstance<TName> {              \
        TName()                                                                         \
            : TFsm(value, TFsm::TOptions().SetCaseInsensitive(true).SetSurround(true))  \
        {}                                                                              \
    }

HEADER_VALUE_FSM(TNoCacheFsm, "no[- ]cache");
HEADER_VALUE_FSM(TNoStoreFsm, "no-store");
HEADER_VALUE_FSM(TOnlyIfCachedFsm, "only-if-cached");
HEADER_VALUE_FSM(TPrivateFsm, "private");
HEADER_VALUE_FSM(TMaxAgeFsm, "max-age");
HEADER_VALUE_FSM(TSMaxAgeFsm, "s-maxage");

#undef HEADER_VALUE_FSM

class TCacheHeaderNamesFsm : public TFsm, public TWithDefaultInstance<TCacheHeaderNamesFsm> {
public:
    TCacheHeaderNamesFsm()
        : TFsm(
                TCacheControlFsm::Instance() // 0
              | TPragmaFsm::Instance()       // 1
              | TAuthorizationFsm::Instance() // 2
        )
    {}
};

class TBackendHeaderNamesFsm : public TFsm, public TWithDefaultInstance<TBackendHeaderNamesFsm> {
public:
    TBackendHeaderNamesFsm() noexcept
        : TFsm(
                TCacheControlFsm::Instance() // 0
              | TVaryFsm::Instance()         // 1
              | TExpiresFsm::Instance()      // 2
        )
    {}
};

struct THeadersPreProcessResult {
    bool CanRequestCache() const noexcept {
        return !CacheControlNoCache && !PragmaNoCache && !MayHaveBody;
    }

    bool MayHaveBody = false;
    bool MethodIsCacheable = false;
    bool IsHeadRequest = false;
    bool CacheControlNoCache = false;
    bool PragmaNoCache = false;
    bool CacheControlOnlyIfCached = false;
    bool CacheControlNoStore = false;
    bool HasAuthorizationHeader = false;
};

struct TBackendHeadersProcessResult {
    bool MayCalculateRecordLifetime() const noexcept {
        return HasExpires || CacheControlMaxAge || CacheControlSMaxAge;
    }

    bool MayStoreInCache() const noexcept {
        return (Status == 200 || Status == 206) && !CacheControlNoStore && !CacheControlNoCache
                && !CacheControlPrivate && !HasVary && MayCalculateRecordLifetime();
    }

    ui32 Status = 0;
    bool CacheControlNoStore = false;
    bool CacheControlNoCache = false;
    bool CacheControlPrivate = false;
    bool HasVary = false;
    bool HasExpires = false;
    bool CacheControlMaxAge = false;
    bool CacheControlSMaxAge = false;
};

THeadersPreProcessResult PreprocessHeaders(const TRequest& request) noexcept {
    THeadersPreProcessResult retval;

    const TRequestLine& requestLine = request.RequestLine();
    if (requestLine.Method == EMethod::GET) {
        retval.MethodIsCacheable = true;
    } else if (requestLine.Method == EMethod::HEAD) {
        retval.MethodIsCacheable = true;
        retval.IsHeadRequest = true;
    }
    retval.MayHaveBody = request.Props().ChunkedTransfer || request.Props().ContentLength > 0;

    const TCacheHeaderNamesFsm& headerNamesFsm = TCacheHeaderNamesFsm::Instance();

    for (const auto& header : request.Headers()) {
        TMatcher matcher(headerNamesFsm);
        if (Match(matcher, header.first.AsStringBuf()).Final()) {
            for (const auto& headerValue : header.second) {
                switch (*matcher.MatchedRegexps().first) {
                    case 0: // cache-control
                        if (!retval.CacheControlNoCache && Match(TNoCacheFsm::Instance(), headerValue.AsStringBuf())) {
                            retval.CacheControlNoCache = true;
                            // TODO: maybe break after this
                        }
                        if (!retval.CacheControlOnlyIfCached &&
                            Match(TOnlyIfCachedFsm::Instance(), headerValue.AsStringBuf()))
                        {
                            retval.CacheControlOnlyIfCached = true;
                        }
                        if (!retval.CacheControlNoStore && Match(TNoStoreFsm::Instance(), headerValue.AsStringBuf())) {
                            retval.CacheControlNoStore = true;
                        }
                        break;
                    case 1: // pragma
                        if (!retval.PragmaNoCache && Match(TNoCacheFsm::Instance(), headerValue.AsStringBuf())) {
                            retval.PragmaNoCache = true;
                            // TODO: maybe break after this
                        }
                        break;
                    case 2: // authorization
                        retval.HasAuthorizationHeader = true;
                        break;
                }
            }
        }
    }

    return retval;
}

bool MayStoreInCache(const TResponse& response) noexcept {
    TBackendHeadersProcessResult retval;
    const TResponseLine& responseLine = response.ResponseLine();
    const ui32 status = responseLine.StatusCode;
    retval.Status = status;
    if (status != 200 && status != 206) {
        return false;
    }

    for (const auto& header: response.Headers()) {
        TMatcher matcher(TBackendHeaderNamesFsm::Instance());
        if (Match(matcher, header.first.AsStringBuf()).Final()) {
            for (const auto& headerValue : header.second) {
                switch (*matcher.MatchedRegexps().first) {
                    case 0: // cache-control
                        if (!retval.CacheControlNoStore && Match(TNoStoreFsm::Instance(), headerValue.AsStringBuf())) {
                            retval.CacheControlNoStore = true;
                        }
                        if (!retval.CacheControlNoCache && Match(TNoCacheFsm::Instance(), headerValue.AsStringBuf())) {
                            retval.CacheControlNoCache = true;
                        }
                        if (!retval.CacheControlPrivate && Match(TPrivateFsm::Instance(), headerValue.AsStringBuf())) {
                            retval.CacheControlPrivate = true;
                        }
                        if (!retval.CacheControlMaxAge && Match(TMaxAgeFsm::Instance(), headerValue.AsStringBuf())) {
                            retval.CacheControlMaxAge = true;
                        }
                        if (!retval.CacheControlSMaxAge && Match(TSMaxAgeFsm::Instance(), headerValue.AsStringBuf())) {
                            retval.CacheControlSMaxAge = true;
                        }
                        break;
                    case 1: // vary
                        retval.HasVary = true;
                        break;
                    case 2: // expires
                        retval.HasExpires = true;
                        break;
                }
            }
        }

    }

    return retval.MayStoreInCache();
}

bool CacheResponseValid(const TResponse& response) {
    return IsIn({200, 206, 304}, response.ResponseLine().StatusCode);
}

}  // namespace

Y_TLS(cachalot) {
    TTls(const TSharedCounter& counter, size_t workerId)
        : CacheHit(counter, workerId)
    {}

    TSharedCounter CacheHit;
    TCleanableCoroStorage Runners;
};

MODULE_WITH_TLS_BASE(cachalot, TModuleWithSubModule) {
public:
    explicit TModule(const TModuleParams& mp)
        : TModuleBase(mp)
    {
        Config->ForEach(this);

        CacheHitCounter_ = TSharedCounter(
            Control->SharedStatsManager().MakeCounter(
                TString::Join("cachalot-", (Uuid_ ? Uuid_ + "-": ""), "cache_hit")).AllowDuplicate().Build());
        Y_ENSURE_EX(Submodule_, TConfigParseError() << " no submodule configured");
        Y_ENSURE_EX(Cacher_, TConfigParseError() << " no cacher submodule configured");
    }

private:
    START_PARSE {
        ON_KEY("uuid", Uuid_) {
            return;
        }
        ON_KEY("collection", Collection_) {
            if (!Collection_->StartsWith('/')) {
                Collection_->prepend('/');
            }

            while (Collection_->EndsWith('/')) {
                Collection_->pop_back();
            }
            return;
        }

        if (key == "cacher") {
            THolder<IModule> submodule;
            TSubLoader(Copy(value->AsSubConfig())).Swap(submodule);
            Cacher_.Reset(std::move(submodule));
            return;
        }

        Y_ENSURE_EX(!Submodule_, TConfigParseError() << "several modules for cachalot");
        Submodule_.Reset(Loader->MustLoad(key, Copy(value->AsSubConfig())).Release());
        return;
    } END_PARSE

    THolder<TTls> DoInitTls(IWorkerCtl* process) override {
        return MakeHolder<TTls>(*CacheHitCounter_, process->WorkerId());
    }

    TError DoRun(const TConnDescr& descr, TTls& tls) const noexcept override {
        tls.Runners.EraseFinished();
        Y_ASSERT(descr.Request);

        const THeadersPreProcessResult preprocessed = PreprocessHeaders(*descr.Request);

        if (!preprocessed.MethodIsCacheable) {
            if (preprocessed.CacheControlOnlyIfCached) {
                return Respond504(descr);
            }
            return Submodule_->Run(descr);
        }

        if (preprocessed.CanRequestCache()) {
            return ServeFull(descr, preprocessed, tls);
        }
        if (preprocessed.CacheControlOnlyIfCached) {
            return Respond504(descr);
        }
        return ServeNonCached(descr, preprocessed, tls);
    }

    bool DoExtraAccessLog() const noexcept override {
        return true;
    }

    TError Respond504(const TConnDescr& descr) const noexcept {
        descr.ExtraAccessLog.SetSummary(GetHandle()->Name(), "not in cache");
        TResponse response = BuildResponse().Version11().Code(HTTP_GATEWAY_TIME_OUT).ContentLength(0);
        Y_PROPAGATE_ERROR(descr.Output->SendHead(std::move(response), false, TInstant::Max()));
        Y_PROPAGATE_ERROR(descr.Output->SendEof(TInstant::Max()));
        return SkipAll(descr.Input, TInstant::Max());
    }

    TRequest PrepareGetCacheRequest(const TConnDescr& descr) const {
        TRequest newRequest = *descr.Request;
        if (Collection_) {
            newRequest.RequestLine().Path = TStringStorage(*Collection_ + newRequest.RequestLine().Path.AsStringBuf());
        }
        return newRequest;
    }

    TRequest PreparePutCacheRequest(const TConnDescr& descr, const TResponse& response) const noexcept {
        TRequestLine requestLine = descr.Request->RequestLine();
        requestLine.Method = EMethod::PUT;
        requestLine.MinorVersion = 1;
        if (Collection_) {
            requestLine.Path = TStringStorage(*Collection_ + requestLine.Path.AsStringBuf());
        }
        TRequest request{std::move(requestLine)};

        request.Props().ContentLength.reset();
        request.Props().ChunkedTransfer = true;

        for (const auto& header: response.Headers()) {
            for (const auto& headerVal : header.second) {
                request.Headers().Add(header.first.AsString(), headerVal.AsString());
            }
        }
        return request;
    }

    TErrorOr<bool> TryRespondFromCache(const TConnDescr& descr, TTls& tls) const {
        bool fromCache = false;
        auto output = MakeHttpOutput([&](TResponse&& response, bool forceClose, TInstant deadline) {
            if (!CacheResponseValid(response)) {
                return TError{};
            }
            fromCache = true;
            return descr.Output->SendHead(std::move(response), forceClose, deadline);
        }, [&](TChunkList lst, TInstant deadline) {
            if (fromCache) {
                return descr.Output->Send(std::move(lst), deadline);
            }
            return TError{};
        }, [&](THeaders&& trailers, TInstant deadline) {
            return descr.Output->SendTrailers(std::move(trailers), deadline);
        });

        Y_TRY(TError, error) {
            const TExtraAccessLogEntry cacheLog(descr, "cacher");
            TConnDescr newDescr = descr.CopyOut(output);
            TRequest req = PrepareGetCacheRequest(newDescr);
            newDescr.Request = &req;
            return Cacher_->Run(newDescr);
        } Y_CATCH {
            if (fromCache) {
                return error;
            }
        }
        if (fromCache) {
            ++tls.CacheHit;
        }
        return fromCache;
    }

    TError ServeFull(const TConnDescr& descr, const THeadersPreProcessResult& preprocessed, TTls& tls) const
    {
        bool responded = false;
        Y_PROPAGATE_ERROR(TryRespondFromCache(descr, tls).AssignTo(responded));
        if (responded) {
            descr.ExtraAccessLog.SetSummary(GetHandle()->Name(), "success");
            return {};
        }
        if (preprocessed.CacheControlOnlyIfCached) {
            return Respond504(descr);
        }

        return ServeNonCached(descr, preprocessed, tls);
    }

    TError ServeNonCached(const TConnDescr& descr, const THeadersPreProcessResult& preprocessed, TTls& tls) const
    {
        if (preprocessed.CacheControlNoStore
            || preprocessed.HasAuthorizationHeader
            || preprocessed.IsHeadRequest) // TODO: and check for public in cache-control
        {
            return Submodule_->Run(descr);
        }

        TSimpleSharedPtr<TQueueIo> io = nullptr;
        auto output = MakeHttpOutput([&](TResponse&& response, bool forceClose, TInstant deadline) {
            if (MayStoreInCache(response)) {
                IWorkerCtl& process = descr.Process();
                io = MakeSimpleShared<TQueueIo>(&process.Executor());
                tls.Runners.Emplace("cachalot_store_cont", &process.Executor(),
                    [this, io, &process](TRequest&& request) {
                        TAsyncRequester requester(*Cacher_, nullptr, process);
                        Y_UNUSED(requester.Requester().Request(std::move(request), *io.Get()));
                        io->Stop();
                    },
                    PreparePutCacheRequest(descr, response)
                );
            }
            return descr.Output->SendHead(std::move(response), forceClose, deadline);
        }, [&](TChunkList lst, TInstant deadline) {
            if (io) {
                Y_UNUSED(io->Send(lst.Copy(), deadline)); // TODO: queue limit
            }
            return descr.Output->Send(std::move(lst), deadline);
        }, [&](THeaders&& trailers, TInstant deadline) {
            return descr.Output->SendTrailers(std::move(trailers), deadline);
        });

        const TConnDescr newDescr = descr.CopyOut(output);
        return Submodule_->Run(newDescr);
    }

private:
    THolder<IModule> Cacher_;
    TMaybe<TSharedCounter> CacheHitCounter_;
    TMaybe<TString> Collection_;
    TString Uuid_;
};

IModuleHandle* NModCachalot::Handle() {
    return TModule::Handle();
}
