#include "config_updater.h"

#include "diff.h"
#include "config_holder.h"

#include <solomon/services/fetcher/lib/cluster/cluster.h>
#include <solomon/services/fetcher/lib/shard_manager/events.h>

#include <solomon/libs/cpp/backoff/backoff.h>
#include <solomon/libs/cpp/backoff/jitter.h>
#include <solomon/libs/cpp/logging/logging.h>

#include <library/cpp/actors/core/actor_bootstrapped.h>
#include <library/cpp/actors/core/hfunc.h>
#include <library/cpp/actors/core/log.h>

#include <library/cpp/monlib/metrics/metric_registry.h>

#include <util/generic/algorithm.h>
#include <util/generic/hash.h>
#include <util/generic/hash_set.h>
#include <util/generic/scope.h>
#include <util/string/builder.h>
#include <util/string/split.h>
#include <util/system/env.h>
#include <util/system/hostname.h>

using namespace std::string_view_literals;

namespace NSolomon::NFetcher {
namespace {
    using namespace NActors;
    using namespace NMonitoring;
    using namespace NDb;
    using namespace NSolomon::NTableLoader;
    using NDb::NModel::IsWriteable;

    struct TEqualById {
        bool operator()(const TFetcherShard& lhs, const TFetcherShard& rhs) const {
            return lhs.Id() == rhs.Id();
        }
    };

    THashSet<TString> GetShardsFromEnv() {
        auto&& raw = GetEnv("SOLOMON_FETCHER_SHARDS");
        THashSet<TString> result;
        for (auto&& shard: StringSplitter(raw).Split(',').SkipEmpty()) {
            result.insert(TString{shard.Token()});
        }

        return result;
    }

    struct TShardMap: IShardLocator {
        struct TEntry {
            TEntry() = default;
            TEntry(TString id, TClusterNode loc)
                : Id{std::move(id)}
                , Location{std::move(loc)}
            {
            }

            bool IsValid() const {
                return !(Id.empty() || Location.IsUnknown());
            }

            TString Id;
            TClusterNode Location;
        };

        const TEntry* Find(ui32 numId) const {
            return ShardMap_.FindPtr(numId);
        }

        TEntry* Find(ui32 numId) {
            return ShardMap_.FindPtr(numId);
        }

        void Erase(ui32 numId) {
            ShardMap_.erase(numId);
        }

        void Emplace(ui32 numId, TString id = {}, TClusterNode loc = {}) {
            ShardMap_.emplace(
                std::piecewise_construct,
                std::forward_as_tuple(numId),
                std::forward_as_tuple(id, loc)
            );
        }

        TMaybeLocation Location(ui32 shardId) const override {
            if (auto* entry = Find(shardId)) {
                return entry->Location;
            }

            return {};
        }

        const THashMap<ui32, TEntry>& Values() const {
            return ShardMap_;
        }

    private:
        THashMap<ui32, TEntry> ShardMap_;
    };

    bool IsLocalShard(const NModel::TShardConfig& shard, IShardLocator& locator) {
        if (auto loc = locator.Location(shard.NumId)) {
            return IsLocal(*loc);
        }

        return false;
    }

    TConfigHolder::TShardFilter MakeFilter(EUpdaterMode mode, IShardLocator& locator) {
        using NDb::NModel::EShardState;
        switch (mode) {
            case EUpdaterMode::OnlyLocal:
                return [&] (const NModel::TShardConfig& shard) {
                    return IsWriteable(shard) && IsLocalShard(shard, locator);
                };

            case EUpdaterMode::All:
                return [] (const NModel::TShardConfig& shard) {
                    return IsWriteable(shard);
                };

            case EUpdaterMode::Env:
                return [] (const NModel::TShardConfig& shard) {
                    static THashSet<TString> shards = GetShardsFromEnv();
                    return IsWriteable(shard) && shards.contains(shard.Id);
                };

            case EUpdaterMode::None:
                return [] (auto&&) {
                    return false;
                };

            case EUpdaterMode::OnlyYasm:
                return [] (const NModel::TShardConfig& shard) {
                    return shard.ProjectId.StartsWith("yasm_");
                };
        }
    }

