#pragma once

#include "counter.h"

#include <util/digest/sequence.h>
#include <util/generic/maybe.h>
#include <util/generic/hash_set.h>

template <>
struct THash<TVector<TString>> : public TSimpleRangeHash
{
};

namespace NMonitor {

const TString TOTAL_VALUE = "_TOTAL_";

template <class TCounters>
class TCounterHypercube : public TAbstractCounterSource {
public:
    using TCountersRef = TAtomicSharedPtr<TCounters>;

    explicit TCounterHypercube(const TVector<TString>& dimensionNames);

    virtual void QueryCounters(TCounterTable*) const override;
    virtual void QueryTables(const TCounterTableName& name, TNamedCounterTables* tables) const override;

    TCountersRef GetOrCreate(const TVector<TString>& point);
    void Clear();
private:
    struct TFixedDimension {
        size_t DimensionIndex;
        TString Value;
    };

    using TOnPointGenerated = std::function<void(const TVector<TString>&)>;

    void GeneratePoints(const TMaybe<TFixedDimension>& fixedDimension, bool withTotals,
                        const TOnPointGenerated& callback, TVector<TString>& pointPrefix) const;
private:
    TRWMutex Lock_;
    TVector<TString> DimensionNames_;
    TVector<THashSet<TString>> DimensionTicks_;
    THashMap<TVector<TString>, TCountersRef> Counters_;
};

template <class TCounters>
TCounterHypercube<TCounters>::TCounterHypercube(const TVector<TString>& dimensionNames)
    : DimensionNames_(dimensionNames)
    , DimensionTicks_(dimensionNames.size())
{
    if (DimensionNames_.empty()) {
        Counters_.emplace(TVector<TString>(), new TCounters());
    }
}

template <class TCounters>
void TCounterHypercube<TCounters>::QueryCounters(TCounterTable*) const
{
}

template <class TCounters>
void TCounterHypercube<TCounters>::GeneratePoints(const TMaybe<TFixedDimension>& fixedDimension, bool withTotals,
                                                  const TOnPointGenerated& callback, TVector<TString>& pointPrefix) const
{
    if (pointPrefix.size() == DimensionTicks_.size()) {
        callback(pointPrefix);
        return;
    }

    for (const auto& value : DimensionTicks_[pointPrefix.size()]) {
        if (fixedDimension && pointPrefix.size() == fixedDimension->DimensionIndex && value != fixedDimension->Value) {
            continue;
        }
        pointPrefix.push_back(value);
        GeneratePoints(fixedDimension, withTotals, callback, pointPrefix);
        pointPrefix.pop_back();
    }

    if (withTotals) {
        pointPrefix.push_back(TOTAL_VALUE);
        GeneratePoints(fixedDimension, withTotals, callback, pointPrefix);
        pointPrefix.pop_back();
    }
}

template <class TCounters>
void TCounterHypercube<TCounters>::QueryTables(const TCounterTableName& name, TNamedCounterTables* tables) const {
    TReadGuard rg(Lock_);

    THashMap<TVector<TString>, TCounterTable> accumulatedTotals;
    auto callback = [this, name, tables, &accumulatedTotals](const TVector<TString>& point) {
        static const size_t NoDimensionIndex = Max<size_t>();

        size_t totalDimensionIndex = NoDimensionIndex;
        size_t numberOfPoints = 1;
        for (size_t dimensionIndex = 0; dimensionIndex < DimensionTicks_.size(); ++dimensionIndex) {
            if (point[dimensionIndex] == TOTAL_VALUE) {
                totalDimensionIndex = dimensionIndex;
                numberOfPoints *= DimensionTicks_[dimensionIndex].size();
            }
        }

        TCounterTable counterTable;
        if (totalDimensionIndex == NoDimensionIndex) {
            Counters_.at(point)->QueryCounters(&counterTable);
        } else {
            auto mutablePoint = point;
            for (const auto& value : DimensionTicks_[totalDimensionIndex]) {
                mutablePoint[totalDimensionIndex] = value;
                AddToTotals(accumulatedTotals[mutablePoint], &counterTable);
            }
        }
        accumulatedTotals[point] = counterTable;

        FinishTotals(&counterTable, numberOfPoints);
        auto counterTableName = name;
        for (size_t dimensionIndex = 0; dimensionIndex < DimensionTicks_.size(); ++dimensionIndex) {
            counterTableName.emplace_back(DimensionNames_[dimensionIndex], point[dimensionIndex]);
        }
        tables->emplace_back(counterTableName, counterTable);
    };

    TVector<TString> pointPrefix;
    GeneratePoints(Nothing(), true, callback, pointPrefix);
}

template <class TCounters>
typename TCounterHypercube<TCounters>::TCountersRef TCounterHypercube<TCounters>::GetOrCreate(const TVector<TString>& point) {
    TReadGuard rg(Lock_);

    if (point.size() != DimensionTicks_.size()) {
        return nullptr;
    }

    for (size_t dimensionIndex = 0; dimensionIndex < DimensionTicks_.size(); ++dimensionIndex) {
        if (DimensionTicks_[dimensionIndex].contains(point[dimensionIndex])) {
            continue;
        }

        {
            auto irg = Unguard(rg);
            TWriteGuard wg(Lock_);

            DimensionTicks_[dimensionIndex].insert(point[dimensionIndex]);

            TFixedDimension fixedDimension;
            fixedDimension.DimensionIndex = dimensionIndex;
            fixedDimension.Value = point[dimensionIndex];
            auto callback = [this](const TVector<TString>& newPoint) {
                Counters_.emplace(newPoint, new TCounters());
            };

            TVector<TString> pointPrefix;
            GeneratePoints(fixedDimension, false, callback, pointPrefix);
        }
    }

    return Counters_.at(point);
}

template <class TCounters>
void TCounterHypercube<TCounters>::Clear() {
    TWriteGuard wg(Lock_);

    for (auto& ticks : DimensionTicks_) {
        ticks.clear();
    }
    Counters_.clear();
}

} // NMonitor
