"""Collect cloudwatch information and store it in graphite."""

import json
import logging
import os
import random
import socket
import sys
import time

from datetime import datetime, timedelta
import dateutil.parser

import boto3
import botocore
import redis

GRAPHITE_HOST = "graphite.internal.justin.tv"
GRAPHITE_PORT = 2003
SESSION_TTL = 1800
METRIC_PREFIX_TEMPLATE = "cloudwatch.{acct}.{region}.{t}.{postfix}"

MASTER_SESSION_MAP = {}
ACCOUNT_SESSION_MAP = {}

METRICS_TO_COLLECT = [
    'RequestCount',
    'HTTPCode_Backend_5XX',
    'HTTPCode_ELB_5XX',
    'Latency'
]
METRICS_TO_COLLECT_ALB = [
    'RequestCount',
    'HTTPCode_Target_5XX_Count',
    'HTTPCode_ELB_5XX_Count',
    'TargetResponseTime',
]
ROLES = [{"master_role": "arn:aws:iam::578510050023:role/twitch-inventory-master",
          "account_role_template": "arn:aws:iam::{acct}:role/twitch-inventory",
          "master_account": "578510050023"},
         {"master_role": "arn:aws:iam::007917851548:role/master-inventory-read-only",
          "account_role_template": "arn:aws:iam::{acct}:role/inventory-read-only",
          "master_account": "007917851548"}]

LOGGER = logging.getLogger()


def setup_logging(log_level):
    """Set up logging for the given log level"""
    for handler in LOGGER.handlers:
        LOGGER.removeHandler(handler)

    handler = logging.StreamHandler(sys.stdout)

    # use whatever format you want here
    formatter = logging.Formatter('%(levelname)s - %(message)s - %(filename)s - %(lineno)d')
    handler.setFormatter(formatter)
    LOGGER.addHandler(handler)
    LOGGER.setLevel(log_level)

    return LOGGER


def lambda_handler(event, _):
    """Process the lambda event"""
    # event format:
    # {
    #   "timestamp": <iso8601 string>,
    #   "resources": [
    #     {
    #       "awsaccountid": ...,
    #       "awsaccountname": ...,
    #       ...
    #     },
    #     ...
    #   ]
    # }
    setup_logging(logging.INFO)
    boto3.set_stream_logger('botocore', level='INFO')
    LOGGER.info("Starting lambda_handler")

    scrape_errors = []
    try:
        start_time, end_time = format_timestamps(event['timestamp'])
    except KeyError:
        LOGGER.info("no 'timestamp' in event payload: %s", event)
        return

    # If we run into errors, gobble them up and stash them, then throw them at the end
    for resource in event['resources']:
        # Validate resource data
        try:
            validate_resource_data(resource)
        except KeyError as err:
            LOGGER.info("invalid resource data; resource=%s; reason=%s", resource, err)
            scrape_errors.append({
                'resource': resource,
                'error': str(err)
            })
            continue
        try:
            expand_resource_data(resource)
        except Exception as err:
            LOGGER.info("error retrieving additional required data: %s", err)
            scrape_errors.append({
                'resource': resource,
                'error': str(err)
            })
            continue

        metric_prefix = format_metric_prefix(resource)
        resource_type = resource.get('type')
        if resource_type == "beanstalk":
            for elb in resource.get('elbnames', []):
                resource['elbname'] = elb
                scrape_errors.extend(produce_metrics(resource, resource['elbnames'][elb], start_time, end_time))
        else:
            scrape_errors.extend(produce_metrics(resource, metric_prefix, start_time, end_time))

    # Any errors encountered during this lambda should be grouped and emitted
    # as an unhandled error for complete logging purposes and correctness of
    # AWS Lambda metrics
    expose_errors(scrape_errors)


def produce_metrics(resource, metric_prefix, start_time, end_time):
    """Send metrics to graphite and return a list of errors."""
    scrape_errors = []
    # Generate a cloudwatch client
    cloudwatch_client = generate_cloudwatch_client(resource['awsaccountid'], resource['awsregion'])
    if not cloudwatch_client:
        return scrape_errors
    try:
        data = fetch_metric_data(cloudwatch_client, resource, start_time, end_time)
    except Exception as err:
        LOGGER.info("error scraping cloudwatch data: %s", err)
        scrape_errors.append({
            'resource': resource,
            'error': str(err)
        })
        return scrape_errors

    # Format metrics to send, given prefix + datapoints
    graphite_lines = format_metric_lines(metric_prefix, data)

    # Send data to graphite
    send_metrics(graphite_lines)
    return scrape_errors


