
#include <tbb/tbb.h>

#include <util/random/random.h>
#include <util/string/split.h>
#include <library/cpp/string_utils/quote/quote.h>
#include <util/system/thread.h>
#include <util/generic/size_literals.h>

#include <library/cpp/json/json_reader.h>
#include <library/cpp/threading/future/future.h>
#include <library/cpp/threading/future/async.h>
#include <library/cpp/coroutine/engine/impl.h>

#include <mail/so/spamstop/tools/simple_shingler/handler_server.h>
#include <mail/so/spamstop/tools/so-common/StorageBase.h>
#include <mail/so/libs/talkative_config/config.h>
#include <mail/so/spamstop/tools/general_shingler/shingler/schemes/db.h>
#include <mail/so/spamstop/tools/general_shingler/data/cleanup.h>
#include <mail/so/spamstop/tools/general_shingler/data/helper.h>
#include <mail/so/spamstop/tools/general_shingler/data/shard.h>

#include "storages/storage.h"
#include "processor.h"

namespace NGeneralShingler {
    struct TUpdateInfo {
        std::atomic_size_t failed{0};
        std::atomic_size_t success{0};
    };

    class TImpl : public TProcessor {
    public:
        explicit TImpl(TAtomicSharedPtr<IThreadPool> pool,
            TAtomicSharedPtr<TLog> logger, TVector<TString> schemeOrder,
            THashMap<TString, TAtomicSharedPtr<TSchemeBase>> schemesByName,
            const TDuration& totalReadTimeout,
            THashMap<TString, TSharder> shardersByScheme)
            : pool(std::move(pool))
            , logger(std::move(logger))
            , schemeOrder(std::move(schemeOrder))
            , schemesByName(std::move(schemesByName))
            , totalReadTimeout(totalReadTimeout)
            , shardersByScheme(std::move(shardersByScheme))
        {
            for (const auto& item : schemeOrder)
                updateInfo[item]; //insert default value
        }

        TMessages ParseMessages(const TStringBuf &input) const final;
        TMessages ParseMessages(const THandleContext &handleContext) const final;

        NJson::TJsonValue ProcessGetMessages(TMessages &&messages) final;
        void ProcessUpdateMessagesSync(TMessages &&messages) final;
        void ProcessUpdateMessages(TMessages &&messages) final;
        void AsyncUpdateMessages();
    private:
        TMaybe<TMessage> CreateMessageFromRequest(const THandleContext &handleContext) const;
        TSchemeBase & GetScheme(const TString & schemeName);
        TMessages ParseAndValidateMessages(NJson::TJsonValue && jsMessages) const;
    private:
        TAtomicSharedPtr<IThreadPool> pool;
        TAtomicSharedPtr<TLog> logger;
        TVector<TString> schemeOrder;
        THashMap<TString, TAtomicSharedPtr<TSchemeBase>> schemesByName;
        TDuration totalReadTimeout;
        THashMap<TString, TSharder> shardersByScheme;
        THashMap<TString, TUpdateInfo> updateInfo;

        tbb::concurrent_queue<TSimpleSharedPtr <TMessages>> messagesToUpdate;
    };

    TMaybe<TMessage> TImpl::CreateMessageFromRequest(const THandleContext &handleContext) const {

        TVector<TStringBuf> subReqs;
        Split(handleContext.request, "/", subReqs);

        if(subReqs.empty() || (subReqs.size() == 1 && subReqs.back() == "v1"))
            return Nothing();

        const auto & schemeName = subReqs.back();

        NJson::TJsonValue::TArray fieldsArray(1);
        {
            auto & fields = fieldsArray.front().SetType(NJson::JSON_MAP);
            for(const auto & field : handleContext.cgiParameters) {
                const auto & key = CGIUnescapeRet(field.first);
                const auto & value = CGIUnescapeRet(field.second);

                fields.SetValueByPath(key, value);
            }
        }

        auto sharderIt = shardersByScheme.find(schemeName);
        if(sharderIt != shardersByScheme.cend()) {
            auto sharded = sharderIt->second.Shardify(std::move(fieldsArray));
            if(!sharded.withoutShard.fields.empty())
                return MakeMaybe<TMessage>(TString{schemeName}, TMessageType::Get, TDeque<TShardedFields>{std::move(sharded.withoutShard)});
            else {
                TDeque<TShardedFields> allShardedFields;
                for(auto & shardedFiedls : sharded.groupedByShard)
                    allShardedFields.emplace_back(std::move(shardedFiedls.second));
                return MakeMaybe<TMessage>(TString{schemeName}, TMessageType::Get, std::move(allShardedFields));
            }
        } else {
            return MakeMaybe<TMessage>(TString{schemeName}, TMessageType::Get, std::move(fieldsArray));
        }
    }

