#include "ru_yandex_solomon_ts_codec_TsStreamNative.h"
#include "aggr_point.h"
#include "heap_bit_buf.h"
#include "histogram.h"
#include "summary.h"

#include <solomon/libs/cpp/ts_codec/double_ts_codec.h>
#include <solomon/libs/cpp/ts_codec/counter_ts_codec.h>
#include <solomon/libs/cpp/ts_codec/gauge_int_ts_codec.h>
#include <solomon/libs/cpp/ts_codec/summary_int_ts_codec.h>
#include <solomon/libs/cpp/ts_codec/summary_double_ts_codec.h>
#include <solomon/libs/cpp/ts_codec/hist_ts_codec.h>
#include <solomon/libs/cpp/ts_codec/hist_log_ts_codec.h>

#include <vector>

using namespace NSolomon;
using namespace NTs;

namespace {

template <typename T>
struct TPointConverter;

template <>
struct TPointConverter<TDoublePoint> {
    static void FromJava(JNIEnv* jenv, TDoublePoint& point, jobject pointJava, NJava::TAggrPointClass& clazz) {
        point.Num = clazz.GetValueNum(jenv, pointJava);
        point.Denom = clazz.GetValueDenom(jenv, pointJava);
    }

    static void ToJava(JNIEnv* jenv, const TDoublePoint& point, jobject pointJava, NJava::TAggrPointClass& clazz) {
        clazz.SetValueNum(jenv, pointJava, point.Num);
        clazz.SetValueDenom(jenv, pointJava, point.Denom);
    }
};

template <>
struct TPointConverter<TLongPoint> {
    static void FromJava(JNIEnv* jenv, TLongPoint& point, jobject pointJava, NJava::TAggrPointClass& clazz) {
        point.Value = clazz.GetLongValue(jenv, pointJava);
    }

    static void ToJava(JNIEnv* jenv, const TLongPoint& point, jobject pointJava, NJava::TAggrPointClass& clazz) {
        clazz.SetLongValue(jenv, pointJava, point.Value);
    }
};

template <>
struct TPointConverter<TSummaryIntPoint> {
    static void FromJava(JNIEnv* jenv, TSummaryIntPoint& point, jobject pointJava, NJava::TAggrPointClass& pointClass) {
        jobject summary = pointClass.GetSummaryInt64(jenv, pointJava);

        NJava::TSummaryInt64Class summaryClass{jenv};
        point.CountValue = summaryClass.GetCount(jenv, summary);
        point.Sum = summaryClass.GetSum(jenv, summary);
        point.Min = summaryClass.GetMin(jenv, summary);
        point.Max = summaryClass.GetMax(jenv, summary);
        point.Last = summaryClass.GetLast(jenv, summary);
    }

    static void ToJava(JNIEnv* jenv, const TSummaryIntPoint& point, jobject pointJava, NJava::TAggrPointClass& pointClass) {
        NJava::TSummaryInt64Class summaryClass{jenv};
        jobject summary = summaryClass.New(
                jenv,
                point.CountValue,
                point.Sum,
                point.Min,
                point.Max,
                point.Last);
        pointClass.SetSummaryInt64(jenv, pointJava, summary);
    }
};

template <>
struct TPointConverter<TSummaryDoublePoint> {
    static void FromJava(JNIEnv* jenv, TSummaryDoublePoint& point, jobject pointJava, NJava::TAggrPointClass& pointClass) {
        jobject summary = pointClass.GetSummaryDouble(jenv, pointJava);

        NJava::TSummaryDoubleClass summaryClass{jenv};
        point.CountValue = summaryClass.GetCount(jenv, summary);
        point.Sum = summaryClass.GetSum(jenv, summary);
        point.Min = summaryClass.GetMin(jenv, summary);
        point.Max = summaryClass.GetMax(jenv, summary);
        point.Last = summaryClass.GetLast(jenv, summary);
    }

    static void ToJava(JNIEnv* jenv, const TSummaryDoublePoint& point, jobject pointJava, NJava::TAggrPointClass& pointClass) {
        NJava::TSummaryDoubleClass summaryClass{jenv};
        jobject summary = summaryClass.New(
                jenv,
                point.CountValue,
                point.Sum,
                point.Min,
                point.Max,
                point.Last);
        pointClass.SetSummaryDouble(jenv, pointJava, summary);
    }
};

template <>
struct TPointConverter<THistogramPoint> {
    static void FromJava(JNIEnv* jenv, THistogramPoint& point, jobject pointJava, NJava::TAggrPointClass& pointClass) {
        NJava::THistogramClass histClass{jenv};
        NValue::THistogram hist = histClass.ToHist(jenv, pointClass.GetHistogram(jenv, pointJava));

        point.Denom = hist.Denom;
        point.Buckets = std::move(hist).Buckets;
    }

