#include "aggregates.h"

#include <util/generic/yexception.h>

namespace NSolomon::NDataProxy {
namespace {

using yandex::solomon::stockpile::MetricData;
using yandex::solomon::math::Aggregate;

template <typename TProto>
void CommonFieldsFromProto(const TProto& proto, TAggregateVariant* aggr) {
    aggr->Count = proto.count();
    aggr->Mask = proto.mask();
}

template <typename TProto>
void CommonFieldsToProto(const TAggregateVariant& aggr, TProto* proto) {
    proto->set_count(aggr.Count);
    proto->set_mask(aggr.Mask);
}

template <typename TProto, typename TScalar>
void ScalarFromProto(const TProto& proto, TScalar* value) {
    value->Max = proto.max(),
    value->Min = proto.min(),
    value->Sum = proto.sum(),
    value->Avg = proto.avg(),
    value->Last = proto.last();
}

template <typename TScalar, typename TProto>
void ScalarToProto(const TScalar& value, TProto* proto) {
    proto->set_max(value.Max);
    proto->set_min(value.Min);
    proto->set_sum(value.Sum);
    proto->set_avg(value.Avg);
    proto->set_last(value.Last);
}

template <typename TProto, typename TSummary>
void SummaryFromProto(const TProto& proto, TSummary* value) {
    value->CountValue = proto.count();
    value->Sum = proto.sum();
    value->Min = proto.min();
    value->Max = proto.max();
    value->Last = proto.last();
}

template <typename TAggr, typename TProto>
void SummaryToProto(const TAggr& value, TProto* proto) {
    proto->set_count(value.CountValue);
    proto->set_sum(value.Sum);
    proto->set_min(value.Min);
    proto->set_max(value.Max);
    proto->set_last(value.Last);
}

void LogHistFromProto(const yandex::solomon::model::LogHistogram& proto, NTs::NValue::TLogHistogram* value) {
    value->Values.assign(proto.buckets().begin(), proto.buckets().end());
    value->ZeroCount = proto.zeroes();
    value->StartPower = proto.start_power();
    value->MaxBucketCount = proto.max_buckets_size();
    value->Base = proto.base();
}

void LogHistToProto(const NTs::NValue::TLogHistogram& value, yandex::solomon::model::LogHistogram* proto) {
    size_t size = value.Values.size();
    auto* buckets = proto->mutable_buckets();
    buckets->Reserve(size);
    std::copy(value.Values.begin(), value.Values.end(), buckets->AddNAlreadyReserved(size));

    proto->set_zeroes(value.ZeroCount);
    proto->set_start_power(value.StartPower);
    proto->set_max_buckets_size(value.MaxBucketCount);
    proto->set_base(value.Base);
}

void HistFromProto(const yandex::solomon::model::Histogram& proto, NTs::NValue::THistogram* value) {
    Y_ENSURE(proto.buckets_size() == proto.bounds_size(),
             "histogram with inconsistent buckets(" << proto.buckets_size()
             << ") and bounds(" << proto.bounds_size() << ") sizes");

    value->Denom = proto.denom();
    value->Buckets.reserve(proto.buckets_size());

    for (int i = 0; i < proto.bounds_size(); i++) {
        value->Buckets.emplace_back(NTs::NValue::THistogram::TBucket{proto.bounds(i), proto.buckets(i)});
    }
}

void HistToProto(const NTs::NValue::THistogram& value, yandex::solomon::model::Histogram* proto) {
    auto* bounds = proto->mutable_bounds();
    auto* buckets = proto->mutable_buckets();
    bounds->Reserve(value.Buckets.size());
    buckets->Reserve(value.Buckets.size());

    for (const auto& b: value.Buckets) {
        bounds->AddAlreadyReserved(b.UpperBound);
        buckets->AddAlreadyReserved(b.Value);
    }

    proto->set_denom(value.Denom);
}

template <typename TProto, typename TTypeCase>
void FromProto(const TProto& proto, TTypeCase typeCase, TAggregateVariant* aggr) {
    switch (typeCase) {
        case 0:
            // type case is not set
            return;

        case TTypeCase::kDouble:
            ScalarFromProto(proto.double_(), &aggr->Value.emplace<TAggregateDouble>());
            CommonFieldsFromProto(proto.double_(), aggr);
            break;

        case TTypeCase::kInt64:
            ScalarFromProto(proto.int64(), &aggr->Value.emplace<TAggregateInt64>());
            CommonFieldsFromProto(proto.int64(), aggr);
            break;

        case TTypeCase::kLogHistogram: {
            const auto& protoAggr = proto.log_histogram();
            auto& value = aggr->Value.emplace<TAggregateLogHistogram>();
            LogHistFromProto(protoAggr.last(), &value.Last);
            LogHistFromProto(protoAggr.sum(), &value.Sum);
            CommonFieldsFromProto(protoAggr, aggr);
            break;
        }

        case TTypeCase::kHistogram: {
            const auto& protoAggr = proto.histogram();
            auto& value = aggr->Value.emplace<TAggregateHistogram>();
            HistFromProto(protoAggr.last(), &value.Last);
            HistFromProto(protoAggr.sum(), &value.Sum);
            CommonFieldsFromProto(protoAggr, aggr);
            break;
        }

        case TTypeCase::kSummaryDouble: {
            const auto& protoAggr = proto.summary_double();
            auto& summary = aggr->Value.emplace<TAggregateSummaryDouble>();
            SummaryFromProto(protoAggr.last(), &summary.Last);
            SummaryFromProto(protoAggr.sum(), &summary.Sum);
            CommonFieldsFromProto(protoAggr, aggr);
            break;
        }

        case TTypeCase::kSummaryInt64: {
            const auto& protoAggr = proto.summary_int64();
            auto& summary = aggr->Value.emplace<TAggregateSummaryInt64>();
            SummaryFromProto(protoAggr.last(), &summary.Last);
            SummaryFromProto(protoAggr.sum(), &summary.Sum);
            CommonFieldsFromProto(protoAggr, aggr);
            break;
        }

        default:
            ythrow yexception() << "unknown aggregate type case: " << static_cast<int>(typeCase);
    }
}

struct TToProtoVisitor {
    const TAggregateVariant& Aggr;
    Aggregate* Proto;