    TMessages TImpl::ParseMessages(const TStringBuf &input) const {
        NJson::TJsonValue js;

        if(!NJson::ReadJsonTree(input, &js, false))
            ythrow TWithBackTrace<THttpError>(HTTP_BAD_REQUEST) << "cannot parse json from " << input;

        return ParseAndValidateMessages(std::move(js));
    }

    TMessages TImpl::ParseMessages(const THandleContext &handleContext) const {

        TVector<TMessage> messages;

        const auto data = handleContext.ExtractData();

        if(!data.Empty()) {
            messages = ParseMessages({static_cast<const char *>(data.Data()), data.Size()});
        }

        auto messageFromRequest = CreateMessageFromRequest(handleContext);
        if(messageFromRequest.Defined()) {
            TMessage & msg = messageFromRequest.GetRef();
            messages.emplace_back(std::move(msg));
        }

        return messages;
    }

    NJson::TJsonValue TImpl::ProcessGetMessages(TMessages &&messages) {

        const TInstant deadLine = Now() + totalReadTimeout;

        NJson::TJsonValue results(NJson::JSON_ARRAY);
        NJson::TJsonValue::TArray & resultsArray = results.GetArraySafe();

        TDeque<std::pair<THolder<IFuture>, TString>> futureResults;

        TContExecutor contExecutor(32_MB);
        contExecutor.SetFailOnError(true);

        TDeque<TString> names;
        for(const TMessage & message : messages) {
            if(message.GetShardedFields().empty())
                continue;

            for(const auto & schemeName : message.GetSchemes()) {
                auto &scheme = GetScheme(schemeName);
                for(const auto & sharded : message.GetShardedFields()) {
                    futureResults.emplace_back(scheme.Find(deadLine, sharded), schemeName);
                    contExecutor.Create<IFuture, &IFuture::Run>(futureResults.back().first.Get(),
                                        names.emplace_back(schemeName + ToString(RandomNumber<ui32>())).c_str());
                }
            }
        }

        contExecutor.Execute();

        for(auto & futureResult : futureResults) {
            auto futureFinds = futureResult.first->Extract();
            auto & schemeName = futureResult.second;

            NJson::TJsonValue & jsResult = resultsArray.emplace_back();
            jsResult["scheme"] = std::move(schemeName);

            NJson::TJsonValue::TArray & jsResults = jsResult["find"].SetType(NJson::JSON_ARRAY).GetArraySafe();
            jsResults = std::move(futureFinds);
        }

        return results;
    }

    void TImpl::ProcessUpdateMessagesSync(TMessages &&messages) {
        for(const auto & message : messages) {
            for(const auto & schemeName : message.GetSchemes()) {
                auto & scheme = GetScheme(schemeName);
                scheme.Update(message.GetShardedFields());
            }
        }
    }

    void TImpl::ProcessUpdateMessages(TMessages &&messages) {
        messagesToUpdate.push(MakeSimpleShared<TMessages>(std::move(messages)));
    }

    TSchemeBase & TImpl::GetScheme(const TString & schemeName) {
        auto schemeIt = schemesByName.find(schemeName);
        if(schemesByName.cend() == schemeIt)
            ythrow TWithBackTrace<yexception>() << "cannot find scheme " << schemeName;
        return *schemeIt->second;
    }

