from __future__ import print_function
import json
# you can import stuff here after installing it in install_deps.sh !
from datetime import datetime
import datetime
import logging
import os
import pg8000
import requests
import sys

from twitch.malachai import S2SAuth

from environment import creds, db, catalog

# Latencies that we care about
latency_percentiles = ['50', '90', '99']
logger = logging.getLogger()


def setup_logging(log_level):
    for h in logger.handlers:
        logger.removeHandler(h)

    h = logging.StreamHandler(sys.stdout)

    # use whatever format you want here
    formatter = logging.Formatter('%(levelname)s - %(message)s')
    h.setFormatter(formatter)
    logger.addHandler(h)
    logger.setLevel(log_level)

    return logger


def lambda_handler(event, context):
    setup_logging(event.get('LOG_LEVEL', logging.WARN))
    # Fetch username and password for redshift access
    user, password = creds.get_redshift_creds()
    if not user or not password:
        raise RuntimeError("redshift username and password could not be read from environment")
    # Format a connection string
    (host, port, db_name) = db.get_db_info()

    # Use pg8000 to connect to Redshift
    c = pg8000.connect(
        database=db_name,
        host=host,
        port=port,
        user=user,
        password=password,
        ssl=True
    )

    # TODO: abstract out the two query blocks
    # Get all the ELBs
    records = get_query_results(c, "SELECT * FROM aws_elb")
    # Turn raw records into dicts
    elbs = map(lambda x: elb_record_to_dict(x), records)
    # Create accounts in the catalogDB if they are missing
    errors = create_missing_accounts(map(lambda elb: (elb['awsaccountid'], elb.get('awsaccountname', None)), elbs))
    if errors != {}:
        logger.error(json.dumps(errors))

    # Batch and flush resource data
    errors = submit_to_catalog(elbs)
    if errors != {}:
        logger.error(json.dumps(errors))

    # Get all the Beanstalk Apps
    records = get_query_results(c, "SELECT * FROM aws_bs")
    # Turn raw records into dicts
    beanstalks = map(lambda x: beanstalk_record_to_dict(x), records)
    # Create accounts in the catalogDB if they are missing
    errors = create_missing_accounts(map(lambda bs: (bs['awsaccountid'], bs.get('awsaccountname', None)), beanstalks))
    if errors != {}:
        logger.error(json.dumps(errors))

    errors = submit_to_catalog(collapse_beanstalks(beanstalks))
    if errors != {}:
        logger.error(json.dumps(errors))

    # Get all the Application Load Balancers
    records = get_query_results(c, "SELECT * FROM aws_alb")
    # Turn raw records into dicts
    albs = map(lambda x: alb_record_to_dict(x), records)
    # Create accounts in the catalogDB if they are missing
    errors = create_missing_accounts(map(lambda alb: (alb['awsaccountid'], alb.get('awsaccountname', None)), albs))
    if errors != {}:
        logger.error(json.dumps(errors))

    errors = submit_to_catalog(albs)
    if errors != {}:
        logger.error(json.dumps(errors))

    # Get Backends of Visage in us-west-2 (we can assume that all back ends are deployed there)
    resp = requests.get("http://graphite-web.internal.justin.tv/metrics/find?query=stats.counters.visage.production.us-west-2.all.hystrix.*")
    resp.raise_for_status()
    # Filter out the bogus things in the graphite bucket
    client_data = filter(lambda v: v['text'] != 'get_authorization_token' and v['text'] != 'ratelimit', resp.json())
    # Map the retrieved data to visage clients
    visage_clients = map(lambda v: visage_client_to_dict(v['text']), client_data)

    errors = submit_to_catalog(visage_clients)
    if errors != {}:
        logger.error(json.dumps(errors))

def get_query_results(conn, query):
    curr = conn.cursor()
    curr.execute(query)
    return curr.fetchall()


def collapse_beanstalks(beanstalk_elbs):
    results = []
    for bs_elb in beanstalk_elbs:
        for app in results:
            if app['applicationname'] == bs_elb['applicationname']:
                continue
        results.append(bs_elb)
    return results

def elb_record_to_dict(record):
    return {
        'source_table': 'aws_elb',
        'awsaccountid': record[0],
        'awsaccountname': record[1],
        'elbname': record[2],
        'isbeanstalk': record[3] == 'True',
        'awsregion': record[4]
    }


