#include "actors.h"

#include <solomon/tools/data-comparison/lib/util/wait_context.h>

#include <optional>
#include <algorithm>
#include <deque>
#include <unordered_map>

class TStrand : public std::enable_shared_from_this<TStrand> {
public:
    using TTask = std::function<void()>;

    TStrand(IActorEnginePtr executor)
        : Executor_(std::move(executor)) {
        Blocked_.store(0);
    }

    void SafeAddFunc(TTask&& task) {
        Queue_.Enqueue(std::move(task));

        if (!TryLock()) {
            return;
        }

        Executor_->SafeAddFunc(MakeWrappedTask());
    }

private:
    bool TryLock() {
        bool blocked = Blocked_.exchange(true);
        if (blocked) {
            return false;
        }

        return NeedContinue();
    }

    bool NeedContinue() {
        if (!Queue_.IsEmpty()) {
            return true;
        }

        ReleaseLock();
        if (!Queue_.IsEmpty() && !Blocked_.exchange(true)) {
            return true;
        }

        return false;
    }

    void ReleaseLock() {
        Blocked_.store(false);
    }

    TTask MakeWrappedTask() {
        return [this, guard = shared_from_this()]() {

            struct TApplier {
                void emplace_back(TTask&& task) {
                    task();
                }
            };
            TApplier applier;

            Queue_.DequeueAll(&applier);

            if (!NeedContinue()) {
                return;
            }

            Executor_->SafeAddFunc(MakeWrappedTask());
        };
    }

private:
    IActorEnginePtr Executor_;
    std::atomic<ui8> Blocked_;

    TLockFreeQueue<TTask> Queue_;
};

using TStrandPtr = std::shared_ptr<TStrand>;

IActorEnginePtr CreateActorEngine(size_t threads) {
    TThreadPoolParams params;
    params.Catching_ = false;
    return IActorEnginePtr(CreateThreadPool(threads, 0, params).Release());
}

TStrandPtr MakeStrand(IActorEnginePtr executor) {
    return std::make_shared<TStrand>(executor);
}

template <class T>
class TBatchVector {
public:
    TBatchVector(size_t batchSize)
        : BatchSize_(batchSize)
    {
        Y_ENSURE(batchSize > 0);
    }

    TVector<T>& Front() {
        Y_VERIFY(!Deque_.empty());
        return Deque_.front();
    }

    void PopFront() {
        Deque_.pop_front();
    }

    void PushBack(T value) {
        if (Deque_.empty() || Deque_.back().size() == BatchSize_) {
            Deque_.emplace_back(TVector<T>{});
        }

        Deque_.back().emplace_back(std::move(value));
    }

    bool Empty() {
        return Deque_.empty();
    }

private:
    std::deque<TVector<T>> Deque_;
    size_t BatchSize_;
};

class TActorDestructor {
protected:
    TActorDestructor()
        : Lock_(MakeWaitContext())
    {}

    void SyncWaitTasks() {
        auto f = Lock_->GetFuture();
        Lock_ = MakeWaitContext();
        f.Wait();
    }

    void SyncClose() {
        if (IsClosed()) {
            return;
        }
        auto f = Lock_->GetFuture();
        Lock_.Reset(nullptr);
        f.Wait();
    }

    TWaitContextPtr Lock() {
        return Lock_;
    }

    bool IsClosed() const {
        return !Lock_;
    }

    ~TActorDestructor() {
        SyncClose();
    }

private:
    TWaitContextPtr Lock_;
};

class TResolveDcActor: public IResolveDcActor, public TActorDestructor {
public:
    TResolveDcActor(
            IActorEnginePtr engine,
            IResolveDcConsumer& consumer,
            THostResolverFactory hFactory,
            TGroupResolverFactory gFactory)
            : Engine_(MakeStrand(std::move(engine)))
            , Consumer_(consumer)
            , HFactory_(std::move(hFactory))
            , GFactory_(std::move(gFactory))
    {
    }

    void ResolveHostDc(TString host, TActorMetaData meta) override {
        Engine_->SafeAddFunc([this, host = std::move(host), meta = std::move(meta), lock = Lock()]() mutable {
            DoResolveHostDc(std::move(host), std::move(meta), std::move(lock));
        });
    }

    void ResolveGroupDc(TString group, TActorMetaData meta) override {
        Engine_->SafeAddFunc([this, group, meta = std::move(meta), lock = Lock()]() mutable {
            DoResolveGroupDc(std::move(group), std::move(meta));
        });
    }

