#!/usr/bin/env python3
# encoding: utf-8

from http import HTTPStatus
from http.server import HTTPServer, BaseHTTPRequestHandler
from subprocess import run, Popen, PIPE
import argparse
import datetime
import ipaddress
import json
import math
import os
import re
import requests
import setproctitle
import shlex
import socket
import threading
import time


setproctitle.setproctitle('ipblocker')
os.environ['REQUESTS_CA_BUNDLE'] = '/etc/ssl/certs/ca-certificates.crt'


class _RequestHandler(BaseHTTPRequestHandler):
    def _set_headers(self):
        self.send_response(HTTPStatus.OK.value)
        self.send_header('Content-type', 'application/json')
        self.end_headers()

    def do_GET(self):
        self._set_headers()
        count = get_metrics()
        self.wfile.write(json.dumps([["ipblocker_axxx", count]]).encode('utf-8'))

    def log_message(self, format, *args):
        return


class HTTPServerV6(HTTPServer):
    address_family = socket.AF_INET6


def run_statserver(yasm_port):
    server_address = ('::', yasm_port)
    httpd = HTTPServerV6(server_address, _RequestHandler)
    print('[INFO] serving statserver at %s:%d' % server_address)
    httpd.serve_forever()


def call_cbb(url, params):
    return requests.get(url, params, timeout=5)


# runs shell subprocess
def run_subprocess(cmd):
    cmd = shlex.split(cmd)
    p = Popen(cmd, close_fds=True, stdout=PIPE, stderr=PIPE)
    res = p.communicate()

    try:
        stdout = str(res[0].decode("utf-8"))
    except:
        stdout = ''

    try:
        stderr = str(res[1].decode("utf-8"))
    except:
        stdout = ''

    return stdout, stderr, p.returncode


def run_capture(cmd):
    return run(shlex.split(cmd), capture_output=True, encoding='utf-8')


# get timestamp
def get_timestamp():
    ts = int(time.time())
    return ts


# get datetime
def get_datetime():
    dt = datetime.datetime.now()
    return dt


# check last modification time
def get_cbb_mtime(cbb_endpoint, cbb_flag):
    dt = str(get_datetime())

    try:
        res = call_cbb(cbb_endpoint + '/cgi-bin/check_flag.pl', {"flag": str(cbb_flag)})
        res.raise_for_status()
    except requests.exceptions.RequestException as e:
        print('[ERROR] {dt} Failed to get last modification time from cbb:'.format(dt=dt))
        print(e)
        return -1

    mtime = res.text
    return mtime


# get blocked ips list
def get_cbb_ips(cbb_endpoint, cbb_flag):
    dt = str(get_datetime())

    try:
        res = call_cbb(
            cbb_endpoint + '/cgi-bin/get_range.pl', {"flag": str(cbb_flag), "with_format": "range_src,range_dst,expire"}
        )
        res.raise_for_status()
    except requests.exceptions.RequestException as e:
        print('[ERROR] {dt} Failed to get ips from cbb:'.format(dt=dt))
        print(e)
        return -1

    # 1.2.3.4; 1.2.3.4; 1
    pattern = r"(\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}); (\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}); (\d+)"
    pattern_cmp = re.compile(pattern)

    data = str(res.text).split('\n')
    # print(data)
    ips = []

    for d in data:
        # print(d)
        match = re.match(pattern_cmp, d)
        if match:
            ip = match.groups()[0]
            ips.append(ip)

    if len(ips) == 0:
        print('[INFO] {dt} Got zero ips from cbb'.format(dt=dt))
    else:
        print('[INFO] {dt} Got ips {ips} from cbb'.format(dt=dt, ips=str(ips).strip('[]')))

    try:
        res = call_cbb(
            cbb_endpoint + '/cgi-bin/get_range.pl',
            {"flag": str(cbb_flag), "with_format": "range_src,range_dst,expire", "version": 6},
        )
        res.raise_for_status()
    except requests.exceptions.RequestException as e:
        print('[ERROR] {dt} Failed to get ipv6s from cbb:'.format(dt=dt))
        print(e)
        return -1

    # 2a01:4f9:c010:68b7::1; 2a01:4f9:c010:68b7::1; 1627486937
    pattern = r"([0-9a-f:]+); ([0-9a-f:]+); (\d+)"
    pattern_cmp = re.compile(pattern)

    data = str(res.text).split('\n')
    # print(data)
    ipv6s = []

    for d in data:
        # print(d)
        match = re.match(pattern_cmp, d)
        if match:
            ip = match.groups()[0]
            ips.append(ip)

    if len(ipv6s) == 0:
        print('[INFO] {dt} Got zero ipv6s from cbb'.format(dt=dt))
    else:
        print('[INFO] {dt} Got ipv6s {ips} from cbb'.format(dt=dt, ips=str(ipv6s).strip('[]')))

    return ips + ipv6s


