#include "normal.h"

#include <util/generic/algorithm.h>
#include <util/generic/utility.h>
#include <util/generic/yexception.h>

#include <cmath>

using namespace NZoom::NHgram;

TNormal::TNormal(TVector<double>&& buckets, const size_t zeros, const i16 startPower)
    : Buckets(std::move(buckets))
    , Zeros(zeros)
    , StartPower(startPower)
{
}

size_t TNormal::Len() const noexcept {
    return Buckets.size();
}

void TNormal::MulFloat(const double value) {
    if (value <= 0) {
        ++Zeros;
    } else {
        MulPositiveFloat(value);
    }
}

void TNormal::MulSlice(const TVector<double>& values) {
    for (const double value: values) {
        MulFloat(value);
    }
}

void TNormal::MulSmallHgram(const TVector<double>& values, const size_t zeros) {
    Zeros += zeros;
    for (const double value: values) {
        MulPositiveFloat(value);
    }
}

void TNormal::MulPositiveFloat(const double value) {
    ssize_t idx = std::floor(std::log(value) / std::log(HGRAM_BASE)) - StartPower;

    if (idx >= Buckets.ysize()) {
        idx = ExtendUp(idx);
    } else if (idx <= 0) {
        idx = Max<ssize_t>(0, ExtendDown(idx, 1));
    }

    ++Buckets[idx];
}

void TNormal::MulNormalHgram(const TNormal& other) {
    Zeros += other.Zeros;

    const i32 firstIdxBeforeExtend = other.StartPower - StartPower;
    const i32 lastIdxBeforeExtend = firstIdxBeforeExtend + other.Buckets.size() - 1;
    if (firstIdxBeforeExtend > Max<i16>() || firstIdxBeforeExtend < Min<i16>()) {
        ythrow yexception() << "i16 overflow on first index";
    }

    if (lastIdxBeforeExtend > Max<i16>() || lastIdxBeforeExtend < Min<i16>()) {
        ythrow yexception() << "i16 overflow on last index";
    }

    i64 firstIdx = ExtendBounds(firstIdxBeforeExtend, lastIdxBeforeExtend, 0).first;

    size_t toMerge = Min<size_t>(Max<i64>(-firstIdx, 0), other.Buckets.size());  // if negative value, we should merge exceeded buckets

    if (toMerge) {
        Buckets[0] += Accumulate(other.Buckets.cbegin(), other.Buckets.cbegin() + toMerge, 0.0);
        firstIdx = 0;
    };

    for (auto it = other.Buckets.cbegin() + toMerge; it != other.Buckets.cend(); ++it) {
        Buckets[firstIdx] += *it;
        ++firstIdx;
    }
}

bool TNormal::operator ==(const TNormal& other) const noexcept {
    return Zeros == other.Zeros && StartPower == other.StartPower && Buckets == other.Buckets;
}

const TVector<double>& TNormal::GetBuckets() const noexcept {
    return Buckets;
}

size_t TNormal::GetZerosCount() const noexcept {
    return Zeros;
}

i16 TNormal::GetStartPower() const noexcept {
    return StartPower;
}

double TNormal::GetHgramPower(i32 power) {
    return std::pow(HGRAM_BASE, power);
}

ssize_t TNormal::ExtendDown(ssize_t indexPretendent, ui8 margin) {
    i16 toAdd = Min<i64>(MAX_NORMAL_HGRAM_SIZE - Buckets.size(),
        static_cast<i16>(margin) - static_cast<i16>(indexPretendent));
    if (toAdd > 0) {
        Buckets.insert(Buckets.begin(), toAdd, 0.0);
        StartPower -= toAdd;
    }
    return indexPretendent + toAdd;
}

ssize_t TNormal::ExtendUp(ssize_t indexPretendent) {
    const size_t bucketSize = Buckets.size();
    const size_t toAdd = indexPretendent + 1 - bucketSize;
    const size_t newSize = bucketSize + toAdd;

    if (newSize > MAX_NORMAL_HGRAM_SIZE) {
        const size_t toRemove = newSize - MAX_NORMAL_HGRAM_SIZE;
        const size_t actualToRemove = Min(toRemove, bucketSize);
        if (actualToRemove > 0) {
            const double firstWeight = Accumulate(Buckets.cbegin(), Buckets.cbegin() + actualToRemove, 0.0);
            Buckets.erase(Buckets.cbegin(), Buckets.cbegin() + actualToRemove );
            if (Buckets.empty()) {
                Buckets.push_back(firstWeight);
            } else {
                Buckets[0] = firstWeight;
            }
        }
        Buckets.resize(newSize - toRemove, 0.0);
        StartPower += toRemove;
        return indexPretendent - toRemove;
    } else {
        Buckets.resize(newSize, 0.0);
        return indexPretendent;
    }
}

std::pair<ssize_t, ssize_t> TNormal::ExtendBounds(ssize_t startIdx, ssize_t endIdx, ui8 margin) {
    // add buckets in both sides
    // return new indexes, corresponding to the given ones
    // new start_idx can be negative, if buckets number exceeds the limit
    ssize_t realEndIdx;
    ssize_t realStartIdx;
    if (endIdx >= Buckets.ysize()) {
        Buckets.reserve(Max(Min<ssize_t>(endIdx - startIdx + 1, MAX_NORMAL_HGRAM_SIZE), 0l));
        realEndIdx = ExtendUp(endIdx);
        startIdx += realEndIdx - endIdx;
    } else {
        realEndIdx = endIdx;
    }
    if (startIdx < 1) {
        realStartIdx = ExtendDown(startIdx, margin);
        realEndIdx += realStartIdx - startIdx;
    } else {
        realStartIdx = startIdx;
    }
    return std::make_pair(realStartIdx, realEndIdx);
}
