#include "multithreading.h"

#include <balancer/kernel/process/children/worker_process.h>
#include <balancer/kernel/process/children/pinger_process.h>
#include <balancer/kernel/process/children/updater_process.h>
#include <balancer/kernel/process/children/quota_syncer_process.h>
#include <balancer/kernel/process/children/storage_gc_process.h>
#include <balancer/kernel/process/children/connection_manager_process.h>

#include <util/system/thread.h>

namespace {
    TDuration CalculateStartDelay(size_t iteration, TDuration workerStartDelay) {
        if (!workerStartDelay) {
            return TDuration::Seconds(0);
        }
        const auto delay = Max(workerStartDelay.MicroSeconds(), 10000ul);
        const auto bindOffset = iteration * delay;
        const auto jitter = iteration ? RandomNumber(1 + delay / 4) : 0;
        return TDuration::FromValue(bindOffset + jitter);
    }

    void SetThreadName(const TString& suffix) {
        TString name = TThread::CurrentThreadName();
        if (name.StartsWith("balancer")) {
            name = TString("b") + TStringBuf(name).Skip(8);
        }

        TThread::SetCurrentThreadName((TStringBuilder() << name << "-" << suffix).c_str());
    }

    TString CreateLogMessageFromStatuses(const TVector<bool>& statuses) {
        TString res("Threads with ids");
        for (size_t i = 1ul; i < statuses.size(); ++i) {
            if (!statuses[i]) {
                res += " ";
                res += ToString(i);
            }
        }
        res += " didn't update their states";
        return res;
    }
}

namespace NSrvKernel {
    using namespace NProcessCore;

    class TChild : public ISimpleThread {
    public:
        TChild(TMainTask& mainTask, TW2WChannel<TM2CMessage>& m2cChannel, TW2WChannel<TC2MMessage>& c2mChannel,
               TChildProcessType type, size_t workerId, TDuration startDelay)
            : MainTask_(mainTask)
            , M2CChannel_(m2cChannel)
            , C2MChannel_(c2mChannel)
            , Type_(type)
            , WorkerId_(workerId)
            , StartDelay_(startDelay)
        {}

        void* ThreadProc() override {
            TString threadName(::ToString(Type_)[0]);
            SetThreadName(threadName);

            Sleep(StartDelay_);

            Y_DEFER {
                AtomicSet(Finished_, 1);
            };

            THolder<TBaseProcess> process;

            try {
                switch (Type_) {
                    case TChildProcessType::Default:
                        process = MakeHolder<TWorkerProcess>(MainTask_, MainTask_.MainOptions(), WorkerId_, M2CChannel_, C2MChannel_);
                        break;
                    case TChildProcessType::Pinger:
                        process = MakeHolder<TPingerProcess>(MainTask_, MainTask_.MainOptions(), M2CChannel_);
                        break;
                    case TChildProcessType::Updater:
                        process = MakeHolder<TUpdaterProcess>(MainTask_, MainTask_.MainOptions(), M2CChannel_);
                        break;
                    case TChildProcessType::QuotaSyncer:
                        process = MakeHolder<TQuotaSyncerProcess>(MainTask_, MainTask_.MainOptions(), M2CChannel_);
                        break;
                    case TChildProcessType::StorageGC:
                        process = MakeHolder<TStorageGCProcess>(MainTask_, MainTask_.MainOptions(), M2CChannel_);
                        break;
                    case TChildProcessType::ConnectionManager:
                        process = MakeHolder<TConnectionManagerProcess>(MainTask_, MainTask_.MainOptions(), M2CChannel_,
                                MainTask_.ConnectionManager_.Get());
                        break;
                }

                process->Execute();
            } catch (...) {
                MainTask_.ShutdownWithError(CurrentExceptionMessage());
            }

            return nullptr;
        }

        bool Finished() const noexcept {
            return AtomicGet(Finished_);
        }

    private:
        TMainTask& MainTask_;
        TW2WChannel<TM2CMessage>& M2CChannel_;
        TW2WChannel<TC2MMessage>& C2MChannel_;
        TChildProcessType Type_ = TChildProcessType::Default;
        size_t WorkerId_ = 0;
        TDuration StartDelay_;
        TAtomic Finished_ = 0;
    };

