import datetime
import collections

from django.conf import settings
from sqlalchemy import orm, or_

from infra.cauth.server.common.alchemy import Session
from infra.cauth.server.common.models import Access, User, Group, UserGroupRelation, PublicKey

from infra.cauth.server.master.notify import events
from infra.cauth.server.master.importers.base import GeneralImporter, AnalyzeResult

UserDesc = collections.namedtuple('UserDesc', (
    'uid',
    'gid',
    'login',
    'first_name',
    'last_name',
    'join_date',
    'shell',
    'is_fired',
    'is_robot',
))
GroupDesc = collections.namedtuple('GroupDesc', (
    'gid',
    'parent_gid',
    'type',
    'name',
    'service_id',
    'staff_id',
))
GroupMemberDesc = collections.namedtuple('GroupMemberDesc', (
    'gid',
    'uid',
    'until'
))


class NewUserImporter(GeneralImporter):
    __item_attrs__ = UserDesc._fields

    TARGET = 'database'
    sanity_limits = {'any': 3000}

    def load_existing_data(self):
        self._user_map = {}
        res = {}

        for user in User.query:
            self._user_map[user.uid] = user
            res[user.uid] = UserDesc(
                uid=int(user.uid),
                gid=int(user.gid),
                login=str(user.login),
                first_name=str(user.first_name),
                last_name=str(user.last_name),
                join_date=str(user.join_date),
                shell=str(user.shell),
                is_fired=bool(user.is_fired),
                is_robot=bool(user.is_robot),
            )

        return res

    def load_new_data(self, input):
        res = {}
        for doc in input['persons']:
            attrs = {k: v for k, v in list(doc.items()) if k in self.__item_attrs__}
            desc = UserDesc(**attrs)
            res[desc.uid] = desc
        return res

    def add(self, keys):
        for key in keys:
            desc = self.new_data[key]
            user = User(
                uid=desc.uid,
                gid=desc.gid,
                login=desc.login,
                first_name=desc.first_name,
                last_name=desc.last_name,
                join_date=desc.join_date,
                shell=desc.shell,
                is_fired=desc.is_fired,
                is_robot=desc.is_robot,
            )
            self.logger.info('Adding new_user %s', desc.login)
            Session.add(user)
        Session.commit()

    def remove(self, keys):
        for key in keys:
            user = self._user_map[key]
            self.logger.info('Deleting new_user %s', user.login)
            Session.delete(user)
        Session.commit()

    def update_one(self, key):
        user = self._user_map[key]

        changes = []
        for att, old, new in self.res.to_update[key]:
            setattr(user, att, new)
            changes.append('%s: %s -> %s' % (att, old, new))
        self.logger.info('Updating new_user %s %s', user.login, ', '.join(changes))

        Session.commit()


class NewGroupsImporter(GeneralImporter):
    __item_attrs__ = GroupDesc._fields

    TARGET = 'database'
    sanity_limits = {'any': 2000}

    def load_existing_data(self):
        res = {}
        self._group_map = {}

        for group in Group.query:
            self._group_map[group.gid] = group
            res[group.gid] = GroupDesc(
                gid=group.gid,
                parent_gid=group.parent_gid,
                type=group.type,
                name=group.name,
                service_id=group.service_id,
                staff_id=group.staff_id,
            )

        return res

    def load_new_data(self, input):
        res = {}

        for doc in input['groups']:
            res[doc['gid']] = GroupDesc(
                gid=doc['gid'],
                parent_gid=doc['parent_gid'],
                type=doc['type'],
                name=doc['name'],
                service_id=doc['service_id'],
                staff_id=doc['staff_id'],
            )

        return res

    def add(self, keys):
        for key in keys:
            desc = self.new_data[key]
            group = Group(
                gid=key,
                parent_gid=desc.parent_gid,
                name=desc.name,
                type=desc.type,
                service_id=desc.service_id,
                staff_id=desc.staff_id,
            )
            self.logger.info('Adding new_group %s', group.name)
            Session.add(group)
        Session.commit()

    def remove(self, keys):
        for key in keys:
            group = self._group_map[key]
            self.logger.info('Deleting new_group %s', group.name)
            Session.delete(group)
        Session.commit()

    def update_one(self, key):
        group = self._group_map[key]
        changes = []

        for att, old, new in self.res.to_update[key]:
            setattr(group, att, new)
            changes.append('%s: %s -> %s' % (att, old, new))

        self.logger.info('Updating new_group %s %s', group.name, ', '.join(changes))

        Session.commit()

    @staticmethod
    def _update_access_rename_src():
        user_query = (
            Session.query(Access, User.login)
            .join(Access.src_user)
            .filter(Access.src != User.login)
        )

        for rule, new_login in user_query:
            rule.src = new_login

        group_query = (
            Session.query(Access, Group.name)
            .join(Access.src_group)
            .filter(Access.src != Group.name)
        )

        for rule, new_name in group_query:
            rule.src = new_name

        Session.commit()

    @staticmethod
    def _update_access_restore_src_ids():
        query = (
            Session.query(Access, User.uid, Group.gid)
            .outerjoin(User, User.login == Access.src)
            .outerjoin(Group, Group.name == Access.src)
            .filter(
                Access.src_user_id.is_(None),
                Access.src_group_id.is_(None),
                or_(
                    User.uid.isnot(None),
                    Group.gid.isnot(None)
                )
            )
        )

        for rule, uid, gid in query:
            rule.src_user_id = uid
            rule.src_group_id = gid

        Session.commit()

    def post_import(self):
        self._update_access_rename_src()
        self._update_access_restore_src_ids()