def beanstalk_record_to_dict(record):
    return {
        'source_table': 'aws_bs',
        'awsaccountid': record[0],
        'awsaccountname': record[1],
        'applicationname': record[2],
        'elbname': record[3],
        'awsregion': record[4]
    }


def alb_record_to_dict(record):
    return {
        'source_table': 'aws_alb',
        'awsaccountid': record[0],
        'awsaccountname': record[1],
        'albname': record[2],
        'albid': record[3],
        'isbeanstalk': record[4] == 'True',
        'awsregion': record[5]
    }


def visage_client_to_dict(name):
    return {
        'source_table': 'visage', # not actually a source table but YOLO
        'name': name
    }

# Tries to submit all components to the catalog, skipping any if they
# are already created.
# Will return a list of error messages retrieved mapped out by component name/label
def submit_to_catalog(components):
    url_base = catalog.get_catalog_endpoint()
    logger.info("Using catalog endpoint {0}".format(url_base))
    errors = {}
    for component in components:
        # Skip beanstalk ELBs
        if component['source_table'] == 'aws_elb' and component['isbeanstalk']:
            continue

        # Turn the component into some formatted maps
        # we can send off to the service catalog as JSON
        data = format_catalog_json(component)

        # Format a bunch of URLs we need to hit
        component_url = url_base + "/components"
        component_lookup_url = component_url + "?label=" + data['component']['label']
        metric_url = url_base + "/metrics"
        query_url = url_base + "/queries"
        account_url = url_base + "/accounts"


        # Try to GET the component. If it exists, skip it...
        logger.info("trying to retrieve component {0}".format(component))
        resp = requests.get(component_lookup_url,
                headers={'Connection':'close'})
        # This route should always return a 200, even if the component doesnt exist
        if resp.status_code != 200:
            errors[data['component']['name']] = resp.text
            logger.error("Could not find component {0}".format(resp.text))
            continue
        # Turn the response into JSON. We expect a list of components
        # Really, this list will be length 0 if the component doesnt exist yet,
        # and length one if it already does.
        existing_components = resp.json()
        if len(existing_components) > 0:
            logger.info("Found component, continuing")
            # Skip... this component already exists
            continue

        # Try to POST the component
        component_data = json.dumps(data['component'])
        logger.info("Creating component since did not find")
        resp = requests.post(component_url, data=component_data)
        # If the component fails to be created for any reason, skip it
        if resp.status_code != 201 and resp.status_code != 200:
            # 422 is the expected non-201 case
            # it just means that the component already exists,
            # so we should just skip over all the junk to create
            if resp.status_code != 422:
                errors[data['component']['name']] = resp.json()['errors']
            logger.error("Could not create component {0}".format(resp.json()))
            continue
        logger.info("Successfully created component. Received {0}".format(resp.json()))
        # Grab the component ID for later
        component_id = resp.json()['id']

        # Create the queries for the component; take note of their IDs
        # as they are created
        query_ids = []
        logger.info("Now trying to create queries for component")
        for query in data['queries']:
            query_data = json.dumps(query)
            logger.info("Creating {0}".format(query))
            resp = requests.post(query_url, data=query_data)
            # Just bail on this component if there's an error
            if resp.status_code != 201 and resp.status_code != 200:
                errors[data['component']['name']] = resp.json()['errors']
                logger.error("Failed to create query, {0}".format(resp.json()['errors']))
                continue
            query_ids.append(resp.json()['id'])
            logger.info("Created query successfully")
        # Create the metric
        # First add the query_ids to the metric data
        data['metric']['queries'] = query_ids
        metric_data = json.dumps(data['metric'])
        logger.info("Creating metric {0}".format(data['metric']))
        resp = requests.post(metric_url, data=metric_data)
        if resp.status_code != 201 and resp.status_code != 200:
            # Bail on failure
            errors[data['component']['name']] = resp.json()['errors']
            logger.error("failed to create metric because of {0}".format(resp.json()['errors']))
            continue
        logger.info("Created metric")
        # Take note of the metric's ID
        metric_id = resp.json()['id']

        # Do it all again, but with Latency
        latency_query_ids = []
        logger.info("Now creating Latency Queries")
        for query in data['latency_queries']:
            logger.info("creating {0}".format(query))
            query_data = json.dumps(query)
            resp = requests.post(query_url, data=query_data)
            # Just bail on this component if there's an error
            if resp.status_code != 201 and resp.status_code != 200:
                errors[data['component']['name']] = resp.json()['errors']
                logger.error("failed to create query because of {0}".format(resp.json()['errors']))
                continue
            latency_query_ids.append(resp.json()['id'])
            logger.info("successfully created query")

        # Create the metric
        # First add the query_ids to the metric data
        logger.info("Now creating latency metric {0}".format(data['latency_metric']))
        data['latency_metric']['queries'] = latency_query_ids
        latency_metric_data = json.dumps(data['latency_metric'])
        resp = requests.post(metric_url, data=latency_metric_data)
        if resp.status_code != 201 and resp.status_code != 200:
            # Bail on failure
            errors[data['component']['name']] = resp.json()['errors']
            logger.error("failed to create latency metric because of {0}".format(resp.json()['errors']))
            continue
        logger.info("created latency metric")
        # Take note of the metric's ID
        latency_metric_id = resp.json()['id']

        update_map = {'id': component_id, 'metric_ids': [metric_id, latency_metric_id]}
        # Try to GET the account
        alias = component.get('awsaccountname')
        logger.info("Trying to retrieve AWS Account")
        if alias is not None:
            account_lookup_url = account_url + "?alias=" + alias
            resp = requests.get(account_lookup_url,
                    headers={'Connection':'close'})
            if resp.status_code != 200:
                errors[data['component']['name']] = resp.text
                logger.error("Failed to retrieve aws account because of {0}".format(resp.text))
                continue
            accounts = resp.json()
            logger.info("retrieved account")
            if len(accounts) == 1:
            	update_map['account']=accounts[0]['id']

        # Associate both metrics (availability and latency) with the component
        component_update_data = json.dumps(update_map)
        component_update_url = component_url + "/" + str(component_id)
        logger.info("Trying to associate component with metrics and accounts")
        resp = requests.put(component_update_url, data=component_update_data)
        if resp.status_code != 200 and resp.status_code != 200:
            errors[data['component']['name']] = resp.json()['errors']
            logger.error("Failed to associate component {0}".format(resp.json()['errors']))
            continue
        logger.info("Completed associating component")

        # Return our errors map. If it isn't empty,
        # it will be logged in the lambda's output
    return errors

