#include "module.h"

#include <balancer/kernel/custom_io/queue.h>
#include <balancer/kernel/http/parser/http.h>
#include <balancer/kernel/module/iface.h>
#include <balancer/kernel/module/module.h>
#include <balancer/kernel/ssl/sslio.h>
#include <balancer/kernel/log/errorlog.h>
#include <balancer/kernel/custom_io/stream.h>

#include <util/datetime/cputimer.h>

using namespace NModThreshold;
using namespace NConfig;
using namespace NSrvKernel;

class TRecvTimeoutInput: public IIoInput {
public:
    TRecvTimeoutInput(IIoInput* slave, TDuration timeout, TCont* mainlineCont) noexcept
        : Slave_(slave)
        , Timeout_(timeout)
        , MainlineCont_(mainlineCont)
    {}

    TError DoRecv(TChunkList& lst, TInstant deadline) noexcept override {
        TInstant effectiveDeadline = deadline;

        if (Timeout_ != TDuration::Zero()) {
            effectiveDeadline = Min(Timeout_.ToDeadLine(), deadline);
        }

        Y_TRY(TError, error) {
            return Slave_->Recv(lst, effectiveDeadline);
        } Y_CATCH {
            if (const auto* e = error.GetAs<TSystemError>()) {
                Timeouted_ = Timeouted_ || (e->Status() == ETIMEDOUT);
                if (Timeouted_) {
                    OnError();
                }
            } else if (error.GetAs<THttpError>()) {
                // strange thing for test_conn_close_on_partial_send,
                // which, probably, should've been solved by proxy
                OnError();
            }
            return error;
        }
        return {};
    }

    bool Timeouted() const noexcept {
        return Timeouted_;
    }

private:
    void OnError() noexcept {
        if (!Failed_) {
            Failed_ = true;
            MainlineCont_->Cancel();
        }
    }
private:
    IIoInput* const Slave_{ nullptr };
    TDuration Timeout_{ TDuration::Max() };
    TCont* MainlineCont_{ nullptr };
    bool Failed_{ false };
    bool Timeouted_{ false };
};

namespace {
    static const TVector<ui64> ThresholdIntervals = {
        0,
        1000,
        5000,
        10000,
        25000,
        50000,
        75000,
        100000,
        125000,
        150000,
        175000,
        200000,
        250000,
        500000,
        750000,
        1000000,
        2500000,
        5000000,
        7500000,
        10000000,
        25000000
    };

    struct TSharedCounters {
        explicit TSharedCounters(TSharedStatsManager& statsManager)
            : ThresholdReciveTimeoutCounter(statsManager.MakeCounter("threshold-receive_timedout").AllowDuplicate().Build())
            , ThresholdPassTimeoutCounter(statsManager.MakeCounter("threshold-timedout").AllowDuplicate().Build())
            , ThresholdPassedTime(statsManager.MakeHistogram("threshold-pass_time", ThresholdIntervals).Scale(1e6).AllowDuplicate().Build())
        {}

        TSharedCounter ThresholdReciveTimeoutCounter;
        TSharedCounter ThresholdPassTimeoutCounter;
        TSharedHistogram ThresholdPassedTime;
    };
}

Y_TLS(threshold) {
    TTls(TSharedCounters& holders, size_t workerId)
        : ThresholdReciveTimeoutCounter(holders.ThresholdReciveTimeoutCounter, workerId)
        , ThresholdPassTimeoutCounter(holders.ThresholdPassTimeoutCounter, workerId)
        , ThresholdPassedTime(holders.ThresholdPassedTime, workerId)
    {}

    TSharedCounter ThresholdReciveTimeoutCounter;
    TSharedCounter ThresholdPassTimeoutCounter;
    TSharedHistogram ThresholdPassedTime;
};


