import json
import datetime
import botocore

from starkutil import es_event, StarkStates
from models import EC2Model, BaseModel, ServiceModel
from route53 import Route53Controller
from config import cfg

from logger import log

class StarkLogic(object):
    def __init__(self, ec2, salt_api):
        self.ec2 = ec2
        self.salt_api = salt_api
        self.route53 = None
        self.cluster = None

    def scale(self, cluster):
        if not self.route53:
            self.route53 = Route53Controller(cluster.name, self)

        self.cluster = cluster
        self.check_old_stark_state()

        self.check_start()
        self.check_create()
        if self.cluster.autoscale:
            self.check_stop()
            self.check_terminate()

        self.check_timed_out()

        if self.cluster.public:
            log.info('[{}] running public route53 checks'.format(self.cluster.name))
            self.route53.ensure_public_regional_endpoint()
            self.route53.ensure_global_endpoint()

        self.route53.ensure_private_regional_endpoint()

        self._log_status()

    def get_cluster_boxes(self):
        servers = EC2Model.get_all()
        if not servers:
            return []

        cluster_string = '-{}-'.format(self.cluster.name)

        servers = [server for server in servers if cluster_string in server.hostname]
        return servers

    def check_old_stark_state(self):
        cluster_boxes = self.get_cluster_boxes()
        for box in cluster_boxes:
            stark_state = box.stark_state
            if not stark_state:
                continue

            dns_state_correct = self.dns_state_correct(box)

            complete_stark_state = False

            if stark_state is StarkStates.CREATING:
                if box.instance_id and dns_state_correct and box.aws_status == 'running':
                    complete_stark_state = True
                elif self.cluster.public:
                    #  self.route53.upsert_public_record(box)
# dns is now tied with box
                    pass
                else:
                    pass #wait for instance_id
            elif stark_state is StarkStates.STARTING:
                if box.aws_status == 'running' and dns_state_correct:
                    complete_stark_state = True
                elif self.cluster.public:
#this doesn't seem right, we'd be tryign this on every iteration
                    #  self.route53.upsert_public_record(box)
                    pass
                else:
                    pass #wait for running state
            elif stark_state is StarkStates.STOPPING:
                if box.aws_status == 'stopped':
                    complete_stark_state = True
