#include "combine.h"

#include <solomon/libs/cpp/ts_math/aggregation.h>
#include <solomon/libs/cpp/ts_math/error.h>

#include <solomon/libs/cpp/ts_model/aggregator.h>
#include <solomon/libs/cpp/ts_model/aggregator_avg.h>
#include <solomon/libs/cpp/ts_model/aggregator_count.h>
#include <solomon/libs/cpp/ts_model/aggregator_last.h>
#include <solomon/libs/cpp/ts_model/aggregator_min.h>
#include <solomon/libs/cpp/ts_model/aggregator_max.h>
#include <solomon/libs/cpp/ts_model/aggregator_sum.h>
#include <solomon/libs/cpp/ts_model/visit.h>

namespace NSolomon::NTsMath {

namespace {

class TCombine: public IOperation {
    template <typename TAggregator>
    class TIterator: public NTsModel::IIterator<typename TAggregator::TOutput> {
        using TInput = typename TAggregator::TInput;
        using TOutput = typename TAggregator::TOutput;

    public:
        TIterator(const TVector<TTimeSeries>& source) {
            Source_.reserve(source.size());
            for (auto& ts: source) {
                THolder<NTsModel::IIterator<TInput>> it{ts.Data->Iterator().Release()->As<TInput>()};
                TInput point;

                if (it->NextPoint(&point)) {
                    Source_.push_back({{std::move(point), std::move(it)}});
                }
            }
        }

    public:
        bool NextPoint(TOutput* point) override {
            TInstant time = TInstant::Max();
            bool hasPoint = false;

            for (auto& ts: Source_) {
                if (ts.has_value()) {
                    time = std::min(time, ts->first.Time);
                    hasPoint = true;
                }
            }

            for (auto& ts: Source_) {
                if (ts.has_value() && time == ts->first.Time) {
                    Aggregator_.Add(ts->first);
                    if (!ts->second->NextPoint(&ts->first)) {
                        ts.reset();
                    }
                }
            }

            *point = Aggregator_.Finish();
            return hasPoint;
        }

    private:
        TVector<std::optional<std::pair<TInput, THolder<NTsModel::IIterator<TInput>>>>> Source_;
        TAggregator Aggregator_ = {};
    };

    template <typename TAggregator>
    class TIterable: public NTsModel::IIterable {
    public:
        TIterable(TVector<TTimeSeries> source)
            : Source_(std::move(source))
        {
            for (auto& source: Source_) {
                WindowBegin_ = std::min(WindowBegin_, source.Data->WindowBegin());
                WindowEnd_ = std::max(WindowEnd_, source.Data->WindowEnd());
            }
        }

    public:
        NTsModel::EPointType Type() const override {
            return TAggregator::TOutput::Type;
        }

        NTsModel::TPointColumns Columns() const override {
            return TAggregator::OutputColumns();
        }

        THolder<NTsModel::IGenericIterator> Iterator() const override {
            return MakeHolder<TIterator<TAggregator>>(Source_);
        }

        TInstant WindowBegin() const override {
            return WindowBegin_;
        }

        TInstant WindowEnd() const override {
            return WindowEnd_;
        }

    private:
        TVector<TTimeSeries> Source_;
        TInstant WindowBegin_ = TInstant::Max();
        TInstant WindowEnd_ = TInstant::Zero();
    };

    template <typename TAggregator>
    TTimeSeries CreateTs(TVector<TTimeSeries> source) {
        return TTimeSeries {
            .Labels = {},
            .Summary = {},
            .Data = MakeHolder<TIterable<TAggregator>>(std::move(source))
        };
    }

public:
    TCombine(yandex::solomon::math::Aggregation aggregationFunction)
        : AggregationFunction_{aggregationFunction}
    {
    }

    TVector<TTimeSeries> Apply(TVector<TTimeSeries>&& source) override {
        if (source.empty() || source.size() == 1 && AggregationFunction_ != yandex::solomon::math::COUNT) {
            return std::move(source);
        }

        auto type = source[0].Data->Type();

        for (auto& ts: source) {
            auto tsType = ts.Data->Type();
            if (tsType != type) {
                ythrow TTypeError{} << "unable to combine metrics of different types: " << type << " and " << tsType;
            }
        }

        auto aggregationFunction = AggregationFunction(AggregationFunction_, type);

        TVector<TTimeSeries> result;

        result.push_back(NTsModel::Visit(type, [this, aggregationFunction, type, &source](auto traits) -> TTimeSeries {
            using TPoint = typename decltype(traits)::TPoint;

            switch (aggregationFunction) {
                case NTsModel::EAggregationFunction::Min:
                    if constexpr (NTsModel::TPointTraits<TPoint>::IsScalar) {
                        return CreateTs<NTsModel::TMinAggregator<TPoint>>(std::move(source));
                    } else {
                        ythrow TNotImplementedError{}
                            << "aggregator 'min' cannot be applied to points of type " << type;
                    }
                case NTsModel::EAggregationFunction::Max:
                    if constexpr (NTsModel::TPointTraits<TPoint>::IsScalar) {
                        return CreateTs<NTsModel::TMaxAggregator<TPoint>>(std::move(source));
                    } else {
                        ythrow TNotImplementedError{}
                            << "aggregator 'max' cannot be applied to points of type " << type;
                    }
                case NTsModel::EAggregationFunction::Sum:
                    return CreateTs<NTsModel::TSumAggregator<TPoint>>(std::move(source));
                case NTsModel::EAggregationFunction::Last:
                    return CreateTs<NTsModel::TLastAggregator<TPoint>>(std::move(source));
                case NTsModel::EAggregationFunction::Avg:
                    if constexpr (NTsModel::TPointTraits<TPoint>::IsScalar) {
                        return CreateTs<NTsModel::TAvgAggregator<TPoint>>(std::move(source));
                    } else {
                        ythrow TNotImplementedError{}
                            << "aggregator 'avg' cannot be applied to points of type " << type;
                    }
                case NTsModel::EAggregationFunction::Count:
                    return CreateTs<NTsModel::TCountAggregator<TPoint>>(std::move(source));
            }
        }));

        return result;
    }

private:
    yandex::solomon::math::Aggregation AggregationFunction_;
};

} // namespace

THolder<IOperation> Combine(const yandex::solomon::math::OperationCombine& settings) {
    return MakeHolder<TCombine>(settings.aggregation());
}

} // namespace NSolomon::NTsMath