    static void ToJava(JNIEnv* jenv, const THistogramPoint& point, jobject pointJava, NJava::TAggrPointClass& pointClass) {
        NJava::THistogramClass histClass{jenv};
        jobject hist = histClass.ToJava(jenv, point);
        pointClass.SetHistogram(jenv, pointJava, hist);
    }
};

template <>
struct TPointConverter<TLogHistogramPoint> {
    static void FromJava(JNIEnv* jenv, TLogHistogramPoint& point, jobject pointJava, NJava::TAggrPointClass& pointClass) {
        NJava::TLogHistogramClass histClass{jenv};
        NValue::TLogHistogram hist = histClass.ToHist(jenv, pointClass.GetLogHistogram(jenv, pointJava));

        point.ZeroCount = hist.ZeroCount;
        point.StartPower = hist.StartPower;
        point.MaxBucketCount = hist.MaxBucketCount;
        point.Base = hist.Base;
        point.Values = std::move(hist).Values;
    }

    static void ToJava(JNIEnv* jenv, const TLogHistogramPoint& point, jobject pointJava, NJava::TAggrPointClass& pointClass) {
        NJava::TLogHistogramClass histClass{jenv};
        jobject hist = histClass.ToJava(jenv, point);
        pointClass.SetLogHistogram(jenv, pointJava, hist);
    }
};

template <typename TEncoder, typename TPoint>
jobject Encode(JNIEnv* jenv, jint columnsMask, jobjectArray points) {
    NJava::TAggrPointClass aggrPointClass{jenv};
    TColumnSet columns{static_cast<TColumnsMask>(columnsMask)};

    TBitBuffer buffer;
    TBitWriter w{&buffer};
    TEncoder encoder{columns, &w};

    for (jsize i = 0, len = jenv->GetArrayLength(points); i < len; ++i) {
        jobject pointJava = jenv->GetObjectArrayElement(points, i);

        TPoint point;
        point.Time = TInstant::MilliSeconds(aggrPointClass.GetTsMillis(jenv, pointJava));

        TPointConverter<TPoint>::FromJava(jenv, point, pointJava, aggrPointClass);

        if (columns.IsSet(EColumn::COUNT)) {
            point.Count = aggrPointClass.GetCount(jenv, pointJava);
        }

        if (columns.IsSet(EColumn::MERGE)) {
            point.Merge = aggrPointClass.GetMerge(jenv, pointJava);
        }

        if (columns.IsSet(EColumn::STEP)) {
            point.Step = TDuration::MilliSeconds(aggrPointClass.GetStepMillis(jenv, pointJava));
        }

        encoder.EncodePoint(point);
    }

    encoder.Flush();
    return NJava::ToHeapBitBuffer(jenv, buffer);
}

template <typename TDecoder, typename TPoint>
jobjectArray Decode(JNIEnv* jenv, jint columnsMask, jobject bitBuf) {
    NJava::TAggrPointClass aggrPointClass{jenv};
    TColumnSet columns{static_cast<TColumnsMask>(columnsMask)};

    auto buffer = NJava::FromHeapBitBuffer(jenv, bitBuf);
    TDecoder decoder{columns, buffer};

    std::vector<TPoint> points;
    for (TPoint point; decoder.NextPoint(&point); ) {
        points.push_back(point);
    }

    jobjectArray pointsJava = jenv->NewObjectArray(points.size(), aggrPointClass.Class(), nullptr);
    for (size_t i = 0; i < points.size(); ++i) {
        const auto& point = points[i];

        jobject pointJava = aggrPointClass.New(jenv);
        aggrPointClass.SetColumnSet(jenv, pointJava, columns.Mask());
        aggrPointClass.SetTsMillis(jenv, pointJava, point.Time.MilliSeconds());

        TPointConverter<TPoint>::ToJava(jenv, point, pointJava, aggrPointClass);

        if (columns.IsSet(EColumn::COUNT)) {
            aggrPointClass.SetCount(jenv, pointJava, point.Count);
        }

        if (columns.IsSet(EColumn::MERGE)) {
            aggrPointClass.SetMerge(jenv, pointJava, point.Merge);
        }

        if (columns.IsSet(EColumn::STEP)) {
            aggrPointClass.SetStepMillis(jenv, pointJava, point.Step.MilliSeconds());
        }

        jenv->SetObjectArrayElement(pointsJava, i, pointJava);
    }

    return pointsJava;
}

} // namespace

