# encoding: UTF-8

import time
import uuid

import enum
import sqlalchemy as sa
import sqlalchemy.event as event
import sqlalchemy.orm as orm
import sqlalchemy.types as types
from dns import rdata
from dns import rdataclass
from dns import rdatatype
from sqlalchemy import inspect
from sqlalchemy import update
from sqlalchemy.orm import object_session
from sqlalchemy.orm import Session

from dns_hosting.utils.timeuuid import uuid_from_time


class UUID(types.TypeDecorator):
    impl = types.CHAR

    def process_bind_param(self, value, dialect):
        if value is None:
            return None
        else:
            return value.hex

    def process_result_value(self, value, dialect):
        if value is None:
            return None
        else:
            return uuid.UUID(hex=value)


def get_values(e):
    return e.__members__.values()


class Domain(object):
    __tablename__ = 'domains'

    id = sa.Column(sa.Integer, primary_key=True)
    name = sa.Column(sa.String)
    serial = sa.Column(sa.Integer)
    org_id = sa.Column(sa.Integer)
    revision = sa.Column(sa.BigInteger)
    is_technical = sa.Column(sa.Boolean)
    pdd_sync_enabled = sa.Column(sa.Boolean)
    pdd_domain_id = sa.Column(sa.Integer)

    records = orm.relation(
        lambda: Record,
        uselist=True,
        cascade='merge, expunge, delete, delete-orphan',
        back_populates='domain',
    )

    def __init__(
            self,
            id=None,
            serial=None,
            name=None,
            org_id=None,
            revision=None,
            is_technical=False,
            pdd_sync_enabled=True,
            pdd_domain_id=None,
    ):
        self.id = id
        self.serial = serial
        self.name = name
        self.org_id = org_id
        self.revision = revision
        self.is_technical = is_technical
        self.pdd_sync_enabled = pdd_sync_enabled
        self.pdd_domain_id = pdd_domain_id

    def update_serial(self):
        if inspect(self).persistent:
            session = object_session(self)  # type: Session
            stmt = update(Domain, values={Domain.serial: Domain.serial + 1}) \
                .where(Domain.id == self.id)
            session.execute(stmt)
            session.expire(self)


class RecordType(str, enum.Enum):
    A = 'A'
    AAAA = 'AAAA'
    CNAME = 'CNAME'
    MX = 'MX'
    TXT = 'TXT'
    SRV = 'SRV'
    NS = 'NS'
    CAA = 'CAA'


class Record(object):
    __tablename__ = 'records'

    id = sa.Column(sa.Integer, primary_key=True)
    domain_id = sa.Column(sa.ForeignKey(Domain.id))
    name = sa.Column(sa.String)
    type = sa.Column(sa.Enum(RecordType))
    content = sa.Column(sa.String)
    ttl = sa.Column(sa.Integer)

    @property
    def rdata(self):
        return rdata.from_text(
            rdataclass.IN,
            rdatatype.from_text(self.type),
            self.content,
        )

    @rdata.setter
    def rdata(self, rdata):
        self.content = rdata.to_text()

    domain = orm.relation(
        lambda: Domain,
        cascade='save-update, merge',
        back_populates='records',
    )

    def __init__(
            self,
            id=None,
            domain_id=None,
            name=None,
            type=None,
            content=None,
            ttl=None,
    ):
        self.id = id
        self.domain_id = domain_id
        self.name = name
        self.type = type
        self.content = content
        self.ttl = ttl


class Operation(str, enum.Enum):
    DOMAIN_ADD = 'domain-add'
    DOMAIN_DELETE = 'domain-delete'
    DOMAIN_CLEAR = 'domain-clear'
    RECORD_ADD = 'record-add'
    RECORD_DELETE = 'record-delete'
    RECORD_UPDATE_TTL = 'record-update-ttl'


