#include "custom_audience_service_impl.h"

#include <crypta/lab/lib/native/segments_user_data_stats_aggregator.h>
#include <crypta/lab/lib/native/user_data_stats_aggregator.h>
#include <crypta/lab/lib/native/user_data_stats_decoder.h>
#include <crypta/lib/native/time/scope_timer.h>
#include <crypta/siberia/bin/custom_audience/common/rule/encoded_ca_rule.h>
#include <crypta/siberia/bin/custom_audience/common/rule/filter.h>

#include <util/generic/hash_set.h>
#include <util/generic/refcount.h>

using namespace NCrypta::NSiberia::NCustomAudience;

TCustomAudienceServiceImpl::TCustomAudienceServiceImpl(
    size_t threadCount,
    const TDb& db,
    size_t describeThreshold,
    const TStats::TSettings& statsSettings
)
    : Db(db)
    , DescribeThreshold(describeThreshold)
    , Log(NLog::GetLog("custom_audience_service"))
    , Stats(TaggedSingleton<TStats, decltype(*this)>("custom_audience_server", statsSettings))
{
    Executor.RunAdditionalThreads(threadCount);
}

grpc::Status TCustomAudienceServiceImpl::GetStatsBase(const TCaRule* rule, NLab::TUserDataStats* response) {
    try {
        if (!Db.IsReady()) {
            return grpc::Status(grpc::StatusCode::UNAVAILABLE, "Service is not ready");
        }

        TScopeTimer scopeTimer(Stats.Percentile, "timing.get_stats");

        Stats.Count->Add("request.get_stats.total.received");

        TVector<NLab::TUserDataStats> results;
        results.resize(Db.GetPacks().size());

        TEncodedCaRule encodedRule(*rule, Db.GetDicts().WordDicts.ReversedDict, Db.GetDicts().HostDicts.ReversedDict, Db.GetDicts().AppDicts.ReversedDict);

        TAtomicCounter filteredRecordCounter = 0;

        NPar::ParallelFor(Executor, 0, Db.GetPacks().size(), [this, &results, &encodedRule, &filteredRecordCounter](int i) {
            NLab::TUserDataStatsAggregator<NLab::TAffinitiesEncodedOptions> statsAggregator({.MaxTokensCount = 1000000, .MinSampleRatio = 0.01, .AccumulateAffinities = true});
            NNativeYT::TProtoState<NLab::TUserDataStatsOptions> state;
            NLab::TSegmentsUserDataStatsAggregator aggregator(state);
            NLab::TUserDataStats userDataStats;

            for (const auto& userData : Db.GetPacks().at(i)) {
                if (static_cast<size_t>(filteredRecordCounter.Val()) >= DescribeThreshold) {
                    break;
                }

                if (NFilter::Filter(userData, encodedRule)) {
                    aggregator.ConvertUserDataToUserDataStats(userDataStats, userData, false);
                    statsAggregator.UpdateWith(userDataStats);
                    filteredRecordCounter.Inc();
                }
            }

            statsAggregator.MergeInto(results[i], &Db.GetDicts().WordDicts.Dict, &Db.GetDicts().HostDicts.Dict, &Db.GetDicts().AppDicts.Dict);
        });

        NLab::TUserDataStatsAggregator<NLab::TAffinitiesEncodedOptions> statsAggregator({.MaxTokensCount = 1000000, .MinSampleRatio = 0.01, .AccumulateAffinities = true});
        for (const auto& userDataStats : results) {
            statsAggregator.UpdateWith(userDataStats);
        }

        statsAggregator.MergeInto(*response, &Db.GetDicts().WordDicts.Dict, &Db.GetDicts().HostDicts.Dict, &Db.GetDicts().AppDicts.Dict);
        NLab::NUserDataStatsDecoder::Decode(*response, Db.GetDicts().WordDicts.Dict, Db.GetDicts().HostDicts.Dict, Db.GetDicts().AppDicts.Dict);

        return grpc::Status::OK;
    } catch (const yexception& e) {
        return grpc::Status(grpc::StatusCode::INTERNAL, e.what());
    } catch (const std::exception& e) {
        return grpc::Status(grpc::StatusCode::INTERNAL, e.what());
    } catch (...) {
        return grpc::Status(grpc::StatusCode::INTERNAL, "Unknown exception");
    }
}

