import socket
import json
import argparse
import collections
import multiprocessing
from threading import Thread
from time import sleep

import requests
import docker
import sentry_sdk
from bebo.utils import init_sentry

from config import cfg
from etcd_api import EtcdApi
from util import write_docker_event, get_private_node_hostname, calc_cpu_percentage

from logger import log

CPU_CAPACITY_PERCENTAGE = 75.0 * multiprocessing.cpu_count()

parser = argparse.ArgumentParser()
parser.add_argument('--cluster-name', dest='cluster_name')

'''
{
    "healthy": bool,
    "full": bool,
    "stoppable": bool,
    "capacity": float,
    "load": float,
    "deprecated": bool
}
'''

init_sentry(cfg.SENTRY_URL)


class Warg(object):
    def __init__(self, args):
        self.etcd = EtcdApi()
        self.docker_client = docker.from_env()
        self.private_node_hostname = get_private_node_hostname()
        if args and args.cluster_name:
            self.cluster_name = args.cluster_name
        else:
            fqdn = socket.getfqdn()
            self.cluster_name = fqdn.split('.')[0].split('-')[1]

        self.docker_stats_counter = {}
        self.cpu_percentage_history = collections.defaultdict(list)
        self.mem_usage_history = collections.defaultdict(list)

        log.info('Warg started -- cluster_name is {}'.format(self.cluster_name))

    def start(self):
        while True:
            try:
                self.tick()
            except KeyboardInterrupt as e:
                raise e
            except Exception:
                sentry_sdk.capture_exception()
                log.exception('Uncaught error')
            sleep(cfg.poll_tick)

    def tick(self):
        current_state = {}
        all_docker_services = self.docker_client.containers.list(all=True)
        self.update_health_state(current_state, all_docker_services)
        determined_health = self.determine_health(current_state)
        log.info('Determined Health: {}'.format(determined_health))
        self.update_etcd(json.dumps(determined_health))

    def determine_health(self, current_state):
        #  log.debug('current_state: {}'.format(current_state))
        healthy = min([current_state[service_name].get('healthy', False) for service_name in current_state]) # not healthy unless all are healthy
        full = max([current_state[service_name].get('full', False) for service_name in current_state]) # full if even one is full
        stoppable = min([current_state[service_name].get('stoppable', True) for service_name in current_state]) # not stoppable unless all are stoppable
        deprecated = max([current_state[service_name].get('deprecated', False) for service_name in current_state]) # deprecated if even one is deprecated

        return {
            "healthy": healthy,
            "full": full,
            "stoppable": stoppable,
            "deprecated": deprecated
        }

    def update_etcd(self, health):
        self.etcd.set_box_health(self.private_node_hostname, health, ttl=cfg.etcd_ttl)

    def update_health_state(self, state, docker_services):
        cluster = self.etcd.get_cluster(self.cluster_name)
        cluster_service_names = cluster.get('service_names', [])
        docker_service_names = [service.name for service in docker_services]

        #  log.debug('cluster_service_names: {}, docker_service_names: {}'.format(cluster_service_names, docker_service_names))

        for service_name in cluster_service_names:
            if service_name not in docker_service_names:
                #  log.info('{} is not running -- cluster demands it is'.format(service_name))
                state[service_name] = {
                    'healthy': False,
                    'full': False,
                    'stoppable': True,
                    'capacity': CPU_CAPACITY_PERCENTAGE,
                    'load': 0.0,
                    'deprecated': False,
                }

        threads = []
        for docker_service in docker_services:
            thread = Thread(target=self.get_docker_service_state, args=(docker_service, state,))
            thread.start()
            threads.append(thread)

        for thread in threads:
            thread.join()

    def write_docker_stats(self, docker_service):
        service_name = docker_service.name
        current_tick_number = self.docker_stats_counter.get(service_name, 0)
        self.docker_stats_counter[service_name] = current_tick_number + 1
        stats = {}

        docker_stats = docker_service.stats(decode=False, stream=False)
        try:
            current_cpu = calc_cpu_percentage(docker_stats)
            current_memory = docker_stats.get('memory_stats', {}).get('usage', 0) / 1024 / 1024

            self.cpu_percentage_history[docker_service.name].append(current_cpu)
            self.mem_usage_history[docker_service.name].append(current_memory)

            #  log.debug('cpu_percentage_history for {}: {}'.format(docker_service.name, cpu_percentage_history[docker_service.name]))
            #  log.debug('mem_usage_history for {}: {}'.format(docker_service.name, self.mem_usage_history[docker_service.name]))

            if current_tick_number >= cfg.poll_tick * 30:
                self.docker_stats_counter[service_name] = 0
                log.debug("writing stats for {}".format(service_name))

                cpu_percent = self.get_cpu_percent(docker_service.name)
                mem_usage = max(self.mem_usage_history[docker_service.name]) or 0.0

                self.cpu_percentage_history[docker_service.name].clear()
                self.mem_usage_history[docker_service.name].clear()

                log.debug('cpu % for {}: {}'.format(docker_service.name, cpu_percent))
                log.debug('mem MB for {}: {}'.format(docker_service.name, mem_usage))

                version = docker_service.image.attrs['RepoTags'][0].split(':')[1]

                stats['cpu_percent_dec'] = cpu_percent
                stats['mem_usage_mb_dec'] = mem_usage
                stats['service_tx'] = docker_service.name
                stats['version_tx'] = version
                stats['category_tx'] = 'stats'
                stats['action_tx'] = docker_service.name
                stats['label_tx'] = version

                stats['routing_key'] = stats['category_tx'] + '.' + stats['action_tx'] + '.' + stats['label_tx']

                write_docker_event(stats)
        except KeyError as e:
            log.warning('KeyError trying to get docker stats for  {}: {}, stats: {}'.format(docker_service.name, e, stats))

    def get_docker_service_state(self, docker_service, state):
        docker_service_name = docker_service.name
        self.write_docker_stats(docker_service)

        docker_service_running = self.check_docker_service_running(docker_service)

        #  log.info('{} runnin? {}'.format(docker_service_name, docker_service_running))

        service_state = {
            "healthy": True,
            "full": False,
            "stoppable": True,
            "deprecated": False
        }

        def commit():
            for k in service_state:
                if not docker_service_name in state:
                    state[docker_service_name] = {}

                state[docker_service_name][k] = service_state[k]
            stats = state[docker_service_name]
            stats["service_tx"] = docker_service_name
            write_docker_event(stats)

        if not docker_service_running:
            service_state["healthy"] = False
            #  log.info('{} is not running -- {}'.format(docker_service_name, service_state))
        else:
            private_docker_service_active_status = self.check_docker_service_active(docker_service_name)
            #  log.debug('private_docker_service_active_status for {}: {}'.format(docker_service_name, private_docker_service_active_status))
            if private_docker_service_active_status is not None:
                #  log.info('Returned status from /active for {}: {}'.format(docker_service_name, private_docker_service_active_status))
                service_state = private_docker_service_active_status

        if 'capacity' not in service_state:
            service_state['capacity'] = CPU_CAPACITY_PERCENTAGE
        elif not isinstance(service_state['capacity'], float):
            try:
                service_state['capacity'] = float(service_state['capacity'])
            except ValueError:
                service_state['capacity'] = CPU_CAPACITY_PERCENTAGE

        if 'load' not in service_state:
            service_state['load'] = self.get_cpu_percent(docker_service_name)
        elif not isinstance(service_state['load'], float):
            try:
                service_state['load'] = float(service_state['load'])
            except ValueError:
                service_state['load'] = self.get_cpu_percent(docker_service_name)

        commit()

    def check_docker_service_running(self, docker_service):
        return docker_service.status == 'running'

    def get_cpu_percent(self, service_name):
        history = self.cpu_percentage_history[service_name]
        if history:
            return max(history) or 0.0
        return 0.0

    def check_docker_service_active(self, docker_service_name):

        etcd_service = self.etcd.get_service(docker_service_name)

        health_port = 0
        if etcd_service and 'health_port' in etcd_service:
            health_port = etcd_service['health_port']

        if not health_port:
            return None

        try:
            url = 'http://localhost:{}/active'.format(health_port)
            #  log.info('Trying to connect to {} @ {}'.format(docker_service_name, url))
            response = requests.get(url, timeout=1)
            #  status_code = response.status_code
            #  log.info('Successful contact to {} @ {} -- status_code: {}'.format(docker_service_name, health_port, status_code))
            return response.json()
        except Exception as e:
            log.warn('Failed to contact {} @ {} -- {}'.format(docker_service_name, health_port, e))

        return {
            "healthy": False,
            "full": False,
            "stoppable": True,
            "deprecated": False
        }

if __name__ == '__main__':
    args = parser.parse_args()
    W = Warg(args)
    W.start()
    log.info('warg started')
