#!/usr/bin/env python
# -*- coding: utf-8 -*-

import requests
import logging
import time
import base64
import six
from retrying import retry

from saas.library.python.token_store import PersistentTokenStore
from saas.tools.devops.lib23 import nanny_helpers


class NannyReplicationPolicy(object):
    API = 'http://nanny.yandex-team.ru/api/repo'

    def __init__(self, service_name, token=''):
        self._service_name = service_name
        # Nanny auth token
        if token:
            self._auth_token = token
        else:
            self._auth_token = PersistentTokenStore.get_token_from_store_env_or_file('nanny')

        # Connection settings.
        self._connection = requests.session()
        self._connection.headers = {
            'Content-Type': 'application/json',
            'Authorization': 'OAuth %s' % self._auth_token
        }

    def is_policy_exists(self):
        nanny_service = nanny_helpers.NannyService(self._service_name)
        return nanny_service.is_policy_exists()

    def create_policy(self, enable_replication=True, wait_duration=3600, snapshot_priority='NORMAL',
                      max_unavailable=1, toleration_duration=43200, replication_method='MOVE'):
        nanny_service = nanny_helpers.NannyService(self._service_name)
        return nanny_service.create_policy(enable_replication=enable_replication, wait_duration=wait_duration, snapshot_priority=snapshot_priority,
                                           max_unavailable=max_unavailable, toleration_duration=toleration_duration, replication_method=replication_method,
                                           pod_group_id_path='/labels/shard', max_unavailable_in_group=1)

    def update_policy(self, enable_replication=True, wait_duration=3600, snapshot_priority='NORMAL', max_unavailable=1,
                      toleration_duration=43200, replication_method='MOVE'):
        nanny_service = nanny_helpers.NannyService(self._service_name)
        return nanny_service.update_policy(enable_replication=enable_replication, wait_duration=wait_duration, snapshot_priority=snapshot_priority,
                                           max_unavailable=max_unavailable, toleration_duration=toleration_duration, replication_method=replication_method,
                                           pod_group_id_path='/labels/shard', max_unavailable_in_group=1)

    def get_policy(self):
        nanny_service = nanny_helpers.NannyService(self._service_name)
        return nanny_service.get_policy()

    def remove_policy(self):
        nanny_service = nanny_helpers.NannyService(self._service_name)
        return nanny_service.remove_policy()