def get_cbb_networks(cbb_endpoint, cbb_flag):
    dt = str(get_datetime())
    networks_list = []
    try:
        for version in [4, 6]:
            res = call_cbb(
                cbb_endpoint + '/api/v1/get_netblock', {"flag": str(cbb_flag), "version": version}
            )
            res.raise_for_status()
            for d in str(res.text).split('\n'):
                if d:
                    network_ip = d.split(';')[0].lstrip().rstrip()
                    network_mask = d.split(';')[1].lstrip().rstrip()
                    networks_list.append(
                        f"{network_ip}/{network_mask}"
                    )

    except requests.exceptions.RequestException as e:
        print('[ERROR] {dt} Failed to get ips from cbb:'.format(dt=dt))
        print(e)
        return -1

    return networks_list


# ban abusive NETS via ipset
def ban_networks(networks, ipset_timeout, ipset_net_id, ipset6_net_id):
    dt = str(get_datetime())
    lines = []
    nets_count = len(networks)
    i = 0
    for net in networks:
        try:
            iptype = type(ipaddress.ip_network(net))
        except ValueError as e:
            print(e)
        else:
            ipset_dst = ipset6_net_id if iptype is ipaddress.IPv6Network else ipset_net_id
            ttl = ipset_timeout if nets_count <= 50 else int(ipset_timeout * (1 + math.log10(1 + i * 10 / nets_count)))
            i += 1

            line = f"add {ipset_dst} {net} timeout {int(ttl)}"
            lines.append(line)

    if len(lines) == 0:
        return 0

    cmd = shlex.split("ipset -! restore")
    p = Popen(cmd, stdout=PIPE, stdin=PIPE, stderr=PIPE)
    stdout, stderr = p.communicate(input="\n".join(lines).encode())

    if p.returncode != 0:
        print('[ERROR] {dt} Failed to add network {networks} to ipset:'.format(dt=dt, networks=','.join(networks)))
        print(str(stderr))
        return -1

    print(
        '[INFO] {dt} Added networks {networks} to ipset'.format(
            dt=dt,
            networks=', '.join(networks),
        )
    )


# ban abusive IPS via ipset
def ban_ips(ips, ipset_timeout, ipset_id, ip6set_id, ip_whitelist):
    dt = str(get_datetime())
    lines = []

    ips_count = len(ips)
    i = 0
    for ip in ips:
        if ip in ip_whitelist:
            print('[WARN] {dt} Tried to ban ip {ip} from whitelist'.format(dt=dt, ip=ip))
        else:
            ttl = ipset_timeout if ips_count <= 50 else int(ipset_timeout * (1 + math.log10(1 + i * 10 / ips_count)))
            i += 1

            ipset_dst = ip6set_id if type(ipaddress.ip_address(ip)) is ipaddress.IPv6Address else ipset_id

            line = 'add {ipset_dst} {ip} timeout {ttl}'.format(
                ipset_dst=ipset_dst,
                ip=ip,
                ttl=int(ttl),
            )
            lines.append(line)

    if len(lines) == 0:
        return 0

    cmd = shlex.split("ipset -! restore")
    p = Popen(cmd, stdout=PIPE, stdin=PIPE, stderr=PIPE)
    stdout, stderr = p.communicate(input="\n".join(lines).encode())

    if p.returncode != 0:
        print('[ERROR] {dt} Failed to add ip {ips} to ipset:'.format(dt=dt, ips=','.join(ips)))
        print(str(stderr))
        return -1

    print(
        '[INFO] {dt} Added ips {ips} to ipset'.format(
            dt=dt,
            ips=','.join(ips),
        )
    )
    return 0


# get metrics
def get_metrics():
    dt = str(get_datetime())
    count = 0
    cmd = "ipset list"
    stdout, stderr, returncode = run_subprocess(cmd)
    if returncode != 0:
        print('[ERROR] {dt} Failed to get metrics:'.format(dt=dt))
        print(str(stderr))
        return -1

    pattern = r"entries: (\d+)"
    pattern_cmp = re.compile(pattern)

    lines = stdout.splitlines()

    for line in lines:
        match = re.search(pattern_cmp, line)
        if match:
            count += int(match.groups()[0])

    return count