    static void AlignArrays(NJson::TJsonValue & val) {
        if(!val.IsMap())
            return;

        auto & map = val.GetMapSafe();

        const bool hasArray = std::any_of(map.cbegin(), map.cend(), [](const std::pair<TString, NJson::TJsonValue> & p){
            return p.second.IsArray();
        });

        if(!hasArray)
            return;

        const size_t maxSize = std::accumulate(map.cbegin(), map.cend(), 0ul, [](size_t s, const std::pair<TString, NJson::TJsonValue> & p){
            if(p.second.IsArray())
                return std::max(s, p.second.GetArraySafe().size());

            return s;
        });

        if(maxSize == 0)
            return;

        NJson::TJsonValue alignedVal(NJson::JSON_ARRAY);

        for(auto & p : map) {
            auto newVal = NJson::TJsonValue(NJson::JSON_ARRAY);
            auto & target = newVal.GetArraySafe();

            if(p.second.IsArray()) {
                auto & arr = p.second.GetArraySafe();
                const auto size = arr.size();

                if(size == 0)
                    continue;

                for(size_t i : xrange(maxSize)) {
                    const size_t ind = arr.size() == 1 ? 0 : (i % size);
                    target.emplace_back(std::move(arr[ind]));
                }
            } else if(p.second.IsMap()) {
                ythrow TWithBackTrace<yexception>() << "cannot align map";
            } else if(p.second.IsNull()) {
                continue;
            } else {
                for(size_t i : xrange(maxSize)) {
                    Y_UNUSED(i);
                    target.emplace_back(p.second);
                }
            }
            alignedVal.InsertValue(p.first, std::move(newVal));
        }

        val = std::move(alignedVal);
    }

    TVector<TMessage> TImpl::ParseAndValidateMessages(NJson::TJsonValue && jsMessages) const {
        if(!jsMessages.IsArray())
            ythrow TWithBackTrace<THttpError>(HTTP_BAD_REQUEST) << "messages must be array: " << jsMessages;

        TVector<TMessage> messages;
        messages.reserve(jsMessages.GetArray().size());

        for(NJson::TJsonValue & jsMessage : jsMessages.GetArraySafe()) {
            if(!jsMessage.IsMap())
                ythrow TWithBackTrace<THttpError>(HTTP_BAD_REQUEST) << "each message must be dict: " << jsMessage;

            if(!jsMessage.Has("scheme"))
                ythrow TWithBackTrace<THttpError>(HTTP_BAD_REQUEST) << "each message must contain scheme field: " << jsMessage;

            TVector<TString> schemes;
            {
                const auto jsScheme = jsMessage["scheme"];
                if(jsScheme.IsString()) {
                    schemes.emplace_back(jsScheme.GetString());
                } else if(jsScheme.IsArray()) {
                    for(const auto & s : jsScheme.GetArray()) {
                        if(!s.IsString())
                            ythrow TWithBackTrace<THttpError>(HTTP_BAD_REQUEST) << "scheme must be string: " << jsScheme;
                        schemes.emplace_back(s.GetString());
                    }
                } else {
                    ythrow TWithBackTrace<THttpError>(HTTP_BAD_REQUEST) << "scheme must be string or array: " << jsScheme;
                }
            }

            if(!jsMessage.Has("type"))
                ythrow TWithBackTrace<THttpError>(HTTP_BAD_REQUEST) << "each message must contain type field: " << jsMessage;

            const auto jsType = jsMessage["type"];
            if(!jsType.IsString())
                ythrow TWithBackTrace<THttpError>(HTTP_BAD_REQUEST) << "type must be string: " << jsType;


            const auto & strType = jsType.GetString();
            TMessageType type{};
            if(!TryFromString(strType, type)) {
                TStringStream msg;
                for(const auto & p : GetEnumNames<TMessageType>())
                    msg << p.second << ',';
                ythrow TWithBackTrace<THttpError>(HTTP_BAD_REQUEST) << "type must one of: [" << msg.Str() << "]; got: " << strType;
            }


            for(const auto & scheme : schemes) {
                if (jsMessage.Has("fields")) {
                    const NJson::TJsonValue &jsFields = jsMessage["fields"];
                    if (!jsFields.IsArray())
                        ythrow TWithBackTrace<THttpError>(HTTP_BAD_REQUEST) << "fields must be array: " << jsFields;
                    NJson::TJsonValue::TArray jsFieldsArray = jsFields.GetArraySafe();

                    auto sharderIt = shardersByScheme.find(scheme);
                    if(sharderIt != shardersByScheme.cend()) {
                        auto sharded  = sharderIt->second.Shardify(std::move(jsFieldsArray));
                        {
                            for (NJson::TJsonValue &jsField : sharded.withoutShard.fields) {
                                if (!jsField.IsMap())
                                    ythrow TWithBackTrace<THttpError>(HTTP_BAD_REQUEST) << "each field must be dict: " << jsField;
                                AlignArrays(jsField);
                            }
                            messages.emplace_back(scheme, type, TDeque<TShardedFields>{std::move(sharded.withoutShard)});
                        }
                        {
                            for (auto &shardedFields : sharded.groupedByShard) {
                                for (NJson::TJsonValue &jsField : shardedFields.second.fields) {
                                    if (!jsField.IsMap())
                                        ythrow TWithBackTrace<THttpError>(HTTP_BAD_REQUEST) << "each field must be dict: " << jsField;
                                    AlignArrays(jsField);
                                }
                                messages.emplace_back(scheme, type, TDeque<TShardedFields>{std::move(shardedFields.second)});
                            }
                        }
                    } else {
                        for (NJson::TJsonValue &jsField : jsFieldsArray) {
                            if (!jsField.IsMap())
                                ythrow TWithBackTrace<THttpError>(HTTP_BAD_REQUEST) << "each field must be dict: " << jsField;

                            AlignArrays(jsField);
                        }
                        messages.emplace_back(scheme, type, jsFieldsArray);
                    }
                } else {
                    messages.emplace_back(scheme, type, NJson::TJsonValue::TArray(1));

                }
            }
        }

        return messages;
    }