def format_timestamps(timestamp):
    """Return a 1 minute interval prior to the given timestamp"""
    end = dateutil.parser.parse(timestamp).replace(tzinfo=None)
    # Go back 1 minute
    delta = timedelta(seconds=60)
    start = end - delta
    return start, end


def validate_resource_data(resource):
    """Make sure common fields are set"""
    for field in ['awsaccountid', 'awsaccountname', 'awsregion']:
        check_field_exists(resource, field)

    resource_type = resource.get('type')
    if resource_type is not None:
        if resource_type == "elb":
            for field in ['elbname']:  # Add more fields here later
                check_field_exists(resource, field)
        elif resource_type == "beanstalk":
            for field in ['applicationname']:  # Add more fields here later
                check_field_exists(resource, field)
        elif resource_type == 'alb':
            for field in ['albname']:
                check_field_exists(resource, field)
        else:
            raise NameError("'type' must be 'elb' or 'beanstalk' or 'alb'")
    else:
        raise KeyError("type not found")


def expand_resource_data(resource):
    """Collect the specific detailed resource data for the given resource"""
    resource_type = resource.get('type')

    if resource_type == "beanstalk":
        get_beanstalk_information(resource)
    elif resource_type == 'alb':
        resource['albid'] = get_alb_information(resource)


def get_alb_information(resource):
    """Collect information for the given alb (eblv2) resource"""
    elbv2 = get_boto_client(resource['awsaccountid'], resource['awsregion'], 'elbv2')
    if elbv2:
        try:
            albs = elbv2.describe_load_balancers(Names=[resource['albname']])
            if len(albs['LoadBalancers']) == 1:
                return albs['LoadBalancers'][0]['LoadBalancerArn'].split('/')[-1]
            elif len(albs['LoadBalancers']) > 1:
                LOGGER.warn("Multiple load balancers were found with name: %s", resource['albname'])
            elif not albs['LoadBalancers']:
                LOGGER.warn("No load balancers were found with name: %s", resource['albname'])
        except botocore.exceptions.ClientError as err:
            LOGGER.warn("ClientError accessing: %s, %s", resource['albname'], err)


def get_beanstalk_information(resource):
    """Record information about the given beanstalk resource"""
    beanstalk = get_boto_client(resource['awsaccountid'], resource['awsregion'], 'elasticbeanstalk')
    if beanstalk:
        redis_session = get_redis_session()
        redis_key = "beanstalk:{applicationname}".format(applicationname=resource['applicationname'])
        elbjson = redis_session.get(redis_key)
        if elbjson is None:
            try:
                bs_info = beanstalk.describe_environments(ApplicationName=resource['applicationname'])
                resource['elbnames'] = {}
                for env in bs_info['Environments']:
                    env_resources = beanstalk.describe_environment_resources(EnvironmentId=env['EnvironmentId'])
                    for load_balancer in env_resources['EnvironmentResources']['LoadBalancers']:
                        resource['elbnames'][load_balancer['Name']] = None
                redis_session.setex(redis_key, random.randint(2*60, 5*60), json.dumps(resource['elbnames']))
            except botocore.exceptions.ClientError as err:
                LOGGER.warn("Client error accessing load balancer: %s, %s", resource['applicationname'], err)
        else:
            resource['elbnames'] = json.loads(elbjson)

        if not resource['elbnames']:
            LOGGER.warn("No load balancers were found with application name: %s", resource['applicationname'])


def check_field_exists(resource, field):
    """Check the given resource data to see if the given field exists"""
    if field not in resource:
        raise KeyError("'{f}' not found".format(f=field))


