#include "address.h"
#include "storage.h"

#include <balancer/serval/contrib/cone/cone.hh>

#include <contrib/libs/c-ares/include/ares.h>

#include <util/generic/hash.h>
#include <util/generic/ptr.h>
#include <util/generic/yexception.h>

#include <netdb.h>

TStringBuf NSv::NormalizePathInPlace(TStringBuf path) noexcept {
    // Preserve ?query and #fragment as-is.
    size_t limit = Min(path.size(), path.find('#'), path.find('?'));
    size_t shift = 0;
    for (size_t i = 0, len; i < limit; i += len) {
        // [i, j) is a single component, not including the delimiting `/`.
        size_t j = Min(limit, path.find('/', i));
        bool isEmpty = (j == i && i != 0 && j != limit); // Preserve `/` at ends.
        bool isCurDir = (j == i + 1 && path[i] == '.');
        bool isParDir = (j == i + 2 && path[i] == '.' && path[i + 1] == '.');
        if (isParDir && i - shift > 1 /* `/../` is still `/` */) {
            size_t p = path.rfind('/', i - shift - 1);
            shift = i - (p == TStringBuf::npos ? 0 : p + 1);
        }
        len = (j - i) + (j != limit); // Include the next `/`, if it exists.
        if (!isEmpty && !isCurDir && !isParDir) {
            memmove((char*)path.data() + i - shift, path.data() + i, len);
        } else {
            shift += len;
        }
    }
    if (limit != path.size()) {
        memmove((char*)path.data() + limit - shift, path.data() + limit, path.size() - limit);
    }
    return path.Trunc(path.size() - shift);
}

NSv::IP NSv::IP::Parse(TStringBuf addr, ui32 port) noexcept {
    // Technically, this is part of the URL format, not IPv6...
    bool bracketed = addr.size() > 1 && addr.front() == '[' && addr.back() == ']';
    char buf[INET6_ADDRSTRLEN + 1] = {};
    memcpy(buf, addr.data() + bracketed, Min<size_t>(addr.size() - bracketed * 2, INET6_ADDRSTRLEN));
    if (in6_addr v6; inet_pton(AF_INET6, buf, &v6) == 1) {
        return {v6, port};
    }
    if (in_addr v4; inet_pton(AF_INET, buf, &v4) == 1) {
        return {v4, port};
    }
    return {};
}

TMaybe<NSv::IP::TRaw> NSv::IP::ParseMask(TStringBuf addr, int family) noexcept {
    if (family != AF_INET && family != AF_INET6) {
        return {};
    }
    if (auto full = Parse(addr); full.Data.Base.sa_family == family) {
        return family == AF_INET ? TRaw{-1, 0xFFFFFFFFull << 32 | full.Raw().second} : full.Raw();
    }
    ui32 suffix;
    if (!TryFromString(addr, suffix) || suffix > (family == AF_INET ? 32 : 128)) {
        return {};
    }
    if (family == AF_INET) {
        suffix += 96;
    }
    return suffix > 64 ? TRaw{-1, ui64(-1) << (128 - suffix)} : suffix > 0 ? TRaw{ui64(-1) << (64 - suffix), 0} : TRaw{0, 0};
}

