from sqlalchemy import orm

from infra.cauth.server.common.models import Access

from infra.cauth.server.master.api.idm.srcs import GroupSrc, PersonSrc
from infra.cauth.server.master.api.idm.dsts import ServerDst, GroupDst
from infra.cauth.server.master.api.idm.roles import SshRole, SudoRole, DstRole, EineRole


class RoleStream(object):
    page_size = 100
    next_stream = None
    page_param_keys = ()
    id = None

    @classmethod
    def get_page(cls, **params):
        raise NotImplementedError


class StreamSet(object):
    def __init__(self, *streams):
        self._map = {}
        self._streams = []

        prev_stream = None
        for cls in streams:
            assert issubclass(cls, RoleStream)
            self._map[cls.id] = cls
            self._streams.append(cls)

            if prev_stream:
                prev_stream.next_stream = cls
            prev_stream = cls

    def __getitem__(self, key):
        return self._map[key]

    def get_first_stream(self):
        return self._streams[0]


class EmptyPage(Exception):
    pass


class InvalidParams(Exception):
    pass


class RolePage(object):
    roles = None
    next_page_params = None

    def __init__(self, roles, next_page_params=None, is_last=False):
        self.roles = roles
        self.next_page_params = next_page_params or {}
        self.is_last = is_last


class AccessStreamBase(RoleStream):
    page_param_keys = ('after_pk',)

    @classmethod
    def get_page(cls, after_pk=None):
        query = cls.get_queryset()

        if after_pk is not None:
            try:
                after_pk = int(after_pk)
            except ValueError:
                raise InvalidParams
            query = query.filter(Access.id > after_pk)

        query = query.order_by(Access.id).limit(cls.page_size)

        roles = []
        max_pk = None
        count = 0
        for access in query:
            max_pk = access.id
            count += 1

            if access.src_user_id:
                src = PersonSrc(access.src)
            elif access.src_group_id:
                src = GroupSrc(access.src_group.staff_id)
            else:
                raise RuntimeError('Access rule with no src')

            if access.type == 'ssh':
                role_cls = SshRole
            elif access.type == 'sudo':
                role_cls = SudoRole
            elif access.type == 'eine':
                role_cls = EineRole
            else:
                raise RuntimeError('Bad accessrule type: %s' % access.type)

            dst = cls.get_dst(access)
            roles.append(role_cls(src, access, parent=DstRole(src, dst)))

        if max_pk is None:
            raise EmptyPage

        return RolePage(
            roles=roles,
            next_page_params={'after_pk': max_pk},
            is_last=count < cls.page_size,
        )

    @classmethod
    def get_queryset(cls):
        raise NotImplementedError

    @classmethod
    def get_dst(cls, rule):
        raise NotImplementedError


class ServerAccessStream(AccessStreamBase):
    id = 'server_access'

    @classmethod
    def get_queryset(cls):
        return (
            Access.get_active_query_with_empty_dst()
            .join(Access.dst_server)
            .options(
                orm.contains_eager(Access.dst_server),
                orm.joinedload(Access.src_group),
                orm.joinedload(Access.sudo_role),
            )
            .distinct()
        )

    @classmethod
    def get_dst(cls, rule):
        return ServerDst(rule.dst_server)


class ServerGroupAccessStream(AccessStreamBase):
    id = 'group_access'

    @classmethod
    def get_queryset(cls):
        return (
            Access.get_active_query_with_empty_dst()
            .join(Access.dst_group)
            .options(
                orm.contains_eager(Access.dst_group),
                orm.joinedload(Access.src_group),
                orm.joinedload(Access.sudo_role),
            )
        )

    @classmethod
    def get_dst(cls, rule):
        return GroupDst(rule.dst_group)


role_streams = StreamSet(
    ServerAccessStream,
    ServerGroupAccessStream,
)
