import threading
import subprocess

import yaml
from jinja2 import Template

from logger import log


from models import ServiceModel
from config import cfg

class SaltCloudTemplateGenerator():
    def __init__(self, ec2):
        self.ec2 = ec2

        self.clusters = []
        self.sg_map = {}

        self.bebo_region = cfg.REGION
        self.aws_region = self.ec2.region
        self.vpc = self.ec2.my_vpc

        self._populate_data()

        self.profiles_template_file = 'stark/cloud_templates/ec2.cluster.profiles.conf.jinja'
        self.providers_template_file = 'stark/cloud_templates/ec2.cluster.providers.conf.jinja'

        self.profiles_file = cfg.SALT_CONFIG_LOCATION + '/cloud.profiles.d/ec2.cluster.profiles.conf'
        self.providers_file = cfg.SALT_CONFIG_LOCATION + '/cloud.providers.d/ec2.cluster.providers.conf'

        self.profiles_template = Template(open(self.profiles_template_file, 'r').read())
        self.providers_template = Template(open(self.providers_template_file, 'r').read())

        self.rendered_profiles = None
        self.rendered_providers = None

    def update_clusters(self, clusters):

        cluster_key_lambda = lambda cluster: cluster['name']

        new_clusters_str = sorted([cluster.__dict__ for cluster in clusters], key=cluster_key_lambda)
        my_clusters_str = sorted([cluster.__dict__ for cluster in self.clusters], key=cluster_key_lambda)

        if new_clusters_str != my_clusters_str:
            log.info('Updating SaltApi clusters to: %s from %s' % (clusters, self.clusters))
            self.clusters = clusters
            self.ensure_sgs()
            self.render_profiles()

    def ensure_sgs(self):
        #dumb logic -- if sg w/ same name doesnt exit, create it. no diffing

        clusters_w_sg = [cluster for cluster in self.clusters if cluster.security_group and 'name' in cluster.security_group]
        required_sg_names = [cluster.security_group['name'] for cluster in clusters_w_sg]
        present_sgs = self.ec2.get_sgs_by_names(required_sg_names)
        present_sg_names = [sg['GroupName'] for sg in present_sgs]

        for sg in present_sgs:
            self.sg_map[sg['GroupName']] = sg['GroupId']

        for cluster in clusters_w_sg:
            sg_name = cluster.security_group['name']
            if sg_name not in present_sg_names:
                sg_id = self.ec2.create_sg(sg_name, cluster.security_group)
                present_sg_names.append(sg_name)
                self.sg_map[sg_name] = sg_id
            else:
                log.info('Skipping sg creation for {} because it already exists'.format(sg_name))
        log.info("sg_map {}".format(self.sg_map))


    def _populate_data(self):
        self.default_sg = self.ec2.get_vpc_default_sg(self.vpc)

        subnets = self.ec2.get_subnets(self.vpc)
        self.availability_zones = []
        self.private_subnet_ids = []
        self.public_subnet_ids = []

        for az in self.reduce_network_info(subnets):
            self.availability_zones.append(az["AvailabilityZone"])
            self.private_subnet_ids.append(az["private_subnet"])
            self.public_subnet_ids.append(az["public_subnet"])

    def reduce_network_info(self, subnets):
        availability_zone_map = {}

        for sn in subnets:
            current = availability_zone_map.get(sn["AvailabilityZone"], None)
            if current is None:
                current = {}

            for t in sn["Tags"]:
                if t["Key"] == "Name":
                    if "private" in t["Value"]:
                        current["private_subnet"] = sn["SubnetId"]
                        break
                    elif "public" in t["Value"]:
                        current["public_subnet"] = sn["SubnetId"]
                        break
            current["AvailabilityZone"] = sn["AvailabilityZone"]
            availability_zone_map[sn["AvailabilityZone"]] = current

        end_list = []
        for az in availability_zone_map:
            end_list.append(availability_zone_map[az])

        return sorted(end_list, key=lambda az: az["AvailabilityZone"])

    def render_profiles(self):
        clusters = []
        for cluster in self.clusters:
            d = cluster.get_public()
            d['ami_id'] = self.ec2.get_public_ami_by_name(d['ami_name'])
            clusters.append(d)

        self.rendered_profiles = self.profiles_template.render(
            vpc=self.bebo_region,
            clusters=clusters,
            availability_zones=self.availability_zones,
            default_sg=self.default_sg,
            sg_map=self.sg_map
        )

        with open(self.profiles_file, 'w') as f:
            f.write(self.rendered_profiles)

    def render_providers(self):
        self.rendered_providers = self.providers_template.render(
            aws_id=cfg.AWS_ACCESS_KEY_ID,
            aws_key=cfg.AWS_SECRET_ACCESS_KEY,
            vpc=self.bebo_region,
            full_region=self.aws_region,
            availability_zones=self.availability_zones,
            private_subnets=self.private_subnet_ids,
            public_subnets=self.public_subnet_ids,
            master_ip=self.ec2.admin_private_ip
            )

        with open(self.providers_file, 'w') as f:
            f.write(self.rendered_providers)

    def render_all(self):
        self.render_profiles()
        self.render_providers()

    def get_profiles(self):
        if self.rendered_profiles is None:
            self.render_profiles()

        if not self.rendered_profiles:
            return []
        if not self.rendered_profiles.strip():
            return []

        return yaml.full_load(self.rendered_profiles).keys()

class SaltApi():
    def __init__(self, clusters, ec2):
        self.templater = SaltCloudTemplateGenerator(ec2)
        self.update_clusters(clusters)
        self.templater.render_all()

        self.avail_profiles = self.templater.get_profiles()

        log.info('Salt cloud profiles: {}'.format(self.avail_profiles))

    def update_clusters(self, new_clusters):
        self.templater.update_clusters(new_clusters)
        self.avail_profiles = self.templater.get_profiles()

    def create(self, hostname):
        cluster_name = hostname.split("-")[1]
        cluster_name_string = '-{}-'.format(cluster_name)
        filtered_profiles = [profile for profile in self.get_profiles() if cluster_name_string in profile]
        profile = filtered_profiles[0]
        command = ['/usr/bin/salt-cloud', '-p', profile, hostname, '--assume-yes', '-l', 'debug']
        log.info('Provision Call: {}'.format(' '.join(command)))
        if cfg.DRY_RUN:
            log.warn('Not running command, as cfg.DRY_RUN is true')
            return
        async_command(command)
        #self._slack_notifications('creating {} {}'.format(ec2_profile, hostname))

    def destroy(self, hostname):
        command = ['/usr/bin/salt-cloud', '-y', '-d', hostname, '-l', 'debug']
        log.info('Provision Call: {}'.format(' '.join(command)))
        if cfg.DRY_RUN:
            log.warn('Not running command, as cfg.DRY_RUN is true')
            return
        async_command(command)

    def get_profiles(self):
        return self.avail_profiles

def async_command(command):
#TODO turn this into a class so we can catch errors and error out in etcd
    def run(cmd):
        proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, encoding='utf-8')
        while True:
            stdout = proc.stdout and proc.stdout.readline() or ""
            stderr = proc.stderr and proc.stderr.readline() or ""
            if (stdout == '' or stderr == '') and proc.poll() is not None:
                break
            if stdout:
                log.info('SaltCloud stdout: {}'.format(stdout.strip()))
            if stderr:
                log.info('SaltCloud stdout: {}'.format(stderr.strip()))

    t = threading.Thread(target=run, args=(command,))
    t.start()