MODULE_WITH_TLS_BASE(threshold, TModuleWithSubModule) {
public:
    TModule(const TModuleParams& mp)
        : TModuleBase(mp)
        , CountersHolders_(mp.Control->SharedStatsManager())
    {
        Config->ForEach(this);

        if (LoBytes_ > HiBytes_) {
            ythrow TConfigParseError() << "lo_bytes is greater than hi_bytes";
        }

        if (PassTimeout_ == TDuration::Zero() && HiBytes_ != 0) {
            ythrow TConfigParseError() << " hi_bytes = " << HiBytes_ << " is meamingless with pass_timeout = 0";
        }

        if (!Submodule_) {
            ythrow TConfigParseError() << "no submodule configured";
        }
    }

private:
    START_PARSE {
        try {
            ON_KEY("pass_timeout", PassTimeout_) {
                return;
            }

            ON_KEY("recv_timeout", RecvTimeout_) {
                return;
            }

            ON_KEY("lo_bytes", LoBytes_) {
                return;
            }

            ON_KEY("hi_bytes", HiBytes_) {
                return;
            }
        } catch (const yexception& e) {
            ythrow TConfigParseError() << "error processing " << key << ": " << e.what();
        }

        if (key == "on_pass_timeout_failure") {
            TSubLoader(Copy(value->AsSubConfig())).Swap(OnPassTimeout_);
            return;
        }

        {
            Submodule_.Reset(Loader->MustLoad(key, Copy(value->AsSubConfig())).Release());
            return;
        }
    } END_PARSE

    THolder<TTls> DoInitTls(IWorkerCtl* process) override {
        return MakeHolder<TTls>(CountersHolders_, process->WorkerId());
    }

    bool Applicable(const TConnDescr& descr) const noexcept {
        if (Y_LIKELY(descr.Request)) {
            const auto& props = descr.Request->Props();
            return (props.ChunkedTransfer || props.ContentLength > 0) && !props.UpgradeRequested;
        }

        return false;
    }

    static TErrorOr<bool> LimitedTransfer(IIoInput* in, TChunkList& out, TInstant deadline,
                                                        size_t limit, size_t& bytesTransfered) noexcept
    {
        bool finished = false;
        bytesTransfered = 0;

        do {
            TChunkList lst;
            Y_PROPAGATE_ERROR(in->Recv(lst, deadline));
            bytesTransfered += lst.size();
            finished = bytesTransfered >= limit || lst.Empty();
            out.Append(std::move(lst));
        } while (!finished && Now() < deadline);

        return finished;
    }

    TErrorOr<bool> ProcessPassTimeout(const TConnDescr& descr, TTls& tls, TInstant startTime) const {
        TChunkList recvd;

        bool transferCompleted = false;
        size_t bytesTransfered = 0;

        Y_TRY(TError, error) {
            return LimitedTransfer(descr.Input, recvd, PassTimeout_.ToDeadLine(),
                                   HiBytes_, bytesTransfered).AssignTo(transferCompleted);
        } Y_CATCH {
            if (const auto* e = error.GetAs<TSystemError>()) {
                if (e->Status() != ETIMEDOUT) {
                    descr.ExtraAccessLog.SetSummary(GetHandle()->Name(), "client read error");
                    return error;
                }
            } else if (const auto* e = error.GetAs<TSslError>()) {
                descr.ExtraAccessLog << " bad_ssl_request";
                LOG_ERROR(TLOG_ERR, descr, "threshold: ssl read failed: " << e->what());
                descr.ExtraAccessLog.SetSummary(GetHandle()->Name(), "client read error");
                return error;
            } else {
                descr.ExtraAccessLog.SetSummary(GetHandle()->Name(), "client read error");
                return error;
            }
        }

        if (!transferCompleted && bytesTransfered < LoBytes_) {
            ++tls.ThresholdPassTimeoutCounter;
            descr.ExtraAccessLog << " pass timeout condition failed";
            if (!OnPassTimeout_) {
                descr.ExtraAccessLog.SetSummary(GetHandle()->Name(), "pass timeout condition failed");
                return Y_MAKE_ERROR(yexception{} << "threshold: pass condition failed");
            }
            descr.Input->UnRecv(std::move(recvd));
            tls.ThresholdPassedTime.AddValue((Now() - startTime).MicroSeconds());
            Y_DEFER {
                TAccessLogSummary *summary = descr.ExtraAccessLog.Summary();
                if (summary) {
                    summary->AnsweredModule = GetHandle()->Name() + " | " + summary->AnsweredModule;
                    summary->AnswerReason = "on_pass_timeout | " + summary->AnswerReason;
                }
            };
            Y_PROPAGATE_ERROR(OnPassTimeout_->Run(descr));
            return false;
        } else {
            descr.Input->UnRecv(std::move(recvd));
            return true;
        }
    }

    TError ProcessRecvTimeout(const TConnDescr& descr, TTls& tls, TInstant startTime) const {
        TCont* mainlineCont = descr.Process().Executor().Running();
        TRecvTimeoutInput subInput(descr.Input, RecvTimeout_, mainlineCont);

        Y_TRY(TError, error) {
            tls.ThresholdPassedTime.AddValue((Now() - startTime).MicroSeconds());
            return Submodule_->Run(descr.CopyIn(subInput));
        } Y_CATCH {
            if (subInput.Timeouted()) {
                descr.Properties->ConnStats.ClientError += 1;
                ++tls.ThresholdReciveTimeoutCounter;
                descr.ExtraAccessLog << " recv_timeout condition failed";
                descr.ExtraAccessLog.SetSummary(GetHandle()->Name(), "recv_timeout condition failed");
            }
            return error;
        }
        return {};
    }

    TError DoRun(const TConnDescr& descr, TTls& tls) const noexcept override {
        if (Applicable(descr)) {
            auto startTime = Now();
            if (PassTimeout_ != TDuration::Zero()) {
                bool processed = false;
                Y_TRY(TError, err) {
                    return ProcessPassTimeout(descr, tls, startTime).AssignTo(processed);
                } Y_CATCH {
                    descr.Properties->ConnStats.ClientError += 1;
                    return err;
                }
                if (!processed) {
                    return {};
                }
            }
            if (RecvTimeout_ != TDuration::Zero()) {
                Y_PROPAGATE_ERROR(ProcessRecvTimeout(descr, tls, startTime));
            } else {
                tls.ThresholdPassedTime.AddValue((Now() - startTime).MicroSeconds());
                Y_PROPAGATE_ERROR(Submodule_->Run(descr));
            }
        } else {
            Y_PROPAGATE_ERROR(Submodule_->Run(descr));
        }
        return {};
    }

    bool DoExtraAccessLog() const noexcept override {
        return true;
    }

private:
    TDuration PassTimeout_;
    TDuration RecvTimeout_;
    size_t LoBytes_ = 0;
    size_t HiBytes_ = 0;
    THolder<IModule> OnPassTimeout_;
    TSharedCounters CountersHolders_;
};

IModuleHandle* NModThreshold::Handle() {
    return TModule::Handle();
}
