//
// ip_matcher implementation by alg
//
#include <util/stream/file.h>
#include <mail/so/spamstop/tools/so-common/so_log.h>
#include "ip_match.h"

static inline void
skip_sp(TString& l) {
    while (!l.empty() && (l[0] == ' ' || l[0] == '\t'))
        l.erase(0, 1);
}

void ip_match::read_file(const TFsPath& fname) {
    TFileInput f(fname);

    TString line;
    while (f.ReadLine(line)) {
        skip_sp(line);
        if (line.empty() || line[0] == '#')
            continue;
        add_mask(line);
    }
}

bool compare_domains(const TString& pattern, const TString& candidate) {
    bool is_last_ltr_good = true;
    if (candidate.size() > pattern.size()) {
        int dot_idx = candidate.size() - pattern.size() - 1;
        if (candidate[dot_idx] != '.')
            is_last_ltr_good = false;
    }
    return (is_last_ltr_good && candidate.StartsWith(pattern));
}

bool ip_match::check_domain(const TString* domains, const TString& domain) const {
    for (int i = 0; i < 3; ++i) {
        if (compare_domains(domains[i], domain))
            return true;
    }
    return false;
}

bool ip_match::match(const TIpAddr& ip, const TString& domain) const {
    for (auto j = ips.begin(); j != ips.end(); ++j) {
        if ((ip & j->mask) == j->addr) {
            if (j->needs_domain_check) {
                if (domain.length() > 0)
                    return check_domain(j->domains, domain);
                else
                    return false;
            } else
                return true;
        }
    }
    return false;
}

bool ip_match::match(const TString& cli, const TString& domain) const {
    TIpAddr ip(cli);
    return match(ip, domain);
}

bool ip_match::add_mask(const TString& val) {
    size_t pos = val.find('/');
    TIpAddr ip(val.substr(0, pos));

    if (!ip.IsValid()) {
        //std::cerr << "Invalid ip network " << val << std::endl;
        return false;
    }

    int len = 128;
    if (pos != TString::npos) {
        len = atol(val.substr(pos + 1).c_str());
        if (len < 0) {
            // std::cerr << "Shit mask found: " << val << std::endl;
            len = 128;
        }
    }

    matcher_t s;
    TIpAddr mask;
    mask.MakeMask(len, ip.IsIpv6());

    s.addr = ip & mask;
    s.mask = mask;
    s.needs_domain_check = false;
    size_t domain_start = val.find_first_of("\t ");
    if (domain_start != TString::npos) {
        for (int i = 0; i < 3; ++i) {
            domain_start += 1;
            size_t domain_end = val.find_first_of(" \t", domain_start);
            int domain_len = 0;
            if (domain_end == TString::npos)
                domain_len = val.size() - domain_start;
            else
                domain_len = domain_end - domain_start;
            s.domains[i] = val.substr(domain_start, domain_len);
            s.needs_domain_check = true;
            if (domain_end == TString::npos)
                break;
            domain_start = domain_end;
        }
    }

    //std::cerr << "ip=" << ip.ToString() << " mask=" << mask.ToString() << " ip&mask=" << s.addr.ToString() << " len=" << len << std::endl;
    ips.push_back(s);

    return true;
}