    void Close() override {
        TActorDestructor::SyncClose();
    }

private:
    struct TValue {
        std::optional<TDcResolveResult> ResultDc;
        TVector<TActorMetaData> Subscribers;
        size_t Retries{0};
    };

private:
    void InitHostResolver() {
        if (!HostResolver_) {
            HostResolver_ = HFactory_();
        }
    }

    void InitGroupResolver() {
        if (!GroupResolver_) {
            GroupResolver_ = GFactory_();
        }
    }

    void DoResolveHostDc(TString host, TActorMetaData&& meta, TWaitContextPtr lock) {
        auto [it, s] = DataTable_.emplace(
                std::make_pair(host, TValue{std::nullopt, {}})
        );

        TValue& value = it->second;

        if (!s && !value.ResultDc.has_value()) {
            value.Subscribers.emplace_back(std::move(meta));
            return;
        }

        DoResolveHostDc(std::move(host), value, std::move(meta), std::move(lock));
    }

    void DoResolveHostDc(TString host, TValue& value, TActorMetaData&& meta, TWaitContextPtr lock) {
        if (value.ResultDc.has_value()) {
            DoResolveHostDc(std::move(host), value, std::move(meta));
            return;
        }

        InitHostResolver();
        HostResolver_->ResolveDc(host).Apply(
                [this, host = std::move(host), meta = std::move(meta), &value, lock = std::move(lock)](auto f) mutable
                {
                    if (f.HasException()) {
                        f.TryRethrow();
                    }

                    auto dc = f.ExtractValue();

                    Engine_->SafeAddFunc([this, dc, host = std::move(host), meta = std::move(meta), &value, lock = std::move(lock)]() mutable {
                        SetHostDc(host, value, dc, std::move(meta), std::move(lock));
                    });
                });
    }

    void SetHostDc(TString host, TValue& value, TDcResolveResult dc, TActorMetaData&& meta, TWaitContextPtr lock) {
        if (dc.Dc == NSolomon::EDc::UNKNOWN && dc.Status == EDcResolveStatus::TIMEOUT && value.Retries < MAX_RETRIES) {
            ++value.Retries;
            DoResolveHostDc(std::move(host), value, std::move(meta), std::move(lock));
            return;
        }

        value.ResultDc = dc;
        DoResolveHostDc(std::move(host), value, std::move(meta));
    }

    void DoResolveHostDc(TString host, TValue& value, TActorMetaData&& meta) {
        Consumer_.OnResolveHostDc(value.ResultDc.value(), host, std::move(meta));
        for (auto&& sub: value.Subscribers) {
            Consumer_.OnResolveHostDc(value.ResultDc.value(), host, std::move(sub));
        }

        value.Subscribers.clear();
    }

    void DoResolveGroupDc(const TString& group, TActorMetaData&& meta) {
        InitGroupResolver();

        Consumer_.OnResolveGroupDc({GroupResolver_->ResolveDc(group), EDcResolveStatus::OK}, group, std::move(meta));
    }

private:
    static constexpr size_t MAX_RETRIES = 5;

    TStrandPtr Engine_;
    IResolveDcConsumer& Consumer_;

    IHostResolverPtr HostResolver_;
    IGroupResolverPtr GroupResolver_;

    THostResolverFactory HFactory_;
    TGroupResolverFactory GFactory_;

    std::unordered_map<TString, TValue, THash<TString>> DataTable_;
};

IResolveDcActorPtr CreateResolveDcActor(
        IActorEnginePtr engine,
        IResolveDcConsumer& consumer,
        THostResolverFactory hFactory,
        TGroupResolverFactory gFactory)
{
    return MakeHolder<TResolveDcActor>(std::move(engine), consumer, std::move(hFactory), std::move(gFactory));
}

IResolveDcActorPtr CreateResolveDcActor(IActorEnginePtr engine, IResolveDcConsumer& consumer, THostResolverFactory hFactory) {
    return MakeHolder<TResolveDcActor>(std::move(engine), consumer, std::move(hFactory), []() -> IGroupResolverPtr {
        ythrow yexception() << "group dc resolver factory is not defined";
    });
}

IResolveDcActorPtr CreateResolveDcActor(IActorEnginePtr engine, IResolveDcConsumer& consumer, TGroupResolverFactory gFactory) {
    return MakeHolder<TResolveDcActor>(std::move(engine), consumer, []() -> IHostResolverPtr {
        ythrow yexception() << "host dc resolver factory is not defined";
    }, std::move(gFactory));
}