bs_request_count_template = "summarize(offset(transformNull(sumSeries(cloudwatch.{account}.{region}.beanstalk.{name}.*.RequestCount, cloudwatch.zero), 0), 0.01), '300s', 'sum', true)"
bs_error_count_template = "summarize(transformNull(sumSeries(cloudwatch.{account}.{region}.beanstalk.{name}.*.HTTPCode_Backend_5XX, cloudwatch.{account}.{region}.beanstalk.{name}.*.HTTPCode_ELB_5XX, cloudwatch.zero), 0), '300s', 'sum', true)"
bs_latency_template = "maxSeries(cloudwatch.{account}.{region}.beanstalk.{name}.*.Latency_p{p})"
bs_description_template = "Availability calculation for Beanstalk application [app={name}, account={account}, account_id={id}, region={region}]"
bs_name_template = "beanstalk:{account}:{name}:{region}"

elb_request_count_template = "summarize(offset(transformNull(sumSeries(cloudwatch.{account}.{region}.elb.{name}.RequestCount, cloudwatch.zero), 0), 0.01), '300s', 'sum', true)"
elb_error_count_template = "summarize(transformNull(sumSeries(cloudwatch.{account}.{region}.elb.{name}.HTTPCode_Backend_5XX, cloudwatch.{account}.{region}.elb.{name}.HTTPCode_ELB_5XX, cloudwatch.zero), 0), '300s', 'sum', true)"
elb_latency_template = "cloudwatch.{account}.{region}.elb.{name}.Latency_p{p}"
elb_description_template = "Availability calculation for ELB [elb={name}, account={account}, account_id={id}, region={region}]"
elb_name_template = "elb:{account}:{name}:{region}"

