#include "coder.h"

#include <contrib/libs/libidn/lib/idna.h>

#include <library/cpp/string_utils/base64/base64.h>
#include <library/cpp/string_utils/quote/quote.h>

#include <util/generic/yexception.h>

#include <array>

namespace NPassport::NUtils {
    namespace {
        constexpr std::array<unsigned char, 256> B64Init(const unsigned char* b64_encode) {
            std::array<unsigned char, 256> buf{};
            for (auto& i : buf) {
                i = 0xff;
            }

            for (int i = 0; i < 64; ++i) {
                buf[b64_encode[i]] = i;
            }

            return buf;
        }
    }

    namespace NUrl {
        constexpr const unsigned char B64_ENCODE[] = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_";
        constexpr const std::array<unsigned char, 256> B64_DECODE = B64Init(B64_ENCODE);
    }

    namespace NPlain {
        constexpr const unsigned char B64_ENCODE[] = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
        constexpr const std::array<unsigned char, 256> B64_DECODE = B64Init(B64_ENCODE);
    }

    template <bool isUrl>
    static TString BinToBase64Impl(const TStringBuf buf, bool pad) {
        if (!buf) {
            return TString();
        }

        size_t len = buf.size();

        constexpr const unsigned char* b64_encode = isUrl ? NUrl::B64_ENCODE : NPlain::B64_ENCODE;

        TString res;
        res.resize(((len + 2) / 3) << 2, 0);

        const unsigned char* pB = (const unsigned char*)buf.data();
        const unsigned char* pE = (const unsigned char*)buf.data() + len;
        unsigned char* p = (unsigned char*)res.data();
        for (; pB + 2 < pE; pB += 3) {
            const unsigned char a = *pB;
            *p++ = b64_encode[(a >> 2) & 0x3F];
            const unsigned char b = *(pB + 1);
            *p++ = b64_encode[((a & 0x3) << 4) | ((b & 0xF0) >> 4)];
            const unsigned char c = *(pB + 2);
            *p++ = b64_encode[((b & 0xF) << 2) | ((c & 0xC0) >> 6)];
            *p++ = b64_encode[c & 0x3F];
        }

        if (pB < pE) {
            const unsigned char a = *pB;
            *p++ = b64_encode[(a >> 2) & 0x3F];
            if (pB == (pE - 1)) {
                *p++ = b64_encode[((a & 0x3) << 4)];
                if (pad) {
                    *p++ = '=';
                }
            } else {
                const unsigned char b = *(pB + 1);
                *p++ = b64_encode[((a & 0x3) << 4) | ((int)(b & 0xF0) >> 4)];
                *p++ = b64_encode[((b & 0xF) << 2)];
            }
            if (pad) {
                *p++ = '=';
            }
        }

        res.resize(p - (unsigned char*)res.data());
        return res;
    }

    template <bool isUrl>
    static TString Base64ToBinImpl(const TStringBuf buf) {
        constexpr const unsigned char* b64_decode = isUrl ? NUrl::B64_DECODE.data() : NPlain::B64_DECODE.data();
        size_t len = buf.size();

        const unsigned char* bufin = (const unsigned char*)buf.data();
        if (bufin == nullptr || len == 0 || b64_decode[*bufin] > 63) {
            return TString();
        }
        const unsigned char* bufend = (const unsigned char*)buf.data() + len;
        while (++bufin < bufend && b64_decode[*bufin] < 64) {
            ;
        }
        int nprbytes = (bufin - (const unsigned char*)buf.data());
        int nbytesdecoded = ((nprbytes + 3) / 4) * 3;
        if (nprbytes % 4 == 1) {
            return TString(); // It is impossible case: one char cannot contain info about one byte
        }

        if (nprbytes < static_cast<int>(len)) {
            int left = len - nprbytes;
            while (left--) {
                if (*(bufin++) != '=') {
                    return TString();
                }
            }
        }

        TString res;
        res.resize(nbytesdecoded);

        unsigned char* bufout = (unsigned char*)res.data();
        bufin = (const unsigned char*)buf.data();

        while (nprbytes > 4) {
            unsigned char a = b64_decode[*bufin];
            unsigned char b = b64_decode[bufin[1]];
            *(bufout++) = (unsigned char)(a << 2 | b >> 4);
            unsigned char c = b64_decode[bufin[2]];
            *(bufout++) = (unsigned char)(b << 4 | c >> 2);
            unsigned char d = b64_decode[bufin[3]];
            *(bufout++) = (unsigned char)(c << 6 | d);
            bufin += 4;
            nprbytes -= 4;
        }

        if (nprbytes > 1) {
            *(bufout++) = (unsigned char)(b64_decode[*bufin] << 2 | b64_decode[bufin[1]] >> 4);
        }
        if (nprbytes > 2) {
            *(bufout++) = (unsigned char)(b64_decode[bufin[1]] << 4 | b64_decode[bufin[2]] >> 2);
        }
        if (nprbytes > 3) {
            *(bufout++) = (unsigned char)(b64_decode[bufin[2]] << 6 | b64_decode[bufin[3]]);
        }

        int diff = (4 - nprbytes) & 3;
        if (diff) {
            nbytesdecoded -= (4 - nprbytes) & 3;
            res.resize(nbytesdecoded);
        }

        return res;
    }

