import datetime
import itertools
import collections

from sqlalchemy import or_, sql, select
from sqlalchemy.dialects.postgresql import aggregate_order_by

from infra.cauth.server.common.alchemy import Session
from infra.cauth.server.common.models import (
    Access,
    User,
    PublicKey,
    UserGroupRelation,
    Group,
    ServerResponsible,
    Server,
    Source,
    ServerGroup,
    gr_m2m,
    sg_m2m,
)

from infra.cauth.server.public.constants import SOURCE_NAME


def _get_access_group_dst_filter(server, sources):
    if sources is None:
        return Session.query(sg_m2m.c.group_id).filter(sg_m2m.c.server_id == server.id)

    if not sources:
        return []

    sources_names = {source.name for source in sources}
    addition_names = SOURCE_NAME.get_addition_sources(sources_names)

    addition_sources = set()
    if addition_names:
        addition_sources = Session.query(Source).filter(Source.name.in_(addition_names)).all()
        addition_sources = {source.id for source in addition_sources}

    sources_ids = {source.id for source in sources} | addition_sources
    if not sources_ids:
        return []

    return (
        Session.query(ServerGroup.id)
        .join(ServerGroup.servers)
        .filter(
            Server.id == server.id,
            ServerGroup.source_id.in_(sources_ids),
        )
    )


def get_access_dst_filters(server, type_, sources):
    filters = [
        Access.type == type_,
        Access.approver.isnot(None),
        Access.approve_date.isnot(None),
        or_(
            Access.src_user_id.isnot(None),
            Access.src_group_id.isnot(None),
        ),
        or_(
            Access.until.is_(None),
            Access.until > datetime.date.today(),
        ),
    ]

    dst_or = [Access.dst_server_id == server.id]

    groups_query = _get_access_group_dst_filter(server, sources)
    group_ids = [group_id for (group_id,) in groups_query]

    if group_ids:
        dst_or.append(Access.dst_group_id.in_(group_ids))

    if len(dst_or) > 1:
        filters.append(or_(*dst_or))
    else:
        filters.extend(dst_or)

    return filters


def get_responsible_uids(server, sources=None):
    if sources is not None and not sources:
        return {}

    server_responsibles = (
        Session.query(ServerResponsible.uid, ServerResponsible.source_id)
        .filter(ServerResponsible.server_id == server.id)
    )

    groups_responsibles = (
        Session.query(gr_m2m.c.uid, ServerGroup.source_id)
        .select_from(Server)
        .join(sg_m2m, Server.id == sg_m2m.c.server_id)
        .join(ServerGroup, sg_m2m.c.group_id == ServerGroup.id)
        .join(gr_m2m, ServerGroup.id == gr_m2m.c.group_id)
        .filter(Server.id == server.id)
    )

    if sources is not None:
        sources_ids = {source.id for source in sources}

        server_responsibles = server_responsibles.filter(
            ServerResponsible.source_id.in_(sources_ids)
        )
        groups_responsibles = groups_responsibles.filter(
            ServerGroup.source_id.in_(sources_ids)
        )

    result = collections.defaultdict(set)
    for uid, source_id in itertools.chain(server_responsibles, groups_responsibles):
        result[uid].add(source_id)

    return result


def get_access_uids(server, sources=None, can_use_access=True):
    if not can_use_access:
        return {}

    access_query = (
        Session.query(
            sql.func.coalesce(Access.src_user_id, UserGroupRelation.uid),
            Access.ssh_is_admin,
        )
        .outerjoin(
            UserGroupRelation,
            UserGroupRelation.gid == Access.src_group_id,
        )
        .filter(
            or_(
                Access.src_user_id.isnot(None),
                UserGroupRelation.uid.isnot(None),
            ),
            *get_access_dst_filters(server, 'ssh', sources)
        )
    )

    result = collections.defaultdict(bool)
    for uid, is_admin in access_query:
        result[uid] |= is_admin

    return result


def get_root_uids(server, sources, can_use_access):
    access_uids = iter(get_access_uids(server, sources, can_use_access).items())
    root_access_uids = {uid for uid, is_root in access_uids if is_root}
    responsibles_uids = get_responsible_uids(server, sources)
    return root_access_uids | responsibles_uids.keys()


def get_keys(uids, entities=None):
    if not uids:
        return []

    if entities is None:
        entities = []

    return (
        Session.query(PublicKey.key, *entities)
        .join(PublicKey.user)
        .filter(
            PublicKey.uid.in_(uids),
            User.is_fired.is_(False),
        )
        .order_by(User.login, PublicKey.id)
    )


def get_posix_groups_new(gids=None, uids=None):
    separator = sql.literal_column("','")
    query = (
        Session.query(
            Group.name,
            Group.gid,
            sql.func.string_agg(User.login, aggregate_order_by(separator, User.login))
        )
        .join(UserGroupRelation, UserGroupRelation.gid == Group.gid)
        .join(User, UserGroupRelation.uid == User.uid)
        .group_by(Group.gid, Group.name)
        .order_by(Group.gid, Group.name)
    )
    if gids:
        query = query.filter(Group.gid.in_(gids))
    if uids:
        query = query.filter(User.uid.in_(uids))
    return query.all()


def get_posix_passwd_new():
    sql_query = (
        select([User.login, User.uid, User.gid, User.first_name, User.last_name, User.home, User.shell])
        .order_by(User.login)
    )

    return Session.execute(sql_query).fetchall()


def get_users(uids):
    if not uids:
        return []

    return (
        Session.query(User)
        .filter(
            User.uid.in_(uids),
            User.is_fired.is_(False),
        )
        .order_by(User.login)
    )


def get_users_gids(uids):
    query = (
        Session.query(
            UserGroupRelation.gid,
        )
        .filter(
            UserGroupRelation.uid.in_(uids),
        )
    )
    return query
