import math
import json
import datetime

from etcd_api import EtcD

from starkutil import StarkStates, health_status_changed
from route53 import Route53Controller
from logger import log

DEFAULT_AMI_NAME = 'debian-stretch-hvm-x86_64-gp2-2018-10-01-66564'

DEFAULT_BLOCK_MAPPINGS = [
    {"name": "xvda", "size": 10}
]

etcd_api = EtcD()

class BaseModel(object):
    model = ""
    key = ""
    pkey = "id"

    def get_public(self):
        hidden_keys = []
        return {k: v for k, v in self.__dict__.items() if k not in hidden_keys}

    def __repr__(self):
        data = self.get_public()
        return json.dumps(data, sort_keys=True, default=str)

    def save(self):
        data = json.dumps(self.__dict__, default=str, sort_keys=True)
        etcd_api.set("{}/{}".format(self.model, self.key), data)

    def delete(self, key):
        etcd_api.delete("/{}/{}".format(self.model, key))

    @classmethod
    def get_all(cls):
        values = []
        list_all = etcd_api.list(cls.model)
        if not list_all:
            return values
        for server in list_all:
            try:
                if server and server.value:
                    values.append(cls(**json.loads(server.value)))
                elif server.dir and not server._children:
                    return []
                else:
                    log.warn('unable to parse server from etcd: {}'.format(server))
            except TypeError:
                log.exception("failed get_all {}".format(server.value))
            except Exception:
                log.exception("failed to parse get_all model {}".format(server.value))
        return values

    @classmethod
    def get_one(cls, key):
        data = etcd_api.get("{}/{}".format(cls.model, key))
        if not data:
            return None

        try:
            return cls(**json.loads(data))
        except TypeError:
            log.exception("failed get_one {}".format(data))
            return None
        except AttributeError:
            log.exception("failed get_one {}".format(data))
            return None
        except Exception:
            log.exception("failed to parse model in get_one key: {}".format(key))
            return None

    @classmethod
    def upsert(cls, _id, **kwargs):
        model = cls.get_one(_id)
        if not model:
            kwargs[cls.pkey] = _id
            model = cls(**kwargs)
            model.save()
            return model

        for key in kwargs:
            model.__dict__[key] = kwargs[key]
        model.save()
        return model