    void operator()(const TAggregateDouble& value) const {
        auto* proto = Proto->mutable_double_();
        ScalarToProto(value, proto);
        CommonFieldsToProto(Aggr, proto);
    }

    void operator()(const TAggregateInt64& value) const {
        auto* proto = Proto->mutable_int64();
        ScalarToProto(value, proto);
        CommonFieldsToProto(Aggr, proto);
    }

    void operator()(const TAggregateSummaryDouble& value) const {
        auto* proto = Proto->mutable_summary_double();
        SummaryToProto(value.Last, proto->mutable_last());
        SummaryToProto(value.Sum, proto->mutable_sum());
        CommonFieldsToProto(Aggr, proto);
    }

    void operator()(const TAggregateSummaryInt64& value) const {
        auto* proto = Proto->mutable_summary_int64();
        SummaryToProto(value.Last, proto->mutable_last());
        SummaryToProto(value.Sum, proto->mutable_sum());
        CommonFieldsToProto(Aggr, proto);
    }

    void operator()(const TAggregateLogHistogram& value) const {
        auto* proto = Proto->mutable_log_histogram();
        LogHistToProto(value.Last, proto->mutable_last());
        LogHistToProto(value.Sum, proto->mutable_sum());
        CommonFieldsToProto(Aggr, proto);
    }