# run
def start(args, ip_whitelist):
    while True:
        ts = get_timestamp()
        dt = str(get_datetime())

        remote_mtime = get_cbb_mtime(args.cbb_endpoint, args.cbb_flag)
        remote_mtime_networks = get_cbb_mtime(args.cbb_endpoint, args.cbb_net_flag)
        if remote_mtime != -1:
            if ts - int(remote_mtime) < args.cbb_mtime_period:
                ips_list = get_cbb_ips(args.cbb_endpoint, args.cbb_flag)
                ban_ips(ips_list, args.ipset_timeout, args.ipset_id, args.ip6set_id, ip_whitelist)
            else:
                print(
                    '[INFO] {dt} No modifications IP to cbb during last {sec} seconds'.format(
                        dt=dt, sec=args.cbb_mtime_period
                    )
                )
        if remote_mtime_networks != -1:
            if ts - int(remote_mtime_networks) < args.cbb_mtime_period:
                networks_list = get_cbb_networks(args.cbb_endpoint, args.cbb_net_flag)
                ban_networks(networks_list, args.ipset_timeout, args.ipset_net_id, args.ip6set_net_id)
            else:
                print(
                    '[INFO] {dt} No modifications NETWORKS to cbb during last {sec} seconds'.format(
                        dt=dt, sec=args.cbb_mtime_period
                    )
                )

        time.sleep(args.sleep)


def init(args):
    res = run_capture('ipset list {id}'.format(id=args.ipset_id))
    if res.returncode == 0:
        print('[INFO] ipset {id} already exists'.format(id=args.ipset_id))
    else:
        res = run_capture('ipset create {id} iphash timeout {ttl}'.format(id=args.ipset_id, ttl=args.ipset_timeout))
        if res.returncode != 0:
            print('[FATAL] cant create ipset')
            print(res.stdout)
            print(res.stderr)
            exit(1)
        else:
            print('[INFO] ipset {id} added'.format(id=args.ipset_id))

    res = run_capture('ipset list {id}net'.format(id=args.ipset_id))
    if res.returncode == 0:
        print('[INFO] ipset {id}net already exists'.format(id=args.ipset_id))
    else:
        res = run_capture('ipset create {id}net hash:net timeout {ttl}'.format(id=args.ipset_id, ttl=args.ipset_timeout))
        if res.returncode != 0:
            print('[FATAL] cant create ipset')
            print(res.stdout)
            print(res.stderr)
            exit(1)
        else:
            print('[INFO] ipset {id}net added'.format(id=args.ipset_id))

    res = run_capture('ipset list {id}'.format(id=args.ip6set_id))
    if res.returncode == 0:
        print('[INFO] ipset {id} already exists'.format(id=args.ip6set_id))
    else:
        res = run_capture(
            'ipset create {id} iphash family inet6 timeout {ttl}'.format(id=args.ip6set_id, ttl=args.ipset_timeout)
        )
        if res.returncode != 0:
            print('[FATAL] cant create ipset')
            print(res.stdout)
            print(res.stderr)
            exit(1)
        else:
            print('[INFO] ipset {id} added'.format(id=args.ip6set_id))

    res = run_capture('ipset list {id}net'.format(id=args.ip6set_id))
    if res.returncode == 0:
        print('[INFO] ipset {id}net already exists'.format(id=args.ip6set_id))
    else:
        res = run_capture(
            'ipset create {id}net hash:net family inet6 timeout {ttl}'.format(id=args.ip6set_id, ttl=args.ipset_timeout)
        )
        if res.returncode != 0:
            print('[FATAL] cant create ipset')
            print(res.stdout)
            print(res.stderr)
            exit(1)
        else:
            print('[INFO] ipset {id}net added'.format(id=args.ip6set_id))

    iptables_rule = 'INPUT -m set --match-set {id} src -j REJECT'.format(id=args.ipset_id)
    ip6tables_rule = 'INPUT -m set --match-set {id} src -j REJECT'.format(id=args.ip6set_id)

    res = run_capture('iptables -C {iptables_rule}'.format(iptables_rule=iptables_rule))
    if res.returncode == 0:
        print('[INFO] ipset {id} already blocked with iptables'.format(id=args.ipset_id))
    else:
        res = run_capture('iptables -A {iptables_rule}'.format(iptables_rule=iptables_rule))
        if res.returncode != 0:
            print('[FATAL] cant add iptables rule')
            print(res.stderr)
            print(res.stdout)
        else:
            print('[INFO] ipset {id} blocked with iptables'.format(id=args.ipset_id))

    res = run_capture('ip6tables -C {ip6tables_rule}'.format(ip6tables_rule=ip6tables_rule))
    if res.returncode == 0:
        print('[INFO] ipset {id} already blocked with ip6tables'.format(id=args.ip6set_id))
    else:
        res = run_capture('ip6tables -A {ip6tables_rule}'.format(ip6tables_rule=ip6tables_rule))
        if res.returncode != 0:
            print('[FATAL] cant add ip6tables rule')
            print(res.stderr)
            print(res.stdout)
        else:
            print('[INFO] ipset {id} blocked with ip6tables'.format(id=args.ip6set_id))

    iptables_net_rule = 'INPUT -m set --match-set {id}net src -j REJECT'.format(id=args.ipset_id)
    ip6tables_net_rule = 'INPUT -m set --match-set {id}net src -j REJECT'.format(id=args.ip6set_id)

    res = run_capture('iptables -C {iptables_rule}'.format(iptables_rule=iptables_net_rule))
    if res.returncode == 0:
        print('[INFO] ipset {id}net already blocked with iptables'.format(id=args.ipset_id))
    else:
        res = run_capture('iptables -A {iptables_rule}'.format(iptables_rule=iptables_net_rule))
        if res.returncode != 0:
            print('[FATAL] cant add iptables rule')
            print(res.stderr)
            print(res.stdout)
        else:
            print('[INFO] ipset {id}net blocked with iptables'.format(id=args.ipset_id))

    res = run_capture('ip6tables -C {ip6tables_rule}'.format(ip6tables_rule=ip6tables_net_rule))
    if res.returncode == 0:
        print('[INFO] ipset {id}net already blocked with ip6tables'.format(id=args.ip6set_id))
    else:
        res = run_capture('ip6tables -A {ip6tables_rule}'.format(ip6tables_rule=ip6tables_net_rule))
        if res.returncode != 0:
            print('[FATAL] cant add ip6tables rule')
            print(res.stderr)
            print(res.stdout)
        else:
            print('[INFO] ipset {id}net blocked with ip6tables'.format(id=args.ip6set_id))