    class TConfigUpdaterActor: public TActorBootstrapped<TConfigUpdaterActor> {
    public:
        explicit TConfigUpdaterActor(const TConfigUpdaterConfig& config)
            : ServiceDao_{config.ServiceDao}
            , ClusterDao_{config.ClusterDao}
            , ShardDao_{config.ShardDao}
            , AgentDao_{config.AgentDao}
            , ProviderDao_{config.ProviderDao}
            , UpdateInterval_{config.Interval}
            , UpdateBackoff_{config.Interval, 4 * config.Interval}
            , MetricRegistry_{config.MetricRegistry}
            , Configs_{ShardMap_, MakeFilter(config.Mode, ShardMap_), config.EnableYasmPulling}
        {
            MetricRegistry_.LazyIntGauge({{"sensor", "configAgeSeconds"}}, [this] {
                if (auto updatedAt = UpdatedAt_.load(std::memory_order_relaxed)) {
                    auto age = TInstant::Now() - TInstant::FromValue(updatedAt);
                    return static_cast<i64>(age.Seconds());
                }
                return static_cast<i64>(0);
            });
            Counters_.ShardsAdded = MetricRegistry_.Counter({{"sensor", "shardsAdded"}});
            Counters_.ShardsRemoved = MetricRegistry_.Counter({{"sensor", "shardsRemoved"}});
            Counters_.ShardsUpdated = MetricRegistry_.Counter({{"sensor", "shardsUpdated"}});
        }

        void Bootstrap() {
            Become(&TThis::StateWork);

            auto configsPullerActor = Register(CreateConfigsPuller({
                    .ClusterDao = ClusterDao_,
                    .ServiceDao = ServiceDao_,
                    .ShardDao = ShardDao_,
                    .AgentDao = AgentDao_,
                    .ProviderDao = ProviderDao_,
                    .UpdateInterval = UpdateInterval_,
                    .MaxUpdateInterval = 4 * UpdateInterval_,
                    .Registry = MetricRegistry_,
            }).release());
            Send(configsPullerActor, new TConfigsPullerEvents::TSubscribe);

            // used backoff here to get jitted interval
            Schedule(UpdateBackoff_(), new TEvents::TEvWakeup);
        }

        STATEFN(StateWork) {
            switch (ev->GetTypeRewrite()) {
                hFunc(TConfigsPullerEvents::TConfigsResponse, OnConfigsResponse);
                hFunc(TEvLoadShardMapResponse, OnShardMapLoaded);
                hFunc(TEvents::TEvSubscribe, OnSubscribe);
                hFunc(TEvents::TEvUnsubscribe, OnUnsubscribe);
                hFunc(TEvReloadShardRequest, OnReloadRequest);
                cFunc(TEvents::TSystem::Wakeup, OnWakeup);
                cFunc(TEvents::TSystem::PoisonPill, PassAway);
            }
        }

    private:
        void OnShardMapLoaded(const TEvLoadShardMapResponse::TPtr& ev) {
            Y_DEFER {
                // reset backoff and schedule the next update
                UpdateBackoff_.Reset();
                auto delay = UpdateBackoff_();
                Schedule(delay, new TEvents::TEvWakeup);
                MON_INFO(ShardUpdater, "Schedule next shard map update after " << delay);
            };

            auto& result = ev->Get()->Result;
            if (!result.Success()) {
                MON_WARN(ShardUpdater, "Unable to load shard map: " << result.Error().Message());
                return;
            }

            UpdateShardMap(result.Extract());
        }

        void OnReloadRequest(const TEvReloadShardRequest::TPtr& ev) {
            const auto sender = ev->Sender;
            const auto numId = ev->Get()->ShardId.NumId();
            auto shard = Configs_.GetShard(numId);

            if (!shard) {
                Send(sender, new TEvNotFound);
                return;
            }

            ForEachSubscriber([=] (auto id) {
                Send(id, new TEvConfigChanged{
                    {},
                    {},
                    // for now, only simple shards are supported in this API
                    { {ev->Get()->ShardId, EFetcherShardType::Simple} }
                });
                Send(id, new TEvConfigChanged{{*shard}, {}, {}});
            });

            Send(sender, new TEvReloadShardResponse);
        }

        void OnSubscribe(const TEvents::TEvSubscribe::TPtr& ev) {
            Subscribers_.insert(ev->Sender);
            const auto id = ev->Sender;
            MON_INFO(ShardUpdater, id << " subscribed to shard updates");

            // send all known shards to the new subscriber
            Send(id, new TEvConfigChanged{Configs_.AllShards(), {}, {}});
            Send(id, new TEvProvidersChanged{Configs_.AllProviders(), {}, {}});
        }