def format_metric_prefix(resource):
    """Returns a metric prefix for the given resource"""
    resource_type = resource.get('type')

    if resource_type == "elb":
        postfix = resource['elbname']
    elif resource_type == "beanstalk":
        for elb in resource.get('elbnames', []):
            postfix = "{app}.{elb}".format(
                app=resource['applicationname'],
                elb=elb
            )
            resource['elbnames'][elb] = METRIC_PREFIX_TEMPLATE.format(
                acct=resource['awsaccountname'],
                region=resource['awsregion'],
                t=resource_type,
                postfix=postfix
            )
        return
    elif resource_type == 'alb':
        postfix = resource['albname']

    return METRIC_PREFIX_TEMPLATE.format(
        acct=resource['awsaccountname'],
        region=resource['awsregion'],
        t=resource_type,
        postfix=postfix
    )


def format_metric_lines(metric_prefix, data):
    """Returns formatted list of metric lines for the given metric_prefix and data."""
    lines = []
    # Create an epoch timestamp to use in metric lines
    epoch = datetime.utcfromtimestamp(0).replace(tzinfo=None)
    for metric in data:
        datapoint = metric['Datapoints'][0]
        timestamp = datapoint['Timestamp'].replace(tzinfo=None)
        timestamp = int((timestamp - epoch).total_seconds())
        metric_name = metric['Label']
        # Grab the right statistic from the datapoint based on the metric label
        # If it's latency, get the quantile data
        if metric_name in ['Latency', 'TargetResponseTime']:
            average = datapoint['Average']
            lines.append("{pre}.{name}_average {v} {ts}".format(pre=metric_prefix, name=metric_name, v=average, ts=timestamp))
            # Extra defensive programming here
            for point in ['p50', 'p90', 'p99']:
                pdata = datapoint['ExtendedStatistics'].get(point, float(0))
                lines.append("{pre}.{name}_{p} {v} {ts}".format(pre=metric_prefix, name=metric_name, p=point, v=pdata, ts=timestamp))
        else:
            value = datapoint['Sum']
            lines.append("{pre}.{name} {v} {ts}".format(pre=metric_prefix, name=metric_name, v=value, ts=timestamp))
    return lines


# TODO: Better individual datapoint fetch error handling
def fetch_metric_data(cloudwatch_client, resource, start_time, end_time):
    """Collect metric data for the given resource and time window."""
    data = []
    resource_type = resource.get('type')
    # ALBs have a different dimension scheme
    if resource_type == 'alb':
        metrics = METRICS_TO_COLLECT_ALB
        dimension = "LoadBalancer"
        lbname = "app/{name}/{id}".format(name=resource['albname'], id=resource['albid'])
        namespace = 'AWS/ApplicationELB'
    else:
        metrics = METRICS_TO_COLLECT
        dimension = "LoadBalancerName"
        lbname = resource['elbname']
        namespace = 'AWS/ELB'

    for metric in metrics:
        try:
            # Hit the CloudWatch API to get 1 minute of data (1 datapoint)
            # Timing data
            if metric not in ['Latency', 'TargetResponseTime']:
                results = cloudwatch_client.get_metric_statistics(
                    Namespace=namespace,
                    MetricName=metric,
                    Dimensions=[
                        {
                            "Name": dimension,
                            "Value": lbname
                        }
                    ],
                    StartTime=start_time,
                    EndTime=end_time,
                    Period=60,
                    Statistics=['Sum']
                )
            # Counter data
            else:
                results = cloudwatch_client.get_metric_statistics(
                    Namespace=namespace,
                    MetricName=metric,
                    Dimensions=[
                        {
                            "Name": dimension,
                            "Value": lbname
                        }
                    ],
                    StartTime=start_time,
                    EndTime=end_time,
                    Period=60,
                    Statistics=['Average'],
                    ExtendedStatistics=['p50', 'p90', 'p99'],
                )
            # Skip this metric if there are more than 1 datapoint received from CloudWatch.
            # If 0 datapoints, make a synthetic zero datapoint
            if not results['Datapoints']:
                results['Datapoints'].append({
                    'Average': float(0),
                    'Sum': float(0),
                    'ExtendedStatistics': {
                        'p50': float(0),
                        'p90': float(0),
                        'p99': float(0)
                    },
                    'Timestamp': start_time
                })
            if len(results['Datapoints']) > 1:
                raise ValueError("CloudWatch scrape did not return exactly 1 datapoint")
            else:
                data.append(results)
        # Swallow all exceptions and log them
        except Exception as err:
            LOGGER.info("error fetching data for metric '%s': %s", metric, err)
    return data


