import argparse
import datetime
import json
import logging
import os
import socket

from pyemf.constants import *

class EMFWriter(object):
    '''
    Base EMF Writer, emits to stdout.
    '''
    def __init__(self, args, **kwargs):
        self.namespace = args.namespace
        self.now = args.now
        self.logGroupName = args.log_group
        self.logStreamName = args.log_stream

    def send_message(self, **kwargs):
        metrics = kwargs.get('metrics')
        dimensions = kwargs.get('dimensions', [])
        ctx_properties = kwargs.get('properties', {})
        ctx_metrics = list()
        ctx_dimensions = list()

        for metric in metrics:
            m = metric.copy()
            ctx_properties.update({m['Name']: m.pop('Value')})
            ctx_metrics.append(m)

        for rollup in dimensions:
            for d in rollup:
                ctx_properties.update({d['Name']: d['Value']})
            ctx_dimensions.append([i['Name'] for i in rollup])

        ctx = MetricContext(
            logGroupName=self.logGroupName,
            logStreamName=self.logStreamName,
            namespace=self.namespace,
            now=self.now,
            dimensions=ctx_dimensions,
            metrics=ctx_metrics,
            properties=ctx_properties,
            )

        self.emit_message(ctx)

    def emit_message(self, ctx):
        print(json.dumps(ctx, indent=2))


class UDPWriter(EMFWriter):
    '''
    Subclass of EMFWriter that writes to a Cloudwatch Agent listening on udp.
    '''
    def __init__(self, args, **kwargs):
        super(UDPWriter, self).__init__(args, **kwargs)
        self.host = kwargs.get('host', EMF_HOST)
        self.port = kwargs.get('port', EMF_PORT)

    def emit_message(self, ctx):
        message = json.dumps(ctx) + "\n"
        sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
        sock.sendto(message.encode('UTF-8'), (self.host, self.port))
        sock.close()


class CWLWriter(EMFWriter):
    '''
    Subclass of EMFWriter that writes to Cloudwatch Logs.
    NOTE: Requires botocore, boto3
    '''
    def __init__(self, args, **kwargs):
        super(CWLWriter, self).__init__(args, **kwargs)
        import boto3
        from botocore.config import Config

        config = Config(
            region_name = args.region,
            retries = {
                'mode': BOTO_RETRY_MODE,
                'max_attempts': BOTO_MAX_ATTEMPTS,
            }
        )

        if args.log_role:
            logging.debug("creating cloudwatch logs client with role {}".format(args.log_role))
            c = boto3.client('sts', config=config)
            r = c.assume_role(RoleArn=args.log_role, RoleSessionName='cwlogs')
            session = boto3.Session(
                aws_access_key_id=r['Credentials']['AccessKeyId'],
                aws_secret_access_key=r['Credentials']['SecretAccessKey'],
                aws_session_token=r['Credentials']['SessionToken']
                )
            self.client = session.client('logs', config=config)
        else:
            logging.debug("creating cloudwatch logs client with default role")
            self.client = boto3.client('logs', config=config)

        # add handler to inject EMF header, add to client
        def add_emf_header(request, **kwargs):
            request.headers.add_header('x-amzn-logs-format', 'json/emf')
        self.client.meta.events.register_first('before-sign.cloudwatch-logs.PutLogEvents', add_emf_header)

        # create log group if it doesnt exist
        resp = self.client.describe_log_groups(logGroupNamePrefix=self.logGroupName)
        if self.logGroupName not in [ lg['logGroupName'] for lg in resp['logGroups'] ]:
            self.client.create_log_group(logGroupName=self.logGroupName)

        # create log stream if it doesnt exist
        resp = self.client.describe_log_streams(logGroupName=self.logGroupName, logStreamNamePrefix=self.logStreamName)
        if self.logStreamName not in [ lg['logStreamName'] for lg in resp['logStreams'] ]:
            self.client.create_log_stream(logGroupName=self.logGroupName, logStreamName=self.logStreamName)
            resp = self.client.describe_log_streams(logGroupName=self.logGroupName, logStreamNamePrefix=self.logStreamName)

        # get initial sequence token for log stream
        for ls in resp['logStreams']:
            if ls['logStreamName'] == self.logStreamName: self.sequencetoken = ls.get('uploadSequenceToken', None)


    def emit_message(self, ctx):
        log = json.dumps(ctx) + "\n"
        logEvents = list()
        ts = int(datetime.datetime.timestamp(datetime.datetime.now(datetime.timezone.utc)) * 1000)
        logEvents.append({ 'timestamp': ts, 'message': log })

        if self.sequencetoken:
            resp = self.client.put_log_events(
                    logGroupName=self.logGroupName,
                    logStreamName=self.logStreamName,
                    logEvents=logEvents,
                    sequenceToken=self.sequencetoken)
        else:
            resp = self.client.put_log_events(
                    logGroupName=self.logGroupName,
                    logStreamName=self.logStreamName,
                    logEvents=logEvents)

        self.sequencetoken = resp['nextSequenceToken']


class Metric(dict):
    def __init__(self, name, value, unit=None):
        self['Name'] = name
        self['Value'] = value
        if unit: self['Unit'] = unit


class Dimension(dict):
    def __init__(self, name, value):
        self['Name'] = name
        self['Value'] = value


class MetricContext(dict):
    def __init__(self, **kwargs):
        properties = kwargs.get('properties', {})
        ts = kwargs.get('now', None)
        self.update(
            {
              '_aws': {
                'Timestamp': int(ts * 1000),
                'LogGroupName': kwargs.get('logGroupName'),
                'LogStreamName': kwargs.get('logStreamName'),
                'CloudWatchMetrics': [
                  {
                    'Namespace': kwargs.get('namespace'),
                    'Dimensions': kwargs.get('dimensions', []),
                    'Metrics': kwargs.get('metrics', []),
                  }
                ]
              },
            'collector_version': VERSION,
            }
        )
        self.update(properties)



def get_argparser(**kwargs):
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument('--debug', action='store_true', default=False,
                        help='turn on debug logging')
    parser.add_argument('-m', '--metricname', type=str,
                        help='The name of the metric to push')
    parser.add_argument('-v', '--value', type=int, default=0,
                        help='The value of the metric')
    parser.add_argument('--namespace', type=str, default=kwargs.get('namespace'),
                        help='The cloudwatch namespace for this metric')
    parser.add_argument('--proto', type=str, default='udp', choices=['udp', 'logs', 'stdout'],
                        help='Send using: udp: local agent listener, logs: cw logs api, stdout: write to stdout')
    parser.add_argument('--log_group', type=str, default=kwargs.get('log_group'),
                        help='Log group name to use')
    parser.add_argument('--log_stream', type=str, default='{}-{}'.format(socket.gethostname(), os.getpid()),
                        help='Log stream name to use')
    parser.add_argument('--log_role', type=str, default=kwargs.get('logs_role', LOGS_ROLE),
                        help='ARN of iam role to use for cw logs api')
    parser.add_argument('--region', type=str, default=os.environ.get('AWS_REGION', AWS_REGION),
                        help='Region to make cloudwatch logs api calls to')
    parser.add_argument('-n', '--now', type=float, default=datetime.datetime.timestamp(datetime.datetime.now()),
                        help='The timestamp when this metric was generated')
    return parser


def get_writer(args):
    if args.proto == 'stdout':
        return EMFWriter(args)
    elif args.proto == 'udp':
        return UDPWriter(args)
    elif args.proto == 'logs':
        return CWLWriter(args)
    else: raise NotImplementedError("{}: proto not implemented".format(args.proto))