#not gonna destroy route53 record here -- will upsert on create/start, and remove on terminate
            elif stark_state is StarkStates.TERMINATING or stark_state is StarkStates.ERROR:
                if box.aws_status == 'terminated' and not box.resolved_ip:
                    complete_stark_state = True
                elif self.cluster.public:
                    self.route53.destroy_public_record(box)
                else:
                    pass #wait for terminated state
            elif stark_state is StarkStates.NONE:
                pass
            else:
                log.warn('unnacounted for state in check_old_stark_state: {} {}'.format(stark_state, box))

            if complete_stark_state:
                log.info('Completing stark state for {}'.format(box))
                box.set_stark_state(StarkStates.NONE)

    def check_start(self):
        stopped_boxes = self.get_stopped()

        num_healthy = len(self.get_healthy())
        num_pending_start = len(self.get_pending_start())
        num_pending_create = len(self.get_pending_create())
        num_pending_unhealthy = len(self.get_unhealthy())
        available = num_pending_start + num_pending_create + num_pending_unhealthy + num_healthy

        min_healthy = self._min_instances()

        num_stopped = len(stopped_boxes)
        num_to_start = min(max(min_healthy - available, 0), num_stopped)
        num_starting = 0

        log.info('[{}] Starting {} because min_healthy is {}, available is {}, and num_stopped is {}'.format(self.cluster.name, num_to_start, min_healthy, available, num_stopped))
        if not num_to_start:
            return

        for box in stopped_boxes:
            box.set_stark_state(StarkStates.STARTING)
            try:
                self.ec2.start_instance(box)
                num_starting += 1
            except botocore.exceptions.ClientError as e:
                log.warning(e)
                box.set_stark_state(StarkStates.ERROR)

            if num_starting >= num_to_start:
                break

    def check_create(self):
        min_healthy = self._min_instances()
        min_stopped = self._min_stopped()

        num_pending_start = len(self.get_pending_start())
        num_pending_create = len(self.get_pending_create())
        num_unhealthy = len(self.get_unhealthy())
        num_stopped = len(self.get_stopped())
        num_healthy = len(self.get_healthy())
        num_stopped_pending = len(self.get_pending_stop())

        available = num_pending_start + num_pending_create + num_unhealthy + num_healthy
        num_create = max(min_healthy - available, 0)

        num_create_stopped = max(min_stopped - (num_unhealthy + num_stopped + num_stopped_pending), 0)

        num_to_create = num_create + num_create_stopped

        log.info('[{}] Creating {} because num_create is {} and num_create_stopped is {} -- available: {}, num_pending_create: {}, num_unhealthy: {}'.format(self.cluster.name, num_to_create, num_create, num_create_stopped, available, num_pending_create, num_unhealthy))
        if num_unhealthy > 0:
            unhealthy_hosts = self.get_unhealthy()
            log.warn("[{}] unhealthy hosts > 1 {}".format(self.cluster.name, unhealthy_hosts))

        for _ in range(num_create + num_create_stopped):
            hostname = self.ec2.get_hostname(self.cluster.name)
            box = EC2Model(hostname=hostname)
            box.set_stark_state(StarkStates.CREATING)
            self.salt_api.create(hostname)

    def check_stop(self):
        stoppable_boxes = self.get_stoppable()
        healthy_boxes = self.get_healthy()
        num_stoppable = len(stoppable_boxes)
        num_healthy = len(healthy_boxes)

        min_healthy = self._min_instances()

        num_to_stop = 0
        if num_healthy - min_healthy > 0:
            num_to_stop = min(num_stoppable, num_healthy - min_healthy)

            if num_to_stop == 0:
                log.info('[{}] check_stop num_to_stop = 0'.format(self.cluster.name))
                return

        log.info('[{}] Stopping {} boxes because num_healthy {} > min_healthy: {} and num_stoppable {}'.format(self.cluster.name, num_to_stop, num_healthy, min_healthy, num_stoppable))

        for box in stoppable_boxes[0:num_to_stop]:
            if box.stark_state is StarkStates.STOPPING:
                continue

            box.set_stark_state(StarkStates.STOPPING)
            try:
                self.ec2.stop_instance(box)
            except botocore.exceptions.ClientError as e:
                log.warning(e)

                box.set_stark_state(StarkStates.ERROR)

    def check_terminate(self):
        stopped_boxes = self.get_stopped()
        num_stopped = len(stopped_boxes)
        min_stopped = self._min_stopped()

        num_to_terminate = 0
        if num_stopped > min_stopped:
            num_to_terminate = num_stopped - min_stopped

        log.info('[{}] Terminating {} because num_stopped: {} and min_stopped: {}'.format(self.cluster.name, num_to_terminate, num_stopped, min_stopped))

        for box in stopped_boxes[0:num_to_terminate]:
            try:
                box.set_stark_state(StarkStates.TERMINATING)
                self.salt_api.destroy(box.hostname)
            except Exception:
                box.set_stark_state(StarkStates.ERROR)
                log.exception("[{}] Failed to terminate box: {}".format(self.cluster.name, box))

    @staticmethod
    def box_timed_out(box, timeout_in_minutes):
        now = datetime.datetime.now()
        updated_dttm = box.updated_dttm
        return updated_dttm and now > datetime.datetime.strptime(updated_dttm, "%Y-%m-%d %H:%M:%S") + datetime.timedelta(minutes=timeout_in_minutes)

    def hard_terminate(self, unhealthy_box):
        try:
            self.ec2.destroy_instance(unhealthy_box.instance_id)
        except Exception:
            log.exception("failed to terminate errored box @johnny")

        unhealthy_box.save(update=True)

    def attempt_terminate(self, unhealthy_box):
        log.info('Attempting to terminate unhealth box: {}'.format(unhealthy_box))
        try:

            unhealthy_box.set_stark_state(StarkStates.TERMINATING)

            self.salt_api.destroy(unhealthy_box.hostname)
        except Exception:
            unhealthy_box.set_stark_state(StarkStates.ERROR)

            log.exception("failed to terminate unhealthy_box")

    def check_create_timeout(self):
        creating = self.get_pending_create()
        for box in creating:
            if StarkLogic.box_timed_out(box, cfg.CREATE_TIMEOUT):
                log.info('[{}] Terminating {} because create timed out'.format(self.cluster.name, box))
                self.attempt_terminate(box)

    def check_start_timeout(self):
        starting = self.get_pending_start()
        for box in starting:
            if StarkLogic.box_timed_out(box, cfg.START_TIMEOUT):
                log.info('[{}] Terminating {} because start timed out'.format(self.cluster.name, box))
                self.attempt_terminate(box)

    def check_stop_timeout(self):
        stopping = self.get_pending_stop()
        for box in stopping:
            if StarkLogic.box_timed_out(box, cfg.STOP_TIMEOUT):
                log.info('[{}] Terminating {} because stop timed out'.format(self.cluster.name, box))
                self.attempt_terminate(box)

    def check_unhealthy_timeout(self):
        unhealthy = self.get_unhealthy()
        for box in unhealthy:
            if StarkLogic.box_timed_out(box, cfg.UNHEALTHY_TIMEOUT):
                log.info('[{}] Terminating {} because unhealthy time out'.format(self.cluster.name, box))
                self.attempt_terminate(box)

    def check_terminate_timeout(self):
        terminating = self.get_pending_terminate()
        for box in terminating:
            if StarkLogic.box_timed_out(box, cfg.TERMINATE_TIMEOUT):
                log.info('[{}] Hard terminating {} because terminate timed out'.format(self.cluster.name, box))
                self.hard_terminate(box)

    def check_timed_out(self):
        self.check_create_timeout()
        self.check_start_timeout()
        if self.cluster.autoscale:
            self.check_stop_timeout()
            self.check_unhealthy_timeout()
        self.check_terminate_timeout()

    # PENDING

    def get_pending_start(self):
        return [box for box in self.get_cluster_boxes() if box.stark_state is StarkStates.STARTING]

    def get_pending_create(self):
        return [box for box in self.get_cluster_boxes() if box.stark_state is StarkStates.CREATING]

    def get_pending_stop(self):
        return [box for box in self.get_cluster_boxes() if box.stark_state is StarkStates.STOPPING]

    def get_pending_terminate(self):
        return [box for box in self.get_cluster_boxes() if box.stark_state is StarkStates.TERMINATING]

    # ACTIONABLE

    def get_stoppable(self): # stoppable = not in middle of action, has no load, is fully provisioned
        return [box for box in self._get_provisioned() if box.health_status.get('stoppable', False)]

    def get_healthy(self): # healthy = not in middle of action, can take load, is fully provisioned
        return [box for box in self._get_provisioned() if not box.health_status.get('full', False)]

    def get_full(self): # full = not in middle of action, cannot take load, is fully provisioned
        return [box for box in self._get_provisioned() if box.health_status.get('full', False)]

    def get_unhealthy(self): # unhealthy = services unhealthy or dns not happy
        return [box for box in self.get_cluster_boxes() if box.aws_status == 'running' and (not box.health_status.get('healthy', False) or not self.dns_state_correct(box)) and box.stark_state not in [StarkStates.STOPPING, StarkStates.TERMINATING]]

    def get_error(self): # error = we tried to do something and it failed
        return [box for box in self.get_cluster_boxes() if box.stark_state is StarkStates.ERROR]

    #  AWS STATUS
    def get_running(self):
        return [box for box in self.get_cluster_boxes() if box.aws_status == 'running' and box.stark_state is StarkStates.NONE]

    def get_stopped(self):
        return [box for box in self.get_cluster_boxes() if box.aws_status == 'stopped' and box.stark_state is StarkStates.NONE]

    def get_terminated(self):
        return [box for box in self.get_cluster_boxes() if box.aws_status == 'terminated' and box.stark_state is StarkStates.NONE]

    # DISTINGUISHERS

    def _get_provisioned(self): # provisioned = not in middle of action, MIGHT be able to take load, is fully provisioned
        return [box for box in self.get_cluster_boxes() if box.health_status.get('healthy', False) and box.stark_state is StarkStates.NONE and box.aws_status == 'running' and self.dns_state_correct(box)]

    def dns_state_correct(self, box):
        if self.cluster.public:
            return box.resolved_ip == box.public_ip

        return not box.resolved_ip

    def _log_status(self):
        log_dict = {}
        log_dict['num_total'] = len(self.get_cluster_boxes())

        log_dict['num_healthy'] = len(self.get_healthy())
        log_dict['num_unhealthy'] = len(self.get_unhealthy())
        log_dict['num_full'] = len(self.get_full())
        log_dict['num_stoppable'] = len(self.get_stoppable())

        log_dict['num_running'] = len(self.get_running())
        log_dict['num_stopped'] = len(self.get_stopped())
        log_dict['num_terminated'] = len(self.get_terminated())

        log_dict['num_creating'] = len(self.get_pending_create())
        log_dict['num_starting'] = len(self.get_pending_start())
        log_dict['num_stopping'] = len(self.get_pending_stop())
        log_dict['num_terminating'] = len(self.get_pending_terminate())
        log_dict['min_healthy'] = self._min_instances()
        log_dict['min_stopped'] = self._min_stopped()

        log_string = ", ".join(["{}: {}".format(key, log_dict[key]) for key in log_dict])

        log.info("[{}] nums: {}".format(self.cluster.name, log_string))

        es_event('cluster', 'status', label_tx=self.cluster.name, data=log_dict)

    def _min_instances(self):
        return self.cluster.min_instances

    def _min_stopped(self):
        return self.cluster.min_stopped