TVector<NSv::IP> NSv::IP::Resolve(const TString& name, ui32 port) noexcept {
    if (auto parsed = Parse(name, port)) {
        // While c-ares *can* return parsed versions of numeric addresses, this avoids
        // initializing the channel if it's not needed.
        return {parsed};
    }

    static const int initCode = ares_library_init(ARES_LIB_INIT_ALL);
    if (initCode != ARES_SUCCESS) {
        mun_error_at(ENODATA, "ares", MUN_CURRENT_FRAME, "%d: %s", initCode, ares_strerror(initCode));
        return {};
    }

    struct TAresChannel {
        ~TAresChannel() {
            ares_destroy(Chan);
        }

        ares_channel Chan;
        THashMap<std::pair<ares_socket_t, bool>, cone::guard> Waiters;
    };

    static NSv::TThreadLocal<TAresChannel> tls;
    auto chan = tls.Get();
    if (!chan) {
        auto n = MakeHolder<TAresChannel>();
        ares_options opts = {};
        opts.timeout = 25;
        opts.flags = ARES_FLAG_STAYOPEN;
        opts.sock_state_cb_data = n.Get();
        opts.sock_state_cb = [](void* data, ares_socket_t fd, int r, int w) {
            auto chan = reinterpret_cast<TAresChannel*>(data);
            for (auto [isWrite, enable] : (std::pair<bool, bool>[]){{false, r}, {true, w}}) {
                auto f = [chan, fd, isWrite = isWrite] {
                    while (!cone_iowait(fd, isWrite)) {
                        ares_process_fd(chan->Chan, isWrite ? ARES_SOCKET_BAD : fd, isWrite ? fd : ARES_SOCKET_BAD);
                    }
                    return false;
                };
                if (enable) {
                    chan->Waiters.emplace(std::make_pair(fd, isWrite), f);
                } else {
                    chan->Waiters.erase(std::make_pair(fd, isWrite));
                }
            }
        };
        if (int code = ares_init_options(&n->Chan, &opts, ARES_OPT_FLAGS | ARES_OPT_SOCK_STATE_CB | ARES_OPT_TIMEOUTMS)) {
            mun_error_at(ENODATA, "ares", MUN_CURRENT_FRAME, "%d: %s", code, ares_strerror(code));
            return {};
        }
        chan = tls.Reset(n.Release());
    }

    struct TAresResult : TSimpleRefCount<TAresResult> {
        ui16 Port;
        TVector<NSv::IP> Results;
        cone::event Done;
    };

    auto cb = [](void* arg, int status, int /*timeouts*/, struct hostent* data) {
        auto out = reinterpret_cast<TAresResult*>(arg);
        if (status == ARES_SUCCESS) {
            for (auto p = data->h_addr_list; *p; p++) {
                if (data->h_length == 4) {
                    out->Results.emplace_back(*(in_addr*)(*p), out->Port);
                } else {
                    out->Results.emplace_back(*(in6_addr*)(*p), out->Port);
                }
            }
        }
        out->Done.wake();
        out->UnRef();
    };
    auto r = MakeIntrusive<TAresResult>();
    r->Port = port;
    r->Ref(2);
    ares_gethostbyname(chan->Chan, name.c_str(), AF_INET, cb, r.Get());
    ares_gethostbyname(chan->Chan, name.c_str(), AF_INET6, cb, r.Get());
    while (r->RefCount() > 1) {
        // FIXME should notify ares of timeouts, too
        if (!r->Done.wait() MUN_RETHROW) {
            // FIXME should cancel the request somehow
            return {};
        }
    }
    if (!r->Results) {
        mun_error_at(ENODATA, "ares", MUN_CURRENT_FRAME, "no addresses resolved");
    }
    return std::move(r->Results);
}

TMaybe<NSv::URL> NSv::URL::Parse(TString addr, bool withTimeout) noexcept {
    ui32 ia, ib;
    TStringBuf a, b;
    auto ret = MakeMaybe<NSv::URL>();
    ret->Full = std::move(addr);
    ret->Host = ret->Full;
    ret->Host.TrySplit("://", ret->Scheme, ret->Host);
    bool hasPath = ret->Host.TrySplitAt(ret->Host.find('/'), ret->Host, ret->Path);
    if (ret->Host.TryRSplit(":", a, b) && TryFromString(b, ia)) {
        ret->Host = a;
        ret->Port = ia;
        if (withTimeout && !hasPath && ret->Host.TryRSplit(":", a, b) && TryFromString(b, ib)) {
            ret->Host = a;
            ret->Port = ib;
            ret->Timeout = ia;
        }
    } else {
        ret->Port = EqualToOneOf(ret->Scheme, "https", "h2", "h2d" /* h2 downgradable */, "wss") ? 443
                : EqualToOneOf(ret->Scheme, "http", "h2c", "ws") ? 80 : 0;
    }
    if (withTimeout && hasPath && ret->Path.TryRSplit(":", a, b) && TryFromString(b, ib)) {
        ret->Path = a;
        ret->Timeout = ib;
    }
    ret->Path.TrySplit("#", ret->Path, ret->Fragment);
    ret->Path.TrySplit("?", ret->Path, ret->Query);
    return ret; // never fails right now, but return TMaybe for possible further improvement
}
