# -*- coding: utf-8 -*-
import logging
import json
import re
import requests
import time
from urlparse import urljoin

from sandbox.sandboxsdk.channel import channel
from sandbox.sandboxsdk.errors import SandboxTaskFailureError
from misc import retry
from spikes import get_task_type


LOAD_RGX = re.compile(r'addrs_[a-z]+_l\d+')
PRIEMKA_WIZARD_RGX = re.compile(r'addrs_[a-z]+2_p\d+')


class Nanny(object):
    def __init__(self, oauth_token):
        self.url = 'https://nanny.yandex-team.ru/'
        self.oauth_token = oauth_token
        self.session = requests.Session()
        self.session.headers['Authorization'] = 'OAuth {}'.format(self.oauth_token)
        self.session.headers['Content-Type'] = 'application/json'

    @retry(tries=3, delay=2)
    def get_load_stands(self):
        uri = '/v2/services/?category=/rcss/addrs/load'
        url = urljoin(self.url, uri)
        data = self.session.get(url).json()
        services = [service['info_attrs']['content']['category'] for service in data.get('result')]
        choices = [(service, service) for service in services if 'load' in service]
        return choices

    @retry(tries=3, delay=2)
    def expand_stand(self, stand):
        uri = '/v2/services/?category=%s' % stand
        url = urljoin(self.url, uri)
        data = self.session.get(url).json()
        services = [service.get('_id') for service in data.get('result')]
        return [service for service in services if 'e1' not in service]

    @retry(tries=3, delay=2)
    def get_conf(self, service):
        uri = '/v2/services/%s/runtime_attrs/resources/' % service
        url = urljoin(self.url, uri)
        data = self.session.get(url).json()
        return data.get('content')

    @retry(tries=3, delay=2)
    def get_snapshot_id(self, service):
        url = urljoin(self.url, '/v2/services/%s/runtime_attrs/' % service)
        data = self.session.get(url).json()
        return data.get('_id')

    @retry(tries=3, delay=2)
    def get_service_state(self, service):
        '''Returns state of service in nanny'''
        url = urljoin(self.url, '/v2/services/%s/current_state/' % service)
        data = self.session.get(url).json()
        data = data.get('content')
        summary = data.get('summary')
        return summary.get('value')

    def get_stand_states(self, stand):
        '''Returns state of service in nanny'''
        states = set()
        for service in self.expand_stand(stand):
            state = self.get_service_state(service)
            states.add(state)
        return states

    @retry(tries=3, delay=2)
    def get_snapshot_state(self, service, snapshot_id):
        '''Returns state of service in nanny'''
        url = urljoin(self.url, '/v2/services/%s/current_state/' % service)
        data = self.session.get(url).json()
        data = data.get('content')
        snapshots = data.get('active_snapshots')
        for snapshot in snapshots:
            if snapshot.get('snapshot_id') == snapshot_id:
                state = snapshot.get('state')
                return state

    @retry(tries=3, delay=2)
    def get_instances(self, service):
        uri = '/v2/services/%s/runtime_attrs/instances/' % service
        url = urljoin(self.url, uri)
        data = self.session.get(url).json()
        return data.get('content')

    @retry(tries=3, delay=2)
    def get_isolated_instances(self, service, location=False):
        uri = '/v2/services/%s/current_state/instances/' % service
        url = urljoin(self.url, uri)
        response = self.session.get(url)
        data = response.json()
        logging.info('nanny.get_isolated_instances data: {}'.format(data))
        if not data:
            logging.info('%s response code: %s' % (url, response.status_code))
            logging.info(response.text)
            raise SandboxTaskFailureError('Could not get instances list from Nanny API')
        instances = data.get('result')
        if location:
            return [instance.get('container_hostname') for instance in instances if location in instance.get('itags')[0]]
        return instances

    @retry(tries=3, delay=2)
    def get_gencfg_hosts(self, service, location):
        url_attrs = urljoin(self.url, '/v2/services/{}/runtime_attrs/'.format(service))
        response = self.session.get(url_attrs)
        data = response.json()
        try:
            gencfg_name = ''
            for gr in data['content']['instances']['extended_gencfg_groups']['groups']:
                if gr['name'].startswith(location):
                    gencfg_name = gr['name']
                    break
            logging.info('Got gencfg name: {}'.format(gencfg_name))
            if gencfg_name:
                url_gencfg = 'http://api.gencfg.yandex-team.ru/trunk/groups/{}'.format(gencfg_name)
                data_gencfg = self.session.get(url_gencfg).json()
                logging.debug('Got hosts from gencfg: {}'.format(', '.join(data_gencfg['hosts'])))
                return data_gencfg['hosts']
        except Exception as e:
            logging.error('Failed to get hosts from gencfg: {}'.format(e))
        return []

    def get_yp_pods(self, service):

        def pod_id(pod):    # For normal sorting
            rgx = re.compile(r'{service_name}-(\d+)'.format(service_name=service.replace('_', '-')))
            return int(rgx.findall(pod)[0])

        instances_data = self.get_instances(service)
        pods = set([pod.get('pod_id') for pod in instances_data['yp_pod_ids']['pods']])
        return sorted(pods, key=pod_id)

    def make_zerodiff(self, src_config, dst_config):
        for resource in dst_config.get('static_files'):
            if resource['local_path'] == 'instancectl.conf':
                dst_instancectl = resource['content']
        for resource in src_config.get('static_files'):
            if resource['local_path'] == 'instancectl.conf':
                if not dst_instancectl.startswith('\n'):
                    resource['content'] = '\n%s' % resource['content']
        for resource in dst_config.get('template_set_files'):
            if resource['local_path'] == 'instancectl.conf':
                dst_instancectl = resource['layout']
        for resource in src_config.get('template_set_files'):
            if resource['local_path'] == 'instancectl.conf':
                if not dst_instancectl.startswith('\n'):
                    resource['layout'] = '\n%s' % resource['layout']
        logging.info(src_config)
        return src_config

    def add_topology(self, instances, topology):
        groups = instances['extended_gencfg_groups']['groups']
        new_groups = []
        for group in groups:
            # Not changing 'trunk' groups
            if group['release'] == 'trunk':
                continue
            new_group = group.copy()
            new_group.update({"release": topology})
            new_groups.append(new_group)
        new_groups.extend(groups)
        instances['extended_gencfg_groups'].update({'groups': new_groups})
        return instances

    def delete_topology(self, instances, topology):
        groups = instances['extended_gencfg_groups']['groups']
        new_groups = []
        for group in groups:
            if group['release'] != topology:
                new_groups.append(group)
        instances['extended_gencfg_groups'].update({'groups': new_groups})
        return instances

    def put_instances(self, service, attrs):
        uri = '/v2/services/%s/runtime_attrs/instances/' % service
        url = urljoin(self.url, uri)
        service_snapshot_id = self.get_snapshot_id(service)
        data = {'snapshot_id': service_snapshot_id,
                'comment': '',
                'content': attrs}
        response = self.session.put(url, data=json.dumps(data))
        if response.ok:
            logging.debug('Response code: %s' % response.status_code)
            logging.debug('Response text: %s' % response.text)
        else:
            logging.debug('Response code: %s' % response.status_code)
            logging.debug('Response text: %s' % response.text)
            raise SandboxTaskFailureError('Error on working with %s: %s'
                                          % (service, response.text))
        return service_snapshot_id

    def compare_services(self, prod_service, load_service):
        '''Compare runtime attributes of given services'''
        prod_conf = self.get_conf(prod_service)
        load_conf = self.get_conf(load_service)
        return prod_conf == load_conf

    def compare_stands(self, src_stand, dst_stand):
        '''Compare runtime attributes of given stands'''
        services_to_change = []
        src_services = sorted(self.expand_stand(src_stand))
        dst_services = sorted(self.expand_stand(dst_stand))
        logging.info(src_services)
        logging.info(dst_services)
        for dst_service in dst_services:
            norm_service_name = '_'.join(dst_service.split('_')[:-1])
            if norm_service_name not in src_services:
                logging.warn("Can't find apropriate src service for {0}, norm name={1}".format(dst_service, norm_service_name))
                continue
            if not self.compare_services(norm_service_name, dst_service):
                services_to_change.append((norm_service_name, dst_service))
        return services_to_change

    def copy_service(self, src_service, dst_service, author, force_snapshot):
        '''Copies runtime attributes from production service to load service'''
        logging.info('Copying configuration from %s to %s'
                     % (src_service, dst_service))
        src_config = self.get_conf(src_service)
        dst_config = self.get_conf(dst_service)
        if force_snapshot:
            src_config = self.make_zerodiff(src_config, dst_config)
        dst_snapshot_id = self.get_snapshot_id(dst_service)
        uri = '/v2/services/%s/runtime_attrs/resources/' % dst_service
        url = urljoin(self.url, uri)
        data = {'snapshot_id': dst_snapshot_id,
                'comment': 'Copyied from %s by %s' % (src_service, author),
                'content': src_config}
        logging.debug('Destination service URL: %s' % url)
        logging.debug('Request data: %s' % data)
        # Just to be on a safe side checking if service is loadtesting
        response = self.session.put(url, data=json.dumps(data))
        if response.ok:
            logging.debug('Response code: %s' % response.status_code)
            logging.debug('Response text: %s' % response.text)
            logging.info('Applied configuration from %s to %s'
                         % (src_service, dst_service))
        else:
            logging.info('Appling configuration from %s to %s failed!'
                         % (src_service, dst_service))
            logging.debug('Response code: %s' % response.status_code)
            logging.debug('Response text: %s' % response.text)
            raise SandboxTaskFailureError('Error on working with %s: %s'
                                          % (dst_service, response.text))
        return dst_snapshot_id

    @retry(tries=3, delay=2)
    def get_sandbox_files(self, service):
        url = urljoin(self.url, '/v2/services/%s/runtime_attrs/resources/' % service)
        data = self.session.get(url).json()
        data = data.get('content')
        return data

    def replace_sandbox_file(self, sandbox_resources, service, author):
        logging.info('Replacing sandbox file')
        nanny_data = self.get_sandbox_files(service)
        for sandbox_resource in sandbox_resources:
            resource = channel.sandbox.get_resource(sandbox_resource)
            # task = channel.sandbox.get_task(resource.task_id)
            sandbox_files = nanny_data.get('sandbox_files')
            logging.info('Current sandbox files')
            logging.info(sandbox_files)
            new_sandbox_files = []
            for sandbox_file in sandbox_files:
                # task_type = sandbox_file.get('task_type')
                resource_type = sandbox_file.get('resource_type')
                # if task_type == task.type and resource_type == resource.type:
                if resource_type == resource.type:
                    logging.info('%s task id = %s' % (resource.type, resource.task_id))
                    task_type = get_task_type(resource.task_id)
                    logging.info('%s task type = %s' % (resource.type, task_type))
                    sandbox_file['task_type'] = task_type
                    sandbox_file['task_id'] = str(resource.task_id)
                    sandbox_file['resource_type'] = resource.type
                    sandbox_file['resource_id'] = str(sandbox_resource)
                new_sandbox_files.append(sandbox_file)
            logging.info('Modified sandbox files: %s' % new_sandbox_files)
            nanny_data.update({'sandbox_files': new_sandbox_files})
        url = urljoin(self.url, '/v2/services/%s/runtime_attrs/resources/' % service)
        data = {'snapshot_id': self.get_snapshot_id(service),
                'comment': 'Replaced %s with #%s by %s' % (task_type,
                                                           resource.task_id,
                                                           author),
                'content': nanny_data}
        logging.info('Request data: %s' % data)
        response = self.session.put(url, data=json.dumps(data))
        if response.ok:
            logging.info('Changed sandbox_files list for %s successfuly' % service)
            logging.debug('Response code: %s' % response.status_code)
            logging.debug('Response text: %s' % response.text)
        else:
            logging.info('Failed to change sandbox_files list for %s' % service)
            logging.debug('Response code: %s' % response.status_code)
            logging.debug('Response text: %s' % response.text)
            raise SandboxTaskFailureError('Error on working with %s: %s'
                                          % (service, response.text))

    def replace_shardmap(self, service, author, resource_id, resource_type='', task_id='', task_type='', ssd_deploy=True):
        nanny_data = self.get_sandbox_files(service)
        bcs_resource = nanny_data['sandbox_bsc_shard']
        logging.info('bsc_resource %s' % bcs_resource)
        shardmap_resource = nanny_data['sandbox_bsc_shard']['sandbox_shardmap']
        # nanny_resource_type = shardmap_resource.get('resource_type')
        resource = channel.sandbox.get_resource(resource_id)
        if not task_id:
            task = channel.sandbox.get_task(resource.task_id)
            task_id = str(task.id)
        if not task_type:
            task_type = task.type
        if not resource_type:
            resource_type = resource.type
        shardmap_resource['task_type'] = task_type
        shardmap_resource['task_id'] = task_id
        shardmap_resource['resource_type'] = resource_type
        logging.info('Modified shardmap_resources: %s' % shardmap_resource)
        nanny_data['sandbox_bsc_shard']['sandbox_shardmap'] = shardmap_resource
        if not ssd_deploy and nanny_data['sandbox_bsc_shard'].get('storage'):
            del(nanny_data['sandbox_bsc_shard']['storage'])
        url = urljoin(self.url, '/v2/services/%s/runtime_attrs/resources/' % service)
        data = {'snapshot_id': self.get_snapshot_id(service),
                'comment': 'Replaced %s with #%s by %s' % (task_type, task_id, author),
                'content': nanny_data}
        logging.info('Request data: %s' % data)
        response = self.session.put(url, data=json.dumps(data))
        if response.ok:
            logging.info('Changed sandbox_files list for %s successfuly' % service)
            logging.debug('Response code: %s' % response.status_code)
            logging.debug('Response text: %s' % response.text)
        else:
            logging.info('Failed to change sandbox_files list for %s' % service)
            logging.debug('Response code: %s' % response.status_code)
            logging.debug('Response text: %s' % response.text)
            raise SandboxTaskFailureError('Error on working with %s: %s'
                                          % (service, response.text))

    def get_receipe(self, service):
        url = urljoin(self.url, '/v2/services/%s/info_attrs' % service)
        data = self.session.get(url).json()
        data = data.get('content')
        recipes = data.get('recipes')
        # recipe = recipes.get('content')[0]
        recipe = recipes.get('content')
        return [r.get('id') for r in recipe]
        # return recipe.get('name')

    # def activate_service(self, dst_service, comment, startrek_task=''):
    def activate_service(self, dst_service, comment, startrek_task=''):
        logging.info('Activating %s' % dst_service)
        url = urljoin(self.url, '/v2/services/%s/events/' % dst_service)
        dst_snapshot_id = self.get_snapshot_id(dst_service)
        request = {'type': 'SET_SNAPSHOT_STATE',
                   'content': {'snapshot_id': dst_snapshot_id,
                               'comment': comment,
                               'recipe': self.get_receipe(dst_service)[0],
                               'state': 'ACTIVE'}}
        # if startrek_task:
        #     request['content'].update({'tracked_tickets': {'startrek_tickets': [startrek_task]}})
        logging.debug('Snapshot activating URL: %s' % url)
        logging.debug('Request data: %s' % request)
        # Just to be on a safe side checking if service is loadtesting
        response = self.session.post(url, data=json.dumps(request))
        if response.ok:
            logging.debug('Response code: %s' % response.status_code)
            logging.debug('Response text: %s' % response.text)
            time.sleep(60)
            snapshot_state = ''
            logging.info('    Waiting %s`s snapshot %s to be activated' %
                         (dst_service, dst_snapshot_id))
            while snapshot_state != 'ACTIVE':
                time.sleep(30)
                snapshot_state = self.get_snapshot_state(dst_service, dst_snapshot_id)
            logging.info('Snapshot activated!')
        else:
            logging.info('Activation of %s failed!' % dst_service)
            raise SandboxTaskFailureError('Error on working with %s: %s'
                                          % (dst_service, response.text))

    def loadstand_to_prod(self, changing_list, stand, author, acivate, force_snapshot):
        for src_service, dst_service in changing_list:
            if not PRIEMKA_WIZARD_RGX.match(dst_service):
                if 'base' not in dst_service:
                    self.copy_service(src_service, dst_service, author, force_snapshot)
                    if acivate:
                        logging.info('Activating stand %s' % stand)
                        self.activate_service(dst_service, author)
                        status = ['']
                        # Check if every service in stand is ONLINE
                        while len(status) != 1 and status[0] != 'ONLINE':
                            logging.info('Sleeping for 30 seconds...')
                            time.sleep(30)
                            status = self.check_stand_status(stand)
                            logging.info('%s status is %s' % (stand, status))
                        logging.info('Load stand %s activated' % stand)
                    else:
                        logging.info('Changes copied to stand %s' % stand)

    def check_stand_status(self, stand):
        services = self.expand_stand(stand)
        return list({self.get_service_state(service) for service in services})

    @retry(tries=3, delay=2)
    def get_configuration_id(self, service):
        uri = '/v2/services/%s/current_state/' % service
        url = urljoin(self.url, uri)
        data = self.session.get(url).json()
        data = data.get('content')
        snapshots = data.get('active_snapshots')
        for snapshot in snapshots:
            if snapshot.get('state') == 'ACTIVE':
                return snapshot.get('conf_id')
