import re
import datetime
import collections

import ldap
import ldif
from ldap.modlist import modifyModlist, addModlist
from sqlalchemy import orm

from django.conf import settings

from infra.cauth.server.common.models import Access, User, Group, UserGroupRelation

from ..base import GeneralImporter

PersonDesc = collections.namedtuple('PersonDesc', (
    'uid',
    'cn',
    'sn',
    'uidNumber',
    'gidNumber',
    'loginShell',
    'homeDirectory',
))

GroupDesc = collections.namedtuple('GroupDesc', (
    'cn',
    'gidNumber',
    'memberUid',
))


class LdapImporterBase(GeneralImporter):
    """Note: На самом деле это не importer, а скорее uploader. Данные собираются
    из базы, приводятся к ldap-формату и выгружаются в ldap. Так как
    используется общий suite с to_database importer, в эти обработчики
    передаются данные, выгруженные из staff, но они не используются"""
    TREE_BASE = 'dc=yandex,dc=net'
    DESC_CLASS = None
    PK = None
    SCOPE = None
    CLASSES = None

    TARGET = 'ldap'

    def run_import(self):
        self._ldif = ldif.LDIFWriter(self.outstream)
        super(LdapImporterBase, self).run_import()

    @classmethod
    def _dn2key(cls, dn):
        dn_local, dn_suffix = dn.split(',', 1)
        pos = len(cls.PK) + 1
        if dn_local[:pos] == '%s=' % cls.PK:
            return dn_local[pos:]
        else:
            raise ValueError("Invalid dn in existing data: %s" % dn)

    @classmethod
    def _key2dn(cls, key):
        return ",".join(("%s=%s" % (cls.PK, key), cls.SCOPE, cls.TREE_BASE))

    def encode_values(self, collection):
        if isinstance(collection, set):
            return {elem.encode() for elem in collection}
        if isinstance(collection, list):
            return [elem.encode() for elem in collection]
        raise NotImplementedError

    def decode_values(self, collection):
        if isinstance(collection, set):
            return {elem.decode() for elem in collection}
        if isinstance(collection, list):
            return [elem.decode() for elem in collection]
        raise NotImplementedError

    def load_existing_data(self):
        connection = ldap.initialize('ldap://{}'.format(settings.LDAP_SERVER))
        result = connection.search_s(
            base=','.join((self.SCOPE, self.TREE_BASE)),
            scope=ldap.SCOPE_SUBORDINATE,
        )

        res = {}
        for dn, entry in result:
            key = self._dn2key(dn)
            attrs = {k: set() for k in self.__item_attrs__}

            for k, v in list(entry.items()):
                if k in self.__item_attrs__:
                    attrs[k] = self.decode_values(set(v))

            res[key] = self.DESC_CLASS(**attrs)

        return res

    def _add_extra(self, key):
        return {}

    def log_update(self, key):
        changes = []
        for k, old, new in self.res.to_update[key]:
            items = []
            for removed in old - new:
                items.append("-%s" % removed)
            for added in new - old:
                items.append("+%s" % added)
            changes.append("%s: {%s}" % (k, ", ".join(items)))

        self.logger.info("Updating %s %s %s", key, self.CLASSES[0],
                         ", ".join(changes))

    def add(self, keys):
        for key, desc in list(keys.items()):
            entry = {k: self.encode_values(getattr(desc, k)) for k in self.__item_attrs__}
            entry['objectClass'] = self.CLASSES
            entry.update(self._add_extra(key))

            self.logger.info("Adding %s %s", self.CLASSES[0].decode(), key)

            self._ldif.unparse(self._key2dn(key), addModlist(entry))

    def remove(self, keys):
        for key in keys:
            self.logger.info("Removing %s %s", self.CLASSES[0].decode(), key)

            self._ldif.unparse(self._key2dn(key), {'changetype': [b'delete']})

    def update_one(self, key):
        old = {}
        new = {}
        for k, _old, _new in self.res.to_update[key]:
            old[k] = self.encode_values(_old)
            new[k] = self.encode_values(_new)

        self.log_update(key)
        self._ldif.unparse(self._key2dn(key), modifyModlist(old, new))