    TString Bin2base64url(const TStringBuf buf, bool pad) {
        return BinToBase64Impl<true>(buf, pad);
    }

    TString Base64url2bin(const TStringBuf buf) {
        return Base64ToBinImpl<true>(buf);
    }

    TString BinToBase64(const TStringBuf buf, bool pad) {
        TString res = Base64Encode(buf);

        if (pad) {
            return res;
        }

        while (!res.empty() && '=' == res.back()) {
            res.pop_back();
        }
        return res;
    }

    TString Base64ToBin(const TStringBuf buf) {
        return Base64ToBinImpl<false>(buf);
    }

    bool IsBase64url(const TStringBuf str) {
        auto it = std::find_if_not(str.begin(), str.end(), [](char c) {
            return (c >= 'a' && c <= 'z') ||
                   (c >= 'A' && c <= 'Z') ||
                   (c >= '0' && c <= '9') ||
                   c == '-' ||
                   c == '_';
        });
        if (it == str.end()) {
            return true; // only valid base64url symbols
        }

        return std::find_if(it, str.end(), [](char c) {
                   return c != '=';
               }) == str.end(); // base64url if only '=' are left at the end
    }

    bool IsBase64(const TStringBuf str) {
        auto it = std::find_if_not(str.begin(), str.end(), [](char c) {
            return (c >= 'a' && c <= 'z') ||
                   (c >= 'A' && c <= 'Z') ||
                   (c >= '0' && c <= '9') ||
                   c == '+' ||
                   c == '/';
        });
        if (it == str.end()) {
            return true; // only valid base64 symbols
        }

        return std::find_if(it, str.end(), [](char c) {
                   return c != '=';
               }) == str.end(); // base64 if only '=' are left at the end
    }

    TString Urldecode(const TStringBuf in) {
        return CGIUnescapeRet(in);
    }

    TString Urlencode(const TStringBuf in) {
        TString res(in);
        Quote(res, "");
        return res;
    }

    static const char hexDigits[] = "0123456789abcdef";

    static char Hexval(char ch) {
        if (ch >= '0' && ch <= '9') {
            return ch - '0';
        }
        char c = tolower(ch);
        if (c >= 'a' && c <= 'f') {
            return c - 'a' + 10;
        }
        return 0;
    }

    TString Bin2hex(const TStringBuf buf) {
        if (!buf) {
            return TString();
        }

        TString res;
        res.reserve(buf.size() << 1);

        for (unsigned i = 0; i < buf.size(); ++i) {
            res.push_back(hexDigits[0xf & (buf[i] >> 4)]);
            res.push_back(hexDigits[0xf & buf[i]]);
        }
        return res;
    }

    TString Hex2bin(const TStringBuf buf) {
        if (!buf) {
            return TString();
        }

        TString res;
        res.reserve((buf.size() + 1) >> 1);

        unsigned i = 0;
        if (0x1 & buf.size()) { // if odd length, prepend 0
            res.push_back(Hexval(buf[i++]));
        }

        while (i + 1 < buf.size()) {
            res.push_back((Hexval(buf[i]) << 4) + Hexval(buf[i + 1]));
            i += 2;
        }

        return res;
    }

    namespace {
        constexpr std::array<unsigned char, 256> B32Init(const unsigned char* b32_encode) {
            std::array<unsigned char, 256> buf{};
            for (auto& i : buf) {
                i = 0xff;
            }

            for (int i = 0; i < 32; ++i) {
                const char c = b32_encode[i];
                buf[c] = i; // lowercase letters
                if (c >= 'a') {
                    buf[c ^ 0x20] = i; // uppercase letters
                }
            }

            return buf;
        }

        // we removed letters I, L, O and S to remove confusion with 1, 0 and 5
        constexpr const unsigned char B32_ENCODE[] = "0123456789abcdefghjkmnpqrtuvwxyz";
        constexpr const std::array<unsigned char, 256> B32_DECODE = B32Init(B32_ENCODE);
    }

