from datetime import datetime
import boto3
import json
import math
import time

SUCCESS = 0
FAILED = 1
NED = 2 #not enough data

START = "start"
PERIOD_S = 60
CONSECUTIVE_SUCCESS_THRESHOLD = 2
MINIMUM_DATA_POINTS = 4

def get_cloudwatch_metrics(client, app, start):
    env_dim = {
        'Name': 'EnvironmentName',
        'Value': app
    }
    resp = client.get_metric_statistics(Namespace='AWS/ElasticBeanstalk',
                                        MetricName='EnvironmentHealth',
                                        Dimensions=[env_dim],
                                        Period=PERIOD_S,
                                        Unit='None',
                                        StartTime=start,
                                        EndTime=datetime.now(),
                                        Statistics=['Sum'])

    sorted_data = sorted(resp[u'Datapoints'], key=lambda x: x[u'Timestamp'])
    return sorted_data

def enough_datapoints(data):
    return len(data) >= MINIMUM_DATA_POINTS

def breach_threshold(data, threshold):
    breach = threshold * len(data)
    for datum in data:
        if datum[u'Sum'] > 0:
            breach -= 1

    return breach <= 0

def recent_failures(data):
    data_len = len(data)
    for i in range (1, CONSECUTIVE_SUCCESS_THRESHOLD + 1):
        if data[data_len - i][u'Sum'] > 0:
            return True
    return False

def check_beanstalk_health(client, app, start, threshold):
    data = get_cloudwatch_metrics(client, app, start)
    if not enough_datapoints(data):
        return NED

    if breach_threshold(data, threshold) or recent_failures(data):
        return FAILED

    return SUCCESS

def monitor_beanstalk_app(app, start, total, threshold):
    client = boto3.client('cloudwatch')

    result = check_beanstalk_health(client, app, threshold)
    while result == NED:
        time.sleep(PERIOD_S)
        result = check_beanstalk_health(client, app, threshold)

    if result == FAILED:
        return FAILED
    else:
        time.sleep(PERIOD_S)

    return SUCCESS

def parse_json(cp, blob, job_id):
    if blob != '':
        try:
            json_blob = json.loads(blob)
        except Exception as e:
            print e
            cp.put_job_failure_result(jobId=job_id)
        return json_blob
    else:
        return {}

def handler(event, context):
    cp = boto3.client("codepipeline")
    job_id = event["CodePipeline.job"]["id"]
    data = event["CodePipeline.job"]["data"]
    params = parse_json(cp, data["actionConfiguration"]["configuration"]["UserParameters"], job_id)
    app = params['app_name']
    total_time_min = params['total_time']
    threshold = params['threshold']
    start = datetime.now()
    if "continuationToken" in data:
        token = parse_json(cp, data["continuationToken"], job_id)
        if len(token.keys()) > 0:
            start = datetime.strptime(token[START], '%Y-%m-%d %H:%M:%S.%f')
        else:
            token = {START: start}
    else:
        token = {START: start}

    result = monitor_beanstalk_app(app, start, total_time_min, threshold)
    elapsed = datetime.now() - start
    if result == FAILED:
        cp.put_job_failure_result(jobId=job_id)
    elif math.ceil(elapsed.total_seconds() / PERIOD_S) > total_time_min:
        cp.put_job_success_result(jobId=job_id, continuationToken=json.dumps(token))
    else:
        cp.put_job_success_result(jobId=job_id)
    return result
