#pragma once

#include "points.h"
#include "point_columns.h"

namespace NSolomon::NTsModel {

enum class EAggregationFunction {
    Min,
    Max,
    Sum,
    Last,
    Avg,
    Count,
};

namespace NPrivate {

struct TAggregatorTag {};

} // NPrivate

template <typename TInput_, typename TOutput_ = TInput_>
class IAggregator;

/**
 * Abstract base class for aggregators of different types.
 */
class IGenericAggregator {
public:
    IGenericAggregator(IGenericAggregator&&) noexcept = default;
    IGenericAggregator& operator=(IGenericAggregator&&) noexcept = default;

    IGenericAggregator(const IGenericAggregator&) = delete;
    IGenericAggregator& operator=(const IGenericAggregator&) = delete;

    IGenericAggregator() = default;
    virtual ~IGenericAggregator() = default;

public:
    /**
     * Get type of this aggregator's input.
     */
    virtual EPointType InputType() const = 0;

    /**
     * Get type of this aggregator's output
     */
    virtual EPointType OutputType() const = 0;

    /**
     * Get aggregation function.
     */
    virtual EAggregationFunction Function() const = 0;

    /**
     * Get columns of the output.
     */
    virtual TPointColumns Columns() const = 0;

    /**
     * Cast this aggregator base to a concrete aggregator type.
     * Panic if requested type doesn't match type of this iterator.
     */
    template <typename TInput, typename TOutput = TInput>
    IAggregator<TInput, TOutput>* As() noexcept {
        Y_VERIFY(InputType() == TInput::Type, "aggregator cast failed");
        Y_VERIFY(OutputType() == TOutput::Type, "aggregator cast failed");
        return CheckedCast<IAggregator<TInput, TOutput>*>(this);
    }
    template <typename TInput, typename TOutput = TInput>
    const IAggregator<TInput, TOutput>* As() const noexcept {
        Y_VERIFY(InputType() == TInput::Type, "aggregator cast failed");
        Y_VERIFY(OutputType() == TOutput::Type, "aggregator cast failed");
        return CheckedCast<const IAggregator<TInput, TOutput>*>(this);
    }
};

/**
 * Base class for iterating over time series.
 */
template <typename TInput_, typename TOutput_ /* = TInput_ */>
class IAggregator: public IGenericAggregator, public NPrivate::TAggregatorTag {
    static_assert(IsPointV<TInput_>, "expected a time series point");
    static_assert(IsPointV<TOutput_>, "expected a time series point");

public:
    using TInput = TInput_;
    using TOutput = TOutput_;

public:
    EPointType InputType() const final {
        return TInput::Type;
    }

    EPointType OutputType() const final {
        return TOutput::Type;
    }

    TPointColumns Columns() const override {
        return OutputColumns();
    }

    static TPointColumns OutputColumns() {
        return TPointColumns{TPointColumns::Step, TPointColumns::Count};
    }

public:
    /**
     * Add another point to the aggregator state.
     */
    virtual void Add(const TInput& point) = 0;

    /**
     * Get number of points added so far.
     */
    virtual i64 Count() const = 0;

    /**
     * Check if there are any points accumulated in this aggregator.
     */
    bool HasPoints() const {
        return Count() > 0;
    }

    /**
     * Finish aggregation and yield the result.
     * Aggregator state is reset after calling this function.
     */
    virtual TOutput Finish() = 0;

    /**
     * Reset aggregation state.
     */
    virtual void Reset() = 0;
};

/**
 * Check if the given type `T` is an aggregator.
 */
template <typename T>
static constexpr const bool IsAggregatorV = std::is_base_of_v<NPrivate::TAggregatorTag, T>;

/**
 * Check if the given type `T` is an aggregator that accepts points of type `TInput`
 * and yields points of type `TOutput`.
 */
template <typename T, typename TInput, typename TOutput = TInput>
static constexpr const bool IsAggregatorPointV = std::is_base_of_v<IAggregator<TInput, TOutput>, T>;

/**
 * Convenience base for aggregator implementations.
 */
template <typename TInput_, typename TOutput_ = TInput_>
class TAggregator: public IAggregator<TInput_, TOutput_> {
public:
    i64 Count() const final {
        return Count_;
    }

protected:
    void AddAggr(const NTs::TPointCommon& point, bool setTime = false) {
        if (Count_ == 0 || setTime) {
            Time_ = point.Time;
            Step_ = point.Step;
        }

        Count_ += 1;
        AggrCount_ += std::max(point.Count, 1ul);
    }

    void FillAggr(NTs::TPointCommon* res) {
        res->Time = Time_;
        res->Step = Step_;
        res->Count = AggrCount_;
        res->Merge = false;
    }

    void ResetAggr() {
        Count_ = 0;
        AggrCount_ = 0;
        Time_ = {};
        Step_ = {};
    }

protected:
    i64 Count_ = 0;
    i64 AggrCount_ = 0;
    TInstant Time_ = {};
    TDuration Step_ = {};
};

} // namespace NSolomon::NTsModel
