#!/usr/bin/env python3
# encoding: utf-8

import argparse
import datetime
import ipaddress
import logging
import os
import re
import requests
import setproctitle
import sys
import time
import yt.wrapper as yt

setproctitle.setproctitle('ipblocker')
os.environ['REQUESTS_CA_BUNDLE'] = '/etc/ssl/certs/ca-certificates.crt'

logger = logging.getLogger()
logger.setLevel(logging.INFO)
handler = logging.StreamHandler(sys.stdout)
logger.addHandler(handler)

def call_cbb(url, params, tvm_ticket):
    return requests.get(url, params, headers={"X-Ya-Service-Ticket": tvm_ticket}, timeout=5)


# get ticket from tvmtool
def get_service_ticket():
    try:
        res = requests.get(
            config['tvm']['base_url'],
            params={ 'src': config['tvm']['src'], 'dsts': config['tvm']['dst'] },
            headers={'Authorization': os.getenv('TVMTOOL_LOCAL_AUTHTOKEN')},
            timeout=1,
        )
        res.raise_for_status()
    except requests.exceptions.RequestException as e:
        logger.error('[ERROR] Failed to get service ticket from tvmtool:')
        exit(1)

    service_ticket = res.json()[config['tvm']['dst']]['ticket']

    return service_ticket


# get ips from CH
def get_ch_ips():
    query = os.environ['CH_QUERY'].format(
        seconds=args.seconds,
        limit=args.limit,
    )
    if args.dry_run:
        logger.info('[INFO] dry-run: query')
        logger.info(query)

    try:
        res = requests.post(
            config['ch']['host'],
            auth=requests.auth.HTTPBasicAuth(config['ch']['user'], config['ch']['password']),
            data=query,
        )
        res.raise_for_status()
    except requests.exceptions.RequestException as e:
        logger.error('[ERROR] Failed to run query:')
        logger.error(e)
        return []

    # 1.2.3.4 123
    pattern = '([0-9a-f.:]+)\t(\d+)'
    pattern_cmp = re.compile(pattern)
    ips = []
    for d in res.text.split('\n'):
        match = re.match(pattern_cmp, d)
        if match:
            ip = match.groups()[0]
            cnt = match.groups()[1]
            if (int(cnt) >= args.threshold):
                ips.append(ip)

    if not ips:
        logger.info('[INFO] Got zero ips from CH')
    else:
        logger.info('[INFO] Got ips {ips} from CH'.format(
            ips=str(ips).strip('[]'),
        ))

    return ips

# modify blocked ips list
def add_ip_to_cbb(ip, service_ticket):
    if ip in config['ip_whitelist']:
        logger.error('[ERROR] Tried to ban ip {ip} from whitelist'.format(ip=ip))
        return -1

    try:
        iptype = type(ipaddress.ip_address(ip))
    except:
        logger.warning('[WARN] Invalid ip {ip}, skipping'.format(ip=ip))
        return 0

    if iptype is ipaddress.IPv6Address:
        version = 6
    if iptype is ipaddress.IPv4Address:
        version = 4

    if args.dry_run:
        logger.info('[INFO] dry-run: Added ip {ip} to cbb v{version}'.format(ip=ip, version=version))

    else:
        try:
            call_cbb(config['cbb']['endpoint'] + '/api/v1/set_range', {
                "operation": "add",
                "range_src": str(ip),
                "range_dst": str(ip),
                "flag": config['cbb']['flag'],
                "description": "CH ipblocker",
                "version": version,
                "expire": (datetime.datetime.now() + datetime.timedelta(seconds=config['cbb']['expire_delay'])).strftime("%s"),
            }, service_ticket)
            logger.info('[INFO] Added ip {ip} to cbb v{version}'.format(ip=ip, version=version))
        except requests.exceptions.RequestException as e:
            logger.error('[ERROR] Failed to add ip {ip} to cbb:'.format(ip=ip))
            logger.error(e)

    return 0


# run
def run():
    while True:
        if check_lock():
            service_ticket = get_service_ticket()

            for ip in get_ch_ips():
                add_ip_to_cbb(ip, service_ticket)

        time.sleep(args.sleep)


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--sleep", type=int, default=60, help="sleep delay in seconds (default: 60)")
    parser.add_argument("--seconds", type=int, default=300, help="seconds to grep (default: 300)")
    parser.add_argument("--limit", type=int, default=1000, help="query limit (default: 1000)")
    parser.add_argument("--threshold", type=int, default=5000, help="minimum requests count (default: 5000)")
    parser.add_argument("--dry-run", action='store_true', help="dry run (don't block, just print ips)")

    return parser.parse_args()


def init_lock():
    lock = os.environ['YT_LOCK']
    yt.config["proxy"]["url"] = "locke"
    yt.config["token"] = os.environ['YT_TOKEN']
    if not yt.exists(lock):
        try:
            yt.create('map_node', lock)
        except:
            logger.error('[ERROR] can not create lock')
            exit(1)
    return True


def check_lock():
    lock = os.environ['YT_LOCK']
    fqdn = os.environ['DEPLOY_POD_PERSISTENT_FQDN']
    logger.debug('[INFO] check lock')
    if not yt.exists(lock + '/@host'):
        logger.info('[INFO] locking current host')
        yt.set(lock + '/@host', fqdn)
        yt.set(lock + '/@ts', time.time() + args.sleep * 3)
        return True

    host = yt.get(lock + '/@host')
    ts = yt.get(lock + '/@ts')

    if host != fqdn and time.time() > ts:
        logger.info('[INFO] locking current host')
        yt.set(lock + '/@host', fqdn)
        yt.set(lock + '/@ts', time.time() + args.sleep * 3)
        return True

    if host == fqdn:
        yt.set(lock + '/@ts', time.time() + args.sleep * 3)
        return True

    logger.info('[INFO] locked by another host')
    return False

# main
if __name__ == '__main__':
    args = parse_args()
    config = {
        'ch':  {
            'host': os.environ['CH_HOST'],
            'user': os.environ['CH_USER'],
            'password': os.environ['CH_PASS'],
        },
        'cbb': {
            'endpoint': os.environ['CBB_ENDPOINT'],
            'flag': os.environ['CBB_FLAG'],
            'expire_delay': int(os.environ['CBB_EXPIRE_DELAY']),
        },
        'tvm': {
            'base_url': os.environ['DEPLOY_TVM_TOOL_URL'] + '/tvm/tickets',
            'src': os.environ['TVM_SRC'],
            'dst': os.environ['TVM_DST'],
        },
        'ip_whitelist': [
            '127.0.0.1',
        ]
    }

    logger.info('[INFO] Started')
    init_lock()
    run()

    logger.info('[INFO] Done')
    handler.close()
    exit(0)