        void OnUnsubscribe(const TEvents::TEvUnsubscribe::TPtr& ev) noexcept {
            const auto it = Subscribers_.find(ev->Sender);

            if (it != Subscribers_.end()) {
                Subscribers_.erase(it);
                MON_INFO(ShardUpdater, ev->Sender << " unsubscribed from shard updates");
            } else {
                MON_INFO(ShardUpdater, ev->Sender << " requested to unsubscribe, but not found");
            }
        }

        void OnWakeup() noexcept {
            MON_INFO(ShardUpdater, "Loading shard map");
            Send(MakeShardMapBuilderId(), new TEvLoadShardMap{{}});
        }

        void OnConfigsResponse(TConfigsPullerEvents::TConfigsResponse::TPtr& ev) {
            MON_INFO(ShardUpdater, "Configs updated");
            UpdatedAt_.store(TActivationContext::Now().GetValue(), std::memory_order_relaxed);

            auto configs = std::move(ev->Get()->Configs);
            auto oldShards = Configs_.AllShards();
            TConfigDiff diff = Configs_.UpdateConfig(configs);
            if (diff.IsEmpty()) {
                MON_INFO(ShardUpdater, "Config diff is empty");
                return;
            }

            MON_INFO(ShardUpdater, "Config diff: " << diff.ToString());

            // update the numid -> id, loc mapping
            auto& shardDiff = diff.ShardDiff;
            for (auto& [_, shard]: shardDiff.Added) {
                ShardMap_.Emplace(shard.NumId, shard.Id);
            }

            for (auto& [_, shard]: shardDiff.Removed) {
                ShardMap_.Erase(shard.NumId);
            }

            THashSet<TFetcherShard, THash<TFetcherShard>, TEqualById> oldShardSet;
            Copy(
                std::make_move_iterator(oldShards.begin()),
                std::make_move_iterator(oldShards.end()),
                std::inserter(oldShardSet, oldShardSet.end())
            );

            auto newShards = Configs_.AllShards();

            auto fetcherShardDiff = MakeDiff(oldShardSet, newShards, [] (auto&& lhs, auto&& rhs) {
                // don't compare locations here
                // XXX too precise?
                return lhs.IsEqual(rhs);
            });

            TVector<TInfoOfAShardToRemove> shardsToRemove(::Reserve(fetcherShardDiff.Removed.size()));
            for (const auto& shard: fetcherShardDiff.Removed) {
                shardsToRemove.emplace_back(shard.Id(), shard.Type());
            }

            ForEachSubscriber([=] (auto id) {
                Send(id, new TEvConfigChanged{
                    fetcherShardDiff.Added,
                    fetcherShardDiff.Changed,
                    shardsToRemove,
                });
            });

            // TODO: process removed clusters as well (SOLOMON-6594)
            TVector<IFetcherClusterPtr> changedClusters;

            if (!diff.ClusterDiff.Changed.empty() || !fetcherShardDiff.Changed.empty()) {
                for (auto&& [_, cluster]: diff.ClusterDiff.Changed) {
                    changedClusters.emplace_back(CreateCluster(cluster));
                }

                // Some of the clusters are constructed separately inside Agent shards --> update them as well.
                // Clusters inside fetcherShardDiff.Added will be correctly created on an initial shard creation.
                // TODO(ivanzhukov): log the diff
                for (const auto& fetcherShard: fetcherShardDiff.Changed) {
                    if (fetcherShard.Type() == EFetcherShardType::Agent) {
                        changedClusters.emplace_back(fetcherShard.Cluster());
                    }
                }
            }

            if (!changedClusters.empty()) {
                ForEachSubscriber([=] (auto id) {
                    Send(id, new TEvClustersChanged{changedClusters});
                });
            }

            if (auto& provDiff = diff.ProviderDiff; !provDiff.IsEmpty()) {
                TVector<TProviderConfigPtr> provAdded{::Reserve(provDiff.Added.size())};
                TVector<TProviderConfigPtr> provChanged{::Reserve(provDiff.Changed.size())};
                TVector<TProviderId> provRemoved{::Reserve(provDiff.Removed.size())};

                for (auto& [id, val]: provDiff.Added) {
                    provAdded.emplace_back(MakeAtomicShared<NDb::NModel::TProviderConfig>(val));
                }

                for (auto& [id, val]: provDiff.Changed) {
                    provChanged.emplace_back(MakeAtomicShared<NDb::NModel::TProviderConfig>(val));
                }

                for (auto& [id, val]: provDiff.Removed) {
                    provRemoved.emplace_back(id);
                }

                // TODO(ivanzhukov): send only to actors subscribed to this particular event
                ForEachSubscriber([&](auto id) {
                    Send(id, new TEvProvidersChanged{provAdded, provChanged, provRemoved});
                });
            }
        }