class NannyYpAPI(object):
    """
    Class with basic nanny yp api handles
    """
    YP_API = 'https://yp-lite-ui.nanny.yandex-team.ru/api/yplite/pod-sets/'

    def __init__(self, service_name, token='', abc_group='664', timeout=None):
        self._abc_group = abc_group
        self._service_name = service_name
        self._timeout = timeout

        # Nanny auth token
        if token:
            self._auth_token = token
        else:
            self._auth_token = PersistentTokenStore.get_token_from_store_env_or_file('nanny')

        # Connection settings.
        self._connection = requests.session()
        self._connection.headers = {
            'Content-Type': 'application/json',
            'Authorization': 'OAuth %s' % self._auth_token
        }

    def _prepare_allocation_data(self, location, request_data, junk_account=False, node_segment_id=None):
        """
        Prepare data structure for yp pods allocation request.
        :param location: type string, possible values: MAN, SAS, VLA
        :param request_data: type dict
        :return:
        """
        allocation_data = {
            'allocation_request': {
                'memoryGuaranteeMegabytes': request_data.get('mem_guarantee'),
                'memoryLimitMegabytes': request_data.get('mem_guarantee'),
                'networkMacro': request_data['networkMacro'],
                'persistentVolumes': request_data.get('volumes_data', []),
                'virtualDisks': request_data.get('virtual_disks', []),
                'replicas': request_data.get('instances_count', 1),
                'rootFsQuotaMegabytes': request_data.get('hdd_root_capacity'),
                'rootVolumeStorageClass': request_data.get('storage_type', 'hdd'),
                'snapshotsCount': request_data.get('snapshots', 10),
                'vcpuGuarantee': request_data.get('cpu_guarantee'),
                'vcpuLimit': request_data.get('cpu_limit') if request_data.get('cpu_limit') else request_data.get('cpu_guarantee'),
                'workDirQuotaMegabytes': request_data.get('workdir_capacity'),
                'podNamingMode': request_data.get('pod_naming_mode'),
                'rootBandwidthGuaranteeMegabytesPerSec': 1,
                'rootBandwidthLimitMegabytesPerSec': 10
            },
            'antiaffinityConstraints': {
                'nodeMaxPods': request_data.get('node_max_pods', 1),
                'rackMaxPods': request_data.get('rack_max_pods', 1)
            },
            'cluster': location,
            'serviceId': self._service_name,
        }

        if node_segment_id is not None:
            allocation_data['nodeSegmentId'] = node_segment_id

        if junk_account:
            allocation_data['quotaSettings'] = {
                'mode': 'TMP_ACCOUNT'
            }
        else:
            allocation_data['quotaSettings'] = {
                'abcServiceId': self._abc_group,
                'mode': 'ABC_SERVICE'
            }

        return allocation_data

    def create_pod_set(self, location, request_data, junk_account=False, node_segment_id='default'):
        request_url = self.YP_API + 'CreatePodSet/'
        request_data = self._prepare_allocation_data(location, request_data, junk_account=junk_account, node_segment_id=node_segment_id)
        return self._connection.post(request_url, json=request_data, verify=False, timeout=self._timeout)

    def copy_pod(self, location, pod_id):
        request_url = self.YP_API + 'GetPod/'
        request_data = {
            'cluster': location,
            'podId': pod_id
        }
        return self._connection.post(request_url, json=request_data, verify=False, timeout=self._timeout)

    def get_pods(self, location):
        request_url = self.YP_API + 'GetPodSet/'
        request_data = {
            'cluster': location,
            'serviceId': self._service_name
        }
        return self._connection.post(request_url, json=request_data, verify=False, timeout=self._timeout)

    @retry(stop_max_attempt_number=3)
    def get_pod(self, location, pod_id):
        request_url = self.YP_API + 'GetPod/'
        request_data = {
            'cluster': location,
            'podId': pod_id
        }
        return self._connection.post(request_url, json=request_data, verify=False, timeout=self._timeout)

    @retry(stop_max_attempt_number=3)
    def list_pods(self, location):
        request_url = self.YP_API + 'ListPods/'
        request_data = {
            'cluster': location,
            'serviceId': self._service_name
        }
        return self._connection.post(request_url, json=request_data, verify=False, timeout=self._timeout)

    def list_pods_groups(self, location):
        request_url = self.YP_API + 'ListPodsGroups/'
        request_data = {
            'cluster': location,
            'serviceId': self._service_name
        }
        return self._connection.post(request_url, json=request_data, verify=False, timeout=self._timeout)

    def list_pod_configuration_instances(self, location, pod_id, snapshot_id, target_state, version):
        request_url = self.YP_API + 'ListPodConfigurationInstances/'
        request_data = {
            'cluster': location,
            'version': version,
            'podId': pod_id,
            'serviceId': self._service_name,
            'snapshotId': snapshot_id,
            'targetState': target_state
        }
        return self._connection.post(request_url, json=request_data, verify=False, timeout=self._timeout)

    def update_pods(self, location, node_max_pods=1, rack_max_pods=0):
        request_url = self.YP_API + 'UpdatePodSet/'
        request_data = {
            'antiaffinityConstraints': {
                'nodeMaxPods': node_max_pods,
                'rackMaxPods': rack_max_pods
            },
            'cluster': location,
            'quotaSettings': {
                'abcServiceId': self._abc_group,
                'mode': 'ABC_SERVICE'
            },
            'serviceId': self._service_name,
        }
        return self._connection.post(request_url, json=request_data, verify=False, timeout=self._timeout)

    def remove_pods(self, location):
        request_url = self.YP_API + 'RemovePodSet/'
        request_data = {
            'cluster': location,
            'serviceId': self._service_name
        }
        return self._connection.post(request_url, json=request_data, verify=False, timeout=self._timeout)

    def remove_pod(self, location, pod_id, version):
        request_url = self.YP_API + 'RemovePod/'
        request_data = {
            'cluster': location,
            'version': version,
            'podId': pod_id
        }
        return self._connection.post(request_url, json=request_data, verify=False, timeout=self._timeout)

    def set_instance_state(self, location, pod_id, snapshot_id, target_state, version):
        request_url = self.YP_API + 'SetInstanceTargetState/'
        request_data = {
            'cluster': location,
            'version': version,
            'podId': pod_id,
            'serviceId': self._service_name,
            'snapshotId': snapshot_id,
            'targetState': target_state
        }
        return self._connection.post(request_url, json=request_data, verify=False, timeout=self._timeout)


