import logging
import dns.resolver
import boto3
import json
import sys
import os

# Logging (replace default)
logger = logging.getLogger()
if logger.handlers:
    for handler in logger.handlers:
        logger.removeHandler(handler)
logging.basicConfig(stream=sys.stdout, level=logging.INFO)


class DnsValidator(object):
    """Manage handling of DNS"""

    def __init__(self,
                 retry_queue,
                 sanitized_topic,
                 region,
                 ):
        super(DnsValidator, self).__init__()

        self.retry_queue = retry_queue
        self.sanitized_topic = sanitized_topic
        self.region = region

        # Instantiate clients
        self.sqs = boto3.client('sqs', region_name=region)
        self.sns = boto3.client('sns', region_name=region)

    # Determine what stage of our "state machine" we're in; return proper signals
    def sm_determine(self, hostname, message, source):
        # If this message is being invoked from the retry queue, just fail into retry DLQ
        if 'delay' in source:
            logger.info(" Retry for %s failed. Putting into DLQ", hostname)
            return False
        else:
            # send message to retry queue
            try:
                self.sqs.send_message(QueueUrl=self.retry_queue, MessageBody=message)
            except Exception as e:  # I don't know what error codes there are :D
                logger.error(" Error sending message for %s to delay queue: %s", hostname, e)
                return False
            # Identify message as "successfully" sent to the next stage of the workflow (retry queue)
            return True

    # Return True if update was successful. Return False otherwise
    # NOTE: Any errors with this function will result in the lambda exiting without going through the retry workflow
    def publish_topic(self, contents):
        try:
            self.sns.publish(TopicArn=self.sanitized_topic, Message=json.dumps(contents))
            logger.info(" Successfully publish machine: %s to topic: %s",
                        contents.get("hostname"), self.sanitized_topic)
        except Exception as e:
            logger.error(" Failed to publish %s to topic: %s, Exception: %s",
                         contents.get("hostname"), self.sanitized_topic, e)
            return False
        return True

    # Decide how a message is handled depending on its contents
    # Return True if message was processed succesfully. Return False otherwise
    def handle_message(self, record, source):
        # load in the message
        try:
            raw_message = json.loads(record)
            # Dumb ternary because the direct SQS SendMessage format is simpler than the SNS -> SQS message format
            # which wraps the message json into the 'Message' key
            contents = json.loads(raw_message.get('Message')) if 'Message' in raw_message else raw_message
        except json.JSONDecodeError as e:
            logger.error(" Failed to load message json: %s", e)
            return False

        # Ensure twitch_role is given
        if 'twitch_role' not in contents:
            logger.error(" Malformed message did not include 'twitch_role' key")
            return False

        # On bootup, dhclient may send a PREINIT message with no IPs. If there's no IP given, let it pass.
        if contents.get('new_ip_address') == '':
            logger.error(" Received an SNS message with an empty IP from %s", contents.get('hostname'))
            return True

        # Find the value for the A record
        try:
            dns_records = dns.resolver.query(contents.get('hostname'))
        except dns.resolver.NXDOMAIN as e:
            logger.error(" Failed to lookup IP for %s: %s", contents.get('hostname'), e)
            return self.sm_determine(contents.get('hostname'), raw_message.get('Message', '{}'), source)

        # If a result was found, try to match it with the IP provided by the SNS message
        for r in dns_records:
            ip = r.to_text()
            if ip != contents.get('new_ip_address'):
                logger.error(" IP in DNS for %s (%s) does not match message from SNS (%s)",
                             contents.get('hostname'), ip, contents.get('new_ip_address'))
                return self.sm_determine(contents.get('hostname'), raw_message.get('Message', '{}'), source)

        # Found the DNS, IP matches, so publish to a sanitized topic
        return self.publish_topic(contents)


def lambda_handler(event, context):
    # Through event source mappings, multiple messages may be processed

    # Lambda Envrionment
    retry_queue = os.environ.get('RETRY_QUEUE')
    sanitized_topic = os.environ.get('SANITIZED_TOPIC')
    region = os.environ.get('AWS_REGION')

    dns_validator = DnsValidator(retry_queue=retry_queue,
                                 sanitized_topic=sanitized_topic,
                                 region=region)

    for record in event.get('Records'):
        if not dns_validator.handle_message(record.get('body'), record.get('eventSourceARN')):
            # Even if 1 message in the batch fails, fail completely
            # 'All messages in a failed batch return to the queue, so your function code must be able to process the same message multiple times without side effects' # noqa: E501
            # https://docs.aws.amazon.com/lambda/latest/dg/with-sqs.html
            sys.exit(1)

    logger.info(" Processed %s messages", len(event.get('Records')))
