import socket
import random
from builtins import TimeoutError, ConnectionError

import requests

import boto3
import botocore

from starkutil import timeout
from config import cfg

from logger import log

class EC2():
    def __init__(self, bebo_region):
        import logging
        logging.getLogger('boto').setLevel(logging.CRITICAL)
        self.bebo_region = bebo_region
        self.my_vpc = None
        self.region = None
        self._populate_aws_info()
        self.session = boto3.session.Session()
        self.client = self.session.client('ec2', region_name=self.region)

        self.ami_name_cache = {}

    def _populate_aws_info(self):
        my_mac = requests.get('http://169.254.169.254/latest/meta-data/network/interfaces/macs', timeout=0.5).text.split('/n')[0]
        self.my_vpc = requests.get('http://169.254.169.254/latest/meta-data/network/interfaces/macs/%s/vpc-id' % my_mac, timeout=0.5).text
        region_request = requests.get('http://169.254.169.254/latest/meta-data/placement/availability-zone', timeout=0.5)
        self.security_groups = requests.get('http://169.254.169.254/latest/meta-data/security-groups', timeout=0.5).text.split("\n")
        self.admin_public_ip = requests.get('http://169.254.169.254/latest/meta-data/public-ipv4', timeout=0.5).text
        self.admin_private_ip = requests.get('http://169.254.169.254/latest/meta-data/local-ipv4', timeout=0.5).text
        self.admin_instance_id = requests.get('http://169.254.169.254/latest/meta-data/instance-id', timeout=0.5).text
        self.region = region_request.text[:-1]

    def _get_client(self, service='ec2'):
        return self.client

    def get_running_boxes(self):
        return self._get_all_boxes('running')

    def get_pending_boxes(self):
        return self._get_all_boxes('pending')

    def get_stopped_boxes(self):
        return self._get_all_boxes('stopped')

    def get_stopping_boxes(self):
        return self._get_all_boxes('stopping')

    def get_terminated_boxes(self):
        return self._get_all_boxes('terminated')

    def get_shutting_down_boxes(self):
        return self._get_all_boxes('shutting-down')

    def get_vpc_tags(self):
        client = self._get_client()
        resp = client.describe_vpcs(
            Filters=[{
                'Name': 'vpc-id',
                'Values': [self.my_vpc]
                }]
            )

        try:
            vpc = resp['Vpcs'][0]
            tags = {}
            for tag in vpc['Tags']:
                tags[tag['Key']] = tag['Value']

            return tags
        except KeyError:
            return {}

    def start_instance(self, box):
        instance_id = box.instance_id
        client = self._get_client()
        log.info('starting {} ({})'.format(instance_id, box))
        # TODO: Add error checking
        if cfg.DRY_RUN:
            log.info('Not starting {}, as cfg.DRY_RUN is true'.format(instance_id))
            return None

        response = client.start_instances(
            InstanceIds=[instance_id],
        )
        log.debug("ec2 start_instance response {}".format(response))
        return response

    def stop_instance(self, box):
        instance_id = box.instance_id
        client = self._get_client()
        log.info('stopping {} ({})'.format(instance_id, box))
        if cfg.DRY_RUN:
            log.info('Not stopping {}, as cfg.DRY_RUN is true'.format(instance_id))
            return None
        try:
            response = client.stop_instances(
                InstanceIds=[instance_id],
            )
            log.debug("ec2.stop_instance response {}".format(response))
            return response
        except Exception as e:
            log.exception(e)
            raise Exception(e)

    def destroy_instance(self, instance_id):
        client = self._get_client()
        log.info('ec2.destroy_instance instance_id: {}'.format(instance_id))
        if cfg.DRY_RUN:
            log.info('Not destroying {}, as cfg.DRY_RUN is true'.format(instance_id))
            return None

        try:
            response = client.terminate_instances(
                InstanceIds=[instance_id],
            )
            log.debug("ec2.destroy_instance response {}".format(response))
            return response
        except Exception as e:
            log.exception("ec2.destroy_instance failed")
            raise Exception(e)

    def _get_node(self, node_name):
        client = self._get_client()
        response = client.describe_instances(
            Filters=[
                {
                    'Name': 'tag:Name',
                    'Values': [
                        node_name
                        ]
                    }
                ]
            )

        try:
            return response['Reservations'][0]['Instances'][0]
        except (IndexError, KeyError):
            return None

    def get_public_ip_for_node(self, node_name):
        resp = self._get_node(node_name)
        if resp:
            return resp['PublicIpAddress']
        return resp

    def get_private_ip_for_node(self, node_name):
        resp = self._get_node(node_name)
        if resp:
            return resp['PrivateIpAddress']
        return resp

    def get_vpc_default_sg(self, vpc):
        client = self._get_client()
        response = client.describe_security_groups(
            Filters=[
                {
                    "Name": "group-name",
                    "Values": ["default"]
                },
                {
                    "Name": "vpc-id",
                    "Values": [vpc]
                }
            ]
        )
        return response["SecurityGroups"][0]["GroupId"]

    def get_subnets(self, vpc):
        client = self._get_client()

        filters = [{
            "Name": "vpc-id",
            "Values": [vpc]
            }]

        return client.describe_subnets(Filters=filters)["Subnets"]

    def get_hostname(self, cluster_name):
        hostname_prefix = self.bebo_region
        domain = cfg.INTERNAL_DOMAIN
        hostname = '{0}-{1}-{2:014d}.{3}'.format(hostname_prefix, cluster_name, random.randint(1, 10000000000), domain)
        return hostname

    def get_public_ami_by_name(self, name):
        if name in self.ami_name_cache:
            return self.ami_name_cache[name]

        client = self._get_client()
        result = client.describe_images(
            ExecutableUsers=['all'], #public
            Filters=[{
                "Name": "name",
                "Values": [name]
                }]
            )

        image_id = [image['ImageId'] for image in result['Images']][0]
        self.ami_name_cache[name] = image_id
        return image_id

    def get_sgs_by_names(self, names):
        client = self._get_client()
        filters = [
            {
                'Name': 'group-name',
                'Values': names
            }, {
                'Name': 'vpc-id',
                'Values': [self.my_vpc]
            }]
        return client.describe_security_groups(Filters=filters)['SecurityGroups']

    def create_sg(self, group_name, group):
        client = self._get_client()
        description = group.get('desc', 'A Security Group generated by Stark')

        log.info('Creating security group: {}'.format(group_name))

        group_id = client.create_security_group(
            Description=description,
            GroupName=group_name,
            VpcId=self.my_vpc
            )['GroupId']

        ingress_rules = []

        for rule in group.get('rules', []):
            ip_permission = {}
            if 'from' in rule:
                ip_permission['FromPort'] = rule['from']
                ip_permission['ToPort'] = rule['from']
            if 'protocol' in rule:
                ip_permission['IpProtocol'] = rule['protocol']
            if 'to' in rule and 'desc' in rule:
                ip_permission['IpRanges'] = [{'CidrIp': rule['to'], 'Description': rule['desc'] }]
            if 'to_ipv6' in rule and 'desc' in rule:
                ip_permission['Ipv6Ranges'] = [{'CidrIpv6': rule['to_ipv6'], 'Description': rule['desc'] }]

            if ip_permission != {}:
                ingress_rules.append(ip_permission)

        if ingress_rules:
            log.info('Attaching ingress rules to security group: {}({}):{}'.format(group_id, group_name, ingress_rules))
            client.authorize_security_group_ingress(
                GroupId=group_id,
                IpPermissions=ingress_rules
            )

        return group_id

    def associate_sg_to_admin(self, sg):
        group_ids = [group['GroupId'] for group in self.get_sgs_by_names(self.security_groups) + [sg]]

        client = self._get_client()
        client.modify_instance_attribute(
            Groups=group_ids,
            InstanceId=self.admin_instance_id
            )

    def _get_all_boxes(self, instance_state=None):
        try:
            with timeout(seconds=10):
                vpc = self.my_vpc
                client = self.client

                filters = [
                    {
                        'Name': 'vpc-id',
                        'Values': [vpc]
                    },
                    {
                        'Name': 'tag:Name',
                        'Values': [
                            '*{}'.format(cfg.INTERNAL_DOMAIN)
                        ]
                    }
                ]

                if instance_state is not None:
                    filters.append({
                        'Name': 'instance-state-name',
                        'Values': [instance_state]
                        })

                composer_instances = client.describe_instances(Filters=filters)
                reservations = composer_instances['Reservations']
            return reservations
        except socket.timeout:
            log.warning('AWS Socket timeout out.')
        except TimeoutError:
            log.warning('AWS API did not respond within 10s')
        except botocore.exceptions.DataNotFoundError as e:
            log.warning('botocore DataNotFoundError: {}'.format(e))
        raise ConnectionError('Could Not Get All Boxes')
