#!/usr/bin/env python

from __future__ import unicode_literals, absolute_import, print_function

import ssl
import json
import argparse
import redis
import uuid


def parse_args():
    parser = argparse.ArgumentParser()
    igroup = parser.add_mutually_exclusive_group(required=True)
    igroup.add_argument("-f", "--file", help='Input file')
    igroup.add_argument("-t", "--targets", help='Targets')
    parser.add_argument("-o", "--output", help='Output file name')
    return parser.parse_args()

def read_targets_from_file(filepath):
    """Read targets from file. File must have a specific format: json dict per line:
    {"result":{"dest_ip":"127.0.0.1", "dest_port":[443]}}
    {"result":{"dest_ip":"127.0.0.2", "dest_port":[443]}}
    """
    fp = open(filepath, 'r')
    data = fp.read()
    fp.close()

    results = list()

    for line in data.split('\n'):
        try:
            line = json.loads(line)
        except ValueError:
            continue

        result = line.get('result')
        if not result:
            continue

        ip = result.get('dest_ip')
        ports = result.get('ports')
        if not ip or not ports:
            continue

        if not isinstance(ports, list):
            ports = [ports]

        for port in ports:
            if port.isdigit():
                results.append((ip, int(port)))

    return results


def read_targets_from_cmdline(targetsline):
    results = list()
    targets = targetsline.split(',')

    for target in targets:
        stripped = target.strip()
        splitted = stripped.rsplit(':', 1)
        if len(splitted) == 2 and splitted[1].isdigit():
            results.append((splitted[0], int(splitted[1])))
        else:
            results.append((stripped, 6379))
            results.append((stripped, 6380))

    return results


def try_redis(ip, port, use_ssl=False):

    result = {
        'ip': ip,
        'port': port,
        'enabled': True,
        'version': None,
        'auth_required': None,
        'unknown_exception': False,
        'protected_mode': False,
    }

    if use_ssl:
        client = redis.StrictRedis(host=ip, port=port, db=0, socket_timeout=2, socket_connect_timeout=2,
                                   ssl=True, ssl_ca_certs=ssl.CERT_NONE)
    else:
        client = redis.StrictRedis(host=ip, port=port, db=0, socket_timeout=3, socket_connect_timeout=3)

    try:
        ver = client.info().get('redis_version')
        if ver:
            result['version'] = ver
            result['enabled'] = True

        if client.get(str(uuid.uuid4())) is None:
            result['auth_required'] = False
            result['enabled'] = True
        else:
            result['unknown_exception'] = True

    except redis.exceptions.ResponseError as e:
        result['enabled'] = True

        if e.message == 'NOAUTH Authentication required.':
            result['auth_required'] = True
        elif "Redis is running in protected mode" in e.message:
            result['protected_mode'] = True
        else:
            result['unknown_exception'] = True

    except redis.exceptions.TimeoutError:
        result['enabled'] = False
        pass

    except redis.exceptions.ConnectionError:
        result['enabled'] = False
        pass

    except KeyboardInterrupt as e:
        raise e

    except Exception:
        result['unknown_exception'] = True

    return result


def check_one(ip, port):
    ip_ = ip[1:-1] if ip[0] == '[' and ip[-1] == ']' else ip

    res = try_redis(ip_, port, use_ssl=False)
    if res['enabled']:
        res['ssl'] = False
        return res

    res = try_redis(ip_, port, use_ssl=True)
    if res['enabled']:
        res['ssl'] = True
        return res

    res['ssl'] = None
    
    return res


def check_all(targets):
    results = list()

    for target in targets:
        res = check_one(target[0], target[1])
        results.append(res)

    return results


def main():
    args = parse_args()

    if args.file:
        targets = read_targets_from_file(args.file)
    else:
        targets = read_targets_from_cmdline(args.targets)

    results = check_all(targets)

    if not args.output:
        print(json.dumps(results))
    else:
        fp = open(args.output, 'w')
        fp.write(json.dumps(results))
        fp.close()


if __name__ == '__main__':
    main()