    private:
        void ForEachSubscriber(const std::function<void(TActorId)>& fn) {
            for (auto&& subscriber: Subscribers_) {
                fn(subscriber);
            }
        }

    private:
        void UpdateShardMap(const TVector<TShardAssignment>& assignments) {
            TVector<TFetcherShard> changedShards;
            TVector<TInfoOfAShardToRemove> removedShards;
            THashSet<ui32> receivedNumIds;

            for (auto&& assignment: assignments) {
                for (auto numId: assignment.Shards) {
                    receivedNumIds.emplace(numId);

                    auto* shard = ShardMap_.Find(numId);
                    if (shard == nullptr)  {
                        // we don't have a config for this shard yet, skip
                        continue;
                    }

                    const auto& oldLocation = shard->Location;
                    if (oldLocation == assignment.Location) {
                        continue;
                    }

                    shard->Location = assignment.Location;

                    auto fetcherShard = Configs_.GetShard(numId);
                    if (!fetcherShard) {
                        // we don't have a config for this shard yet, skip
                        continue;
                    }

                    changedShards.push_back(*fetcherShard);
                }
            }

            for (const auto& [numId, _]: ShardMap_.Values()) {
                if (numId == 0 || receivedNumIds.contains(numId)) {
                    // numId equal to 0 is illegal, but "special" shards (yasm agent or service provider shards)
                    //  have it. Logic for these shards is different, therefore we do not process them in this section
                    continue;
                }

                const auto fetcherShard = Configs_.GetShard(numId);
                if (!fetcherShard) {
                    // so we have this shard in ShardMap_, but we don't have any config -- data is inconsistent
                    auto shardId = TShardId{"", numId}; // only numId should be used for searching in a hashtable
                    removedShards.emplace_back(shardId, EFetcherShardType::Simple);

                    if (auto* shard = ShardMap_.Find(numId); shard) {
                        shard->Location = {};
                    }
                } else {
                    // this should not happen, since we've checked that numId != 0, but just to be safe
                    if (fetcherShard->Type() != EFetcherShardType::Simple) {
                        continue;
                    }

                    removedShards.emplace_back(fetcherShard->Id(), fetcherShard->Type());
                    if (auto* shard = ShardMap_.Find(numId); shard) {
                        shard->Location = {};
                    }
                }
            }

            if (changedShards.empty() && removedShards.empty()) {
                return;
            }

            ForEachSubscriber([=] (auto id) {
                Send(id, new TEvConfigChanged{{}, changedShards, removedShards});
            });
        }

    private:
        IServiceConfigDaoPtr ServiceDao_;
        IClusterConfigDaoPtr ClusterDao_;
        IShardConfigDaoPtr ShardDao_;
        IAgentConfigDaoPtr AgentDao_;
        IProviderConfigDaoPtr ProviderDao_;
        TDuration UpdateInterval_;
        TLinearBackoff<THalfJitter> UpdateBackoff_;
        std::atomic<ui64> UpdatedAt_{0};

        TMetricRegistry& MetricRegistry_;
        struct {
            TCounter* ShardsAdded{nullptr};
            TCounter* ShardsRemoved{nullptr};
            TCounter* ShardsUpdated{nullptr};
        } Counters_;

        TConfigHolder Configs_;
        TShardMap ShardMap_;

        THashSet<TActorId> Subscribers_;
    };
} // namespace

    TActorId MakeConfigUpdaterId() {
        static constexpr TStringBuf ID = "ConfUpdSrv\0"sv;
        return TActorId(0, ID);
    }

    IActor* CreateConfigUpdaterActor(const TConfigUpdaterConfig& conf) {
        return new TConfigUpdaterActor{conf};
    }

} // namespace NSolomon::NFetcher