class IStockpileQueryActorShard: public TActorDestructor {
protected:
    IStockpileQueryActorShard(IActorEnginePtr engine, TStockpileClientFactory factory)
    : Engine_(MakeStrand(std::move(engine)))
    , Factory_(std::move(factory))
    {
    }

    virtual ~IStockpileQueryActorShard() = default;

    void Close() {
        Closed_ = true;

        Engine_->SafeAddFunc([this, lock = Lock()]() {
            for (auto& [key, value]: DataTable_) {
                QueryImpl(value, lock);
            }
        });

        TActorDestructor::SyncClose();
    }

    void Query(TStockpileIds id, TActorMetaData meta) {
        Engine_->SafeAddFunc([this, id, meta = std::move(meta), lock = Lock()]() mutable {
            QueryImpl(id, std::move(meta), std::move(lock));
        });
    }

protected:
    static constexpr size_t BATCH_SIZE = 5000u;

    struct TItem {
        ui64 LocalId;
        TActorMetaData Meta;
    };

    struct TValue {
        TBatchVector<TItem> Queue;
        ui32 ShardId;
        bool Locked{false};
    };

protected:
    void QueryEnd(TValue& value, TWaitContextPtr lock) {
        Engine_->SafeAddFunc([this, &value, lock = std::move(lock)]() {
            OnQueryEnd(value, std::move(lock));
        });
    }

    template <class TContainer, class TLocalIdProvider, class TLess, class TEx, class TEq>
    void Merge(
            TVector<TItem>& items,
            TContainer& response,
            TLocalIdProvider provider,
            TLess less,
            TEq eq,
            TEx ex)
    {
        auto lhsIt = items.begin();
        auto rhsIt = response.begin();

        while (lhsIt < items.end() && rhsIt < response.end()) {
            auto firstId = lhsIt->LocalId;
            auto secondId = provider(*rhsIt);;

            if (firstId == secondId) {
                eq(*lhsIt, *rhsIt);
                ++lhsIt;
                ++rhsIt;
            } else if (firstId < secondId) {
                less(*lhsIt);
                ++lhsIt;
            } else {
                ex(*rhsIt);
                ++rhsIt;
            }
        }

        while (lhsIt < items.end()) {
            less(*lhsIt);
            ++lhsIt;
        }

        while (rhsIt < response.end()) {
            ex(*rhsIt);
            ++rhsIt;
        }
    }

protected:
    IStockpileClientPtr Client_;

private:
    void InitClient() {
        if (!Client_) {
            Client_ = Factory_();
        }
    }

    void QueryImpl(TStockpileIds id, TActorMetaData&& meta, TWaitContextPtr lock) {
        auto [it, _] = DataTable_.emplace(
                std::make_pair(id.ShardId, TValue{TBatchVector<TItem>(BATCH_SIZE), id.ShardId, false})
        );
        auto& value = it->second;
        value.Queue.PushBack(TItem{id.LocalId, std::move(meta)});

        QueryImpl(value, std::move(lock));
    }

    void QueryImpl(TValue& value, TWaitContextPtr lock) {
        if (value.Locked || value.Queue.Empty()) {
            return;
        }

        if (!Closed_ && value.Queue.Front().size() < BATCH_SIZE) {
            return;
        }

        value.Locked = true;

        TVector<TItem> itemsCopy = std::move(value.Queue.Front());
        value.Queue.PopFront();
        TVector<ui64> localIds;
        localIds.reserve(itemsCopy.size());
        for (const auto& item: itemsCopy) {
            localIds.emplace_back(item.LocalId);
        }

        Y_VERIFY(localIds.size() <= BATCH_SIZE);

        InitClient();
        ClientQuery(std::move(itemsCopy), std::move(localIds), value, std::move(lock));
    }

    virtual void ClientQuery(TVector<TItem>&& items, TVector<ui64>&& localIds, TValue& value, TWaitContextPtr lock) = 0;

    void OnQueryEnd(TValue& value, TWaitContextPtr lock) {
        value.Locked = false;

        QueryImpl(value, std::move(lock));
    }

private:
    TStrandPtr Engine_;

    TStockpileClientFactory Factory_;

    std::unordered_map<ui32, TValue> DataTable_;

    bool Closed_{false};
};

class TReadMetricsMetaActorShard: public IReadMetricsMetaActor, public IStockpileQueryActorShard {
public:
    using TBase = IStockpileQueryActorShard;