def make_argument_parser():
    parser = argparse.ArgumentParser(description='Ipblocker')
    parser.add_argument(
        '--cbb-flag',
        dest='cbb_flag',
        default='144',
        help='cbb flag',
    )
    parser.add_argument(
        '--cbb-net_flag',
        dest='cbb_net_flag',
        default='1044',
        help='cbb network flag',
    )
    parser.add_argument(
        '--cbb-mtime-period',
        dest='cbb_mtime_period',
        default=5,
        type=int,
        help='cbb mtime period',
    )
    parser.add_argument(
        '--cbb-endpoint',
        dest='cbb_endpoint',
        default='https://cbb-ext.yandex-team.ru',
        help='cbb endpoint',
    )
    parser.add_argument(
        '--sleep',
        dest='sleep',
        default=3,
        type=int,
        help='sleep period',
    )
    parser.add_argument(
        '--ipset-id',
        dest='ipset_id',
        default='ipblocker',
        help='ipset id for ipv4',
    )
    parser.add_argument(
        '--ip6set-id',
        dest='ip6set_id',
        default='ip6blocker',
        help='ipset id for ipv6',
    )
    parser.add_argument(
        '--ipset-net-id',
        dest='ipset_net_id',
        default='ipblockernet',
        help='ipset id for ipv4 networks',
    )
    parser.add_argument(
        '--ip6set-net-id',
        dest='ip6set_net_id',
        default='ip6blockernet',
        help='ipset id for ipv6 networks',
    )
    parser.add_argument(
        '--ipset-timeout',
        dest='ipset_timeout',
        type=int,
        default=300,
        help='ipset base timeout',
    )
    parser.add_argument(
        '--yasm-port',
        dest='yasm_port',
        type=int,
        default=8082,
        help='statserver port',
    )

    parser.add_argument('--dry-run', action='store_true')

    return parser


# main
def main():
    parser = make_argument_parser()
    args = parser.parse_args()

    ip_whitelist = ['127.0.0.1']

    dt = str(get_datetime())

    print('[INFO] {dt} Started'.format(dt=dt))

    init(args)

    t = threading.Thread(target=run_statserver, args=(args.yasm_port,))
    t.start()
    start(args, ip_whitelist)

    print('[INFO] {dt} Done')

    exit(0)


if __name__ == "__main__":
    main()
