#include "tags_search.h"

#include <drive/backend/database/transaction/assert.h>
#include <drive/backend/logging/events.h>
#include <drive/backend/logging/evlog.h>

#include <rtline/library/json/exception.h>

#include <util/string/builder.h>

template <>
NJson::TJsonValue NJson::ToJson(const TTagsSearch::TTagStat& object) {
    NJson::TJsonValue result;
    result["count"] = object.Count;
    result["timestamp"] = NJson::ToJson(object.Timestamp);
    return result;
}

bool TTagsSearchRequest::IsSimpleQuery() const {
    return (HasAllOf.size() == 1 && HasOneOf.size() == 0 && HasNoneOf.size() == 0 && !HasPerformer());
}

TVector<TString> TTagsSearchRequest::GetObservedTagNames() const {
    TVector<TString> tagNames;
    for (auto&& tag : HasAllOf) {
        tagNames.emplace_back(tag);
    }
    for (auto&& tag : HasOneOf) {
        tagNames.emplace_back(tag);
    }
    return tagNames;
}

TVector<TString> TTagsSearchRequest::ResolveSpecials(const TVector<TString>& rawTagsList, const ITagsMeta& tagsMeta) const {
    TVector<TString> result;
    for (auto&& tag : rawTagsList) {
        if (tag.StartsWith("@")) {
            auto tagDescrs = tagsMeta.GetTagsByType(tag.substr(1, tag.size() - 1));
            for (auto&& descr : tagDescrs) {
                result.push_back(descr->GetName());
            }
        } else {
            result.push_back(tag);
        }
    }
    return result;
}

void TTagsSearchRequest::Prepare(const ITagsMeta& tagsMeta) {
    HasAllOf = ResolveSpecials(HasAllOf, tagsMeta);
    HasOneOf = ResolveSpecials(HasOneOf, tagsMeta);
    HasNoneOf = ResolveSpecials(HasNoneOf, tagsMeta);
    if (HasAllOf.empty() && HasOneOf.size() == 1) {
        HasAllOf.swap(HasOneOf);
    }
}

bool TTagsSearchRequest::DeserializeFromJson(const NJson::TJsonValue& value) {
    return
        NJson::ParseField(value["has_all_of"], HasAllOf) &&
        NJson::ParseField(value["has_none_of"], HasNoneOf) &&
        NJson::ParseField(value["has_one_of"], HasOneOf) &&
        NJson::ParseField(value["limit"], Limit) &&
        NJson::ParseField(value["performer"], Performer);
}

NJson::TJsonValue TTagsSearchRequest::SerializeToJson() const {
    NJson::TJsonValue result;
    NJson::InsertField(result, "has_all_of", HasAllOf);
    NJson::InsertField(result, "has_none_of", HasNoneOf);
    NJson::InsertField(result, "has_one_of", HasOneOf);
    NJson::InsertField(result, "limit", Limit);
    NJson::InsertField(result, "performer", Performer);
    return result;
}

NDrive::TScheme TTagsSearchRequest::GetScheme(const IServerBase& /*server*/) {
    NDrive::TScheme scheme;
    scheme.Add<TFSArray>("has_all_of", "has_all_of").SetElement<TFSString>();
    scheme.Add<TFSArray>("has_none_of", "has_none_of").SetElement<TFSString>();
    scheme.Add<TFSArray>("has_one_of", "has_one_of").SetElement<TFSString>();
    scheme.Add<TFSString>("performer", "performer");
    scheme.Add<TFSNumeric>("limit", "limit").SetDefault(100);
    return scheme;
}

bool TTagsSearchResult::IsLimitReached() const {
    return MatchedIds.size() >= ResultsLimit;
}

TTagsSearch::TTagsSearch(const IEntityTagsManager& observedObject, const IDriveTagsManager& tagsManager)
    : TagStatCache(1024)
    , ObservedObject(observedObject)
    , TagsMeta(tagsManager.GetTagsMeta())
    , SearchCount({ observedObject.GetTableName() + "-search-count" }, false)
    , SearchErrors({ observedObject.GetTableName() + "-search-errors" }, false)
    , SearchTimes({ observedObject.GetTableName() + "-search-times" })
{
}

ui64 TTagsSearch::Count(NDrive::TEntitySession& tx, const NSQL::TQueryOptions& queryOptions) const {
    TRecordsSet records;
    auto countQuery = queryOptions.PrintQuery(*tx, ObservedObject.GetTableName(), {
        "COUNT(DISTINCT object_id) as count"
    });
    auto queryResult = tx->Exec(countQuery, &records);
    R_ENSURE(ParseQueryResult(queryResult, tx), {}, "cannot Exec CountQuery", tx);
    R_ENSURE(records.size() == 1, HTTP_INTERNAL_SERVER_ERROR, "bad CountQuery rows count: " << records.size(), tx);
    return FromString<ui64>(records.GetRecords()[0].Get("count"));
}