    TString BinToBase32(const TStringBuf buf) {
        if (!buf) {
            return TString();
        }

        size_t len = buf.size();

        TString res;
        res.resize(((len + 4) / 5) * 8, 0);

        const unsigned char* pB = (const unsigned char*)buf.data();
        const unsigned char* pE = (const unsigned char*)buf.data() + len;
        unsigned char* p = (unsigned char*)res.data();
        for (; pB + 4 < pE; pB += 5) {
            const unsigned char a = *pB;
            *p++ = B32_ENCODE[(a >> 3) & 0x1F];
            const unsigned char b = *(pB + 1);
            *p++ = B32_ENCODE[((a & 0x7) << 2) | ((b >> 6) & 0x3)];
            *p++ = B32_ENCODE[((b >> 1) & 0x1F)];
            const unsigned char c = *(pB + 2);
            *p++ = B32_ENCODE[((b & 0x1) << 4) | ((c >> 4) & 0x0F)];
            const unsigned char d = *(pB + 3);
            *p++ = B32_ENCODE[((c & 0xF) << 1) | ((d >> 7) & 0x1)];
            *p++ = B32_ENCODE[((d >> 2) & 0x1F)];
            const unsigned char e = *(pB + 4);
            *p++ = B32_ENCODE[((d & 0x3) << 3) | ((e >> 5) & 0x7)];
            *p++ = B32_ENCODE[e & 0x1F];
        }

        if (pB < pE) {
            const unsigned char a = *pB++;
            *p++ = B32_ENCODE[(a >> 3) & 0x1F];
            if (pB == pE) {
                *p++ = B32_ENCODE[((a & 0x7) << 2)];
            } else {
                const unsigned char b = *pB++;
                *p++ = B32_ENCODE[((a & 0x7) << 2) | ((b >> 6) & 0x3)];
                *p++ = B32_ENCODE[((b >> 1) & 0x1F)];
                if (pB == pE) {
                    *p++ = B32_ENCODE[((b & 0x1) << 4)];
                } else {
                    const unsigned char c = *pB++;
                    *p++ = B32_ENCODE[((b & 0x1) << 4) | ((c >> 4) & 0x0F)];
                    if (pB == pE) {
                        *p++ = B32_ENCODE[((c & 0xF) << 1)];
                    } else {
                        const unsigned char d = *pB++;
                        *p++ = B32_ENCODE[((c & 0xF) << 1) | ((d >> 7) & 0x1)];
                        *p++ = B32_ENCODE[((d >> 2) & 0x1F)];
                        *p++ = B32_ENCODE[((d & 0x3) << 3)];
                    }
                }
            }
        }

        res.resize(p - (unsigned char*)res.data());
        return res;
    }

    TString Base32ToBin(const TStringBuf buf) {
        if (!buf) {
            return TString();
        }

        // check that no invalid characters in the string
        for (const unsigned char c : buf) {
            if (B32_DECODE[c] > 32) {
                return TString();
            }
        }

        size_t len = buf.size();

        // check that length is correct
        size_t tail = len % 8;
        if (tail == 1 || tail == 3 || tail == 6) {
            return TString(); // invalid tail
        }

        TString res;
        res.resize(((len + 7) / 8) * 5, 0);

        const unsigned char* pB = (const unsigned char*)buf.data();
        const unsigned char* pE = (const unsigned char*)buf.data() + len;

        unsigned char* p = (unsigned char*)res.data();
        for (; pB + 7 < pE; pB += 8) {
            unsigned char a = B32_DECODE[*pB];
            unsigned char b = B32_DECODE[*(pB + 1)];
            *p++ = (unsigned char)(a << 3 | b >> 2);
            unsigned char c = B32_DECODE[*(pB + 2)];
            unsigned char d = B32_DECODE[*(pB + 3)];
            *p++ = (unsigned char)(b << 6 | c << 1 | d >> 4);
            unsigned char e = B32_DECODE[*(pB + 4)];
            *p++ = (unsigned char)(d << 4 | e >> 1);
            unsigned char f = B32_DECODE[*(pB + 5)];
            unsigned char g = B32_DECODE[*(pB + 6)];
            *p++ = (unsigned char)(e << 7 | f << 2 | g >> 3);
            unsigned char h = B32_DECODE[*(pB + 7)];
            *p++ = (unsigned char)(g << 5 | h);
        }

        if (pB < pE) { // 2, 4, 5 or 7 bytes left, checked above
            unsigned char a = B32_DECODE[*pB++];
            unsigned char b = B32_DECODE[*pB++];
            *p++ = (unsigned char)(a << 3 | b >> 2);
            if (pB < pE) {
                unsigned char c = B32_DECODE[*pB++];
                unsigned char d = B32_DECODE[*pB++];
                *p++ = (unsigned char)(b << 6 | c << 1 | d >> 4);
                if (pB < pE) {
                    unsigned char e = B32_DECODE[*pB++];
                    *p++ = (unsigned char)(d << 4 | e >> 1);
                    if (pB < pE) {
                        unsigned char f = B32_DECODE[*pB++];
                        unsigned char g = B32_DECODE[*pB++];
                        *p++ = (unsigned char)(e << 7 | f << 2 | g >> 3);
                    }
                }
            }
        }

        res.resize(p - (unsigned char*)res.data());
        return res;
    }

    unsigned char BinToBase32(ui8 val) {
        if (val < 32) {
            return B32_ENCODE[val];
        }

        return 0xff;
    }

    ui8 Base32ToBin(unsigned char val) {
        return B32_DECODE[val];
    }
}
