# import boto3
import os
import pytest
import logging
import json

import dns.resolver

from .. import dns_validator

AWS_REGION = "us-west-2"

#################
# Payloads
#################

HOST = "video-edge-c2a9ac"
POP = "rio01"
FQDN = "{}.{}.justin.tv".format(HOST, POP)
TWITCH_ROLE = "video-edge"
IP = "52.223.208.18"


# Sample payload from DHCP hook published to SNS
def dhcp_payload(host=HOST, twitch_role=TWITCH_ROLE, new_ip=IP, pop=POP):
    return json.dumps({
        "source": "dhclient",
        "reason": "RENEW",
        "hostname": "{}.{}.justin.tv".format(host, pop),
        "timestamp": 1625435276,
        "pop": pop,
        "twitch_role": twitch_role,
        "twitch_environment": "production",
        "old_ip_address": "1.2.3.4",  # can be anything
        "new_ip_address": new_ip,
    })


# Sample event from from dhcp hook SNS -> SQS -> Lambda
def render_event(payload=dhcp_payload()):
    return {
        "Records": [{
            "body": json.dumps({
                "Message": payload,
            }),
            "eventSourceARN": "arn:aws:sqs:us-west-2:123456789012:MyQueue",
        }]
    }


#########################
# Basic components setup
#########################


@pytest.fixture
def retry_queue(sqs_client):
    retry_queue = sqs_client.create_queue(QueueName="RetryQueue")
    return retry_queue.get("QueueUrl")


@pytest.fixture
def sanitized_queue(sqs_client):
    retry_queue = sqs_client.create_queue(QueueName="SanitizedQueue")
    return retry_queue.get("QueueUrl")


@pytest.fixture
def sanitized_topic(sns_client, sqs_client, sanitized_queue):
    topic = sns_client.create_topic(Name="SanitizedTopic").get("TopicArn")

    # Setup SNS -> SQS for inspection
    sqs = sqs_client.get_queue_attributes(QueueUrl=sanitized_queue)["Attributes"]["QueueArn"]
    sns_client.subscribe(TopicArn=topic, Protocol="sqs", Endpoint=sqs)

    return topic


#############################################################################
# Mock DNS response since valid response is only available on Twitch network
#############################################################################

class MockAnswer():
    @staticmethod
    def to_text():
        return IP


# Mock machine DNS exist
def dns_found(*args, **kwargs):
    yield MockAnswer()


def nx_domain(*args, **kwargs):
    raise dns.resolver.NXDOMAIN

#########################
# Tests
#########################


@pytest.fixture(autouse=True)
def setup(retry_queue, sanitized_topic, caplog):
    """
    Setups environment variables for use for all tests
    """
    os.environ["AWS_REGION"] = AWS_REGION
    os.environ["RETRY_QUEUE"] = retry_queue
    os.environ["SANITIZED_TOPIC"] = sanitized_topic

    caplog.set_level(logging.INFO)


def test_handler_ideal(caplog, monkeypatch, sqs_client, sanitized_queue):
    """
    Simulate a trigger from SQS with a valid machine and DNS in ideal conditions
    """

    # Mock found valid A record for machine
    monkeypatch.setattr(dns.resolver, "query", dns_found)

    status = dns_validator.lambda_handler(event=render_event(), context={})
    assert status is None

    # Assert the same payload was sent to the sanitized queue
    messages = sqs_client.receive_message(QueueUrl=sanitized_queue)
    message = json.loads(messages["Messages"][0]["Body"]).get("Message")

    assert message == dhcp_payload()


def test_handler_delay_bad_ip(monkeypatch, caplog):
    """
    IP mismatch behavior
    """

    # Mock found valid A record for machine
    monkeypatch.setattr(dns.resolver, "query", dns_found)

    bad_ip = "1.2.3.4"
    status = dns_validator.lambda_handler(event=render_event(payload=dhcp_payload(new_ip=bad_ip)), context={})

    assert status is None
    assert " IP in DNS for {} ({}) does not match message from SNS ({})".format(FQDN, IP, bad_ip) in caplog.messages  # noqa: E501


def test_handler_machine_slow(sqs_client, retry_queue, monkeypatch, caplog):
    """
    Simulate a trigger from SQS with a machine DNS took time to exist
    """

    # Mock no machine DNS not found
    monkeypatch.setattr(dns.resolver, "query", nx_domain)

    status = dns_validator.lambda_handler(event=render_event(), context={})
    assert status is None

    messages = sqs_client.receive_message(QueueUrl=retry_queue)
    assert len(messages.get("Messages")) == 1

    # Retry from retry queue
    retry_event = {
        "Records": [{
            "body": messages.get("Messages")[0]["Body"],
            "eventSourceARN": "arn:aws:sqs:us-east-1:123456789012:crud-delay-queue"
        }]
    }

    # Fail to DLQ if FQDN still cannot be resolved
    with pytest.raises(SystemExit) as pytest_wrapped_e:
        dns_validator.lambda_handler(event=retry_event, context={})

    assert pytest_wrapped_e.type == SystemExit
    assert pytest_wrapped_e.value.code == 1

    # Mock DNS exist now after retries
    monkeypatch.setattr(dns.resolver, "query", dns_found)
    status = dns_validator.lambda_handler(event=retry_event, context={})
    assert status is None

    messages = sqs_client.receive_message(QueueUrl=retry_queue)
    # No messages since all have been processed
    assert messages.get("Messages") is None


class TestDnsValidator(object):
    def setup_class(self):
        self.dns_validator = dns_validator.DnsValidator(retry_queue=retry_queue,
                                                        sanitized_topic=sanitized_topic,
                                                        region=AWS_REGION)

    def test_handle_message_no_ip(self, caplog):
        record = dhcp_payload(new_ip="")
        assert self.dns_validator.handle_message(record=record, source="test") is True
        assert " Received an SNS message with an empty IP from video-edge-c2a9ac.rio01.justin.tv" in caplog.messages

    def test_handle_message_missing_machine_class(self, caplog):
        record = json.dumps({})
        assert self.dns_validator.handle_message(record=record, source="test") is False
        assert " Malformed message did not include 'twitch_role' key" in caplog.messages

    def test_sm_determine_delay(self):
        """
        If this event comes from the delay queue
        """
        result = self.dns_validator.sm_determine(hostname="test.foo",
                                                 message="test message",
                                                 source="arn::test::delay")
        assert result is False