    TChildrenManager::TChildrenManager(TContExecutor& masterExecutor, TMainTask& mainTask,
                                       TW2WChannel<TC2MMessage>& masterChannel, TChildrenManagerOpts opts,
                                       TChildProcessMask mask)
        : Executor_(masterExecutor)
        , MainTask_(mainTask)
        , Log_(mainTask.Log_)
        , MasterChannel_(masterChannel)
        , Opts_(std::move(opts))
        , Stats_(mainTask)
        , Children_(Reserve(mainTask.GetCountOfChildren()))
        , ChildrenChannels_(Reserve(mainTask.GetCountOfChildren()))
    {
        for (const auto& type : GetEnumAllValues<TChildProcessType>()) {
            if (mask.HasFlags(type)) {
                if (type == TChildProcessType::Default) {
                    for (size_t i = 0ul; i < Opts_.WorkersCount; ++i) {
                        ChildrenChannels_.push_back({MakeHolder<TW2WChannel<TM2CMessage>>(32u), type});
                    }
                } else {
                    ChildrenChannels_.push_back({MakeHolder<TW2WChannel<TM2CMessage>>(32u), type});
                }
            }
        }
    }

    TChildrenManager::~TChildrenManager() = default;

    void TChildrenManager::Start() {
        SpawningCoroutine_ = TCoroutine("children manager", &Executor_, &TChildrenManager::Run, this);
        SpawningCoroutine_.Join();
    }

    void TChildrenManager::Shutdown(const TGracefulShutdownOpts& opts) {
        const TInstant fullDeadline = Now() + opts.CoolDown + opts.Timeout + opts.CloseTimeout + TDuration::Seconds(10);
        ShutDown_ = true;
        Log_ << "Master is broadcasting shutdown message (cooldown: " << opts.CoolDown << ", timeout: "
             << opts.Timeout << ", close timeout: " << opts.CloseTimeout << ", skip block requests: "
             << opts.SkipBlockRequests << ")" << Endl;

        SendMessage(TShutDown{
            .CoolDown = opts.CoolDown,
            .Timeout = opts.Timeout,
            .CloseTimeout = opts.CloseTimeout,
            .SkipBlockRequests = opts.SkipBlockRequests,
        }, TDuration::Max());
        OnShutdown(fullDeadline);
    }

    void TChildrenManager::ResetDnsCache() {
        SendMessage(TResetDnsCache(), TDuration::Max());
    }

    void TChildrenManager::CallEvent(TEventData& eventData, bool jsonOut) {
        auto outputChannel = MakeAtomicShared<TW2WChannel<TString>>(ChildrenChannels_.size());
        TEvent event{TString(eventData.Event()), jsonOut, outputChannel};
        SendMessage(
            std::move(event),
            TDuration::Max(),
            [&eventData, jsonOut, &outputChannel](size_t messagesSent, TDuration timeout, TCont* runningCont, TThreadLocalEventWaker* wakerPtr){
                IOutputStream& out = eventData.RawOut();
                if (jsonOut) {
                    out << "[";
                }

                bool first = true;
                size_t responsesReceived = 0ul;
                while (responsesReceived != messagesSent) {
                    TString msg;
                    auto status = outputChannel->Receive(msg, timeout.ToDeadLine(), runningCont, wakerPtr);
                    if (status != EChannelStatus::Success) {
                        break;
                    }
                    ++responsesReceived;
                    if (jsonOut && !first && !msg.empty()) {
                        out << ",";
                    }

                    if (!msg.empty()) {
                        first = false;
                        out << msg;
                    }
                }

                if (jsonOut) {
                    out << "]";
                }
            }
        );
    }

    void TChildrenManager::ListEventHandlers(TEventData& eventData) {
        auto outputChannel = MakeAtomicShared<TW2WChannel<TVector<TString>>>(ChildrenChannels_.size());
        TListEvents listEvents{outputChannel};
        SendMessage(
            std::move(listEvents),
            TDuration::Max(),
            [&eventData, &outputChannel](size_t messagesSent, TDuration timeout, TCont* runningCont, TThreadLocalEventWaker* wakerPtr){
                TSet<TString> uniqueEvents;
                size_t responsesReceived = 0ul;
                while (responsesReceived != messagesSent) {
                    TVector<TString> eventNames;
                    auto status = outputChannel->Receive(eventNames, timeout.ToDeadLine(), runningCont, wakerPtr);
                    if (status != EChannelStatus::Success) {
                        break;
                    }
                    ++responsesReceived;
                    for (auto& eventName: eventNames) {
                        uniqueEvents.emplace(std::move(eventName));
                    }
                }

                eventData.Out().OpenArray();
                for (auto& eventName: uniqueEvents) {
                    eventData.Out().Write(eventName);
                }
                eventData.Out().CloseArray();
                eventData.Out().Flush();
            }
        );
    }