    void operator()(const TAggregateHistogram& value) const {
        auto* proto = Proto->mutable_histogram();
        HistToProto(value.Last, proto->mutable_last());
        HistToProto(value.Sum, proto->mutable_sum());
        CommonFieldsToProto(Aggr, proto);
    }
};

template <typename TResp>
std::optional<TResp> GetMinAggregation(yandex::solomon::stockpile::MetricData* metricData) {
    if (metricData->has_double_()) {
        return static_cast<TResp>(metricData->double_().min());
    } else if (metricData->has_int64()) {
        return static_cast<TResp>(metricData->int64().min());
    }
    return std::nullopt;
}

template <typename TResp>
std::optional<TResp> GetMaxAggregation(yandex::solomon::stockpile::MetricData* metricData) {
    if (metricData->has_double_()) {
        return static_cast<TResp>(metricData->double_().max());
    } else if (metricData->has_int64()) {
        return static_cast<TResp>(metricData->int64().max());
    }
    return std::nullopt;
}

template <typename TResp>
std::optional<TResp> GetSumAggregation(yandex::solomon::stockpile::MetricData* metricData) {
    if (metricData->has_double_()) {
        return static_cast<TResp>(metricData->double_().sum());
    } else if (metricData->has_int64()) {
        return static_cast<TResp>(metricData->int64().sum());
    }
    return std::nullopt;
}

template <typename TResp>
std::optional<TResp> GetAvgAggregation(yandex::solomon::stockpile::MetricData* metricData) {
    if (metricData->has_double_()) {
        return static_cast<TResp>(metricData->double_().avg());
    } else if (metricData->has_int64()) {
        return static_cast<TResp>(metricData->int64().avg());
    }
    return std::nullopt;
}

template <typename TResp>
std::optional<TResp> GetLastAggregation(yandex::solomon::stockpile::MetricData* metricData) {
    if (metricData->has_double_()) {
        return static_cast<TResp>(metricData->double_().last());
    } else if (metricData->has_int64()) {
        return static_cast<TResp>(metricData->int64().last());
    }
    return std::nullopt;
}

template <typename TResp>
std::optional<TResp> GetCountAggregation(yandex::solomon::stockpile::MetricData* metricData) {
    if (metricData->has_double_()) {
        return static_cast<TResp>(metricData->double_().count());
    } else if (metricData->has_int64()) {
        return static_cast<TResp>(metricData->int64().count());
    }
    return std::nullopt;
}

template <typename TResp>
std::optional<TResp> GetAggregationFiller(yandex::solomon::stockpile::MetricData* metricData) {
    Y_UNUSED(metricData);
    return std::nullopt;
}

} // namespace

bool operator==(const TAggregateVariant& lhs, const TAggregateVariant& rhs) {
    if (lhs.Count != rhs.Count || lhs.Mask != rhs.Mask) {
        return false;
    }
    if (lhs.Value.index() != rhs.Value.index()) {
        return false;
    }
    return std::visit([&rhs](const auto& left) {
        using T = std::decay_t<decltype(left)>;
        const auto& right = std::get<T>(rhs.Value);
        return left == right;
    }, lhs.Value);
}

void FromProto(const Aggregate& proto, TAggregateVariant* aggr) {
    FromProto<Aggregate, Aggregate::TypeCase>(proto, proto.type_case(), aggr);
}

void FromProto(const MetricData& proto, TAggregateVariant* aggr) {
    FromProto<MetricData, MetricData::AggregateCase>(proto, proto.aggregate_case(), aggr);
}

void ToProto(const TAggregateVariant& aggr, Aggregate* proto) {
    std::visit(TToProtoVisitor{aggr, proto}, aggr.Value);
}

template <typename TResp>
TOptionalMetricProcessor<TResp> GetSummaryFieldExtractor(yandex::solomon::math::Aggregation aggregation) {
    switch (aggregation) {
        case yandex::solomon::math::Aggregation::MIN:
            return &GetMinAggregation<TResp>;
        case yandex::solomon::math::Aggregation::MAX:
            return &GetMaxAggregation<TResp>;
        case yandex::solomon::math::Aggregation::SUM:
            return &GetSumAggregation<TResp>;
        case yandex::solomon::math::Aggregation::AVG:
            return &GetAvgAggregation<TResp>;
        case yandex::solomon::math::Aggregation::LAST:
            return &GetLastAggregation<TResp>;
        case yandex::solomon::math::Aggregation::COUNT:
            return &GetCountAggregation<TResp>;
            // Extractor is only used with int & double aggregation, default is last for both
        case yandex::solomon::math::Aggregation::DEFAULT_AGGREGATION:
            return &GetLastAggregation<TResp>;
        default:
            return &GetAggregationFiller<TResp>;
    }
}

// For linker to be able to find template function implementation in this file.
template TOptionalMetricProcessor<double> GetSummaryFieldExtractor<double>(yandex::solomon::math::Aggregation aggregation);
template TOptionalMetricProcessor<i64> GetSummaryFieldExtractor<i64>(yandex::solomon::math::Aggregation aggregation);

} // namespace NSolomon::NDataProxy
