#include <stdio.h>
#include <string.h>
#include <stdlib.h>
extern "C" {
#include "rbldnsd.h"
}
#ifdef _WIN32
#include <WinSock2.h>
#include <in6addr.h>
#else
#include <syslog.h>
#include <sys/socket.h>
#include <netinet/in.h>
#include <arpa/inet.h>
#endif
#include <util/generic/strbuf.h>
#include <util/string/cast.h>
#include <util/thread/pool.h>
#include <mail/so/spamstop/tools/postgreclient/PostgreBase.h>
#include <mail/so/spamstop/tools/mongoclient_v3/StorageMongo.h>

enum EDBType {
    None,
    Postgre,
    Mongo
};

struct TStorageRbl {
    TStorageRbl()
        : pool(MakeAtomicShared<TAdaptiveThreadPool>())
    {
        pool->Start(0, 0);
    }

    TString uri;
    TString ip_field;
    TString collection;
    THolder<TStorageBase> storage;
    TAtomicSharedPtr<IThreadPool> pool;

    EDBType type = None;
    size_t connectionCount = 32;
    TDuration timeout = TDuration::MilliSeconds(100);
};

struct dsdata {
    TStorageRbl* impl;
    const char* def_rr; /* default A and TXT RRs */
};

definedstype(storage, DSTF_IP4REV | DSTF_IP6REV, "set of ips(v4 and v6)");

/* Reset all previos connection and replies */
static void ds_storage_reset(struct dsdata* dsd, int UNUSED unused_freeall) {
    if (dsd->impl) {
        delete dsd->impl;
        dsd->impl = nullptr;
    }

    dsd->def_rr = nullptr;
}

static void ds_storage_start(struct dataset* ds) {
    ds->ds_dsd->impl = new TStorageRbl();
}

static int ds_storage_line(struct dataset* ds, char* s, struct dsctx* dsc) {
    dswarn(dsc, "%s line: %s", __FUNCTION__, s);
    struct dsdata* dsd = ds->ds_dsd;
    unsigned rrl;
    const char* rr;

    //Parse default TXT answer
    if (*s == ':') {
        if (!(rrl = parse_a_txt(s, &rr, def_rr, dsc)))
            return 1;
        if (!(dsd->def_rr = static_cast<const char*>(mp_dmemdup(ds->ds_mp, rr, rrl))))
            return 0;
        return 1;
    }

    TStringBuf data(s);
    if (data.empty() || ISCOMMENT(data[0])) {
        //skip
    } else if (data.SkipPrefix("uri:")) {
        dsd->impl->uri = ToString(data);
    } else if (data.SkipPrefix("collection:")) {
        dsd->impl->collection = ToString(data);
    } else if (data.SkipPrefix("ip_field:")) {
        dsd->impl->ip_field = ToString(data);
    } else if (data.SkipPrefix("connection_count:")) {
        if (!TryFromString(data, dsd->impl->connectionCount)) {
            dswarn(dsc, "connection_count is not integer: %s", data.data());
            return 0;
        }
    } else if (data.SkipPrefix("connection_timeout:")) {
        if (!TDuration::TryParse(data, dsd->impl->timeout)) {
            dswarn(dsc, "cannot parse connection_timeout: %s", data.data());
            return 0;
        }
    } else if (data.SkipPrefix("type:")) {
        if (data == "Postgre")
            dsd->impl->type = Postgre;
        else if (data == "Mongo")
            dsd->impl->type = Mongo;
        else {
            dswarn(dsc, "cannot parse type: %s", data.data());
            return 0;
        }
    } else {
        dswarn(dsc, "unknown option: %s", data.data());
        return 0;
    }

    return 1;
}

