#!/usr/bin/env python3
# encoding: utf-8

import argparse
import requests
import time
import datetime
import re
import os
import ipaddress
import sys
import yaml

os.environ['REQUESTS_CA_BUNDLE'] = '/etc/ssl/certs/ca-certificates.crt'


def call_cbb(url, params, tvm_ticket):
    return requests.get(url, params, headers={"X-Ya-Service-Ticket": tvm_ticket}, timeout=5)


# get ticket from tvmtool
def get_service_ticket():
    try:
        res = requests.get(
            os.getenv('DEPLOY_TVM_TOOL_URL') + '/tvm/tickets',
            params={ 'src': config['tvm']['src'], 'dsts': config['tvm']['dst'] },
            headers={'Authorization': os.getenv('TVMTOOL_LOCAL_AUTHTOKEN')},
            timeout=1,
        )
        res.raise_for_status()
    except requests.exceptions.RequestException as e:
        print(f'[ERROR] Failed to get service ticket from tvmtool:')
        print(e)
        exit(1)

    service_ticket = res.json()[config['tvm']['dst']]['ticket']

    return service_ticket


# modify blocked ips list
def add_ip_to_cbb(ip, service_ticket):

    if ip in config['ip_whitelist']:
        print(f'[ERROR] Tried to ban ip {ip} from whitelist')
        return -1

    try:
        iptype = type(ipaddress.ip_address(ip))
    except:
        print(f'[WARN] Invalid ip {ip}, skipping')
        return 0

    if iptype is ipaddress.IPv6Address:
        version = 6
    if iptype is ipaddress.IPv4Address:
        version = 4

    if args.dry_run:
        print(f'[INFO] dry-run: Added ip {ip} to cbb v{version}')

    else:
        try:
            call_cbb(config['cbb']['endpoint'] + '/api/v1/set_range', {
                "operation": "add",
                "range_src": str(ip),
                "range_dst": str(ip),
                "flag": config['cbb']['flag'],
                "description": "CH ipblocker",
                "version": version,
                "expire": (datetime.datetime.now() + datetime.timedelta(seconds=config['cbb']['expire_delay'])).strftime("%s"),
            }, service_ticket)
            print(f'[INFO] Added ip {ip} to cbb v{version}')
        except requests.exceptions.RequestException as e:
            print(f'[ERROR] Failed to add ip {ip} to cbb:')
            print(e)

    return 0


def add_network_to_cbb(network, service_ticket):
    try:
        net_ip, net_mask = network.split('/')
        network = ipaddress.ip_network(network)
        iptype = type(network)
    except:
        print(f'[WARN] Invalid network {network}, skipping')
        return 0

    if iptype is ipaddress.IPv6Network:
        version = 6
    if iptype is ipaddress.IPv4Network:
        version = 4

    if args.dry_run:
        print(f'[INFO] dry-run: Added ip {network} to cbb v{version}')
    else:
        try:
            call_cbb(config['cbb']['endpoint'] + '/api/v1/set_netblock', {
                "operation": "add",
                "flag": config['cbb']['flag_netblock'],
                "net_ip": net_ip,
                "net_mask": net_mask,
                "description": "CH ipblocker",
                "version": version,
                "expire": (datetime.datetime.now() + datetime.timedelta(seconds=config['cbb']['expire_delay'])).strftime("%s"),
            }, service_ticket)
            print(f'[INFO] Added network {network} to cbb v{version}')
        except requests.exceptions.RequestException as e:
            print(f'[ERROR] Failed to add network {network} to cbb:')
            print(e)

    return 0


# run
def run():
    service_ticket = get_service_ticket()
    ips = list(map(lambda x: x.strip(), sys.stdin.readlines()))
    for ip in ips:
        if '/' in ip:
            add_network_to_cbb(ip, service_ticket)
        else:
            add_ip_to_cbb(ip, service_ticket)


def parse_args():
    parser = argparse.ArgumentParser(
        usage='echo "1.2.3.4" | ipblocker --dry-run',
    )
    parser.add_argument("--config", type=str, default='/app/config/ipblocker/web.yaml', help="config (default: /app/config/ipblocker/web.yaml)")
    parser.add_argument("--dry-run", action='store_true', help="dry run (don't block, just print ips)")

    return parser.parse_args()


# main
if __name__ == '__main__':
    args = parse_args()

    config = yaml.safe_load(open(args.config, 'r'))

    run()

    exit(0)
