# encoding: UTF-8

import sqlalchemy.exc
import sqlalchemy.orm.exc
import typing
from sqlalchemy.orm import joinedload

from appcore.data.model import Page
from appcore.data.repository import Repository
from dns_hosting.models.domains import (
    ChangeLog,
    RecordType,
)
from dns_hosting.models.domains import Domain
from dns_hosting.models.domains import Record


class DomainRepository(Repository[Domain]):
    def __init__(self, session_factory):
        super(DomainRepository, self).__init__(Domain, session_factory)

    def find_by_org_id(self, org_id):
        # type: (...) -> typing.List[Domain]

        return list(self.find_iter(Domain.org_id == org_id))

    def find_by_name(self, name):
        # type: (...) -> Domain

        if not name.endswith('.'):
            name += '.'

        return self.find_one(
            filter=(Domain.name == name.lower())
        )

    def get_domain_serial(self, name):
        if not name.endswith('.'):
            name += '.'

        result = (
            self.session
                .query(Domain.serial)
                .filter(Domain.name == name.lower())
                .scalar()
        )
        if result is None:
            raise LookupError()
        else:
            return result

    def get_domain_names(self, yield_per=10000):
        query = (
            self.session
                .query(Domain.name)
                .order_by(Domain.name.asc())
                .yield_per(yield_per)
                .execution_options(server_side_cursor=True)
        )
        return (row[0] for row in query)

    def get_all_domain_names(self):
        query = (
            self.session
                .query(Domain.name)
                .order_by(Domain.name.asc())
        )
        return (row[0] for row in query)

    def save(self, entity, flush=False):
        entity.update_serial()
        return super(DomainRepository, self).save(entity, flush)


class RecordRepository(Repository[Record]):
    def __init__(self, session_factory):
        super(RecordRepository, self).__init__(Record, session_factory)

    def find_paged_by_domain_name(self, pageable, domain_name):
        # type: (...) -> Page[Record]

        return self.find_paged(
            pageable=pageable,
            filter=[
                Record.domain.has(Domain.name == domain_name),
            ]
        )

    def find_all_by_domain_name(self, domain_name):
        return (
            self.query
                .filter(Record.domain.has(Domain.name == domain_name))
                .all()
        )

    def find(self, domain_name, record_type=None, name=None):
        # type: (basestring, RecordType, basestring) -> typing.Iterable[Record]
        query = self.query
        query = query.filter(Record.domain.has(Domain.name == domain_name))
        if record_type:
            query = query.filter(Record.type == record_type)
        if name:
            query = query.filter(Record.name == name)
        return query.all()

    def find_all_by_domain_name_and_record_type(self, domain_name, record_type):
        # type: (basestring, RecordType) -> typing.Iterable[Record]
        assert record_type in RecordType

        query = self.query
        query = query.options(joinedload(Record.domain))
        query = query.join(Record.domain)
        query = query.filter(Domain.name == domain_name)
        query = query.filter(Record.type == record_type)

        return query.all()

    def find_by_domain_name_and_id(self, domain_name, record_id):
        # type: (basestring, int) -> Record

        query = self.query
        query = query.options(joinedload(Record.domain))
        query = query.join(Record.domain)
        query = query.filter(Domain.name == domain_name)
        query = query.filter(Record.id == record_id)

        return query.one()

    def save(self, entity, flush=False):
        entity.domain.update_serial()
        return super(RecordRepository, self).save(entity, flush)

    def delete(self, entity):
        entity.domain.update_serial()
        super(RecordRepository, self).delete(entity)


class ChangeLogRepository(Repository[ChangeLog]):
    def __init__(self, session_factory):
        super(ChangeLogRepository, self).__init__(ChangeLog, session_factory)

    def find_until_iter(self, from_, until, batch_size):
        # type: (...) -> typing.Iterable[ChangeLog]
        last_from = from_
        while True:
            try:
                it = (
                    self.query
                        .filter(ChangeLog.version > last_from)
                        .filter(ChangeLog.version < until)
                        .order_by(ChangeLog.version.asc())
                        .yield_per(batch_size)
                        .enable_eagerloads(False)
                )
                for change in it:
                    last_from = change.version
                    yield change
            except sqlalchemy.exc.OperationalError as e:
                if 'SSL SYSCALL error: EOF detected' in e.message:
                    continue
                raise

            break
