from email.errors import HeaderParseError
from email.header import Header
from email.headerregistry import Address, parser
from email.utils import formataddr
from django.utils.encoding import force_str
from django.core.mail.backends import smtp as smtp_lib
from django.core.mail import message as message_lib
import logging

logger = logging.getLogger('wiki.tasks')
"""
BACKPORT OF DJANGO 2-2 FIX for https://code.djangoproject.com/ticket/31784
"""


def punycode(domain):
    """Return the Punycode of the given domain if it's non-ASCII."""
    return domain.encode('idna').decode('ascii')


def sanitize_address_backport(addr, encoding):
    """
    Format a pair of (name, address) or an email address string.
    """
    logger.info(f'Patched fn is used for {addr}')

    address = None
    if not isinstance(addr, tuple):
        addr = force_str(addr)
        try:
            token, rest = parser.get_mailbox(addr)
        except (HeaderParseError, ValueError, IndexError):
            raise ValueError('Invalid address "%s"' % addr)
        else:
            if rest:
                # The entire email address must be parsed.
                raise ValueError('Invalid address; only %s could be parsed from "%s"' % (token, addr))
            nm = token.display_name or ''
            localpart = token.local_part
            domain = token.domain or ''
    else:
        nm, address = addr
        localpart, domain = address.rsplit('@', 1)

    address_parts = nm + localpart + domain
    if '\n' in address_parts or '\r' in address_parts:
        raise ValueError('Invalid address; address parts cannot contain newlines.')

    # Avoid UTF-8 encode, if it's possible.
    try:
        nm.encode('ascii')
        nm = Header(nm).encode()
    except UnicodeEncodeError:
        nm = Header(nm, encoding).encode()
    try:
        localpart.encode('ascii')
    except UnicodeEncodeError:
        localpart = Header(localpart, encoding).encode()
    domain = punycode(domain)

    parsed_address = Address(username=localpart, domain=domain)
    return formataddr((nm, parsed_address.addr_spec))


class BackportMonkeyPatch:
    old_fn_smtp = None
    old_fn_mess = None

    def __enter__(self):
        logger.info('Applying monkey patch')
        self.old_fn_smtp = smtp_lib.sanitize_address
        self.old_fn_mess = message_lib.sanitize_address

        smtp_lib.sanitize_address = sanitize_address_backport
        message_lib.sanitize_address = sanitize_address_backport

    def __exit__(self, exc_type, exc_val, exc_tb):
        if self.old_fn_smtp:
            logger.info('Releasing monkey patch')
            smtp_lib.sanitize_address = self.old_fn_smtp
        if self.old_fn_mess:
            message_lib.sanitize_address = self.old_fn_mess
