#include <mail/so/spamstop/tools/so-common/StorageBase.h>
#include <mail/so/spamstop/tools/general_shingler/data/cast.h>
#include <numeric>
#include <util/generic/algorithm.h>
#include "db.h"

namespace NGeneralShingler {

    static TDbScheme* DynamicCastSafe(TSchemeBase* v) {
        auto * result = dynamic_cast<TDbScheme*>(v);
        if (result == nullptr)
            ythrow TWithBackTrace<yexception>() << "cannot cast to TDbScheme";
        return result;
    }

    void TDbScheme::SetOrder(TSchemeBase* scheme, TSchemeBase* order) {
        DynamicCastSafe(scheme)->order_keys = DynamicCastSafe(order)->update_keys;
    }

    class TDbFuture : public IFuture {
    public:
        void DoRun(TCont * cont) override {
            const TFindResults & anyvaluesFromDB = internalFuture->Get(cont);

            for(const auto & v : anyvaluesFromDB)
                AnyvalueToJson(v, foundValues.emplace_back());
        }

        NJson::TJsonValue::TArray DoExtract() override {
            return std::move(foundValues);
        }

        explicit TDbFuture(THolder<TStorageBase::IFuture> f)
                : internalFuture(std::move(f)) {}
    private:
        NJson::TJsonValue::TArray foundValues;
        THolder<TStorageBase::IFuture> internalFuture;
    };

    THolder<IFuture> TDbScheme::InternalFind(TInstant deadline, const TShardedFields& sharded) {
        TFindAction findAction{};
        for(const NJson::TJsonValue & fields : sharded.fields) {
            *logger << TLOG_DEBUG << "trying find in db " << fields << Endl;
            ProcessFind(fields, findAction, sharded.shard);
        }

        return MakeHolder<TDbFuture>(dbContext.db->FindNonblock(deadline, dbContext.collection, findAction));
    }

    size_t TDbScheme::Remove(TShardedFields sharded) {
        TFindAction findAction = TFindAction();
        for (const NJson::TJsonValue & fields : sharded.fields) {
            ProcessFind(fields, findAction, sharded.shard);
        }

        return dbContext.db->Remove(dbContext.collection, findAction);
    }

    static size_t CalculateHash(const TVector<TAtomicSharedPtr<TKeyScheme>>& keys, const NJson::TJsonValue& value) {
        return std::accumulate(keys.begin(), keys.end(), size_t(0), [&value](size_t h, const auto& key) {
            return CombineHashes(h, key->Hash(value));
        });
    }

    static void ReorderUpdateQueue(const TVector<TAtomicSharedPtr<TKeyScheme>>& order_keys, TMap<size_t, TDeque<const NJson::TJsonValue *>>& data) {
        if (order_keys.empty())
            return;

        TMap<size_t, TDeque<const NJson::TJsonValue *>> local;
        for (auto& v : data) {
            local[CalculateHash(order_keys, *v.second.front())] = std::move(v.second);
        }

        local.swap(data);
    }

    static auto GetUpdatesByKey(const NJson::TJsonValue::TArray & messageFields, const TVector<TAtomicSharedPtr<TKeyScheme>>& update_keys, const TVector<TAtomicSharedPtr<TKeyScheme>>& order_keys) {
        TMap<size_t, TDeque<const NJson::TJsonValue *>> updatesByKey;

        for (const auto & f : messageFields) {
            updatesByKey[CalculateHash(update_keys, f)].emplace_back(&f);
        }

        ReorderUpdateQueue(order_keys, updatesByKey);
        return updatesByKey;
    }

    void TDbScheme::Update(const TDeque<TShardedFields>& sharded) {
        const size_t totalSize = Accumulate(sharded.cbegin(), sharded.cend(), 0ul, [](size_t s, const TShardedFields & f){
            return s + f.fields.size();
        });

        TActionSeries series(Reserve(totalSize));

        for(const auto & shardedFields : sharded) {
            const auto updatesByKey = GetUpdatesByKey(shardedFields.fields, update_keys, order_keys);

            switch (aggregationType) {
                case TAggregationType::Accumulate: {
                    for (const auto &p : updatesByKey) {
                        TUpdateAction action;
                        for (const auto &key : update_keys) {
                            key->Apply(action.query, *p.second.front());
                        }

                        for (const NJson::TJsonValue *f : p.second) {
                            for (const auto &update : updates) {
                                update.Apply(action.update, *f);
                            }
                        }

                        action.upsert = upsert;
                        action.query.shard = shardedFields.shard;

                        series.emplace_back(std::move(action));
                    }
                    break;
                }
                case TAggregationType::Replace: {
                    for (const auto &p : updatesByKey) {
                        ProcessUpdate(*p.second.back(), series, shardedFields.shard);
                    }
                    break;
                }
                case TAggregationType::Concatenate: {
                    for (const auto &p : updatesByKey) {
                        for (const auto &f : p.second)
                            ProcessUpdate(*f, series, shardedFields.shard);
                    }
                    break;
                }
            }
        }

        dbContext.db->UpdateSeries(dbContext.collection, series, true);
    }

    void TDbScheme::ProcessUpdate(const NJson::TJsonValue &messageFields, TActionSeries & series, TMaybe<size_t> shard) const {

        TUpdateAction action;

        action.query.shard = std::move(shard);

        for(const auto & key : update_keys) {
            key->Apply(action.query, messageFields);
        }

        for(const auto & update : updates) {
            update.Apply(action.update, messageFields);
        }

        action.upsert = upsert;

        series.emplace_back(std::move(action));
    }

    void TDbScheme::ProcessFind(const NJson::TJsonValue &messageFields, TFindAction & action, TMaybe<size_t> shard) const {
        auto & target = action.query.ors.emplace_back();
        action.query.shard = std::move(shard);
        for (const auto &key : find_keys) {
            key->Apply(target, messageFields);
        }
    }
}



