# encoding: UTF-8

import itertools
import logging
import time

import dns.rdata
import dns.rdataclass
import dns.rdatatype
import sqlalchemy as sa
import sqlalchemy.dialects.postgresql as psql

from dns_hosting.models.domains import Record, ChangeLog, Operation
from dns_hosting.services.migrator.bucket import RecordsBucket
from dns_hosting.services.migrator.qbundle import QueryBundle
from dns_hosting.utils.iterators import ijoin
from dns_hosting.utils.timeuuid import uuid_from_time

logger = logging.getLogger(__name__)


# NOTE: в 3 python unpacking не поддерживается
#       https://www.python.org/dev/peps/pep-3113/
def make_migration(((pdd_bucket, own_bucket), query_bundle)):
    # type: (((RecordsBucket, RecordsBucket), QueryBundle)) -> str|None

    domain_id = (
            (pdd_bucket and pdd_bucket.domain_id) or
            (own_bucket and own_bucket.domain_id)
    )
    origin = (
            (pdd_bucket and pdd_bucket.origin) or
            (own_bucket and own_bucket.origin)
    )
    sync_enabled = (own_bucket is None or own_bucket.sync_enabled)

    if not sync_enabled:
        logger.info('Sync was disabled for origin \'%s\'', origin)
        return None

    try:
        origin = dns.name.from_text(origin).to_text()
        pdd_rrset = prepare_records(
            origin,
            domain_id,
            records=pdd_bucket and pdd_bucket.records,
            fixup=True,
        )
        own_rrset = prepare_records(
            origin,
            domain_id,
            records=own_bucket and own_bucket.records,
            fixup=False,
        )

        if pdd_rrset and pdd_rrset.get(('@', 'SOA')):
            del pdd_rrset['@', 'SOA']
        else:
            pdd_rrset = None

        version = uuid_from_time(time.time()).hex
        if pdd_rrset is None:
            # remove whole domain
            return compile_statement(
                sa
                    .text(query_bundle.get('delete_domain'))
                    .bindparams(domain_id=domain_id,
                                origin=origin,
                                version=version),
            )
        elif own_rrset is None:
            # append whole domain
            queries = make_diff(origin, pdd_rrset, {})
            return compile_statement(
                sa
                    .text(query_bundle.get('insert_domain'))
                    .bindparams(domain_id=domain_id,
                                origin=origin,
                                technical=origin.endswith('yaconnect.com.'),
                                version=version),
            ).format(queries=queries)
        elif not pdd_rrset and own_rrset:
            # clear whole domain
            return compile_statement(
                sa
                    .text(query_bundle.get('clear_domain'))
                    .bindparams(domain_id=domain_id,
                                origin=origin,
                                version=version),
            )
        else:
            queries = make_diff(origin, pdd_rrset, own_rrset)
            if not queries:
                return None

            return compile_statement(
                sa
                    .text(query_bundle.get('apply_domain_changes'))
                    .bindparams(domain_id=domain_id, origin=origin)
            ).format(queries=queries)
    except Exception:
        logger.exception('Failed to make migration of origin \'%s\'', origin)
        return None


def rdata_dbl_conv(origin, type, content):
    try:
        return dns.rdata.from_text(
            dns.rdataclass.IN,
            dns.rdatatype.from_text(type),
            content,
            origin,
            relativize=False,
        ).to_text()
    except Exception:
        return None


def prepare_records(origin, domain_id, records, fixup=False):
    if records is None:
        return None

    origin = dns.name.from_text(origin)

    if fixup:
        for record in records:
            if record['type'] == 'TXT':
                wrapped_content = '"' + record['content'] + '"'
                dbl_converted = rdata_dbl_conv(origin, 'TXT', wrapped_content)
                if wrapped_content == dbl_converted:
                    record['content'] = wrapped_content
            elif record['type'] == 'CNAME' and not record['content'].endswith('.'):
                record['content'] += '.'
            elif record['type'] == 'MX' and not record['content'].endswith('.'):
                record['content'] += '.'
            elif record['type'] == 'SRV' and not record['content'].endswith('.'):
                record['content'] += '.'
            elif record['type'] == 'NS' and not record['content'].endswith('.'):
                record['content'] += '.'

    records.sort(key=record_sort_key)

    i = 0
    known_records = set()
    known_cnames = {}
    while i < len(records):
        record = records[i]
        error = validate_content(origin, record)
        identity = record_identity(record)

        error_message = None
        if error:
            error_message = make_record_error_message(
                record['id'], origin, domain_id, error,
            )
        elif identity in known_records:
            error_message = make_record_error_message(
                record['id'], origin, domain_id, 'duplicates another one.',
            )
        elif record['type'] == 'CNAME':
            id, pos = known_cnames.get(record['name'], (None, None))
            if id is None:
                known_cnames[record['name']] = (record['id'], i)
            elif record['id'] <= id:
                error_message = make_record_error_message(
                    record['id'], origin, domain_id, 'duplicates CNAME %s.' % id,
                )
            else:
                records[pos] = record
                known_cnames[record['name']] = (record['id'], pos)
                error_message = make_record_error_message(
                    id, origin, domain_id, 'duplicates CNAME %s.' % record['id'],
                )

        if error_message is not None:
            if fixup:
                del records[i]
                logger.warning(error_message)
            else:
                raise RuntimeError(error_message)
        else:
            i += 1
            known_records.add(identity)

    rrsets = {}
    it = itertools.groupby(records, key=rrset_identity)
    for (name, type), set_records_it in it:
        if type != 'CNAME' and name in known_cnames:
            error_message = make_record_error_message(
                ','.join(str(record['id']) for record in set_records_it),
                origin,
                domain_id,
                'covered by CNAME',
            )
            if fixup:
                logger.debug(error_message)
                continue
            else:
                raise RuntimeError(error_message)
        else:
            ttl = None
            set_records = []
            for record in set_records_it:
                if ttl is None or ttl > record['ttl']:
                    ttl = record['ttl']

                set_records.append((record['id'], record['content']))

            rrsets[name, type] = ttl, set_records

    return rrsets


