import json,time,uuid,base64

import boto3
import botocore.exceptions


WAIT_PERIOD_S = 5
BUFFER_MS = 10000
DEFAULT_RETRIES = 10


class ExecutionStatus(object):
    IN_PROGRESS = 'InProgress'
    SUCCEEDED = 'Succeeded'
    FAILED = 'Failed'


class Timeout(Exception): pass


def get_version(s3_bucket, s3_key, region):
    """Get the current version of the artifact on S3."""
    s3 = boto3.client('s3', region_name=region)
    return s3.head_object(Bucket=s3_bucket, Key=s3_key)[u'VersionId']


def is_relevant_version(context, client, pipeline, execution, version):
    # This check happens here because it's the innermost loop of the polling logic.
    if context.get_remaining_time_in_millis() < BUFFER_MS:
        raise Timeout

    execution_details = client.get_pipeline_execution(pipelineName=pipeline, pipelineExecutionId=execution['pipelineExecutionId'])
    versions = (r['revisionId'] for r in execution_details['pipelineExecution']['artifactRevisions'])
    return version in versions


def get_pipeline_state(context, client, pipeline, version):
    """Query the pipeline for the state of previous executions and look for runs matching our artifact version.

    Returns an ExecutionStatus and failure details. Failure details will be None if the status is not FAILED.
    Failure details match the format for put_job_failure_result:
    http://boto3.readthedocs.io/en/latest/reference/services/codepipeline.html#CodePipeline.Client.put_job_failure_result
    """
    executions = client.list_pipeline_executions(pipelineName=pipeline, maxResults=10)['pipelineExecutionSummaries']

    relevant_executions = [
        execution for execution in executions
        if is_relevant_version(context, client, pipeline, execution, version)
    ]

    # Any successful execution for this version ==> success
    if any(e['status'] == ExecutionStatus.SUCCEEDED for e in relevant_executions):
        return ExecutionStatus.SUCCEEDED, None

    # No executions for this version (yet, hopefully) or any in progress execution ==> in progress
    if not relevant_executions or any(e['status'] == ExecutionStatus.IN_PROGRESS for e in relevant_executions):
        return ExecutionStatus.IN_PROGRESS, None

    # Must have failed
    first_failed_execution = [e for e in relevant_executions if e['status'] == ExecutionStatus.FAILED][0]
    return ExecutionStatus.FAILED, get_failure_details(client, pipeline, first_failed_execution['pipelineExecutionId'])


def make_continuation_token(job_id, try_number):
    """The continuation token lets us pass information between runs."""
    return base64.b64encode(json.dumps({
        'previous_id': job_id,
        'try_number': try_number,
    }))


def parse_try_number(job_data):
    """Pull the try number out of a continuation token.

    On the first run the token won't exist, that's expected.
    """
    try:
        return json.loads(base64.b64decode(job_data['continuationToken']))['try_number']
    except (KeyError, ValueError):
        return 1


def put_result(job_id, result, try_number, retries, failure_details):
    """Send a result off to codepipeline."""
    client = boto3.client('codepipeline')

    if result == ExecutionStatus.FAILED:
        client.put_job_failure_result(jobId=job_id, failureDetails=failure_details)
    elif result == ExecutionStatus.IN_PROGRESS:
        if try_number < retries:
            token = make_continuation_token(job_id, try_number + 1)
            client.put_job_success_result(jobId=job_id, continuationToken=token)
        else:
            fd = {
                'message': 'Pipeline took too long to complete. Configure more retries if necessary',
                'type': 'JobFailed',
            }
            client.put_job_failure_result(jobId=job_id, failureDetails=fd)
    else:
        client.put_job_success_result(jobId=job_id)


def poll_until_timeout(context, role_arn, pipeline, version, region):
    """Poll get_pipeline_state until we run out of time or get a terminal state."""
    if role_arn != '':
        sts = boto3.client('sts')
        resp = sts.assume_role(RoleSessionName=str(uuid.uuid4()), RoleArn=role_arn)
        creds = resp['Credentials']

        session = boto3.Session(
            aws_access_key_id=creds['AccessKeyId'],
            aws_secret_access_key=creds['SecretAccessKey'],
            aws_session_token=creds['SessionToken'],
        )
        client = session.client('codepipeline', region_name=region)
    else:
        client = boto3.client('codepipeline', region_name=region)

    try:
        result, failure_details = ExecutionStatus.IN_PROGRESS, None
        while result == ExecutionStatus.IN_PROGRESS:
            result, failure_details = get_pipeline_state(context, client, pipeline, version)
            time.sleep(WAIT_PERIOD_S)
    except Timeout:
        result = ExecutionStatus.IN_PROGRESS
        failure_details = {
            'message': 'Pipeline took too long to complete. Configure more retries if necessary',
            'type': 'JobFailed',
        }
    except botocore.exceptions.ClientError as e:
        # Botocore doesn't use types, so we have to dig into the ClientError to discover if we were throttled.
        # If the error is a ThrottlingException, we can sleep for a while and retry. Otherwise, reraise the error.
        error_code = e.get('Error', {}).get('Code', 'Unknown')
        if error_code == 'ThrottlingException':
            print 'Throttled, sleeping for 30 seconds.'
            time.sleep(30)
        else:
            raise


    return result, failure_details


def get_failure_details(client, pipeline, execution_id):
    """Get failure details by inspecting the current pipeline state.

    Unfortunately, we can't get historical details. Only the last execution. Once that's overridden, GG.
    """
    pipeline_state = client.get_pipeline_state(name=pipeline)
    default = {'type': 'JobFailed', 'message': 'Unknown error, check the pipeline in the console.'}

    # Iterate backward to find the latest stage in the pipeline with the right execution id.
    for stage in reversed(pipeline_state['stageStates']):
        if stage.get('latestExecution', {}).get('pipelineExecutionId') == execution_id:
            # Similarly, we want the latest action in the stage.
            for action in reversed(stage['actionStates']):
                if action['status'] == ExecutionStatus.FAILED:
                    details = action.get('errorDetails')
                    # Sadly, we have to rename this attribute.
                    if details:
                        details['type'] = details.pop('code')
                    return details or default
    return default


def handler(event, context):
    job = event['CodePipeline.job']
    job_id = job['id']
    src_account = job['accountId']
    data = job['data']

    bucket = data['inputArtifacts'][0]['location']['s3Location']['bucketName']

    params = json.loads(data['actionConfiguration']['configuration']['UserParameters'])

    role = params['iam_role']
    src_region = params['src_region']
    dst_region = params['dst_region']
    dst_account = params['dst_account']
    pipeline = params['pipeline']
    artifact = params['artifact']

    s3_bucket = bucket.replace(src_region, dst_region).replace(src_account, dst_account)
    version = get_version(s3_bucket, artifact, dst_region)
    print 'beginning to monitor version:', version

    try:
        retries = int(params['retries'])
    except (KeyError, ValueError):
        retries = DEFAULT_RETRIES
    try_number = parse_try_number(data)

    result, failure_details = poll_until_timeout(context, role, pipeline, version, dst_region)

    put_result(job_id, result, try_number, retries, failure_details)
    return result