class EC2Model(BaseModel):
    model = "ec2"
    pkey = "hostname"

    def __init__(self, **kwargs):
        self.hostname = kwargs.get('hostname')
        self.key = self.hostname
        self.instance_id = kwargs.get('instance_id', None)
        self.health_status = kwargs.get('health_status', {})
        self.stark_state = kwargs.get('stark_state', StarkStates.NONE)
        self.aws_status = kwargs.get('aws_status', None)
        self.instance_type = kwargs.get('instance_type', None)
        self.public_ip = kwargs.get('public_ip', None)
        self.private_ip = kwargs.get('private_ip', None)
        self.resolved_ip = kwargs.get('resolved_ip', None)
        self.created_dttm = kwargs.get('created_dttm', None)
        self.updated_dttm = kwargs.get('updated_dttm', None)
        self.deleted_dttm = kwargs.get('deleted_dttm', None)

        try:
            self.stark_state = StarkStates(self.stark_state)
        except ValueError:
            log.warn('stark_state is not in StarkStates enum -- {}'.format(self))

        if self.hostname:
            split_hostname = self.hostname.split('.')
            public_hostname = '.'.join([component for component in split_hostname if component != 'aws'])
            self.public_hostname = public_hostname
        else:
            self.public_hostname = None

    @staticmethod
    def from_ec2_info(ec2_info_dict):
        try:

            tag_list = ec2_info_dict['Instances'][0]['Tags']
            hostname = None
            for tag_dict in tag_list:
                if tag_dict['Key'] == 'Name':
                    hostname = tag_dict['Value']

            if hostname is None:
                return None

            data = {
                'hostname': hostname,
                'instance_id': ec2_info_dict['Instances'][0]['InstanceId'],
                'aws_status': ec2_info_dict['Instances'][0]['State']['Name'],
                'instance_type': ec2_info_dict['Instances'][0]['InstanceType'],
                'created_dttm': ec2_info_dict['Instances'][0]['LaunchTime'].strftime("%Y-%m-%d %H:%M:%S"),
                'updated_dttm': ec2_info_dict['Instances'][0]['LaunchTime'].strftime("%Y-%m-%d %H:%M:%S"),
                }

            if 'PublicIpAddress' in ec2_info_dict['Instances'][0]:
                public_ip = ec2_info_dict['Instances'][0]['PublicIpAddress']
                data['public_ip'] = public_ip

            if 'PrivateIpAddress' in ec2_info_dict['Instances'][0]:
                private_ip = ec2_info_dict['Instances'][0]['PrivateIpAddress']
                data['private_ip'] = private_ip

        except KeyError:
            return None

        return EC2Model(**data)

    def update_keys(self, newer_model):
        if self.aws_status != newer_model.aws_status:
            log.info("update_keys: aws_status: {} --> {}".format(self.aws_status, newer_model.aws_status))
            self.aws_status = newer_model.aws_status
            self.updated_dttm = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")

        self.instance_type = newer_model.instance_type
        self.instance_id = newer_model.instance_id
        self.created_dttm = newer_model.created_dttm
        self.public_ip = newer_model.public_ip
        self.private_ip = newer_model.private_ip
        if not self.updated_dttm:
            self.updated_dttm = newer_model.updated_dttm
            log.info('Model:update_keys model doesnt have updated_dttm and newer_model does -- taking newer_models {}'.format(newer_model.updated_dttm))

        self.save()

    def update_resolved_ip(self):
        resolved_ips = Route53Controller.get_ips_for_record(self.public_hostname)

        if resolved_ips:
            resolved_ip = resolved_ips[0]
        else:
            resolved_ip = None

        if resolved_ip != self.resolved_ip:
            self.resolved_ip = resolved_ip
            self.save(update=True)

    def update_health(self):
        if self.stark_state in [StarkStates.STOPPING, StarkStates.TERMINATING] or self.aws_status in ['stopped', 'terminated']:
            #  log.info("start_state is {} or aws_status is {}".format(self.stark_state, self.aws_status))
            if self.health_status != {}:
                self.health_status = {}
                self.save(update=True)
            return

        new_health_status = etcd_api.get_json("health/{}".format(self.hostname), default={})

        changed = health_status_changed(self.hostname, self.health_status, new_health_status) or health_status_changed(self.hostname, new_health_status, self.health_status)

        if changed:
            log.info('[{}] health_status changes from {} to {}'.format(self.hostname, self.health_status, new_health_status))
            self.health_status = new_health_status
            self.save(update=True)

    def set_stark_state(self, stark_state):
        if not isinstance(stark_state, StarkStates):
            log.warn('set_stark_state called with invalid stark state -- {} {}'.format(stark_state, self))
            return

        if stark_state is not self.stark_state:
            self.stark_state = stark_state
            self.save(update=True)

    def save(self, update=False):
        now = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")

        if update:
            self.updated_dttm = now

        output = self.get_public()

        if self.stark_state and isinstance(self.stark_state, StarkStates):
            output['stark_state'] = self.stark_state.value

        data = json.dumps(output, default=str, sort_keys=True)
        etcd_api.set("{}/{}".format(self.model, self.key), data)

class ClusterModel(BaseModel):
    model = "cluster"
    pkey = "name"
    def __init__(self, **kwargs):
        self.name = kwargs.get('name')
        self.key = self.name

        self.service_names = kwargs.get('service_names', [])
        self.instance_types = kwargs.get('instance_types', [])
        self.public = kwargs.get('public', False)
        self.min_instances = kwargs.get('min_instances', 0)
        self.min_stopped = kwargs.get('min_stopped', math.ceil(self.min_instances / 2))
        self.ignore = kwargs.get('ignore', False)
        self.security_group = kwargs.get('security_group', {})
        self.autoscale = kwargs.get('autoscale', True)
        self.block_device_mappings = kwargs.get('block_device_mappings', DEFAULT_BLOCK_MAPPINGS)
        self.ami_name = kwargs.get('ami_name', DEFAULT_AMI_NAME)

class ServiceModel(BaseModel):
    model = "service"
    pkey = "name"
    def __init__(self, **kwargs):
        self.name = kwargs.get('name')
        self.key = self.name

        self.env = kwargs.get('env', [])
        self.version = kwargs.get('version', None)
        self.volumes = kwargs.get('volumes', [])
