import logging
import multiprocessing
import sys
import time
import traceback

from .abstract import ThreadedModule, run_command, iso_date, measure, resource_usage, Warnings
from .ewma import EWMA

try:
    from typing import Optional, List
except ImportError:
    pass

LOG = logging.getLogger(__name__)


def clean_user_map(m):  # type: (dict) -> None
    to_delete = []
    for k, v in m.items():
        if not isinstance(v, EWMA):
            continue
        d = v.dict()
        if d['avg'] == 0 and d['m1'] == 0 and d['m5'] == 0 and d['m15'] == 0 and not v.values:
            to_delete.append(k)
    for k in to_delete:
        m.pop(k)


def parse_vmstat(warnings):  # type: (Warnings) -> float
    lines = run_command('vmstat 1 2', lines=True).out
    try:
        header_line = lines[1]
        headers = header_line.split()
        idx = -1
        for i, h in enumerate(headers):
            if h == 'id':
                idx = i
                break
        if idx != -1:
            line = lines[-1]
            return float(line.split()[idx])

        warnings.log('no id in string %s' % (header_line,))
    except Exception:
        warnings.log("Failed to parse `vmstat` output [%s]: %s", lines, traceback.format_exc())

    return 0.


def process_sysctl(warnings):  # type: (Warnings) -> Optional[dict]
    try:
        mem_hw = int(run_command('sysctl -n hw.physmem').out)
        mem_pgs = int(run_command('sysctl -n hw.pagesize').out)
        mem_inactive = int(run_command('sysctl -n vm.stats.vm.v_inactive_count').out) * mem_pgs
        mem_cache = int(run_command('sysctl -n vm.stats.vm.v_cache_count').out) * mem_pgs
        mem_free = int(run_command('sysctl -n vm.stats.vm.v_free_count').out) * mem_pgs

        res = dict(
            total=mem_hw,
            cache=mem_cache,
            free=mem_inactive + mem_free,
            changeTime=iso_date()
        )
        res['used'] = res['total'] - res['free'] - res['cache']
        for s in run_command('swapinfo -k', lines=True).out:
            if s.startswith('/'):
                ss = s.split()
                res['swapTotal'] = int(ss[1]) * 1024
                res['swapUsed'] = int(ss[2]) * 1024
                res['swapFree'] = int(ss[3]) * 1024
        return res
    except Exception:
        warnings.log("error getting memory stats: %s", traceback.format_exc())

    return None


def process_free(warnings):  # type: (Warnings) -> Optional[dict]
    res = {}
    try:
        for s in run_command('free -b', lines=True).out:
            if s.startswith('Mem:'):
                ss = s.split()
                res['total'] = int(ss[1])
                res['cache'] = int(ss[6])
            if s.startswith('-/+ buffers'):
                ss = s.split()
                res['used'] = int(ss[2])
                res['free'] = int(ss[3])
            if s.startswith('Swap:'):
                ss = s.split()
                res['swapTotal'] = int(ss[1])
                res['swapUsed'] = int(ss[2])
                res['swapFree'] = int(ss[3])
        return res
    except Exception:
        warnings.log("failed to get `free` output: %s", traceback.format_exc())
    return None


def get_stat_cpu():  # type: () -> Optional[list]
    with open('/proc/stat', 'r') as f:
        for s in f.readlines():
            if s.startswith('cpu '):
                return map(lambda x: int(x), s.split()[1:])
    return None


