#include "sweep_tone_generator.h"

#include "private/wav_file.h"

#include <cmath>
#include <iostream>

namespace {

    constexpr int sampleSize = 4;
    constexpr int channelCount = 2;
    constexpr int sampleRate = 48000;
    constexpr int msInSecond = 1000;

    constexpr double beginFreq = 20.;
    constexpr double endFreq = 8000.;
    constexpr double levelDb = -15.0;
    constexpr double dbfsOffset = 6.02;
    const double amplitude = std::pow(10., ((levelDb + dbfsOffset) / 20)); // according to transformation sound power to amplitude
    constexpr int32_t sampleMaxValue = 1 << (sizeof(int32_t) * 8 - 1) - 1;

    constexpr std::chrono::duration sweepDuration = std::chrono::seconds(4);
    constexpr std::chrono::duration transitionDuration = std::chrono::milliseconds(100);

    constexpr int sweepCount = 3;

    double exponentialFreq(size_t frameIndex, size_t frameCount) {
        return beginFreq * std::pow(endFreq / beginFreq, static_cast<double>(frameIndex) / frameCount);
    }

    template <class Rep, class Period>
    size_t calculateBufferSize(const SweepToneGen::WavFile::SoundInfo& info, const std::chrono::duration<Rep, Period>& duration) {
        const size_t bytesInSec = static_cast<size_t>(info.sampleSize) * info.channelCount * info.sampleRate;
        return bytesInSec * std::chrono::duration_cast<std::chrono::milliseconds>(duration).count() / msInSecond;
    }

} // namespace

bool quasar::generateSweepTone(std::string filename)
{
    SweepToneGen::WavFile::SoundInfo info{channelCount, sampleRate, sampleSize};
    SweepToneGen::WavFile file{info};
    if (!file.init(filename)) {
        return false;
    }

    size_t bufferSize = calculateBufferSize(info, sweepDuration);
    std::vector<uint8_t> buffer(bufferSize);

    const size_t frameSize = sampleSize * channelCount;
    const size_t frameCount = bufferSize / frameSize;

    const size_t transitionFrameCount = calculateBufferSize(info, transitionDuration) / frameSize;
    const double transitionAmplitudeStep = amplitude / transitionFrameCount;

    double phase = 0.;
    double freq = beginFreq;
    double delta = 2 * M_PI * beginFreq / sampleRate;
    double currAmplitude = 0.;

    for (size_t frameIndex = 0; frameIndex < frameCount; ++frameIndex) {
        float floatSample = currAmplitude * sin(phase);
        const int32_t sample = (int32_t)(floatSample * sampleMaxValue);

        for (size_t channelIndex = 0; channelIndex < channelCount; ++channelIndex) {
            *reinterpret_cast<int32_t*>(buffer.data() + frameIndex * frameSize + channelIndex * sampleSize) =
                sample;
        }
        phase += delta;

        freq = exponentialFreq(frameIndex, frameCount);
        delta = 2 * M_PI * freq / sampleRate;

        if (frameIndex < transitionFrameCount) {
            currAmplitude += transitionAmplitudeStep;
        } else if (frameIndex >= frameCount - transitionFrameCount) {
            currAmplitude -= transitionAmplitudeStep;
        }
    }

    std::vector<uint8_t> transition(calculateBufferSize(info, transitionDuration), 0);
    for (int sweepIndex = 0; sweepIndex < sweepCount; ++sweepIndex) {
        file.write(buffer.data(), bufferSize);
        file.write(transition.data(), transition.size());
    }

    return true;
}