class NewGroupMembersImporter(GeneralImporter):
    __item_attrs__ = ('until',)

    TARGET = 'database'
    sanity_limits = {'any': 20000}

    def load_existing_data(self):
        memberships = (
            UserGroupRelation.query
            .join(UserGroupRelation.user)
            .join(UserGroupRelation.group)
            .options(
                orm.contains_eager(UserGroupRelation.user),
                orm.contains_eager(UserGroupRelation.group),
            )
        )
        return {(ms.gid, ms.uid): ms for ms in memberships}

    def load_new_data(self, input):
        res = {}
        users_uid = {}
        for doc in input['persons']:
            users_uid[doc['login']] = doc['uid']

        for doc in input['groups']:
            gids = [doc['gid']] + doc['ancestors']
            for gid in gids:
                for login in doc['members']:
                    uid = users_uid[login]
                    res[(gid, uid)] = GroupMemberDesc(
                        gid=gid,
                        uid=uid,
                        until=None,
                    )

        return res

    @classmethod
    def analyze(cls, ex_data, new_data):
        result = super(NewGroupMembersImporter, cls).analyze(ex_data, new_data)

        to_remove = {}
        to_update = result.to_update

        today = datetime.date.today()
        until = today + datetime.timedelta(days=settings.HOLD_MEMBERSHIP_PERIOD)
        for key, membership in result.to_remove.items():
            if membership.is_expired or membership.user.is_fired:
                to_remove[key] = membership
            elif membership.until is None:
                to_update[key] = (('until', None, until),)

        return AnalyzeResult(
            to_add=result.to_add,
            to_update=to_update,
            to_remove=to_remove
        )

    def add(self, keys):
        for key in keys:
            gid, uid = key
            self.logger.info('Adding user %s to group %s', uid, gid)

            membership = UserGroupRelation(uid=uid, gid=gid)
            Session.add(membership)

        Session.commit()

    def remove(self, keys):
        for key in keys:
            membership = self.ex_data[key]
            self.logger.info('Remove user %s from group %s',
                             membership.uid, membership.gid)

            Session.delete(membership)

        Session.commit()

    def update(self, keys):
        for key, ((_, _, until),) in keys.items():
            membership = self.ex_data[key]

            if until is not None:
                self.logger.info('Hold user %s in group %s',
                                 membership.uid, membership.gid)
                events.removed_from_group(
                    login=membership.user.login,
                    group=membership.group.name,
                )
            else:
                self.logger.info('Unhold user %s in group %s',
                                 membership.uid, membership.gid)

            membership.until = until

        Session.commit()


class NewPublicKeyImporter(GeneralImporter):
    TARGET = 'database'
    sanity_limits = {'any': 2000}

    def load_existing_data(self):
        res = {}

        uids = set()
        for pk in PublicKey.query:
            res[(pk.uid, pk.key)] = pk
            uids.add(pk.uid)

        self._login_map = {u.uid: u.login for u in User.query}

        return res

    def load_new_data(self, input):
        res = {}
        self._login_map = {}
        for user in input['persons']:
            self._login_map[user['uid']] = user['login']
            for key in user['pub_keys']:
                res[(user['uid'], key)] = None
        return res

    def add(self, keys):
        for uid, key in keys:
            self.logger.info('Adding key for user %s', self._login_map[uid])
            Session.add(PublicKey(uid=uid, key=key))
        Session.commit()

    def remove(self, keys):
        for login, key in list(keys.items()):
            if login[0] in self._login_map:
                self.logger.info('Removing key from user %s', self._login_map[login[0]])
            Session.delete(key)
        Session.commit()

    @classmethod
    def diff(cls, first, second):
        return None

    def update_one(self, key):
        pass
