import logging
import time
from collections import defaultdict
import itertools
from datetime import datetime

from infra.qyp.account_manager.src.lib.gutils import idle_iter
from infra.qyp.proto_lib import vmset_pb2
from yt import yson


log = logging.getLogger(__name__)


class ExpiredVmWorker:
    JNS_BACKUP_TEMPLATE = 'expired_vm_backup'
    JNS_REMOVED_TEMPLATE = 'expired_vm_removed'

    def __init__(self, yp_client_list, vmproxy_client, jns_client, qdm_client):
        """
        :type yp_client_list: list[infra.qyp.account_manager.src.model.yp_client.YpClient]
        :type vmproxy_client: infra.qyp.vmctl.src.api.VMProxyClient
        :type jns_client:  infra.qyp.account_manager.src.lib.jns_client.JNSClient
        :type qdm_client:  infra.qyp.account_manager.src.lib.qdm_client.QDMClient
        """
        self.yp_client_list = yp_client_list
        self.vmproxy_client = vmproxy_client
        self.jns_client = jns_client
        self.qdm_client = qdm_client

    def process(self):
        vms_by_cluster = self._get_expired_vms_by_cluster()

        for cluster in vms_by_cluster:
            for vm_id, vm_data in vms_by_cluster[cluster].iteritems():
                status, url = self._get_backup_status(cluster, vm_id, vm_data)
                if status == 0:
                    self._backup_vm(vm_id, cluster, vm_data)
                elif status == 1:
                    log.info('Backup for vm {} is in progress'.format(vm_id))
                else:
                    self._deallocate_vm(cluster, vm_id, vm_data, url)

    def run(self):
        try:
            self.process()
        except Exception as e:
            log.error('run failed: {}'.format(e))

    def _get_expired_vms_by_cluster(self):
        """
        :rtype: dict
        {
            cluster1: {
                vm_id1: {
                    shelf_time: int
                    creation_time: int
                    logins: list(str)
                },
                vm_id2...
            }
        }
        """
        vms = defaultdict(dict)
        for yp_client in self.yp_client_list:
            cluster_vms = self._get_vms(yp_client)
            vms[yp_client.cluster.get('cluster_name')] = cluster_vms
        return vms

    def _get_vms(self, yp_client):
        """
        :type yp_client: infra.qyp.account_manager.src.model.yp_client.YpClient
        :rtype: dict
        """
        query = '[/labels/deploy_engine] = "QYP" AND NOT is_null([/labels/qyp_vm_shelf_time_seconds])'
        selectors = [
            '/meta/id',
            '/meta/creation_time',
            '/labels/qyp_vm_shelf_time_seconds',
            '/annotations/owners/logins',
        ]
        pods = yp_client.list_pods(query=query, selectors=selectors)
        vms = {}
        for pod in idle_iter(pods.results):
            pod_id, creation_time, shelf_time, logins = [yson.loads(val) for val in pod.values]
            creation_time /= 10**6
            if time.time() - creation_time > shelf_time:
                vms[pod_id] = {
                    'shelf_time': shelf_time,
                    'creation_time': creation_time,
                }
                vms[pod_id]['logins'] = logins if logins and logins != '#' else []
        return vms

    def _get_backup_status(self, cluster, vm_id, vm_data):
        """
        :type cluster: str
        :type vm_id: str
        :type vm_data: dict
        :rtype: (int, str)
        mark backup_status field: 0 - No backup/last backup failed, 1 - backup in progress, 2 - backup completed
        """
        backups = self.qdm_client.backup_list(vm_id=vm_id, yp_cluster=cluster)
        if not backups:
            return 0, ''
        last_backup = backups[-1]
        # last backup was done after shelf time expired
        if last_backup.meta.creation_time.seconds > vm_data['creation_time'] + vm_data['shelf_time']:
            if last_backup.status.state == vmset_pb2.BackupStatus.IN_PROGRESS:
                return 1, ''
            elif last_backup.status.state == vmset_pb2.BackupStatus.COMPLETED:
                return 2, last_backup.status.url
        return 0, ''

    def _backup_vm(self, vm_id, cluster, vm_data):
        """
        :type vm_id: str
        :type cluster: str
        :type vm_data: dict
        :rtype: None
        """
        log.info('sending backup request for vm: {} in {}'.format(vm_id, cluster))
        creation_date = datetime.utcfromtimestamp(vm_data['creation_time'])
        params = {
            'cluster': cluster,
            'vm_id': vm_id,
            'shelf_time_in_hours': str(vm_data['shelf_time']//3600),
            'creation_date': creation_date.strftime('%d %b %Y %H:%M (UTC)')
        }
        try:
            self.vmproxy_client.backup(cluster, vm_id)
            self._send_notification(vm_data['logins'], self.JNS_BACKUP_TEMPLATE, params)
        except Exception as e:
            log.warning('Error creating backup for vm {} in cluster {}: {}'.format(vm_id, cluster, e))

    def _deallocate_vm(self, cluster, vm_id, vm_data, backup_url):
        """
        :type cluster: str
        :type vm_id: str
        :type vm_data: dict
        :type backup_url: str
        :rtype: None
        """
        params = {
            'vm_id': vm_id,
            'cluster': cluster,
            'backup_url': backup_url,
        }
        self._send_notification(vm_data['logins'], self.JNS_REMOVED_TEMPLATE, params)
        try:
            self.vmproxy_client.deallocate(cluster, vm_id)
        except Exception as e:
            log.warning('Error sending deallocation request for vm {} in cluster {}: {}'.format(vm_id, cluster, e))

    def _send_notification(self, logins, template, params):
        """
        :type logins: list
        :type template: str
        :type params: dict
        """
        if self.jns_client.send(logins, template, params):
            log.info('sent notification to {} with params: {}'.format(template, params))
        else:
            log.warning('notification to {} failed. params: {}, logins: {}'.format(template, params, logins))

    @staticmethod
    def split_iterable_into_batches(iterable, batch_size=50):
        it = iter(iterable)
        batch = list(itertools.islice(it, batch_size))
        while len(batch) > 0:
            yield batch
            batch = list(itertools.islice(it, batch_size))
