#include "reed_solomon_decoder.h"

#include <stdexcept>
#include <string>

using namespace quasar;

ReedSolomonDecoder::ReedSolomonDecoder(const GaloisField& field)
{
    field_ = field;
}

std::pair<GaloisFieldPolynomial, GaloisFieldPolynomial>
ReedSolomonDecoder::runEuclideanAlgorithm(GaloisFieldPolynomial a, GaloisFieldPolynomial b, int R) const {
    // Assume a's degree is >= b's
    if (a.getDegree() < b.getDegree())
    {
        GaloisFieldPolynomial temp = a;
        a = b;
        b = temp;
    }

    GaloisFieldPolynomial rLast = a;
    GaloisFieldPolynomial r = b;
    GaloisFieldPolynomial tLast = GaloisFieldPolynomial(field_, std::vector<int>{0});
    GaloisFieldPolynomial t = GaloisFieldPolynomial(field_, std::vector<int>{1});

    // Run Euclidean algorithm until r's degree is less than R/2
    while (r.getDegree() >= R / 2)
    {
        GaloisFieldPolynomial rLastLast = rLast;
        GaloisFieldPolynomial tLastLast = tLast;
        rLast = r;
        tLast = t;

        // Divide rLastLast by rLast, with quotient in q and remainder in r
        if (rLast.isZero())
        {
            // Oops, Euclidean algorithm already terminated?
            throw std::runtime_error("r_{i-1} was zero");
        }
        r = rLastLast;
        GaloisFieldPolynomial q = GaloisFieldPolynomial(field_, std::vector<int>{0});
        int denominatorLeadingTerm = rLast.getCoefficient(rLast.getDegree());
        int dltInverse = field_.inverse(denominatorLeadingTerm);
        while (r.getDegree() >= rLast.getDegree() && !r.isZero())
        {
            int degreeDiff = r.getDegree() - rLast.getDegree();
            int scale = field_.multiply(r.getCoefficient(r.getDegree()), dltInverse);
            q = q.addOrSubtract(GaloisFieldPolynomial::buildMonomial(field_, degreeDiff, scale));
            r = r.addOrSubtract(rLast.multiplyByMonomial(degreeDiff, scale));
        }

        t = q.multiply(tLast).addOrSubtract(tLastLast);

        if (r.getDegree() >= rLast.getDegree())
        {
            throw std::runtime_error("Division algorithm failed to reduce polynomial?");
        }
    }

    int sigmaTildeAtZero = t.getCoefficient(0);
    if (sigmaTildeAtZero == 0)
    {
        throw std::runtime_error("sigmaTilde(0) was zero");
    }

    int inverse = field_.inverse(sigmaTildeAtZero);
    GaloisFieldPolynomial sigma = t.multiply(inverse);
    GaloisFieldPolynomial omega = r.multiply(inverse);
    return std::make_pair(sigma, omega);
}

void ReedSolomonDecoder::decode(std::vector<int>& received, int twoS) const {
    GaloisFieldPolynomial poly(field_, received);
    std::vector<int> syndromeCoefficients(twoS);
    bool noError = true;
    for (int i = 0; i < twoS; i++)
    {
        int eval = poly.evaluateAt(field_.exp(i + field_.getGeneratorBase()));
        syndromeCoefficients[syndromeCoefficients.size() - 1 - i] = eval;
        if (eval != 0)
        {
            noError = false;
        }
    }
    if (noError)
    {
        return;
    }
    GaloisFieldPolynomial syndrome(field_, syndromeCoefficients);
    std::pair<GaloisFieldPolynomial, GaloisFieldPolynomial> sigmaOmega =
        runEuclideanAlgorithm(GaloisFieldPolynomial::buildMonomial(field_, twoS, 1), syndrome, twoS);
    GaloisFieldPolynomial sigma = sigmaOmega.first;
    GaloisFieldPolynomial omega = sigmaOmega.second;
    std::vector<int> errorLocations = findErrorLocations(sigma);
    std::vector<int> errorMagnitudes = findErrorMagnitudes(omega, errorLocations);

    for (uint i = 0; i < errorLocations.size(); i++)
    {
        int position = received.size() - 1 - field_.log(errorLocations[i]);
        if (position < 0)
        {
            throw std::runtime_error("Bad error location: " + std::to_string(position));
        }
        received[position] = field_.addOrSubtract(received[position], errorMagnitudes[i]);
    }
}

std::vector<int> ReedSolomonDecoder::findErrorLocations(const GaloisFieldPolynomial& errorLocator) const {
    // This is a direct application of Chien's search
    int numErrors = errorLocator.getDegree();
    if (numErrors == 1)
    {
        return std::vector<int>{errorLocator.getCoefficient(1)};
    }
    std::vector<int> result(numErrors);
    int e = 0;
    for (int i = 1; i < field_.getSize() && e < numErrors; i++)
    {
        if (errorLocator.evaluateAt(i) == 0)
        {
            result[e] = field_.inverse(i);
            e++;
        }
    }
    if (e != numErrors)
    {
        throw std::runtime_error("Error locator degree does not match number of roots");
    }
    return result;
}

std::vector<int>
ReedSolomonDecoder::findErrorMagnitudes(const GaloisFieldPolynomial& errorEvaluator,
                                        const std::vector<int>& errorLocations) const {
    // This is directly applying Forney's Formula
    size_t s = errorLocations.size();
    std::vector<int> result(s);
    for (size_t i = 0; i < s; i++)
    {
        int xiInverse = field_.inverse(errorLocations[i]);
        int denominator = 1;
        for (size_t j = 0; j < s; j++)
        {
            if (i != j)
            {
                int term = field_.multiply(errorLocations[j], xiInverse);
                int termPlus1 = (term & 0x1) == 0 ? term | 1 : term & ~1;
                denominator = field_.multiply(denominator, termPlus1);
            }
        }
        result[i] = field_.multiply(errorEvaluator.evaluateAt(xiInverse),
                                    field_.inverse(denominator));
        if (field_.getGeneratorBase() != 0)
        {
            result[i] = field_.multiply(result[i], xiInverse);
        }
    }
    return result;
}