    void TImpl::AsyncUpdateMessages() {
        THashMap<TString, THashMap<TShard, TShardedFields>> fieldsBySchemeShard;

        TSimpleSharedPtr<TMessages> messages;
        while(messagesToUpdate.try_pop(messages)) {
            for (const auto &message : *messages) {
                const auto & schemesNames = message.GetSchemes();

                for (const auto &schemeName : schemesNames) {
                    auto & target = fieldsBySchemeShard[schemeName];
                    for(const auto & shardedFields : message.GetShardedFields()) {
                        const auto shard = shardedFields.shard;

                        auto targetShardedIt = target.find(shard);
                        if(targetShardedIt == target.end()) {
                            target.emplace(shard, shardedFields);
                        } else {
                            std::copy(shardedFields.fields.cbegin(), shardedFields.fields.cend(), std::back_inserter(targetShardedIt->second.fields));
                        }
                    }
                }
            }
        }

        size_t totalUpdated = 0;
        size_t toBeUpdatedSize = 0;
        for (const auto& schemeName : schemeOrder) {
            auto it = fieldsBySchemeShard.find(schemeName);
            if (it == fieldsBySchemeShard.end())
                continue;

            toBeUpdatedSize = 0;
            TDeque<TShardedFields> fields;
            for (auto & f : it->second) {
                toBeUpdatedSize += f.second.fields.size();
                fields.emplace_back(std::move(f.second));
            }

            auto & scheme = GetScheme(schemeName);
            try {
                scheme.Update(std::move(fields));
                totalUpdated += toBeUpdatedSize;
                updateInfo[schemeName].success += toBeUpdatedSize;
            }
            catch (const yexception & e) {
                updateInfo[schemeName].failed += toBeUpdatedSize;
                *logger << TLOG_WARNING << "async update message failed. scheme: " << schemeName << ", count: " << toBeUpdatedSize << ", error: " << e.what() << Endl;
            }
        }

        *logger << TLOG_INFO << "async updated: " << totalUpdated << Endl;
    }