    TReadMetricsMetaActorShard(IActorEnginePtr engine, IReadMetricsMetaConsumer& consumer, TStockpileClientFactory factory)
        : TBase(std::move(engine), std::move(factory))
        , Consumer_(consumer)
    {}

    void Close() override {
        TBase::Close();
    }

    void ReadMetricMeta(TStockpileIds id, TActorMetaData meta) override {
        Query(id, std::move(meta));
    }

private:
    void ClientQuery(TVector<TItem>&& items, TVector<ui64>&& localIds, TValue& value, TWaitContextPtr lock) override {
        Client_->ReadMetricsMeta(value.ShardId, localIds).Subscribe(
                [this, shardId = value.ShardId, itemsCopy = std::move(items), &value, lock = std::move(lock)] (auto f) mutable
        {
            TReadMetricsMetaResponseOrError valueOrError = f.ExtractValue();

            if (valueOrError.Fail()) {
                for (auto&& item: itemsCopy) {
                    Consumer_.OnError(valueOrError.Error(), {shardId, item.LocalId}, std::move(item.Meta));
                }

                return;
            }

            std::sort(itemsCopy.begin(), itemsCopy.end(), [](const auto& lhs, const auto& rhs) {
                return lhs.LocalId < rhs.LocalId;
            });

            auto result = std::move(valueOrError.Value());
            std::sort(result.begin(), result.end(), [](const auto& lhs, const auto& rhs) {
               return lhs.LocalId < rhs.LocalId;
            });

            if(result.size() != itemsCopy.size()) {
                Merge(itemsCopy, result, shardId);
            } else {
                for (size_t i = 0; i < result.size(); ++i) {
                    Y_ENSURE(result[i].LocalId == itemsCopy[i].LocalId);
                    Consumer_.OnMetricMeta(TStockpileIds{shardId, result[i].LocalId}, result[i].LastTsMillis,
                                           std::move(itemsCopy[i].Meta));
                }
            }

            QueryEnd(value, std::move(lock));
        });
    }

    void Merge(TVector<TItem>& items, TReadMetricsMetaResponse& response, ui32 shardId) {
        auto provider = [](const TLastUpdate& lastUpdate) {
            return lastUpdate.LocalId;
        };
        auto less = [this, shardId](TItem& item) {
            Consumer_.OnError("failed to get metric meta", TStockpileIds{shardId, item.LocalId}, std::move(item.Meta));
        };
        auto eq = [this, shardId](TItem& item, TLastUpdate& lastUpdate) {
            Consumer_.OnMetricMeta(TStockpileIds{shardId, item.LocalId}, lastUpdate.LastTsMillis, std::move(item.Meta));
        };
        auto ex = [](TLastUpdate&) {
            Y_FAIL("fatal error: some metrics is lost");
        };
        IStockpileQueryActorShard::Merge(items, response, provider, less, eq, ex);
    }

private:
    IReadMetricsMetaConsumer& Consumer_;
};

IReadMetricsMetaActorPtr CreateReadMetricsMetaActor(
        IActorEnginePtr engine,
        IReadMetricsMetaConsumer& consumer,
        TStockpileClientFactory factory)
{
    return MakeHolder<TReadMetricsMetaActorShard>(std::move(engine), consumer, std::move(factory));
}

class TDownloadMetricsActorShard: public IDownloadMetricsActor, public IStockpileQueryActorShard {
public:
    using TBase = IStockpileQueryActorShard;

    TDownloadMetricsActorShard(
            IActorEnginePtr engine,
            IDownloadMetricsConsumer& consumer,
            TStockpileClientFactory factory,
            ui64 fromMillis,
            ui64 toMillis)
            : TBase(std::move(engine), std::move(factory))
            , Consumer_(consumer)
            , FromMillis_(fromMillis)
            , ToMillis_(toMillis)
    {}

    void Close() override {
        TBase::Close();
    }