    void TChildrenManager::SpawnChildren() {
        Log_ << "Spawning " << ChildrenChannels_.size() << " children" << Endl;
        for (size_t i = 0; i < ChildrenChannels_.size() && !ShutDown_; ++i) {
            TChildProcessType type = ChildrenChannels_[i].second;
            const size_t id = i + 1; // 0 is master
            const TDuration startDelay = type == TChildProcessType::Default ? CalculateStartDelay(i, Opts_.WorkerStartDelay) : TDuration::Zero();
            Children_.emplace_back(MakeHolder<TChild>(MainTask_, *ChildrenChannels_[i].first, MasterChannel_, type, id, startDelay));
            Children_.back()->Start();

            Log_ << "Spawned thread child with workerId " << id << " and type " << ::ToString(type) <<  Endl;
        }

        Stats_.SetAliveChildren(Opts_.WorkersCount);
    }

    void TChildrenManager::OnShutdown(TInstant deadline) {
        ThreadsCheckerTask_.Cancel();
        JoinThreads(deadline);
    }

    void TChildrenManager::JoinThreads(TInstant deadline) {
        auto* cont = Executor_.Running();

        bool deadlineEx = false;

        while (cont->SleepT(TDuration::MilliSeconds(10)) != ECANCELED) {
            bool finished = std::all_of(Children_.begin(), Children_.end(), [](const auto& thread) {
                return thread->Finished();
            });

            if (finished) {
                break;
            }

            deadlineEx = deadline < Now();
            if (deadlineEx) {
                break;
            }
        }

        if (deadlineEx) {
            Y_FAIL("worker threads not finished until deadline");
        }

        for (auto& i : Children_) {
            i->Join();
        }
    }

    void TChildrenManager::SendMessage(TM2CMessage message, TDuration timeout, std::function<void(size_t, TDuration, TCont*, TThreadLocalEventWaker*)> postProcess) {
        TCont* runningCont = Executor_.Running();
        auto waker = ThreadLocalEventWaker(runningCont);
        auto* wakerPtr = waker.Get();
        size_t sendCount = 0ul;
        for (auto& ch : ChildrenChannels_) {
            auto status = ch.first->Send(message, timeout.ToDeadLine(), runningCont, wakerPtr);
            if (status != EChannelStatus::Success) {
                Log_ << "can't send message" << Endl;
            } else {
                ++sendCount;
            }
        }

        if (postProcess) {
            postProcess(sendCount, timeout, runningCont, wakerPtr);
        }
    }

    void TChildrenManager::Run() {
        TCont* cont = Executor_.Running();
        TUniversalGuard guard(MainTask_.TreesMutex);
        if (!guard.Lock()) {
            return;
        }

        SpawnChildren();

        const auto delay = Max(Opts_.WorkerStartDelay.MicroSeconds(), 10000ul) * (Opts_.WorkersCount ? Opts_.WorkersCount - 1 : 0);
        cont->SleepT(TDuration::FromValue(delay));

        // wait start of all children
        while (!cont->Cancelled() && !MainTask_.Failed_) {
            if (static_cast<size_t>(AtomicGet(MainTask_.LiveWorkersCounter_)) == Opts_.WorkersCount) {
                break;
            }
            cont->SleepT(TDuration::MilliSeconds(10));
        }

        if (cont->Cancelled()) {
            return;
        }

        if (Opts_.Watchdog) {
            ThreadsCheckerTask_ = TCoroutine("master_threads_checker", &Executor_, &TChildrenManager::ThreadsChecker, this);
        }
    }

    void TChildrenManager::ThreadsChecker() {
        TCont* cont = Executor_.Running();
        bool freeze = false;
        bool threadsAreFine = false;
        size_t attempts = 10 + RandomNumber(10u);
        const size_t threadsCount = Opts_.WorkersCount; // Children_.size(); TODO: return watchdog for special children
        TVector<bool> statuses;

        while (!cont->Cancelled()) {
            if (cont->SleepT(TDuration::Seconds(5)) == ECANCELED) {
                return;
            }

            MainTask_.ProcessStatOwner_->GetThreadsStatuses(statuses);
            ui8 summ = Accumulate(statuses.begin(), statuses.end(), 0u);
            threadsAreFine = summ == threadsCount;

            if (!threadsAreFine && freeze) {
                // Previosly was detected that one of the threads didn't update his state and now the same thing is happening
                if (!--attempts) {
                    Cerr << "Threads freeze detecting, master is about to abort" << Endl;
                    Y_FAIL();
                }
                Log_ << CreateLogMessageFromStatuses(statuses) << Endl;
            } else if (!threadsAreFine) {
                // One of the threads didn't update its state
                Stats_.SetThreadFreezing(true);
                freeze = true;
                Log_ << CreateLogMessageFromStatuses(statuses) << Endl;
            } else if (freeze) {
                // Previosly was detected that one of the threads didn't update his state and now all is OK
                Stats_.SetThreadFreezing(false);
                freeze = false;
                attempts = 10 + RandomNumber(10u);
            }
        }
    }
}