    TAtomicSharedPtr<TProcessor> ProcessorFactory(const NConfig::TConfig &config,
        const TAtomicSharedPtr<IThreadPool>& pool,
        const TAtomicSharedPtr<TLog>& logger,
        THashMap<TString, TAtomicSharedPtr<ICache>>& caches,
        TSimpleScheduler& scheduler
    ) {
        THashMap<TString, TAtomicSharedPtr<TCacheContext>> cacheContextByName;
        if (config.Has("caches")) {
            for (const auto & v : NTalkativeConfig::Get<NConfig::TDict>(config, "caches")) {
                auto cache = CreateCache(v.second, &scheduler, "cache::" + v.first);
                auto item = MakeAtomicShared<TCacheContext>(v.second, cache);

                caches.emplace(v.first + "_ammx", cache);
                cacheContextByName.emplace(v.first, item);
            }
        }
        else {
            *logger << TLOG_WARNING << "there are no caches in config" << Endl;
        }

        THashMap<TString, TTimeSet> timeSetsByName;
        if (config.Has("time_sets")) {
            for (const auto & v : NTalkativeConfig::Get<NConfig::TDict>(config, "time_sets")) {
                timeSetsByName.emplace(v.first, TTimeSet(v.second));
            }
        }

        TStorage storage(config["storage"], pool, logger, scheduler);

        THashMap<TString, TSharder> shardersByScheme;
        THashMap<TString, TAtomicSharedPtr<TSchemeBase>> schemesByName;
        if (config.Has("schemes")) {
            TFieldSets sets(config["field_sets"]);
            const auto& schemes = NTalkativeConfig::Get<NConfig::TDict>(config, "schemes");
            for (const auto & v : schemes) {
                if(v.second.Has("sharder")) {
                    shardersByScheme.emplace(v.first, v.second["sharder"]);
                }

                auto schemeIt = schemesByName.emplace(v.first, SchemeFactory(v.second, timeSetsByName, cacheContextByName, sets, storage)).first;
                schemeIt->second->SetLogger(logger);
            }

            for (const auto& v : schemes) {
                if (v.second.Has("use_order_from_scheme")) {
                    const auto& scheme = NHelper::GetSafe(schemesByName, v.first);
                    const auto& order = NHelper::GetSafe(schemesByName, NTalkativeConfig::Get<TString>(v.second, "use_order_from_scheme"));
                    TDbScheme::SetOrder(scheme.Get(), order.Get());
                }
            }
        }
        else {
            *logger << TLOG_WARNING << "there are no schemes in config" << Endl;
        }

        TVector<TString> schemeOrder;
        if (config.Has("scheme_order")) {
            const auto& order = NTalkativeConfig::Get<NConfig::TArray>(config, "scheme_order");
            for (const auto & v : order) {
                schemeOrder.emplace_back(NTalkativeConfig::Get<TString>(v));
                if (!schemesByName.contains(schemeOrder.back()))
                    ythrow yexception() << "can't find scheme " << schemeOrder.back();
            }
        }

        for (const auto& item : schemesByName) {
            if (std::find(schemeOrder.begin(), schemeOrder.end(), item.first) == schemeOrder.end()) {
                schemeOrder.emplace_back(item.first);
            }
        }

        if (config.Has("cleanup")) {
            TString collection;
            TAtomicSharedPtr<TStorageBase> db;

            const NConfig::TConfig& cleanup = config["cleanup"];
            if (cleanup.Has("statistic")) {
                const NConfig::TConfig& statistic = cleanup["statistic"];
                collection = NTalkativeConfig::Get<TString>(statistic, "collection");

                db = storage.GetDB(NTalkativeConfig::Get<TString>(statistic, "db"));
            }

            SetCleanupSchedules(scheduler, NTalkativeConfig::Get(cleanup, "items"), db, collection, schemesByName, timeSetsByName);
        }

        const TDuration totalReadTimeout = config.Has("total_read_timeout") ? NTalkativeConfig::As<TDuration>(config, "total_read_timeout") : TDuration::Max();
        auto processor = MakeAtomicShared<TImpl>(pool, logger, std::move(schemeOrder), std::move(schemesByName), totalReadTimeout, std::move(shardersByScheme));

        if(config.Has("update_frequency")) {
            scheduler.Add(([processor]() { processor->AsyncUpdateMessages(); }), NTalkativeConfig::As<TDuration>(config, "update_frequency"), "async update messages");
        } else {
            (*logger) << TLOG_WARNING << "shingler section hasn't 'update_frequency', so there is no thread for async updating" << Endl;
        }

        return processor;
    }
}