class SaaSNannyYpWorkflow(NannyYpAPI):
    """
    Class for pods allocation for SaaS services
    """
    def __init__(self, service_name, abc_group='', token='', flexible_quota=False, log_volume_size=5120, timeout=None):
        super(SaaSNannyYpWorkflow, self).__init__(service_name, abc_group=abc_group, token=token, timeout=timeout)
        self._locations = ['MAN', 'SAS', 'VLA']
        self._flexible_quota = flexible_quota

        self._persistent_volumes = []
        self.add_persistent_volume('/logs', 'hdd', log_volume_size)

        self._virtual_disks = []

    def _allocate_pod_request(self, location, requirements, junk_account=False, node_segment_id=None):
        pods_list = []
        if junk_account:
            logging.info('Trying to allocate pods using junk account')
        resp = self.create_pod_set(location, requirements, junk_account=junk_account, node_segment_id=node_segment_id)
        if resp.ok:
            logging.debug('Successfully allocated pods, %s', resp.text)
            for pod_id in resp.json().get('podIds'):
                pods_list.append({'cluster': location, 'pod_id': pod_id})
        else:
            logging.error(resp.json().get('message', resp.text))
        return pods_list

    def add_virtual_disk(self, disk_id, storage_type, capacity, bandwidth=30):
        """
        Method for adding virtual disk drives (LVM)
        :param disk_id: type str
        :param storage_type: type str in ['SSD', 'HDD']
        :param capacity: type int (Mbytes)
        :param bandwidth: bandwidth guarantee  in Megabytes/sec
        """
        if not type(capacity) == int:
            logging.error('Parameter capacity must be int type')
            return
        ALLOWED_STORAGE_TYPES = ['ssd', 'hdd']
        if storage_type not in ALLOWED_STORAGE_TYPES:
            logging.error('Parameter storage_type must be in {}'.format(ALLOWED_STORAGE_TYPES))
            return

        virtual_disk_params = {
            'id': disk_id,
            'diskQuotaMegabytes': capacity,
            'storageClass': storage_type,
            'bandwidthGuaranteeMegabytesPerSec': int(bandwidth),
            'bandwidthLimitMegabytesPerSec': int(bandwidth * 1.1)
        }
        for virtual_disk in self._virtual_disks:
            if disk_id == virtual_disk['id']:
                virtual_disk.update(virtual_disk_params)
                break
        else:
            self._virtual_disks.append(virtual_disk_params)

    def add_persistent_volume(self, mount_point, storage_type, capacity, bandwidth=20, disk_id=None):
        """
        Method for adding or modifying persistent volumes
        :param mount_point: type str, must be started from symbol /
        :param storage_type: 'ssd', 'lvm' or 'hdd'
        :param capacity: type int (Mbytes)
        :param bandwidth: bandwidth guarantee  in Megabytes/sec
        :type bandwidth: int
        """
        # Check for symbol '/' at begin
        mount_point = mount_point if mount_point.startswith('/') else '/' + mount_point
        ALLOWED_STORAGE_TYPES = ['ssd', 'hdd', 'lvm']
        if storage_type not in ALLOWED_STORAGE_TYPES:
            logging.error('Parameter storage_type must be in {}'.format(ALLOWED_STORAGE_TYPES))
            return
        if not type(capacity) == int:
            logging.error('Parameter capacity must be int type')
            return
        if storage_type == 'lvm' and not disk_id:
            logging.error('Parameter storage_type with value "lvm" requires additional parameter "disk_id"')
            return

        volume_params = {
            'mountPoint': mount_point,
            'diskQuotaMegabytes': capacity
        }
        if storage_type == 'lvm':
            volume_params.update({'storageProvisioner': 'LVM',
                                  'virtualDiskId': disk_id})
            self.add_virtual_disk(disk_id, 'ssd', capacity, bandwidth)
        else:
            volume_params.update({
                'storageClass': storage_type,
                'bandwidthGuaranteeMegabytesPerSec': int(bandwidth),
                'bandwidthLimitMegabytesPerSec': max(int(bandwidth * 1.2), 100)
            })
        # FIXME: temporary workaround.
        if storage_type == 'hdd':
            volume_params['bandwidthGuaranteeMegabytesPerSec'] = 1
            volume_params['bandwidthLimitMegabytesPerSec'] = 10

        for volume in self._persistent_volumes:
            if mount_point == volume['mountPoint']:
                volume.update(volume_params)
                break
        else:
            self._persistent_volumes.append(volume_params)

    def remove_persistent_volume(self, mount_point):
        """
        Method for removing persistent volumes by mount point
        :param mount_point: type string, must be started from symbol '/'
        """
        mount_point = mount_point if mount_point.startswith('/') else '/' + mount_point
        for volume in self._persistent_volumes:
            if mount_point == volume['mountPoint']:
                self._persistent_volumes.remove(volume)
                break

    def allocate_pods(self, cpu_guaranteee, memory_guarantee, hdd_root_capacity, network_macros,
                      root_storage_type='hdd', cpu_limit=0, instances_count=1, locations='',
                      node_max_pods=1, snapshot_count=10, rack_max_pods=1, workdir_capacity=0, retry_count=3,
                      hr_pod_names=True, node_segment_id='default'):
        """
        Method for allocating pods for service with the specified requirements
        :param cpu_guaranteee: type int (ms)
        :param memory_guarantee: type int (Mbytes)
        :param hdd_root_capacity: type int (Mbytes)
        :param network_macros: type str
        :param root_storage_type: type str: 'hdd' or 'ssd'
        :param cpu_limit: type int (ms)
        :param instances_count: type int
        :param locations: type str or list, possible values: MAN, VLA, SAS
        :param node_max_pods: type int
        :param rack_max_pods: type int
        :param snapshot_count: type int
        :param workdir_capacity: type int(Mbytes)
        :param retry_count: type float
        :param hr_pod_names: type bool
        :param node_segment_id: type str
        :return: type list, structure [{'cluster': location, 'pod_id': pod_id}]
        """

        pods_list = []
        if not locations:
            locations = self._locations

        # Collect pod requirements
        pods_requirements = {
            'cpu_guarantee': cpu_guaranteee,
            'cpu_limit': cpu_limit,
            'mem_guarantee': memory_guarantee,
            'hdd_root_capacity': hdd_root_capacity,
            'workdir_capacity': workdir_capacity if workdir_capacity else hdd_root_capacity,
            'storage_type': root_storage_type,
            'instances_count': instances_count,
            'node_max_pods': node_max_pods,
            'rack_max_pods': rack_max_pods,
            'snapshots': snapshot_count,
            'volumes_data': self._persistent_volumes,
            'virtual_disks': self._virtual_disks,
            'pod_naming_mode': 1 if hr_pod_names else 0,
            'networkMacro': network_macros
        }

        # Check exists pod sets
        pods_exist = self.get_list_pods()
        for podset in pods_exist:
            if podset['cluster'] in locations:
                locations.remove(podset['cluster'])
                pods_list.append(podset)
                logging.warning('Found already existing podset for location %s, id - %s', podset['cluster'], podset['pod_id'])

        # Allocate pods
        for location in locations:
            pods = self._allocate_pod_request(location, pods_requirements, node_segment_id=node_segment_id)
            if not pods and self._flexible_quota:
                logging.warning('Can\'t allocate pods in project quota, falling back to junk quota')
                while not pods and retry_count > 0:
                    pods = self._allocate_pod_request(location, pods_requirements, junk_account=True)
                    if not pods:
                        logging.debug('Could not allocate pods in junk quota, retrying')
                        retry_count -= 1
                        time.sleep(1)

            pods_list.extend(pods)
        return pods_list

    def get_list_pods(self):
        """
        Get list of attached pods to service.
        :return: type list, structure [{'cluster': location, 'pod_id': pod_id}]
        """
        pods_list = []
        for location in self._locations:
            resp = self.list_pods(location)
            if resp.ok:
                for pod in resp.json().get('pods'):
                    pods_list.append({'cluster': location, 'pod_id': pod['meta']['id']})
        return pods_list

    def get_pods_by_shards(self, locations=None):
        result = []
        if not locations:
            locations = self._locations
        for loc in locations:
            resp = self.list_pods(loc.upper())
            if resp.ok:
                pods_list = resp.json()['pods']
                for pod in pods_list:
                    pod_instance = pod['status']['dns']['persistentFqdn'] + ':80'
                    pod_shard_range = ''
                    for label in pod['labels']['attributes']:
                        if label['key'] == 'shard':
                            pod_shard_label = six.ensure_str(base64.b64decode(label['value']))
                            pod_shard_range = '-'.join(pod_shard_label.split('_')[-2:])
                            break
                    result.append((pod_shard_range, pod_instance))
        return set(result)

    @retry(stop_max_attempt_number=3)
    def calc_service_resources(self):
        """
        Calculate service resources used
        :return: type dict format {'LOC': { 'CPU': ..., 'RAM': ..., 'HDD': ..., 'SSD': ...}}
        """
        pods_resource_report = {}
        # Prepare report construction
        for loc in self._locations:
            pods_resource_report[loc] = {res: 0 for res in ['CPU', 'CPU_LIMIT', 'RAM', 'SSD', 'HDD', 'HDD_READ_BW', 'SSD_READ_BW']}
        # Resource calculation
        pods_list = self.get_list_pods()
        for pod in pods_list:
            p = self.get_pod(pod['cluster'], pod['pod_id']).json()
            pod_spec = p['pod']['spec']
            # Volumes calculation
            for volume in pod_spec['diskVolumeRequests']:
                if volume['storageClass'] == 'hdd':
                    pods_resource_report[pod['cluster']]['HDD'] += int(volume['quotaPolicy']['capacity'])
                    pods_resource_report[pod['cluster']]['HDD_READ_BW'] += int(volume['quotaPolicy']['bandwidthGuarantee'])
                elif volume['storageClass'] == 'ssd':
                    pods_resource_report[pod['cluster']]['SSD'] += int(volume['quotaPolicy']['capacity'])
                    pods_resource_report[pod['cluster']]['SSD_READ_BW'] += int(volume['quotaPolicy']['bandwidthGuarantee'])
            # CPU & RAM calculation
            pods_resource_report[pod['cluster']]['RAM'] += int(pod_spec['resourceRequests']['memoryGuarantee'])
            pods_resource_report[pod['cluster']]['CPU_LIMIT'] += int(pod_spec['resourceRequests']['vcpuLimit'])
            pods_resource_report[pod['cluster']]['CPU'] += int(pod_spec['resourceRequests']['vcpuGuarantee'])
        return pods_resource_report