class ChangeLog(object):
    __tablename__ = 'change_log'

    version = sa.Column(UUID, primary_key=True)
    origin = sa.Column(sa.String)
    serial = sa.Column(sa.Integer)
    operation = sa.Column(sa.Enum(Operation, values_callable=get_values))
    name = sa.Column(sa.String)
    type = sa.Column(sa.Enum(RecordType))
    content = sa.Column(sa.String)
    ttl = sa.Column(sa.Integer)

    def __init__(
            self,
            version=None,
            origin=None,
            serial=None,
            operation=None,
            name=None,
            type=None,
            content=None,
            ttl=None,
    ):
        self.version = version
        self.origin = origin
        self.serial = serial
        self.operation = operation
        self.name = name
        self.type = type
        self.content = content
        self.ttl = ttl


@event.listens_for(Domain, 'after_insert')
def _make_domain_add_change(mapper, connection, domain):
    change = ChangeLog(
        uuid_from_time(time.time()),
        domain.name,
        domain.serial,
        Operation.DOMAIN_ADD,
    )

    session = orm.object_session(domain)  # type: orm.Session
    session.add(change)


@event.listens_for(Domain, 'after_delete')
def _make_domain_delete_change(mapper, connection, domain):
    change = ChangeLog(
        uuid_from_time(time.time()),
        domain.name,
        domain.serial,
        Operation.DOMAIN_DELETE,
    )

    session = orm.object_session(domain)  # type: orm.Session
    session.add(change)


@event.listens_for(Record, 'after_insert')
def _make_record_add_change(mapper, connection, record):
    change = ChangeLog(
        uuid_from_time(time.time()),
        record.domain.name,
        record.domain.serial,
        Operation.RECORD_ADD,
        record.name,
        record.type,
        record.content,
        record.ttl,
    )

    session = orm.object_session(record)  # type: orm.Session
    session.add(change)


@event.listens_for(Record, 'after_update')
def _make_record_update_change(mapper, connection, record):
    session = orm.object_session(record)  # type: orm.Session
    state = sa.inspect(record)
    changed = set()
    for attr in state.attrs:
        hist = state.get_history(attr.key, True)
        if hist.has_changes():
            changed.add(attr.key)

    if not changed:
        return
    elif len(changed) == 1 and Record.ttl.key in changed:
        change = ChangeLog(
            uuid_from_time(time.time()),
            record.domain.name,
            record.domain.serial,
            Operation.RECORD_UPDATE_TTL,
            record.name,
            record.type,
            record.content,
            record.ttl,
        )
        session.add(change)
    else:
        def get_deleted(state, column_key):
            hist = state.get_history(column_key, True)
            return hist.deleted[0] if hist.has_changes() else hist.unchanged[0]

        def get_added(state, column_key):
            hist = state.get_history(column_key, True)
            return hist.added[0] if hist.has_changes() else hist.unchanged[0]

        delete_change = ChangeLog(
            uuid_from_time(time.time()),
            record.domain.name,
            record.domain.serial,
            Operation.RECORD_DELETE,
            get_deleted(state, Record.name.key),
            get_deleted(state, Record.type.key),
            get_deleted(state, Record.content.key),
        )
        add_change = ChangeLog(
            uuid_from_time(time.time()),
            record.domain.name,
            record.domain.serial,
            Operation.RECORD_ADD,
            get_added(state, Record.name.key),
            get_added(state, Record.type.key),
            get_added(state, Record.content.key),
            get_added(state, Record.ttl.key),
        )
        session.add(delete_change)
        session.add(add_change)


@event.listens_for(Record, 'after_delete')
def _make_record_delete_change(mapper, connection, record):
    change = ChangeLog(
        uuid_from_time(time.time()),
        record.domain.name,
        record.domain.serial,
        Operation.RECORD_DELETE,
        record.name,
        record.type,
        record.content,
    )

    session = orm.object_session(record)  # type: orm.Session
    session.add(change)