class AgentModule(ThreadedModule):
    default_config = {
        'circle_time': 10.,
        'interval': 60,
        'procnames': 'skynet, java, mapreduce, srch-mmeta-, srch-base-, apache, nginx, /gcc/, /glusterfs'
    }

    def __init__(self, *args, **kw):
        super(AgentModule, self).__init__(*args, **kw)
        self.memory = {'used': EWMA(), 'swapUsed': EWMA(), 'cpu': EWMA()}
        self.ram2user = {}
        self.cpu2user = {}
        self.num_cpu = multiprocessing.cpu_count()
        self.prev_stat = None

    def loop(self):
        if self.arch.startswith('linux'):
            res = process_free(self.warnings)
            if res:
                self.memory['total'] = res.get('total')
                self.memory['used'].update(res.get('used'))
                self.memory['swapTotal'] = res.get('swapTotal')
                self.memory['swapUsed'].update(res.get('swapUsed'))
                self.memory['cpu'].update(self.process_stat())
            self.process_ps()

        if self.arch.startswith('freebsd'):
            res = process_sysctl(self.warnings)
            if res:
                self.memory['total'] = res.get('total')
                self.memory['used'].update(res.get('used'))
                self.memory['swapTotal'] = res.get('swapTotal')
                self.memory['swapUsed'].update(res.get('swapUsed'))
                self.memory['cpu'].update(100 - parse_vmstat(self.warnings))
            self.process_ps()

    def get_value(self):  # type: () -> Optional[List[dict]]
        if not self.memory:
            return None

        clean_user_map(self.cpu2user)
        clean_user_map(self.ram2user)

        resources = {
            'RAM': dict(
                registrationTime=iso_date(),
                name='RAM',
                capacity=self.memory.get('total'),
                usage=measure(self.memory.get('used')),
                consumerToUsage=resource_usage(self.ram2user))
        }
        if self.memory.get('swapTotal'):
            resources['Swap'] = dict(
                registrationTime=iso_date(),
                name='Swap',
                capacity=self.memory.get('swapTotal'),
                usage=measure(self.memory.get('swapUsed')),
                consumerToUsage=[])

        return [
            self.format_answer('meminfo', dict(
                capacity=self.memory.get('total'),
                swapCapacity=self.memory.get('swapTotal')
            )),
            self.format_answer('memory_usage', dict(
                name='Memory',
                resources=resources)),
            self.format_answer('cpu_usage', dict(
                registrationTime=iso_date(),
                name='CPU',
                capacity=100,
                usage=measure(self.memory.get('cpu')),
                consumerToUsage=resource_usage(self.cpu2user))),
        ]

    def process_stat(self):  # type: () -> Optional[float]
        r = get_stat_cpu()
        if not r:
            return None
        if not self.prev_stat:
            self.prev_stat = r
            return 100 - float(r[3]) / sum(r) * 100
        r1 = map(lambda x: x[1] - x[0], zip(self.prev_stat, r))
        self.prev_stat = r
        return 100 - float(r1[3]) / sum(r1) * 100

    def process_ps(self):  # type: () -> None
        for (proc, user), d in self.parse_ps().items():
            if d['rss']:
                self.ram2user.setdefault((proc, user), EWMA()).update(d['rss'])
            if d['cpu']:
                self.cpu2user.setdefault((proc, user), EWMA()).update(d['cpu'])

    def parse_ps(self):  # type: () -> dict
        proc_names = map(lambda x: x.strip(), self.config.get('procnames').split(','))

        res = {}
        for line in run_command('ps aux', lines=True).out[1:]:
            if not line.strip():
                continue
            try:
                user, pid, cpu, mem, vsz, rss, tt, stat, started, time, cmd = line.split(None, 10)
            except Exception:
                self.warnings.log("failed to parse `ps` output [%s]: %s", line, traceback.format_exc())
                continue
            if cmd == '[idle]':
                continue
            prc_name = None
            for p in proc_names:
                if p in cmd:
                    prc_name = p.strip('-/')
                    break
            d = res.setdefault((prc_name, user), {'rss': 0, 'mem': 0., 'cpu': 0.})
            d['rss'] += int(rss) * 1024
            d['mem'] += float(mem)
            d['cpu'] += float(cpu)
        return res


if __name__ == '__main__':
    logging.basicConfig(level='INFO')
    module = AgentModule(sys.platform, config={'circle_time': 2., 'interval': 2})
    module.start()
    for i in range(20):
        module.loop()
        print module.memory['cpu']
        time.sleep(.2)
    # print json.dumps(module.get_value(), indent=4)
    module.stop()