def make_record_error_message(id, origin, domain_id, msg):
    return 'Invalid record %s in %s (%s): %s' % (id, origin, domain_id, msg)


def record_sort_key(r):
    # Record with the highest id has precedence over the equal records
    return r['name'], r['type'], r['content'], -r['id']


def record_identity(r):
    return r['name'], r['type'], r['content']


def rrset_identity(r):
    return r['name'], r['type']


def validate_content(origin, record):
    try:
        dns.name.from_text(record['name'], origin)

        dns.rdata.from_text(
            dns.rdataclass.IN,
            dns.rdatatype.from_text(record['type']),
            record['content'],
        )
    except Exception as e:
        return str(e)


def make_diff(origin, rrsets1, rrsets2):
    rrsets = set(rrsets1) | set(rrsets2)

    records_to_delete = []
    records_to_update = []
    records_to_insert = []
    for rrset in rrsets:
        name, type = rrset

        ttl1, records1 = rrsets1.get(rrset, (None, []))
        ttl2, records2 = rrsets2.get(rrset, (None, []))

        records_it = ijoin(
            records1,
            records2,
            cmp=lambda r1, r2: cmp(r1[1], r2[1]),
            default=(None, None),
        )
        for (_, content1), (id2, content2) in records_it:
            if content1 is None:
                records_to_delete.append((id2, name, type, ttl2, content2))
            elif content2 is None:
                records_to_insert.append((None, name, type, ttl1, content1))
            elif ttl1 != ttl2:
                records_to_update.append((id2, name, type, ttl1, content1))

    queries = []
    queries.extend(make_deletes(origin, records_to_delete))
    queries.extend(make_updates(origin, records_to_update))
    queries.extend(make_inserts(origin, records_to_insert))
    queries = filter(None, queries)
    queries = '; '.join(queries)

    if queries:
        return queries + ';'
    else:
        return ''


def make_deletes(origin, records):
    if not records:
        return '', ''

    domain_id = sa.literal_column('found_domain_id')
    domain_serial = sa.literal_column('domain_serial')
    ids = []
    changelog_values = []

    for id, name, type, ttl, content in records:
        ids.append(id)
        changelog_values.append({
            ChangeLog.version.key: uuid_from_time(time.time()),
            ChangeLog.origin.key: origin,
            ChangeLog.serial.key: domain_serial,
            ChangeLog.operation.key: Operation.RECORD_DELETE,
            ChangeLog.name.key: name,
            ChangeLog.type.key: type,
            ChangeLog.content.key: content,
        })

    return (
        compile_statement(
            sa
                .delete(Record)
                .where(Record.domain_id == domain_id)
                .where(Record.id.in_(ids)),
        ),
        compile_statement(
            sa
                .insert(ChangeLog)
                .values(changelog_values),
        ),
    )


def make_updates(origin, records):
    if not records:
        return '', ''

    domain_id = sa.literal_column('found_domain_id')
    domain_serial = sa.literal_column('domain_serial')

    changelog_values = [
        {
            ChangeLog.version.key: uuid_from_time(time.time()),
            ChangeLog.origin.key: origin,
            ChangeLog.serial.key: domain_serial,
            ChangeLog.operation.key: Operation.RECORD_UPDATE_TTL,
            ChangeLog.name.key: name,
            ChangeLog.type.key: type,
            ChangeLog.content.key: content,
            ChangeLog.ttl.key: ttl,
        }
        for id, name, type, ttl, content in records
    ]

    return (
        ';'.join(
            compile_statement(
                sa
                    .update(Record)
                    .where(Record.id == id)
                    .where(Record.domain_id == domain_id)
                    .values(ttl=ttl)
            )
            for id, _, _, ttl, _ in records
        ),
        compile_statement(
            sa
                .insert(ChangeLog)
                .values(changelog_values),
        ),
    )


def make_inserts(origin, records):
    if not records:
        return '', ''

    domain_id = sa.literal_column('found_domain_id')
    domain_serial = sa.literal_column('domain_serial')
    insert_values = []
    changelog_values = []
    for _, name, type, ttl, content in records:
        insert_values.append({
            Record.domain_id.key: domain_id,
            Record.name.key: name,
            Record.type.key: type,
            Record.content.key: content,
            Record.ttl.key: ttl,
        })
        changelog_values.append({
            ChangeLog.version.key: uuid_from_time(time.time()),
            ChangeLog.origin.key: origin,
            ChangeLog.serial.key: domain_serial,
            ChangeLog.operation.key: Operation.RECORD_ADD,
            ChangeLog.name.key: name,
            ChangeLog.type.key: type,
            ChangeLog.content.key: content,
            ChangeLog.ttl.key: ttl,
        })

    return (
        compile_statement(
            psql
                .insert(Record)
                .values(insert_values)
        ),
        compile_statement(
            sa
                .insert(ChangeLog)
                .values(changelog_values),
        ),
    )


def compile_statement(stmt):
    dialect = psql.dialect()
    dialect._backslash_escapes = False
    compiled = stmt.compile(
        dialect=dialect,
        compile_kwargs={"literal_binds": True},
    )
    return compiled.string