ui64 TTagsSearch::GetCachedCount(const TString& tag, NDrive::TEntitySession& tx) const {
    auto evlog = NDrive::GetThreadEventLogger();
    auto now = Now();
    auto threshold = now - TDuration::Seconds(1000);
    auto cachedValue = TagStatCache.find(tag);
    if (cachedValue && cachedValue->Timestamp > threshold) {
        return cachedValue->Count;
    }

    if (evlog) {
        evlog->AddEvent(NJson::TMapBuilder
            ("event", "TagStatCacheMiss")
            ("cached", NJson::ToJson(cachedValue))
            ("tag", tag)
        );
    }

    NSQL::TQueryOptions queryOptions;
    queryOptions.AddGenericCondition("tag", tag);
    TTagStat stat;
    stat.Count = Count(tx, queryOptions);
    stat.Timestamp = now;
    TagStatCache.update(tag, stat);
    return stat.Count;
}

TTagsSearchResult TTagsSearch::Search(TTagsSearchRequest& request, TDuration timeout) const {
    auto tx = ObservedObject.BuildTx<NSQL::ReadOnly | NSQL::RepeatableRead>(/*lockTimeout=*/timeout, /*statementTimeout=*/timeout);
    return Search(request, tx);
}

TTagsSearchResult TTagsSearch::Search(TTagsSearchRequest& request, NDrive::TEntitySession& tx) const try {
    SearchCount.Signal(1);
    auto start = Now();
    auto result = Search2(request, tx);
    auto finish = Now();
    auto duration = finish - start;
    SearchTimes.Signal(duration.MilliSeconds());
    return result;
} catch (...) {
    NDrive::TEventLog::Log("TagSearchError", NJson::TMapBuilder
        ("request", request.SerializeToJson())
        ("exception", CurrentExceptionInfo())
    );
    SearchErrors.Signal(1);
    throw;
}

