import os
import json
import boto3
import logging

env = {
    'debug': os.getenv('DEBUG'),
    'alarm_sns_arn': os.getenv('ALARM_SNS_ARN'),
}
logger = logging.getLogger()
ec2 = boto3.client('ec2')
cw = boto3.client('cloudwatch')
initialized = False


def init():
    global logger, env, initialized
    if not env['alarm_sns_arn']:
        initialized = False
        return

    if env['debug'] or (not initialized):
        logger.setLevel(logging.DEBUG)
    else:
        logger.setLevel(logging.INFO)

    initialized = True


def lambda_handler(event, context):
    logger.debug(env)
    if not initialized:
        logger.error("not initialized")
        return {
            'statusCode': 500,
        }

    logger.debug('Received event: ' + json.dumps(event, indent=2))
    e = parse_event(event)
    logger.info('EC2 State Change Detected: ' + json.dumps(e, indent=2))

    if e['instance_state'] == 'running':
        install_alarms(e)
    elif e['instance_state'] == 'terminated':
        remove_alarms(e)
    elif e['instance_state'] == 'all':
        install_all_alarms()
    else:
        logger.debug('doing nothing')

    return {
        'statusCode': 200,
    }


def install_all_alarms():
    ids = get_all_instance_ids()
    for id in ids:
        install_alarms(id)


def install_alarms(e):
    logger.info('Installing alarms for ' + e['instance_id'])
    instance = get_instance_description(e['instance_id'])
    cw.put_metric_alarm(
        AlarmName='%s_CPUUtilization' % e['instance_id'],
        OKActions=[env['alarm_sns_arn']],
        AlarmActions=[env['alarm_sns_arn']],
        Namespace='AWS/EC2',
        MetricName='CPUUtilization',
        Dimensions=[
            {
                'Name': 'InstanceId',
                'Value': e['instance_id']
            },
        ],
        Statistic='Average',
        Period=60,
        ComparisonOperator='GreaterThanThreshold',
        Threshold=90,
        EvaluationPeriods=10,
        DatapointsToAlarm=5,
        TreatMissingData='missing',
    )
    cw.put_metric_alarm(
        AlarmName='%s_disk_used_percent' % e['instance_id'],
        OKActions=[env['alarm_sns_arn']],
        AlarmActions=[env['alarm_sns_arn']],
        Namespace='CWAgent',
        MetricName='disk_used_percent',
        Dimensions=[
            {'Name': 'path', 'Value': '/'},
            {'Name': 'InstanceId', 'Value': e['instance_id']},
            {'Name': 'AutoScalingGroupName', 'Value': get_instance_tag(instance, 'aws:autoscaling:groupName')},
            {'Name': 'ImageId', 'Value': instance['ImageId']},
            {'Name': 'InstanceType', 'Value': instance['InstanceType']},
            {'Name': 'device', 'Value': 'nvme0n1p1'},
            {'Name': 'fstype', 'Value': 'xfs'},
        ],
        Statistic='Maximum',
        Period=60,
        ComparisonOperator='GreaterThanThreshold',
        Threshold=90,
        EvaluationPeriods=10,
        DatapointsToAlarm=3,
        TreatMissingData='missing',
    )
    cw.put_metric_alarm(
        AlarmName='%s_mem_used_percent' % e['instance_id'],
        OKActions=[env['alarm_sns_arn']],
        AlarmActions=[env['alarm_sns_arn']],
        Namespace='CWAgent',
        MetricName='mem_used_percent',
        Dimensions=[
            {'Name': 'InstanceId', 'Value': e['instance_id']},
            {'Name': 'AutoScalingGroupName', 'Value': get_instance_tag(instance, 'aws:autoscaling:groupName')},
            {'Name': 'ImageId', 'Value': instance['ImageId']},
            {'Name': 'InstanceType', 'Value': instance['InstanceType']},
        ],
        Statistic='Maximum',
        Period=60,
        ComparisonOperator='GreaterThanThreshold',
        Threshold=90,
        EvaluationPeriods=10,
        DatapointsToAlarm=5,
        TreatMissingData='missing',
    )


def remove_alarms(e):
    logger.info("Removing alarms for " + e['instance_id'])
    cw.delete_alarms(
        AlarmNames=[
            '%s_CPUUtilization' % e['instance_id'],
            '%s_disk_used_percent' % e['instance_id'],
            '%s_mem_used_percent' % e['instance_id'],
        ]
    )


def parse_event(event):
    e = {
        'region': event['region'],
        'instance_id': event['detail']['instance-id'],
        'instance_state': event['detail']['state'],
    }
    assert e['region'] != ''
    assert e['instance_id'] != ''
    assert e['instance_state'] != ''
    return e


def get_all_instance_ids():
    instances = ec2.describe_instances()
    ids = []
    for r in instances['Reservations']:
        for i in r['Instances']:
            ids.append(i['InstanceId'])
    logger.debug(ids)
    return ids


def get_instance_description(instance_id):
    instances = ec2.describe_instances(InstanceIds=[instance_id])
    # no error handling is necessary but let it error then log the stack trace
    return instances['Reservations'][0]['Instances'][0]


def get_instance_tag(instance, tagname):
    tags = instance['Tags']
    for tag in tags:
        if tag["Key"] == tagname:
            return tag["Value"]
    return ""


init()