class PeopleLdapImporter(LdapImporterBase):
    """Выгружает пользователей в ldap. В выгрузку не попадают пользователи
    указанные в настройке  LDAP_SKIP_USERS. Так же выгружается специальный
    пользователь cauth_watchdog
    """
    PK = 'uid'
    SCOPE = 'ou=people'
    CLASSES = (
        b'person',
        b'organizationalPerson',
        b'inetOrgPerson',
        b'posixAccount',
    )
    DESC_CLASS = PersonDesc
    __item_attrs__ = PersonDesc._fields

    DN_PATTERN = re.compile(r'^uid=(.+)$')

    def load_new_data(self, input):
        res = {}
        for user in User.query.filter(~User.login.in_(settings.LDAP_SKIP_USERS)):
            name = ' '.join((user.first_name, user.last_name)).strip()
            name = name or user.login

            res[user.login] = PersonDesc(
                uid={user.login},
                cn={name},
                sn={name},
                uidNumber={str(user.uid)},
                gidNumber={str(user.gid)},
                loginShell={user.shell},
                homeDirectory={user.home},
            )

        # special ldap user with cn = sn = timestamp updated hourly by cron
        res['cauth_watchdog'] = PersonDesc(
            uid={'cauth_watchdog'},
            cn={'wdog_timestamp'},
            sn={'wdog_timestamp'},
            uidNumber={'20000'},
            gidNumber={'20000'},
            loginShell={'/sbin/nologin'},
            homeDirectory={'/tmp'},
        )

        return res

    def update_one(self, key):
        if key != 'cauth_watchdog':
            super(PeopleLdapImporter, self).update_one(key)

    def _add_extra(self, key):
        return {'userPassword': self.encode_values(['{SASL}%s' % key])}


class GroupsLdapImporter(LdapImporterBase):
    """Собирает плоские данные о принадлежности пользователей группам и
    выгружает их в ldap.
    Каждая группа в ldap должна содержать в себе всех пользователей подгрупп,
    т.е. группа dpt_yandex должна содержать в себе практически всех сотрудников
    Яндекса (кроме асессоров, сотрудников Я.денег etc)
    Уволеные пользователи не попадают в обычные группы, однако должны
    присутствовать в специальной группе 'fired'
    Новые сотрудники (проработавшие меньше 90 дней) добавляются в специальную
    группу 'newbie'
    """
    PK = 'cn'
    SCOPE = 'ou=groups'
    CLASSES = (b'posixGroup',)
    DESC_CLASS = GroupDesc
    __item_attrs__ = GroupDesc._fields

    def _get_visible_groups(self):
        used_svc_groups = (
            Group.query
            .join(Access.src_group)
            .filter(Group.type.in_(('svc', 'svcrole')))
            .all()
        )
        groups = Group.query.filter(~Group.type.in_(('svc', 'svcrole'))).all()
        groups.extend(used_svc_groups)
        return groups

    def _get_fired_users(self):
        fired_users = User.query.filter(User.is_fired)
        return {user.login for user in fired_users}

    def _get_newbie_users(self):
        newbie_threshold = datetime.date.today() - datetime.timedelta(days=90)
        newbie_users = (
            User.query
            .filter(
                User.join_date >= newbie_threshold,
                User.is_fired.is_(False),
            )
        )
        return {user.login for user in newbie_users}

    def _get_group_members(self):
        memberships = (
            UserGroupRelation.get_active_query()
            .join(User)
            .filter(User.is_fired.is_(False))
            .options(
                orm.contains_eager(UserGroupRelation.user)
            )
        )

        group_members = collections.defaultdict(set)
        for membership in memberships:
            gid, login = membership.gid, membership.user.login
            group_members[gid].add(login)

        return group_members

    def _get_group_ancestors(self):
        groups = {group.gid: group for group in Group.query}

        group_ancestors = collections.defaultdict(set)
        for group in list(groups.values()):
            if group.parent_gid is None:
                continue

            parent = groups[group.parent_gid]

            ancestors = {parent.gid}
            while parent.parent_gid is not None:
                parent = groups[parent.parent_gid]
                ancestors.add(parent.gid)

            group_ancestors[group.gid] = ancestors

        return group_ancestors

    def load_new_data(self, input):
        group_members = self._get_group_members()
        group_ancestors = self._get_group_ancestors()
        visible_groups = self._get_visible_groups()

        flat_members = collections.defaultdict(set)
        for group in visible_groups:
            members = group_members[group.gid]
            flat_members[group.gid] |= members
            for ancestor_gid in group_ancestors[group.gid]:
                flat_members[ancestor_gid] |= members

        result = {}
        for group in visible_groups:
            result[group.name] = GroupDesc(
                cn={group.name},
                gidNumber={str(group.gid)},
                memberUid=flat_members[group.gid],
            )

        fired_users = self._get_fired_users()
        result['fired'] = GroupDesc(
            cn={'fired'},
            gidNumber={'20000'},
            memberUid=fired_users,
        )

        newbie_users = self._get_newbie_users()
        result['newbie'] = GroupDesc(
            cn={'newbie'},
            gidNumber={'19999'},
            memberUid=newbie_users,
        )

        return result