def get_boto_client(account_id, region, client):
    """Use the the inventory cross-account credential to access the destination account."""
    timestamp = int(time.time())
    LOGGER.info("Account ID: %s", account_id)
    def valid_session(session_map, acct_id):
        """Helper to verify session status"""
        return acct_id in session_map and timestamp - session_map[acct_id]['timestamp'] < SESSION_TTL

    if valid_session(ACCOUNT_SESSION_MAP, account_id):
        LOGGER.info("Reusing previous account session: %s", account_id)
        return ACCOUNT_SESSION_MAP[account_id]['session'].client(client, region_name=region)

    for role_data in ROLES:
        LOGGER.info("Checking role: %s", role_data)
        if valid_session(MASTER_SESSION_MAP, role_data['master_account']):
            base_session = MASTER_SESSION_MAP[role_data['master_account']]['session']
        else:
            base_sts = boto3.client('sts')
            LOGGER.info("Using identity: %s", base_sts.get_caller_identity())
            LOGGER.info("Assuming role: %s", role_data['master_role'])
            base_creds = base_sts.assume_role(RoleArn=role_data['master_role'], RoleSessionName="c2g-scraper-master")
            base_token = base_creds['Credentials']
            base_session = boto3.session.Session(
                aws_access_key_id=base_token['AccessKeyId'],
                aws_secret_access_key=base_token['SecretAccessKey'],
                aws_session_token=base_token['SessionToken']
            )
            MASTER_SESSION_MAP[role_data['master_account']] = {
                'timestamp': timestamp,
                'session': base_session,
            }
        # Make a cross-account assume role
        sts = base_session.client('sts')
        try:
            # Use the base session STS client to make creds in the cross-account
            role_arn = role_data['account_role_template'].format(acct=account_id)
            LOGGER.info("Checking role: %s", role_arn)
            creds = sts.assume_role(RoleArn=role_arn, RoleSessionName="c2g-scraper-account")
            token = creds['Credentials']
            session = boto3.session.Session(
                aws_access_key_id=token['AccessKeyId'],
                aws_secret_access_key=token['SecretAccessKey'],
                aws_session_token=token['SessionToken']
            )
            ACCOUNT_SESSION_MAP[account_id] = {'timestamp': timestamp, 'session': session}
            LOGGER.info("Success accessing role session for: %s, using role: %s", account_id, role_data['master_role'])
            break
        except botocore.exceptions.ClientError:
            LOGGER.info("Client error accessing role session for: %s, using role: %s", account_id, role_data['master_role'])
    if ACCOUNT_SESSION_MAP.get(account_id, {}).get('session'):
        return ACCOUNT_SESSION_MAP[account_id]['session'].client(client, region_name=region)
    return None


def generate_cloudwatch_client(account_id, region):
    """Get the a boto3 cloudwatch client for the given account id, and region."""
    return get_boto_client(account_id, region, 'cloudwatch')


def send_metrics(lines):
    """Send the given metric lines to graphite."""
    sock = socket.socket()
    sock.connect((GRAPHITE_HOST, GRAPHITE_PORT))
    sock.sendall("\n".join(lines) + "\n")
    sock.close()


def expose_errors(errors):
    """Raise any errors which have been collected."""
    if errors:
        error_messages = {'errors': errors}
        raise RuntimeError(json.dumps(error_messages))


def get_redis_session():
    """Get a redis session for the currently set environment variable."""
    try:
        env = os.environ['environment']
    except KeyError:
        return None

    if env == "production":
        endpoint = "c2g-describe-prod.qwrxfk.0001.usw2.cache.amazonaws.com"
    elif env == "staging":
        endpoint = "c2g-describe-staging.xvzvze.0001.usw2.cache.amazonaws.com"
    else:
        raise NameError("no such environment '{env}'".format(env=env))

    session = redis.StrictRedis(host=endpoint, port=6379, db=0)

    return session
