# encoding: UTF-8

import dns.name
import dns.rdataclass
import dns.rdatatype
import dns.rdtypes
import dns.rdtypes.ANY.CAA
import dns.rdtypes.ANY.CNAME
import dns.rdtypes.ANY.MX
import dns.rdtypes.ANY.NS
import dns.rdtypes.ANY.TXT
import dns.rdtypes.IN.A
import dns.rdtypes.IN.AAAA
import dns.rdtypes.IN.SRV
import marshmallow as ma

from appcore.data.schema import EnumField
from appcore.data.schema import ModelSchema
from dns_hosting.models.domains import Domain
from dns_hosting.models.domains import Record
from dns_hosting.models.domains import RecordType


class NameField(ma.fields.String):
    def _deserialize(self, value, attr, data):
        return super(NameField, self)._deserialize(value, attr, data).lower()


class DomainSchema(ModelSchema):
    MSG_INVALID_DOMAIN_NAME = 'Domain name must end with dot.'

    class Meta:
        model = Domain

    id = ma.fields.Integer(required=True)
    name = NameField(required=True)
    serial = ma.fields.Integer(dump_only=True)
    pdd_sync_enabled = ma.fields.Boolean(dump_only=True)

    @ma.validates('name')
    def _validate_name(self, name):
        try:
            dns.name.from_text(name)
        except dns.exception.DNSException as e:
            raise ma.ValidationError(e.message)

        if not name.endswith('.'):
            raise ma.ValidationError(self.MSG_INVALID_DOMAIN_NAME)

    @ma.pre_load
    def _punycode_name(self, item):
        if 'name' in item:
            item['name'] = item['name'].encode('idna')
        return item

    @ma.post_dump
    def _un_punycode_name(self, item):
         item['name'] = item['name'].decode('idna')
         return item


class ASchema(ma.Schema):
    MSG_INVALID_ADDRESS = 'Invalid IPv4 address.'

    address = ma.fields.String(required=True)

    @ma.validates('address')
    def _validate_address(self, address):
        try:
            dns.ipv4.inet_aton(address)
        except (dns.exception.SyntaxError, UnicodeEncodeError):
            raise ma.ValidationError(self.MSG_INVALID_ADDRESS)

    @ma.post_load
    def _make_rdata(self, data):
        return dns.rdtypes.IN.A.A(
            dns.rdataclass.IN,
            dns.rdatatype.A,
            **data
        )


class AAAASchema(ma.Schema):
    MSG_INVALID_ADDRESS = 'Invalid IPv6 address.'

    address = ma.fields.String(required=True)

    @ma.validates('address')
    def _validate_address(self, address):
        try:
            dns.ipv6.inet_aton(address)
        except (dns.exception.SyntaxError, UnicodeEncodeError):
            raise ma.ValidationError(self.MSG_INVALID_ADDRESS)

    @ma.post_load
    def _make_rdata(self, data):
        return dns.rdtypes.IN.AAAA.AAAA(
            dns.rdataclass.IN,
            dns.rdatatype.AAAA,
            **data
        )


class CNAMESchema(ma.Schema):
    MSG_INVALID_TARGET = 'Invalid target.'

    target = NameField(required=True)

    @ma.validates('target')
    def _validate_target(self, target):
        try:
            dns.name.from_text(target)
        except (dns.exception.SyntaxError, ValueError):
            raise ma.ValidationError(self.MSG_INVALID_TARGET)

    @ma.post_load
    def _make_rdata(self, data):
        return dns.rdtypes.ANY.CNAME.CNAME(
            dns.rdataclass.IN,
            dns.rdatatype.CNAME,
            dns.name.from_text(data['target'])
        )


class MXSchema(ma.Schema):
    MSG_INVALID_PREFERENCE = 'Invalid preference.'
    MSG_INVALID_EXCHANGE = 'Invalid exchange.'

    preference = ma.fields.Integer(required=True)
    exchange = NameField(required=True)

    @ma.validates('preference')
    def _validate_preference(self, preference):
        if not (0 <= preference <= 65535):
            raise ma.ValidationError(self.MSG_INVALID_PREFERENCE)

    @ma.validates('exchange')
    def _validate_exchange(self, exchange):
        try:
            dns.name.from_text(exchange)
        except (dns.exception.SyntaxError, ValueError):
            raise ma.ValidationError(self.MSG_INVALID_EXCHANGE)

    @ma.post_load
    def _make_rdata(self, data):
        return dns.rdtypes.ANY.MX.MX(
            dns.rdataclass.IN,
            dns.rdatatype.MX,
            data['preference'],
            dns.name.from_text(data['exchange']),
        )


class TXTSchema(ma.Schema):
    MSG_EMPTY_STRINGS = 'At least one string must present.'

    strings = ma.fields.List(ma.fields.String(), required=True)

    @ma.validates('strings')
    def _validate_strings(self, value):
        if not value:
            raise ma.ValidationError(self.MSG_EMPTY_STRINGS)

    @ma.post_load
    def _make_rdata(self, data):
        return dns.rdtypes.ANY.TXT.TXT(
            dns.rdataclass.IN,
            dns.rdatatype.TXT,
            **data
        )


