from django.core.paginator import Paginator
from django.core.management import BaseCommand

from infra.cauth.server.common.models import Server, ServerGroup, Source

from infra.cauth.server.master.api.idm.client import IdmClient
from infra.cauth.server.master.api.idm.update import create_dst_requests
from infra.cauth.server.master.constants import BATCH_OPERATION


class Command(BaseCommand):
    def add_arguments(self, parser):
        parser.add_argument('--object-type', type=str, default='server', choices=['server', 'group'])
        parser.add_argument('--source-name', type=str, default='')
        parser.add_argument('--start-from-id', type=int, default=0)
        parser.add_argument('--batch-count', type=int, default=50)
        parser.add_argument('--comment', type=str, default=BATCH_OPERATION.BATCH_PUSH_NODE_COMMAND)

    def handle(self, *args, **options):
        objects = self.get_objects(options)

        idm_client = IdmClient()
        paginator = Paginator(objects, options['batch_count'])
        pages_count = paginator.num_pages

        for page_num in paginator.page_range:
            requests = []
            for obj in paginator.page(page_num):
                self.stdout.write('  add object id={} {}'.format(obj.id, obj))
                server_requests = create_dst_requests(idm_client, obj)
                requests.extend(server_requests)

            if requests:
                self.stdout.write('send page {}/{} batch'.format(page_num, pages_count))
                idm_client.perform_batch(requests, options['comment'])
            else:
                self.stdout.write('nothing to send in {}/{} batch'.format(page_num, pages_count))

    def get_objects(self, options):
        if options['object_type'] == 'server':
            obj_class = Server
            queryset = Server.query.join(Server.sources).distinct()
        else:
            obj_class = ServerGroup
            queryset = ServerGroup.query.join(ServerGroup.source)

        queryset = queryset.order_by(obj_class.id)

        if options['source_name']:
            queryset = queryset.filter(Source.name == options['source_name'])

        if options['start_from_id']:
            queryset = queryset.filter(obj_class.id >= options['start_from_id'])

        return queryset