alb_request_count_template = "summarize(offset(transformNull(sumSeries(cloudwatch.{account}.{region}.alb.{name}.RequestCount, cloudwatch.zero), 0), 0.01), '300s', 'sum', true)"
alb_error_count_template = "summarize(transformNull(sumSeries(cloudwatch.{account}.{region}.alb.{name}.HTTPCode_Target_5XX_Count, cloudwatch.{account}.{region}.alb.{name}.HTTPCode_ELB_5XX_Count, cloudwatch.zero), 0), '300s', 'sum', true)"
alb_latency_template = "cloudwatch.{account}.{region}.alb.{name}.TargetResponseTime_p{p}"
alb_description_template = "Availability calculation for ALB [alb={name}, account={account}, account_id={id}, region={region}]"
alb_name_template = "alb:{account}:{name}:{region}"

vc_request_count_template = "summarize(offset(transformNull(sumSeries(stats.counters.visage.production.[a-z][a-z]-*-[0-9].all.hystrix.{name}.*.attempts.sum, cloudwatch.zero), 0), 0.01), '300s', 'sum', true)"
vc_error_count_template = "summarize(transformNull(sumSeries(stats.counters.visage.production.[a-z][a-z]-*-[0-9].all.hystrix.{name}.*.errors.sum, cloudwatch.zero), 0), '300s', 'sum', true)"
vc_latency_template = "maxSeries(scale(stats.timers.visage.production.[a-z][a-z]-*-[0-9].all.hystrix.{name}.*.totalDuration.upper_{p}, 0.001))"
vc_description_template = "Availability calculation for Visage Client [client={name}]"
vc_name_template = "visage:{name}"

# Returns a map of objects to create in the service catalog
# Will contain:
# 'component': map for the component
# 'metric': map for a metric to associate with the component
# 'queries': a list of maps for multiple queries to associate with the metric
def format_catalog_json(component):
    if component['source_table'] == 'aws_elb':
        comp_type = "elb"
        comp_name = elb_name_template.format(
            account=component['awsaccountname'],
            name=component['elbname'],
            region=component['awsregion']
        )
        description = elb_description_template.format(
            name=component['elbname'],
            account=component['awsaccountname'],
            id=component['awsaccountid'],
            region=component['awsregion']
        )
        request_count_query = elb_request_count_template.format(
            account=component['awsaccountname'],
            region=component['awsregion'],
            name=component['elbname']
        )
        error_count_query = elb_error_count_template.format(
            account=component['awsaccountname'],
            region=component['awsregion'],
            name=component['elbname']
        )
        latency_queries = {}
        for percentile in latency_percentiles:
            latency_queries[percentile] = elb_latency_template.format(
                account=component['awsaccountname'],
                region=component['awsregion'],
                name=component['elbname'],
                p=percentile
            )
    elif component['source_table'] == 'aws_bs':
        comp_type = "beanstalk"
        comp_name = bs_name_template.format(
            account=component['awsaccountname'],
            name=component['applicationname'],
            region=component['awsregion']
        )
        description = bs_description_template.format(
            name=component['applicationname'],
            account=component['awsaccountname'],
            id=component['awsaccountid'],
            region=component['awsregion']
        )
        request_count_query = bs_request_count_template.format(
            account=component['awsaccountname'],
            region=component['awsregion'],
            name=component['applicationname']
        )
        error_count_query = bs_error_count_template.format(
            account=component['awsaccountname'],
            region=component['awsregion'],
            name=component['applicationname']
        )
        latency_queries = {}
        for percentile in latency_percentiles:
            latency_queries[percentile] = bs_latency_template.format(
                account=component['awsaccountname'],
                region=component['awsregion'],
                name=component['applicationname'],
                p=percentile
            )
    elif component['source_table'] == 'aws_alb':
        comp_type = "alb"
        comp_name = alb_name_template.format(
            account=component['awsaccountname'],
            name=component['albname'],
            region=component['awsregion']
        )
        description = alb_description_template.format(
            name=component['albname'],
            account=component['awsaccountname'],
            id=component['awsaccountid'],
            region=component['awsregion']
        )
        request_count_query = alb_request_count_template.format(
            account=component['awsaccountname'],
            region=component['awsregion'],
            name=component['albname']
        )
        error_count_query = alb_error_count_template.format(
            account=component['awsaccountname'],
            region=component['awsregion'],
            name=component['albname']
        )
        latency_queries = {}
        for percentile in latency_percentiles:
            latency_queries[percentile] = alb_latency_template.format(
                account=component['awsaccountname'],
                region=component['awsregion'],
                name=component['albname'],
                p=percentile
            )
    elif component['source_table'] == 'visage':
        comp_type = 'visageClient'
        comp_name = vc_name_template.format(name=component['name'])
        description = vc_description_template.format(name=component['name'])
        request_count_query = vc_request_count_template.format(name=component['name'])
        error_count_query = vc_error_count_template.format(name=component['name'])
        latency_queries = {}
        for percentile in latency_percentiles:
            latency_queries[percentile] = vc_latency_template.format(
                name=component['name'],
                p=percentile
            )


    component_data = {
        'label': comp_name,
        'name': comp_name,
        'type': comp_type,
    }
    metric_data = {
        'label': comp_name,
        'name': comp_name,
        'description': description,
        'rollup': False,
        'component_rollup': True,
        'autogenerated': True,
        'threshold': 5,
        'calculation_type': 'error_rate'
    }
    latency_metric_data = {
        'label': comp_name + ":latency",
        'name': comp_name + ":latency",
        'description': description,
        'rollup': False,
        'component_rollup': True,
        'autogenerated': True,
        'threshold': 5,
        'calculation_type': 'latency',
        'latency_query': 'percentile_90'
    }

    query_data = [
        {
            'type': 'request_count',
            'query': request_count_query,
            'aggregate_type': 'sum'
        },
        {
            'type': 'error_count',
            'query': error_count_query,
            'aggregate_type': 'sum'
        }
    ]
    latency_query_data = []
    for percentile in latency_percentiles:
        query_type = "percentile_{p}".format(p=percentile)
        latency_query_data.append({
            'type': query_type,
            'query': latency_queries[percentile],
            'aggregate_type': 'max'
        })

    return {
        'component': component_data,
        'metric': metric_data,
        'latency_metric': latency_metric_data,
        'queries': query_data,
        'latency_queries': latency_query_data
    }