class SRVSchema(ma.Schema):
    MSG_INVALID_PRIORITY = 'Invalid target.'
    MSG_INVALID_WEIGHT = 'Invalid target.'
    MSG_INVALID_PORT = 'Invalid target.'
    MSG_INVALID_TARGET = 'Invalid target.'

    priority = ma.fields.Integer(required=True)
    weight = ma.fields.Integer(required=True)
    port = ma.fields.Integer(required=True)
    target = NameField(required=True)

    @ma.validates('priority')
    def _validate_priority(self, priority):
        if not (0 <= priority <= 65535):
            raise ma.ValidationError(self.MSG_INVALID_PRIORITY)

    @ma.validates('weight')
    def _validate_weight(self, weight):
        if not (0 <= weight <= 65535):
            raise ma.ValidationError(self.MSG_INVALID_WEIGHT)

    @ma.validates('port')
    def _validate_port(self, port):
        if not (0 <= port <= 65535):
            raise ma.ValidationError(self.MSG_INVALID_PORT)

    @ma.validates('target')
    def _validate_target(self, target):
        try:
            dns.name.from_text(target)
        except (dns.exception.SyntaxError, ValueError):
            raise ma.ValidationError(self.MSG_INVALID_TARGET)

    @ma.post_load
    def _make_rdata(self, data):
        target = dns.name.from_text(data.pop('target'))
        return dns.rdtypes.IN.SRV.SRV(
            dns.rdataclass.IN,
            dns.rdatatype.SRV,
            target=target,
            **data
        )


class NSSchema(ma.Schema):
    MSG_INVALID_TARGET = 'Invalid target.'

    target = NameField(required=True)

    @ma.validates('target')
    def _validate_target(self, target):
        try:
            dns.name.from_text(target)
        except (dns.exception.SyntaxError, ValueError):
            raise ma.ValidationError(self.MSG_INVALID_TARGET)

    @ma.post_load
    def _make_rdata(self, data):
        return dns.rdtypes.ANY.NS.NS(
            dns.rdataclass.IN,
            dns.rdatatype.NS,
            dns.name.from_text(data['target']),
        )


class CAASchema(ma.Schema):
    MSG_INVALID_FLAGS = 'Invalid flags.'
    MSG_INVALID_TAG = 'Invalid tag.'

    flags = ma.fields.Integer(required=True)
    tag = ma.fields.String(required=True)
    value = ma.fields.String(required=True)

    @ma.validates('flags')
    def _validate_flags(self, flags):
        if not (0 <= flags <= 255):
            raise ma.ValidationError(self.MSG_INVALID_FLAGS)

    @ma.validates('tag')
    def _validate_tag(self, tag):
        if len(tag) > 255:
            raise ma.ValidationError(self.MSG_INVALID_TAG)
        if not tag.isalnum():
            raise ma.ValidationError(self.MSG_INVALID_TAG)

    @ma.post_load
    def _make_rdata(self, data):
        return dns.rdtypes.ANY.CAA.CAA(
            dns.rdataclass.IN,
            dns.rdatatype.CAA,
            **data
        )


class RDataField(ma.fields.Field):
    _RDATA_SCHEMA_MAP = {
        dns.rdatatype.A: ASchema,
        dns.rdatatype.AAAA: AAAASchema,
        dns.rdatatype.MX: MXSchema,
        dns.rdatatype.CNAME: CNAMESchema,
        dns.rdatatype.TXT: TXTSchema,
        dns.rdatatype.SRV: SRVSchema,
        dns.rdatatype.NS: NSSchema,
        dns.rdatatype.CAA: CAASchema,
    }

    def _serialize(self, value, attr, obj):
        schema_cls = self._RDATA_SCHEMA_MAP[value.rdtype]
        data, errors = schema_cls().dump(value)
        if errors:
            raise ma.ValidationError(errors, data=data)
        return data

    def _deserialize(self, value, attr, data):
        type = data.get('type')
        if type is None:
            return None

        rdtype = dns.rdatatype.from_text(type)
        schema_cls = self._RDATA_SCHEMA_MAP[rdtype]
        data, errors = schema_cls().load(value)
        if errors:
            raise ma.ValidationError(errors, data=data)
        return data


class RecordSchema(ModelSchema):
    MSG_INVALID_NAME = 'Invalid name.'
    MSG_INVALID_TTL = 'Invalid ttl.'
    MSG_FORBIDDEN = 'Forbidden.'

    class Meta:
        model = Record

    id = ma.fields.Integer(required=True)
    name = NameField(required=True)
    type = EnumField(RecordType, required=True)
    rdata = RDataField(required=True)
    ttl = ma.fields.Integer(required=True)

    @ma.validates('name')
    def _validate_name(self, name):
        try:
            dns.name.from_text(name)
        except (dns.exception.SyntaxError, ValueError):
            raise ma.ValidationError(self.MSG_INVALID_NAME)

        if name.endswith('.'):
            raise ma.ValidationError(self.MSG_INVALID_NAME)

    @ma.validates('ttl')
    def _validate_ttl(self, ttl):
        if not (0 <= ttl <= 4294967295):
            raise ma.ValidationError(self.MSG_INVALID_TTL)

    @ma.validates_schema(skip_on_field_errors=True)
    def _validates_record_schema(self, record):
        if (
                record['name'] == '@' and
                record['type'] in {RecordType.NS, RecordType.CNAME}
        ):
            raise ma.ValidationError(self.MSG_FORBIDDEN)
