#include "wmd.h"

#include <contrib/libs/intel/ipp/include/ipps.h>

#include <util/generic/algorithm.h>
#include <util/generic/ymath.h>
#include <util/system/yassert.h>

namespace NWmd {
    TMatrix::TMatrix(TMatrix &&m) noexcept {
        this->operator=(std::move(m));
    }

    TMatrix::TMatrix(size_t cols, size_t rows, float defaultVal)
            : Cols(cols), Rows(rows), Storage(cols * rows + alignof(float), defaultVal) {
        Align();
    }

    TMatrix::TMatrix(size_t cols, size_t rows)
            : Cols(cols), Rows(rows), Storage(cols * rows + alignof(float)) {
        Align();
    }

    TMatrix &TMatrix::operator=(TMatrix &&m) noexcept {
        Storage = std::move(m.Storage);
        Cols = m.Cols;
        Rows = m.Rows;
        Align();
        return *this;
    }

    float &TMatrix::operator()(size_t row, size_t col) {
        return Data[row * Cols + col];
    }

    float TMatrix::operator()(size_t row, size_t col) const {
        return Data[row * Cols + col];
    }

    void TMatrix::Align() {
        void *p = Storage.data();
        size_t space = Storage.size();
        Y_VERIFY(std::align(alignof(float), Cols * Rows, p, space));
        Data = static_cast<float *>(p);
    }

    struct TPenalty {
        ui32 SquareSize;
        float ColsPenalty;
        float RowsPenalty;
    };

    static TPenalty NonSquareMatrixPenalty(size_t cols, size_t rows, float maxDistance) {
        TPenalty penalty;

        if (cols < rows) {
            penalty.SquareSize = rows;
            penalty.ColsPenalty = maxDistance * float(rows - cols);
            penalty.RowsPenalty = 0.f;
        } else if (rows < cols) {
            penalty.SquareSize = cols;
            penalty.RowsPenalty = maxDistance * float(cols - rows);
            penalty.ColsPenalty = 0.f;
        } else {
            penalty.SquareSize = rows;
            penalty.ColsPenalty = 0.f;
            penalty.RowsPenalty = 0.f;
        }

        return penalty;
    }

    static float DeNan(float value) {
        if (IsNan(value)) {
            return 1.f;
        } else {
            return value;
        }
    }

    float RelaxedWmd(float *data, size_t cols, size_t rows, float maxDistance) {
        if (rows == 0 || cols == 0) {
            if (rows == cols) {
                return -1.f;
            } else {
                return 1.f;
            }
        }

        float *rowMins = data;

        float rowsCost = *MinElement(rowMins, rowMins + cols);

        for (const float *row = data + cols;
             row != data + cols * rows;
             row += cols) {
            float minValInRow = *row;
            if (*rowMins > minValInRow)
                *rowMins = minValInRow;
            for (size_t col = 1; col < cols; col++) {
                const float val = row[col];
                if (rowMins[col] > val)
                    rowMins[col] = val;

                if (val < minValInRow)
                    minValInRow = val;
            }
            rowsCost += minValInRow;
        }

        float colsCost = Accumulate(rowMins, rowMins + cols, 0.f);

        const auto penalty = NonSquareMatrixPenalty(cols, rows, maxDistance);
        colsCost += penalty.ColsPenalty;
        rowsCost += penalty.RowsPenalty;

        return DeNan(Max(colsCost, rowsCost) / penalty.SquareSize);
    }

    float RelaxedWmd(TMatrix &distances, float maxDistance) {
        return RelaxedWmd(distances.Data, distances.Cols, distances.Rows, maxDistance);
    }

    float RelaxedWmd(TMatrix &&distances, float maxDistance) {
        return RelaxedWmd(distances, maxDistance);
    }

    struct TPlacedValue {
        bool operator<(const TPlacedValue &v) const noexcept {
            return Value < v.Value;
        }

        float Value;
        ui16 Col, Row;
    };