static void ds_storage_finish(struct dataset* ds, struct dsctx* dsc) try {
    switch (ds->ds_dsd->impl->type) {
        case Mongo:
            ds->ds_dsd->impl->storage = MakeHolder<mongo_v3::TStorageMongo>();
            break;
        case Postgre:
            ds->ds_dsd->impl->storage = MakeHolder<sql::TPostgreBase>(
                TPoolParams{ds->ds_dsd->impl->connectionCount, ds->ds_dsd->impl->timeout},
                TPoolParams{ds->ds_dsd->impl->connectionCount, ds->ds_dsd->impl->timeout},
                ds->ds_dsd->impl->pool);
            break;
        default:
            dswarn(dsc, "unknown type");
            return;
    }

    ds->ds_dsd->impl->storage->Connect(ds->ds_dsd->impl->uri);
} catch (const std::exception& e) {
    dswarn(dsc, "cant connect to database: %s", e.what());
}

int get_addr_str(const struct dnsqinfo* qi, char* addr_str, int maxsize) {
    if (maxsize > 0)
        memset(addr_str, 0, maxsize);

    if (qi->qi_ip4valid) {
        struct in6_addr addr;
        memset(&addr, 0, sizeof(addr));
        addr.s6_addr[10] = 0xFF;
        addr.s6_addr[11] = 0xFF;

        ui32 ipv4_data = ntohl(qi->qi_ip4);
        memcpy(addr.s6_addr + 12, &ipv4_data, 4);

        return inet_ntop(AF_INET6, &addr, addr_str, maxsize) != NULL;
    } else if (qi->qi_ip6valid) {
        return inet_ntop(AF_INET6, qi->qi_ip6, addr_str, maxsize) != NULL;
    }

    return 0;
}

static int ds_storage_query(const struct dataset* ds, const struct dnsqinfo* qi, struct dnspacket* pkt) try {
    struct dsdata* dsd = ds->ds_dsd;
    if (!dsd->impl->storage || (!qi->qi_ip4valid && !qi->qi_ip6valid))
        return 0;
    check_query_overwrites(qi);

    char addr_str[128];
    if (!get_addr_str(qi, addr_str, sizeof(addr_str))) {
#ifdef LOG_WARNING
        dslog(LOG_WARNING, 0, "error while transforming addr");
#endif
        return 0;
    }

    TFindAction action;
    action.query.equals[dsd->impl->ip_field] = ToString(addr_str);

    NAnyValue::TScalarMap result;
    dsd->impl->storage->FindOne(dsd->impl->collection, action, result);

    if (!result.empty()) {
        const char* ipsubst = nullptr;
        if (qi->qi_ip4valid)
            ipsubst = (qi->qi_tflag & NSQUERY_TXT) ? ip4atos(qi->qi_ip4) : NULL;
        else if (qi->qi_ip6valid)
            ipsubst = (qi->qi_tflag & NSQUERY_TXT) ? ip6atos(qi->qi_ip6, IP6ADDR_FULL) : NULL;

        addrr_a_txt(pkt, qi->qi_tflag, dsd->def_rr, ipsubst, ds);

        return NSQUERY_FOUND;
    }

    return 0;
} catch (const std::exception& e) {
#ifdef LOG_WARNING
    dslog(LOG_WARNING, 0, "request error: %s", e.what());
#endif
    return 0;
}

static void ds_storage_dump(const struct dataset* ds, const unsigned char UNUSED* unused_odn, FILE* f) try {
    TFindAction action;
    TFindResults result;
    ds->ds_dsd->impl->storage->Find(ds->ds_dsd->impl->collection, action, result);

    for (const auto& item : result) {
        auto it = item.find(ds->ds_dsd->impl->ip_field);
        if (it != item.end()) {
            struct in6_addr addr;
            memset(&addr, 0, sizeof(addr));
            inet_pton(AF_INET6, it->second.AsString().c_str(), &addr);
            dump_ip6(reinterpret_cast<ip6oct_t*>(&addr), 0, ds->ds_dsd->def_rr, ds, f);
        }
    }
} catch (const std::exception& e) {
#ifdef LOG_WARNING
    dslog(LOG_WARNING, 0, "dump error: %s", e.what());
#endif
}