def get_s2s_service_name():
    service_name = None
    try:
        env = os.environ["environment"]
        if env == "production":
            service_name = "inv-collector-prod"
        elif env == "staging":
            service_name = "inv-collector-dev"
        else:
            raise NameError("no such environment '{env}'".format(env=env))
    except KeyError:
        service_name = "inv-collector-dev"
    return service_name


s2s_session = requests.Session()
s2s_session.auth = S2SAuth(get_s2s_service_name())


# Creates accounts based on aliases missing from the catalog DB
# ...only adds an account if there is no account with the same AWS Account ID
def create_missing_accounts(accounts):
    logger.info("Creating missing accounts")
    url_base = catalog.get_catalog_gql_endpoint()
    all_query = "query GetAccounts { accounts { aws_account_id alias } }"
    existing_account_data = s2s_session.post(url_base, data=json.dumps({"query": all_query}))
    if existing_account_data.status_code != 200:
        logger.error("Did not successfully fetch catalog existing account data"
                "{0}".format(existing_account_data.content))
        return {"account_fetch": existing_account_data.content}
    logger.info("Successfully retrieved account data from service catalog")
    existing_accounts = json.loads(existing_account_data.content)['data']['accounts']
    # Turn the accounts into a map from ID->Alias
    accts = {}
    for a in existing_accounts:
        accts[a['aws_account_id']] = a['alias']
    existing_accounts = accts

    # Iterate over accounts we want to ensure are created
    for account in accounts:
        # If there is an actual alias associated with the account
        # ... this should always be the case because we are only adding
        # accounts to the redshift table if they end in `-aws` or `-prod`.
        if account[1] is not None and existing_accounts.get(account[0], None) is None:
            logger.error("Account {0} was not found in DB, attempting"
                    "creation".format(account[1]))
            # Create the account
            gql_create = "mutation CreateAccount($account:AccountInput!) { createAccount(account:$account) { id } }"
            gql_vars = {
                "account": {
                    "alias": account[1],
                    "aws_account_id": account[0]
                }
            }
            request_data = json.dumps({"query": gql_create, "variables": gql_vars})
            resp = s2s_session.post(url_base, data=request_data)
            if resp.status_code != 200:
                logger.error("Account creation unsucessful: reason is"
                        "{0}".format(resp.content))
                return {"account_create": resp.content}
            logger.info("Account creation successful")
    # Return no errors
    return {}


# For local development
if __name__ == "__main__":
    t = datetime.datetime.utcnow().isoformat()
    lambda_handler({'time': t}, None)