    float GreedWmd(const float *data, size_t cols, size_t rows, float maxDistance) {
        if (rows == 0 || cols == 0) {
            if (rows == cols) {
                return -1.f;
            } else {
                return 1.f;
            }
        }

        bool* busyRows = (bool*) alloca(rows);
        bool* busyCols = (bool*) alloca(cols);
        memset(busyRows, 0, rows);
        memset(busyCols, 0, cols);

        size_t size = cols * rows;
        size_t valuesToCount = Min(cols, rows);
        float cost = 0;
        // Empirically calculated threshold when std::sort(...) works faster
        // than ippsSortRadixIndexAscend_32f(...)
        if (size <= 580) {
            TPlacedValue* placedValues = (TPlacedValue*) alloca(size * sizeof(TPlacedValue));
            size_t i = 0;
            for (ui16 row = 0; row < rows; row++) {
                const float *srcRow = data + cols * row;
                for (ui16 col = 0; col < cols; col++) {
                    placedValues[i++] = TPlacedValue{srcRow[col], col, row};
                }
            }

            TPlacedValue *placedValuesBegin = placedValues;
            std::sort(placedValuesBegin, placedValues + size);

            while (true) {
                const TPlacedValue& value = *placedValuesBegin++;
                if (busyRows[value.Row] || busyCols[value.Col]) {
                    continue;
                }
                cost += value.Value;
                busyRows[value.Row] = true;
                busyCols[value.Col] = true;
                if (--valuesToCount == 0) {
                    break;
                }
            }
        } else {
            int bufferSize;
            ippsSortRadixIndexGetBufferSize(size, IppDataType::ipp32f, &bufferSize);
            void* buffer = alloca(bufferSize);
            int* index = (int*) alloca(size * sizeof(int));
            ippsSortRadixIndexAscend_32f(data, sizeof(float), index, size, (ui8*) buffer);

            for (size_t i = 0; i < size; ++i) {
                int idx = *index++;
                size_t row = idx / cols;
                size_t col = idx % cols;
                if (busyRows[row] || busyCols[col]) {
                    continue;
                }
                cost += data[idx];
                busyRows[row] = true;
                busyCols[col] = true;
                if (--valuesToCount == 0) {
                    break;
                }
            }
        }

        const auto penalty = NonSquareMatrixPenalty(cols, rows, maxDistance);
        return DeNan((cost + Max(penalty.RowsPenalty, penalty.ColsPenalty)) / penalty.SquareSize);
    }

    float GreedWmd(TMatrix &distances, float maxDistance) {
        return GreedWmd(distances.Data, distances.Cols, distances.Rows, maxDistance);
    }

    float GreedWmd(TMatrix &&distances, float maxDistance) {
        return GreedWmd(distances, maxDistance);
    }
//    struct TSolutionIndex {
//        TSolutionIndex(ui32 row, ui32 col, ui32 size, const float *weights) noexcept
//                : Col(col), Row(weights + row * size) {
//
//        }
//
//        bool operator<(const TSolutionIndex &index) const noexcept {
//            return *(Row + Col) > *(index.Row + index.Col);
//        }
//
//        TSolutionIndex(const TSolutionIndex &) noexcept = default;
//
//        TSolutionIndex(TSolutionIndex &&) noexcept = default;
//
//        TSolutionIndex &operator=(const TSolutionIndex &) noexcept = default;
//
//        TSolutionIndex &operator=(TSolutionIndex &&) noexcept = default;
//
//        ui32 Col;
//        const float *Row;
//    };

//    static inline float cost(const TSolutionIndex *b, const TSolutionIndex *e) {
//        float sum{};
//
//        for (auto it = b; it != e; ++it) {
//            sum += *(it->Row + it->Col);
//        }
//
//        return sum;
//    }

//    float WMD(const TMatrix &distances) {
//        if (distances.Empty())
//            return -1.f;
//        const size_t size = distances.GetSize();
//
//        TTempArray<TSolutionIndex> solution(size);
//
//        const auto solveBegin = solution.Data();
//        const auto solveEnd = solution.Data() + size;
//
//        for (size_t i = 0; i < size; i++) {
//            new(solveBegin + i) TSolutionIndex(i, i, size, distances.Data());
//        }
//
//        std::sort(solveBegin, solveEnd);
//
//        auto minCost = cost(solveBegin, solveEnd);
//        ui32 permutations{};
//
//        AGAIN:
//        {
//            for (auto prev = solveBegin; prev != solveEnd - 1; ++prev) {
//                for (auto it = prev + 1; it != solveEnd; ++it) {
//                    const auto costOutcome = *(it->Row + it->Col) + *(prev->Row + prev->Col);
//                    const auto costIncome = *(it->Row + prev->Col) + *(prev->Row + it->Col);
//
//                    if (const auto delta = costOutcome - costIncome; delta > 0) {
//                        minCost -= delta;
//
//                        std::swap(prev->Col, it->Col);
//                        std::partial_sort(solveBegin, it, solveEnd);
//                        permutations++;
//
//                        goto AGAIN;
//                    }
//                }
//            }
//        };
//        return minCost / size;
//    }
} // namespace NWmd