TTagsSearchResult TTagsSearch::Search2(TTagsSearchRequest& request, NDrive::TEntitySession& tx) const {
    const auto& settings = NDrive::GetServer().GetSettings();
    const auto countThreshold = settings.GetValue<ui64>(
        TStringBuilder() << "tags_manager." << ObservedObject.GetEntityType() << ".optimization.count_threshold"
    ).GetOrElse(100000);
    const auto evlog = NDrive::GetThreadEventLogger();

    request.Prepare(TagsMeta);
    auto allOf = MakeSet(request.GetHasAllOf());
    auto anyOf = MakeSet(request.GetHasOneOf());
    auto noneOf = MakeSet(request.GetHasNoneOf());
    if (allOf.size() == 1 && anyOf.size() == 0) {
        std::swap(allOf, anyOf);
    }
    auto observed = MakeSet(request.GetObservedTagNames());

    auto fields = TVector<TString>{
        "object_id",
        "1 AS present",
    };
    auto groupBy = TVector<TString>{
        "object_id",
    };
    auto query = TStringBuilder() << "SELECT has_any_of.object_id AS object_id FROM";
    auto table = ObservedObject.GetTableName();

    auto effectiveAnyOf = anyOf;
    auto effectiveAllOf = allOf;
    bool optimizedAllOf = false;
    {
        TVector<std::pair<ui64, TString>> tags;
        for (auto&& tag : allOf) {
            auto count = GetCachedCount(tag, tx);
            tags.emplace_back(count, tag);
        }
        std::sort(tags.begin(), tags.end());
        std::reverse(tags.begin(), tags.end());
        for (auto&& [count, tag] : tags) {
            if (effectiveAnyOf.size() + effectiveAllOf.size() < 2) {
                break;
            }
            if (count < countThreshold) {
                break;
            }
            effectiveAllOf.erase(tag);
            optimizedAllOf = true;

            if (evlog) {
                evlog->AddEvent(NJson::TMapBuilder
                    ("event", "OptimizeAllOf")
                    ("count", count)
                    ("tag", tag)
                );
            }
        }
        if (effectiveAnyOf.empty() && !effectiveAllOf.empty() && !request.HasPerformer()) {
            auto tag = *effectiveAllOf.begin();
            effectiveAnyOf.insert(tag);
            effectiveAllOf.erase(tag);
            if (evlog) {
                evlog->AddEvent(NJson::TMapBuilder
                    ("event", "OptimizeAnyOf")
                    ("tag", tag)
                );
            }
        }
    }

    auto effectiveNoneOf = noneOf;
    bool optimizedNoneOf = false;
    {
        for (auto&& tag : noneOf) {
            auto count = GetCachedCount(tag, tx);
            if (count < countThreshold) {
                continue;
            }

            effectiveNoneOf.erase(tag);
            optimizedNoneOf = true;

            if (evlog) {
                evlog->AddEvent(NJson::TMapBuilder
                    ("event", "OptimizeNoneOf")
                    ("count", count)
                    ("tag", tag)
                );
            }
        }
    }

    auto effectiveObserved = effectiveAnyOf;
    if (request.HasPerformer()) {
        effectiveObserved.insert(effectiveAllOf.begin(), effectiveAllOf.end());
    }
    {
        NSQL::TQueryOptions queryOptions;
        queryOptions.SetGroupBy(groupBy);
        if (!effectiveObserved.empty()) {
            queryOptions.SetGenericCondition("tag", effectiveObserved);
        }
        if (request.HasPerformer()) {
            queryOptions.AddGenericCondition("performer", request.GetPerformerRef());
        }
        query << " (" << queryOptions.PrintQuery(*tx, table, fields) << ") AS has_any_of";
    }
    size_t i = 0;
    for (auto&& tag : effectiveAllOf) {
        NSQL::TQueryOptions queryOptions;
        queryOptions.AddGenericCondition("tag", tag);
        queryOptions.SetGroupBy(groupBy);
        query << " JOIN (" << queryOptions.PrintQuery(*tx, table, fields) << ") AS has_all_of_" << i << " ON has_any_of.object_id = has_all_of_" << i << ".object_id";
        i += 1;
    }
    i = 0;
    for (auto&& tag : effectiveNoneOf) {
        NSQL::TQueryOptions queryOptions;
        queryOptions.AddGenericCondition("tag", tag);
        queryOptions.SetGroupBy(groupBy);
        query << " LEFT OUTER JOIN (" << queryOptions.PrintQuery(*tx, table, fields) << ") AS has_none_of_" << i << " ON has_any_of.object_id = has_none_of_" << i << ".object_id";
        i += 1;
    }
    query << " WHERE True";
    for (i = 0; i < effectiveNoneOf.size(); ++i) {
        query << " AND has_none_of_" << i << ".present IS NULL";
    }
    if (auto limit = request.GetLimit(); limit && !optimizedAllOf && !optimizedNoneOf) {
        query << " LIMIT " << limit;
    }

    TRecordsSet records;
    {
        auto g = NDrive::BuildEventGuard("SearchQuery");
        auto queryResult = tx->Exec(query, &records);
        R_ENSURE(ParseQueryResult(queryResult, tx), {}, "cannot Exec SearchQuery", tx);
    }

    TSet<TString> objectIds;
    for (auto&& record : records) {
        auto objectId = record.Get("object_id");
        if (objectId) {
            objectIds.insert(std::move(objectId));
        }
    }

    TTagsSearchResult result(request.GetLimit());
    if (objectIds.empty()) {
        result.SetTotalMatched(0);
        return result;
    }

    TMap<TString, TDBTags> objects;
    {
        auto g = NDrive::BuildEventGuard("RestoreFoundTags");
        auto tagNames = MakeVector(observed);
        for (auto&& tag : noneOf) {
            tagNames.push_back(tag);
        }
        auto tags = TDBTags();
        R_ENSURE(ObservedObject.RestoreTags(objectIds, tagNames, tags, tx), {}, "cannot restore " << objects.size() << " objects", tx);
        for (auto&& tag : tags) {
            objects[tag.GetObjectId()].push_back(std::move(tag));
        }
    }
    for (auto&& [id, tags] : objects) {
        size_t allOfCount = 0;
        size_t anyOfCount = 0;
        size_t noneOfCount = 0;
        TSet<TString> uniqueAllOf;
        TVector<TConstDBTag> associatedTags;
        for (auto&& tag : tags) {
            const auto& name = tag->GetName();
            const auto& performer = tag->GetPerformer();
            if (noneOf.contains(name)) {
                R_ENSURE(optimizedNoneOf, HTTP_INTERNAL_SERVER_ERROR, "object " << id << " passed filtration with tag " << name, tx);
                noneOfCount += 1;
                break;
            }
            if (!observed.empty() && !observed.contains(name)) {
                continue;
            }
            if (allOf.contains(name) && uniqueAllOf.insert(name).second) {
                allOfCount += 1;
            }
            if (anyOf.contains(name)) {
                anyOfCount += 1;
            }
            if (request.HasPerformer() && request.GetPerformerRef() != performer) {
                continue;
            }
            associatedTags.push_back(std::move(tag));
        }
        if (noneOfCount) {
            continue;
        }
        if (allOfCount != allOf.size()) {
            R_ENSURE(optimizedAllOf, HTTP_INTERNAL_SERVER_ERROR, "all_of mismatch: " << allOfCount << " == " << allOf.size(), tx);
            continue;
        }
        R_ENSURE(anyOfCount > 0 || anyOf.empty(), HTTP_INTERNAL_SERVER_ERROR, "any_of mismatch: " << anyOfCount, tx);
        result.AddMatchedId(id, std::move(associatedTags));
        if (result.IsLimitReached()) {
            break;
        }
    }

    if (allOf.empty() && noneOf.empty()) {
        auto g = NDrive::BuildEventGuard("CountQuery");
        NSQL::TQueryOptions queryOptions;
        queryOptions.SetGenericCondition("tag", anyOf);
        if (request.HasPerformer()) {
            queryOptions.AddGenericCondition("performer", request.GetPerformerRef());
        }
        result.SetTotalMatched(Count(tx, queryOptions));
    }

    return result;
}