grpc::Status TCustomAudienceServiceImpl::GetStatsByExtendedRule(grpc::ServerContext*, const TExtendedCaRule* rule, NLab::TUserDataStats* response) {
    try {
        if (!Db.IsReady()) {
            return grpc::Status(grpc::StatusCode::UNAVAILABLE, "Service is not ready");
        }

        TScopeTimer scopeTimer(Stats.Percentile, "timing.get_stats");

        Stats.Count->Add("request.get_stats.total.received");

        TVector<NLab::TUserDataStats> results;
        results.resize(Db.GetPacks().size());

        TEncodedExtendedCaRule encodedRule(*rule, Db.GetDicts().WordDicts.ReversedDict, Db.GetDicts().HostDicts.ReversedDict, Db.GetDicts().AppDicts.ReversedDict);

        TAtomicCounter filteredRecordCounter = 0;

        NPar::ParallelFor(Executor, 0, Db.GetPacks().size(), [this, &results, &encodedRule, &filteredRecordCounter](int i) {
            NLab::TUserDataStatsAggregator<NLab::TAffinitiesEncodedOptions> statsAggregator({.MaxTokensCount = 1000000, .MinSampleRatio = 0.01, .AccumulateAffinities = true});
            NNativeYT::TProtoState<NLab::TUserDataStatsOptions> state;
            NLab::TSegmentsUserDataStatsAggregator aggregator(state);
            NLab::TUserDataStats userDataStats;

            for (const auto& userData : Db.GetPacks().at(i)) {
                if (static_cast<size_t>(filteredRecordCounter.Val()) >= DescribeThreshold) {
                    break;
                }

                if (NFilter::Filter(userData, encodedRule)) {
                    aggregator.ConvertUserDataToUserDataStats(userDataStats, userData, false);
                    statsAggregator.UpdateWith(userDataStats);
                    filteredRecordCounter.Inc();
                }
            }

            statsAggregator.MergeInto(results[i], &Db.GetDicts().WordDicts.Dict, &Db.GetDicts().HostDicts.Dict, &Db.GetDicts().AppDicts.Dict);
        });

        NLab::TUserDataStatsAggregator<NLab::TAffinitiesEncodedOptions> statsAggregator({.MaxTokensCount = 1000000, .MinSampleRatio = 0.01, .AccumulateAffinities = true});
        for (const auto& userDataStats : results) {
            statsAggregator.UpdateWith(userDataStats);
        }

        statsAggregator.MergeInto(*response, &Db.GetDicts().WordDicts.Dict, &Db.GetDicts().HostDicts.Dict, &Db.GetDicts().AppDicts.Dict);
        NLab::NUserDataStatsDecoder::Decode(*response, Db.GetDicts().WordDicts.Dict, Db.GetDicts().HostDicts.Dict, Db.GetDicts().AppDicts.Dict);

        return grpc::Status::OK;
    } catch (const yexception& e) {
        return grpc::Status(grpc::StatusCode::INTERNAL, e.what());
    } catch (const std::exception& e) {
        return grpc::Status(grpc::StatusCode::INTERNAL, e.what());
    } catch (...) {
        return grpc::Status(grpc::StatusCode::INTERNAL, "Unknown exception");
    }
}


grpc::Status TCustomAudienceServiceImpl::GetStats(grpc::ServerContext*, const TCaRule* rule, NLab::TUserDataStats* response) {
    return GetStatsBase(rule, response);
}

grpc::Status TCustomAudienceServiceImpl::GetIds(grpc::ServerContext*, const TCaRule* rule, TPlainIds* response) {
    try {
        if (!Db.IsReady()) {
            return grpc::Status(grpc::StatusCode::UNAVAILABLE, "Service is not ready");
        }

        TScopeTimer scopeTimer(Stats.Percentile, "timing.get_ids");
        Stats.Count->Add("request.get_ids.total.received");

        TVector<TVector<TString>> results;
        results.resize(Db.GetPacks().size());
        for (auto& ids : results) {
            ids.reserve(DescribeThreshold / (Db.GetPacks().size() == 1 ? 1 : Db.GetPacks().size() - 1));
        }

        TEncodedCaRule encodedRule(*rule, Db.GetDicts().WordDicts.ReversedDict, Db.GetDicts().HostDicts.ReversedDict, Db.GetDicts().AppDicts.ReversedDict);
        TAtomicCounter filteredRecordCounter = 0;

        NPar::ParallelFor(Executor, 0, Db.GetPacks().size(), [this, &results, &encodedRule, &filteredRecordCounter](int i) {
            auto& ids = results[i];
            for (const auto& userData : Db.GetPacks().at(i)) {
                if (static_cast<size_t>(filteredRecordCounter.Val()) >= DescribeThreshold) {
                    break;
                }

                if (NFilter::Filter(userData, encodedRule)) {
                    ids.push_back(userData.GetCryptaID());
                    filteredRecordCounter.Inc();
                }
            }
        });

        auto* responseIds = response->MutableIds();
        responseIds->Reserve(DescribeThreshold);

        for (auto& ids : results) {
            for (auto& id : ids) {
                if (static_cast<size_t>(responseIds->size()) == DescribeThreshold) {
                    break;
                }
                responseIds->Add(std::move(id));
            }
        }

        return grpc::Status::OK;
    } catch (const yexception& e) {
        return grpc::Status(grpc::StatusCode::INTERNAL, e.what());
    } catch (const std::exception& e) {
        return grpc::Status(grpc::StatusCode::INTERNAL, e.what());
    } catch (...) {
        return grpc::Status(grpc::StatusCode::INTERNAL, "Unknown exception");
    }
}

grpc::Status TCustomAudienceServiceImpl::Ping(grpc::ServerContext*, const google::protobuf::Empty*, TPingResponse* response) {
    TScopeTimer scopeTimer(Stats.Percentile, "timing.ping");

    Stats.Count->Add("request.ping.status.ok");
    response->SetMessage("OK");
    return grpc::Status::OK;
}

grpc::Status TCustomAudienceServiceImpl::Ready(grpc::ServerContext*, const google::protobuf::Empty*, TPingResponse* response) {
    TScopeTimer scopeTimer(Stats.Percentile, "timing.ready");

    if (!Db.IsReady()) {
        return grpc::Status(grpc::StatusCode::UNAVAILABLE, "Service is not ready");
    }

    Stats.Count->Add("request.ready.status.ok");
    response->SetMessage("OK");
    return grpc::Status::OK;
}
