#!/usr/bin/env python

import argparse
import json
import sys
import requests
import socket
from subprocess import Popen, PIPE

STORMWATCH_ENDPOINT = 'https://prod.stormwatch.twitch.a2z.com'
DEFAULT_ENDPOINT = 'http://localhost:7331/health'
SQS_ENDPOINT = 'sqs.us-west-2.amazonaws.com'
OK = 0
WARN = 1
CRIT = 2
UNK = 3
_EXIT_CODE_TO_MSG_MAP = {
    OK: 'OK',
    WARN: 'WARN',
    CRIT: 'CRIT',
    UNK: 'UNK',
}

USER_ERROR_MODE='user'
OPS_ERROR_MODE='ops'


class Check(object):
    def __init__(self, mode=USER_ERROR_MODE, endpoint=DEFAULT_ENDPOINT):
        self.mode = mode
        self.endpoint = endpoint

    def check_stormwatch(self, secret_names=[]):
        for secret in secret_names:
            name = secret.get('name')
            url = '/'.join([STORMWATCH_ENDPOINT, 'secrets', name])

            try:
                stormwatch_resp = requests.get(url, timeout=30).text
                stormwatch_data = json.loads(stormwatch_resp)
            except requests.Timeout as e:
               return message(WARN, "Requests Read Exception: {}\n{}".format(e, self.netstats(STORMWATCH_ENDPOINT)))
            except socket.timeout as e:
               return message(WARN, "Socket Read Exception: {}\n{}".format(e, self.netstats(STORMWATCH_ENDPOINT)))
            except Exception as e:
                return message(UNK, "{} on endpoint: {}".format(e, url))

            expected_version = int(stormwatch_data.get('CurrentVersion', -1))
            current_version = int(secret.get('updated_at', 0))
            if current_version != expected_version:
                err_msg = 'secret {} version {}, expected version {}'.format(
                    name, current_version, expected_version)
                return message(CRIT, "{} on endpoint: {}".format(err_msg, url))

    def check_stormwatch_health(self):
        url = '/'.join([STORMWATCH_ENDPOINT, 'health'])
        try:
            code = requests.get(url, timeout=30).status_code
        except requests.Timeout as e:
            return message(WARN, "URL Exception: {}\n{}".format(e, self.netstats(STORMWATCH_ENDPOINT)))
        except socket.timeout as e:
           return message(WARN, "Socket Read Exception: {}\n{}".format(e, self.netstats(STORMWATCH_ENDPOINT)))
        except Exception as e:
            return message(WARN, "{} on endpoint: {}".format(e, url))
        if code != 200:
            return message(WARN, "{} on endpoint: {}".format(e, url))

    def netstats(self, endpoint):
        try:
            output_dig = Popen(['dig', endpoint], stdout=PIPE)
            if '//' in endpoint:
                endpoint = endpoint.split('//')[1]
            output_traceroute = Popen(['traceroute', endpoint], stdout=PIPE)
            net_stats_output = "{}\n{}".format(output_dig.stdout.read(), output_traceroute.stdout.read())
        except OSError as e:
            net_stats_output = "Netstats function failed."
        return net_stats_output

    def check(self):
        try:
            sandstorm_resp = requests.get(self.endpoint, timeout=30)
            data = json.loads(sandstorm_resp.text)
        except requests.Timeout as e:
            return message(WARN, "URL Exception: {}\n{}".format(e, self.netstats(SQS_ENDPOINT)))
        except socket.timeout as e:
           return message(WARN, "Socket Read Exception: {}\n{}".format(e, self.netstats(SQS_ENDPOINT)))
        except Exception as e:
            if self.mode == USER_ERROR_MODE:
                return message(CRIT, "{} on endpoint: {}".format(e, self.endpoint))
            elif self.mode == OPS_ERROR_MODE:
                return message(WARN, "{} on endpoint: {}".format(e, self.endpoint))

        recent_error = data.get('RecentError', '')
        if recent_error:
            if (self.mode == USER_ERROR_MODE and not data.get('system_error', True)) or (self.mode == OPS_ERROR_MODE and data.get('system_error', False)):
                return message(CRIT, "recent agent error: {}".format(recent_error))

        if self.mode == USER_ERROR_MODE:
            self.check_stormwatch(data.get('SecretStatus', []))
        elif self.mode == OPS_ERROR_MODE:
            self.check_stormwatch_health()

        if self.mode == USER_ERROR_MODE:
            return message(OK, 'all secrets up to date')
        elif self.mode == OPS_ERROR_MODE:
            return message(OK, 'SeemsGood')


def message(retval, msg):
    print("{}:Sandstorm: {}".format(_EXIT_CODE_TO_MSG_MAP.get(retval), msg))
    sys.exit(retval)


def parse_args():
    parser = argparse.ArgumentParser(
        description='sandstorm agent secret version check')
    parser.add_argument(
        '-e', '--endpoint', required=False, dest='endpoint',
        default=DEFAULT_ENDPOINT,
        help='local sandstorm-agent healthcheck endpoint')
    parser.add_argument(
        '-m', '--mode', required=False, dest='mode', default=USER_ERROR_MODE,
        help='set this flag to user or ops to filter for either error type')
    args = parser.parse_args()
    return args


def main(args):
    Check(args.mode, args.endpoint).check()


if __name__ == '__main__':
    main(parse_args())