jobject Java_ru_yandex_solomon_ts_1codec_TsStreamNative_encodeDouble(
        JNIEnv* jenv, jclass, jint columnsMask, jobjectArray points)
{
    return Encode<TDoubleTsEncoder, TDoublePoint>(jenv, columnsMask, points);
}

jobjectArray Java_ru_yandex_solomon_ts_1codec_TsStreamNative_decodeDouble(
        JNIEnv* jenv, jclass, jint columnsMask, jobject bitBuf)
{
    return Decode<TDoubleTsDecoder, TDoublePoint>(jenv, columnsMask, bitBuf);
}

jobject Java_ru_yandex_solomon_ts_1codec_TsStreamNative_gaugeInt64Encode(
        JNIEnv* jenv, jclass, jint columnsMask, jobjectArray points)
{
    return Encode<TGaugeIntTsEncoder, TLongPoint>(jenv, columnsMask, points);
}

jobjectArray Java_ru_yandex_solomon_ts_1codec_TsStreamNative_gaugeInt64Decode(
        JNIEnv* jenv, jclass, jint columnsMask, jobject bitBuf)
{
    return Decode<TGaugeIntTsDecoder, TLongPoint>(jenv, columnsMask, bitBuf);
}

jobject Java_ru_yandex_solomon_ts_1codec_TsStreamNative_counterEncode(
        JNIEnv* jenv, jclass, jint columnsMask, jobjectArray points)
{
    return Encode<TCounterTsEncoder, TLongPoint>(jenv, columnsMask, points);
}

JNIEXPORT jobjectArray JNICALL Java_ru_yandex_solomon_ts_1codec_TsStreamNative_counterDecode(
        JNIEnv* jenv, jclass, jint columnsMask, jobject bitBuf)
{
    return Decode<TCounterTsDecoder, TLongPoint>(jenv, columnsMask, bitBuf);
}

jobject Java_ru_yandex_solomon_ts_1codec_TsStreamNative_summaryInt64Encode(
        JNIEnv* jenv, jclass, jint columnsMask, jobjectArray points)
{
    return Encode<TSummaryIntTsEncoder, TSummaryIntPoint>(jenv, columnsMask, points);
}

jobjectArray Java_ru_yandex_solomon_ts_1codec_TsStreamNative_summaryInt64Decode(
        JNIEnv* jenv, jclass, jint columnsMask, jobject bitBuf)
{
    return Decode<TSummaryIntTsDecoder, TSummaryIntPoint>(jenv, columnsMask, bitBuf);
}

jobject Java_ru_yandex_solomon_ts_1codec_TsStreamNative_summaryDoubleEncode(
        JNIEnv* jenv, jclass, jint columnsMask, jobjectArray points)
{
    return Encode<TSummaryDoubleTsEncoder, TSummaryDoublePoint>(jenv, columnsMask, points);
}

jobjectArray Java_ru_yandex_solomon_ts_1codec_TsStreamNative_summaryDoubleDecode(
        JNIEnv* jenv, jclass, jint columnsMask, jobject bitBuf)
{
    return Decode<TSummaryDoubleTsDecoder, TSummaryDoublePoint>(jenv, columnsMask, bitBuf);
}

jobject JNICALL Java_ru_yandex_solomon_ts_1codec_TsStreamNative_histogramEncode(
        JNIEnv* jenv, jclass, jint columnsMask, jobjectArray points)
{
    return Encode<THistogramTsEncoder, THistogramPoint>(jenv, columnsMask, points);
}

jobjectArray JNICALL Java_ru_yandex_solomon_ts_1codec_TsStreamNative_histogramDecode(
        JNIEnv* jenv, jclass, jint columnsMask, jobject bitBuf)
{
    return Decode<THistogramTsDecoder, THistogramPoint>(jenv, columnsMask, bitBuf);
}

jobject JNICALL Java_ru_yandex_solomon_ts_1codec_TsStreamNative_logHistogramEncode(
        JNIEnv* jenv, jclass, jint columnsMask, jobjectArray points)
{
    return Encode<TLogHistogramTsEncoder, TLogHistogramPoint>(jenv, columnsMask, points);
}

jobjectArray JNICALL Java_ru_yandex_solomon_ts_1codec_TsStreamNative_logHistogramDecode(
        JNIEnv* jenv, jclass, jint columnsMask, jobject bitBuf)
{
    return Decode<TLogHistogramTsDecoder, TLogHistogramPoint>(jenv, columnsMask, bitBuf);
}
