import base64
import json
import logging
import waffle

from sqlalchemy import not_
from django.core.management.base import BaseCommand

from infra.cauth.server.common.alchemy import Session
from infra.cauth.server.common.constants import IDM_STATUS
from infra.cauth.server.common.models import Server, ServerGroup
from infra.cauth.server.master.api.tasks import push_idm_object
from infra.cauth.server.master.metrics.functions import is_not_pushed_q, is_in_ignorelist
from infra.cauth.server.master.constants import BATCH_OPERATION
from infra.cauth.server.master.utils.mongo import get_mongo_database
from infra.cauth.server.master.utils.tasks import lock_manager


log = logging.getLogger(__name__)


def get_dst_from_message(message):
    payload = json.loads(message['payload'])['body']
    payload = json.loads(base64.decodestring(payload))
    return payload['args'][0]


def get_queued_dsts():
    client = get_mongo_database()
    dsts = set()
    for message in client.messages.find({'queue': 'push'}):
        host = get_dst_from_message(message)
        dsts.add(host)
    return dsts


class Command(BaseCommand):
    LOCK_NAME = 'cron.poke_hanging_hosts'

    def _handle(self, **options):
        session = Session()
        dirty_dsts = dict()

        include_dirty = waffle.switch_is_active('cauth.poke_dirty_hosts')
        if include_dirty:
            dirty_dsts.update({
                s.fqdn: s
                for s
                in session.query(Server.fqdn).filter(
                    Server.idm_status != IDM_STATUS.ACTUAL,
                    not_(is_in_ignorelist(Server)),
                ).all()
            })
            dirty_dsts.update({
                g.name: g
                for g
                in session.query(ServerGroup.name).filter(ServerGroup.idm_status != IDM_STATUS.ACTUAL).all()
            })
        not_pushed_dsts = dict()
        not_pushed_dsts.update({
            s.fqdn: s
            for s
            in session.query(Server.fqdn).filter(is_not_pushed_q(Server)).all()
        })
        not_pushed_dsts.update({
            g.name: g
            for g
            in session.query(ServerGroup.name).filter(is_not_pushed_q(ServerGroup)).all()
        })

        not_pushed_actual = set(not_pushed_dsts) - set(dirty_dsts)
        if not_pushed_actual:
            log.warning('Found %s not pushed dsts which are considered actual', len(not_pushed_actual))
            dirty_dsts.update(not_pushed_dsts)

        log.info('Found %d dirty objects', len(dirty_dsts))

        queued_dsts = get_queued_dsts()
        for dst in list(dirty_dsts.values()):
            if dst[0] not in queued_dsts:
                push_idm_object.delay(dst[0], BATCH_OPERATION.POKE_HANGING_HOSTS)

    def handle(self, **options):
        with lock_manager.lock(self.LOCK_NAME, block=False):
            return self._handle(**options)