    void Download(TStockpileIds id, TActorMetaData meta) override {
        Query(id, std::move(meta));
    }

private:
    void ClientQuery(TVector<TItem>&& items, TVector<ui64>&& localIds, TValue& value, TWaitContextPtr lock) override {
        Client_->CompressedReadMany(value.ShardId, localIds, FromMillis_, ToMillis_).Subscribe(
                [this, shardId = value.ShardId, itemsCopy = std::move(items), &value, lock = std::move(lock)] (auto f) mutable
                {
                    TReadManyResponseOrError valueOrError = f.ExtractValue();

                    if (valueOrError.Fail()) {
                        for (auto&& item: itemsCopy) {
                            Consumer_.OnError(
                                    valueOrError.Error(),
                                    TStockpileIds{value.ShardId, item.LocalId},
                                    std::move(item.Meta));
                        }

                        return;
                    }

                    std::sort(itemsCopy.begin(), itemsCopy.end(), [](const auto& lhs, const auto& rhs) {
                        return lhs.LocalId < rhs.LocalId;
                    });

                    auto result = std::move(valueOrError.Value());
                    std::sort(result.begin(), result.end(), [](const auto& lhs, const auto& rhs) {
                        return lhs.Id.LocalId < rhs.Id.LocalId;
                    });

                    if(result.size() != itemsCopy.size()) {
                        Merge(itemsCopy, result, shardId);
                    } else {
                        for (size_t i = 0; i < result.size(); ++i) {
                            Y_ENSURE(result[i].Id.LocalId == itemsCopy[i].LocalId);
                            Consumer_.OnMetric(std::move(result[i]), std::move(itemsCopy[i].Meta));
                        }
                    }

                    QueryEnd(value, std::move(lock));
                });
    }

    void Merge(TVector<TItem>& items, TReadManyResponse& response, ui32 shardId) {
        auto provider = [](const TSeriesAndId& s) {
            return s.Id.LocalId;
        };
        auto less = [this, shardId](TItem& item) {
            Consumer_.OnError("failed to download metric", TStockpileIds{shardId, item.LocalId}, std::move(item.Meta));
        };
        auto eq = [this](TItem& item, TSeriesAndId& series) {
            Consumer_.OnMetric(std::move(series), std::move(item.Meta));
        };
        auto ex = [](TSeriesAndId&) {
            Y_FAIL("fatal error: some metrics is lost");
        };
        IStockpileQueryActorShard::Merge(items, response, provider, less, eq, ex);
    }

private:
    IDownloadMetricsConsumer& Consumer_;
    ui64 FromMillis_;
    ui64 ToMillis_;
};

IDownloadMetricsActorPtr CreateDownloadMetricsActorShard(
        IActorEnginePtr engine,
        IDownloadMetricsConsumer& consumer,
        TStockpileClientFactory factory,
        TInstant fromMillis,
        TInstant toMillis)
{
    return MakeHolder<TDownloadMetricsActorShard>(
            std::move(engine),
            consumer,
            std::move(factory),
            fromMillis.MilliSeconds(),
            toMillis.MilliSeconds());
}

class TDownloadMetricsActor: public IDownloadMetricsActor {
public:
    TDownloadMetricsActor(
            IActorEnginePtr engine,
            IDownloadMetricsConsumer& consumer,
            TStockpileClientFactory factory,
            TInstant fromMillis,
            TInstant toMillis)
        : Engine_(std::move(engine))
        , Consumer_(consumer)
        , Factory_(std::move(factory))
        , FromMillis_(fromMillis)
        , ToMillis_(toMillis)
    {
        for (size_t i = 0; i < SHARDS; ++i) {
            Shards_.emplace_back(nullptr);
        }
    }

    void Close() override {
        for (auto& actor: Shards_) {
            if (actor) {
                actor->Close();
            }
        }
    }

    void Download(TStockpileIds id, TActorMetaData meta) override {
        ui32 idx = id.ShardId % SHARDS;
        if (!Shards_[idx]) {
            Shards_[idx] = CreateDownloadMetricsActorShard(Engine_, Consumer_, Factory_, FromMillis_, ToMillis_);
        }

        Shards_[idx]->Download(id, std::move(meta));
    }

private:
    TVector<IDownloadMetricsActorPtr> Shards_;
    IActorEnginePtr Engine_;
    IDownloadMetricsConsumer& Consumer_;
    TStockpileClientFactory Factory_;
    TInstant FromMillis_;
    TInstant ToMillis_;

    static constexpr size_t SHARDS = 128;
};

IDownloadMetricsActorPtr CreateDownloadMetricsActor(
        IActorEnginePtr engine,
        IDownloadMetricsConsumer& consumer,
        TStockpileClientFactory factory,
        TInstant fromMillis,
        TInstant toMillis)
{
    return MakeHolder<TDownloadMetricsActor>(
            std::move(engine),
            consumer,
            std::move(factory),
            fromMillis,
            toMillis);
}
