#include "channel_splitter.h"

#include <util/system/yassert.h>

#include <stdexcept>
#include <string>

namespace quasar {

    namespace {
        template <typename Sample>
        class ChannelSplitterT: public IChannelSplitter {
        public:
            ChannelSplitterT(int micChannels, int spkChannels);

            void splitAndSkip(const std::vector<uint8_t>& rawData, std::vector<float>& inputMic, std::vector<float>& inputSpk, int skipLastMics) const override;
            void getSeparateChannel(const std::vector<uint8_t>& rawData, int channelNumber, std::vector<int16_t>& channel) const override;

        private:
            size_t countFrames(const std::vector<uint8_t>& rawData) const;
            void checkChannelNumber(int channelNumber) const;

        private:
            int micChannels_;
            int spkChannels_;
            int channelsTotal_;
        };

        template <typename Sample>
        ChannelSplitterT<Sample>::ChannelSplitterT(int micChannels, int spkChannels)
            : micChannels_(micChannels)
            , spkChannels_(spkChannels)
            , channelsTotal_(micChannels + spkChannels)
        {
            // No operations.
        }

        template <typename Sample>
        void ChannelSplitterT<Sample>::splitAndSkip(const std::vector<uint8_t>& rawData,
                                                    std::vector<float>& inputMic, std::vector<float>& inputSpk, int skipLastMics) const {
            const int frameSizeInBytes = channelsTotal_ * sizeof(Sample);
            const size_t frames = rawData.size() / frameSizeInBytes;

            Y_VERIFY(rawData.size() % frameSizeInBytes == 0);

            const int outputMics = micChannels_ - skipLastMics;
            inputMic.resize(outputMics * frames);
            inputSpk.resize(spkChannels_ * frames);

            const Sample* raw = reinterpret_cast<const Sample*>(rawData.data());

            for (size_t i = 0; i < frames; i++)
            {
                for (int j = 0; j < outputMics; j++) {
                    inputMic[i * outputMics + j] = raw[i * channelsTotal_ + j];
                }
                for (int j = 0; j < spkChannels_; j++) {
                    inputSpk[i * spkChannels_ + j] = raw[i * channelsTotal_ + j + micChannels_];
                }
            }
        }

        template <typename Sample>
        size_t ChannelSplitterT<Sample>::countFrames(const std::vector<uint8_t>& rawData) const {
            const int frameSizeInBytes = channelsTotal_ * sizeof(Sample);
            const size_t frames = rawData.size() / frameSizeInBytes;
            Y_VERIFY(rawData.size() % frameSizeInBytes == 0);
            return frames;
        }

        template <typename Sample>
        void ChannelSplitterT<Sample>::checkChannelNumber(int channelNumber) const {
            if (channelNumber < 0 || channelNumber >= channelsTotal_)
            {
                throw std::runtime_error("Cannot get channel number " + std::to_string(channelNumber) +
                                         ". Should be in range [0.." + std::to_string(channelsTotal_ - 1) + "]");
            }
        }

        template <>
        void ChannelSplitterT<int16_t>::getSeparateChannel(const std::vector<uint8_t>& rawData, int channelNumber,
                                                           std::vector<int16_t>& channel) const {
            using Sample = int16_t;

            checkChannelNumber(channelNumber);
            const size_t frames = countFrames(rawData);

            const auto* raw = reinterpret_cast<const Sample*>(rawData.data());
            channel.resize(frames);

            for (size_t i = 0; i < frames; i++) {
                channel[i] = raw[channelNumber + i * channelsTotal_];
            }
        }

        template <>
        void ChannelSplitterT<int32_t>::getSeparateChannel(const std::vector<uint8_t>& rawData, int channelNumber,
                                                           std::vector<int16_t>& channel) const {
            using Sample = int32_t;

            checkChannelNumber(channelNumber);
            const size_t frames = countFrames(rawData);

            const Sample* raw = reinterpret_cast<const Sample*>(rawData.data());
            channel.resize(frames * 2);

            for (size_t i = 0; i < frames; i++) {
                const auto* framePieces = reinterpret_cast<const int16_t*>(&raw[channelNumber + i * channelsTotal_]);
                channel[2 * i] = framePieces[0];
                channel[2 * i + 1] = framePieces[1];
            }
        }
    } // namespace

    std::unique_ptr<IChannelSplitter> IChannelSplitter::create(int micChannels, int spkChannels, int sampleSizeBytes) {
        if (sampleSizeBytes == 2) {
            return std::make_unique<ChannelSplitterT<int16_t>>(micChannels, spkChannels);
        } else if (sampleSizeBytes == 4) {
            return std::make_unique<ChannelSplitterT<int32_t>>(micChannels, spkChannels);
        } else {
            throw std::runtime_error("Can't create ChannelSplitter for sampleSizeBytes = " + std::to_string(sampleSizeBytes));
        }
    }

} // namespace quasar
