import datetime
import logging
import shlex
import subprocess
import threading
import time
from collections import namedtuple
from copy import deepcopy
from threading import Timer

from .ewma import EWMA

try:
    from typing import Any, Dict  # noqa
except ImportError:
    pass

LOG = logging.getLogger(__name__)


def log_time(func):
    root_logger = logging.getLogger()

    def timer(*arg, **kw):
        t1 = time.time()
        res = func(*arg, **kw)
        t2 = time.time()
        root_logger.info('## function %s done in %.4f s', func.func_name, t2 - t1)
        return res

    return timer


CommandResult = namedtuple('CommandResult', ['returncode', 'out', 'err', 'has_timeout', 'elapsed'])


def kill_process(process, dto):
    """
    timed out recipe from
      https://stackoverflow.com/questions/1191374/using-module-subprocess-with-timeout/10768774#10768774
    """
    dto["value"] = True
    process.kill()


def run_command(args, lines=False, saved_output=None, timeout_sec=60, exception_on_timeout=True):
    if type(args) == str:
        args = shlex.split(args)

    cmdline = " ".join(args)

    started_time = time.time()

    proc = subprocess.Popen(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE)

    timeout_dto = {"value": False}
    timer = Timer(timeout_sec, kill_process, [proc, timeout_dto])

    timer.start()
    out, err = proc.communicate()
    timer.cancel()

    if saved_output:
        with open(saved_output, 'w') as f:
            f.write(out)

    if lines:
        out = filter(lambda x: x, out.splitlines())
        err = filter(lambda x: x, err.splitlines())

    elapsed_time = time.time() - started_time

    if exception_on_timeout and timeout_dto["value"]:
        raise Exception("got timeout (%r sec) on [%s]" % (timeout_sec, cmdline))

    return CommandResult(returncode=proc.returncode, out=out, err=err, has_timeout=timeout_dto["value"],
                         elapsed=elapsed_time)


def iso_date(s=None):
    if s is None or s == '':
        return datetime.datetime.utcnow().isoformat() + 'Z'
    if type(s) == datetime.datetime:
        return s.isoformat() + 'Z'
    if type(s) in (int, float):
        if s < 0:
            d = datetime.datetime.utcnow() - datetime.timedelta(seconds=-s)
        else:
            d = datetime.datetime.utcfromtimestamp(s)
        return d.isoformat() + 'Z'
    if type(s) in (str, unicode):
        return s


def event(category, name, event_time, status='INFO', reg_time=None, payload=None):
    evt = {
        'category': category,
        'name': name,
        'status': status,
        'eventTime': iso_date(event_time),
        'registrationTime': iso_date(reg_time),
    }
    if payload:
        evt['payload'] = payload
    return evt


def gauge(name, reg_time, val, val1=None, val5=None, val15=None):
    return {
        'name': name,
        'registrationTime': iso_date(reg_time),
        'measure': measure(val, val1, val5, val15),
    }


def measure(val, val1=None, val5=None, val15=None, avg=None, min_=None, max_=None):
    m = {}
    if isinstance(val, EWMA):
        d = val.dict()
        m['lastValue'] = d['last']
        m['average'] = d['avg']
        m['min'] = d['min']
        m['max'] = d['max']
        m['m1Average'] = d['m1']
        m['m5Average'] = d['m5']
        m['m15Average'] = d['m15']
    else:
        m['lastValue'] = val if val is not None else 'NaN'
        m['average'] = avg if avg is not None else 'NaN'
        m['min'] = min_ if min_ is not None else 'NaN'
        m['max'] = max_ if max_ is not None else 'NaN'
        m['m1Average'] = val1 if val1 is not None else 'NaN'
        m['m5Average'] = val5 if val5 is not None else 'NaN'
        m['m15Average'] = val15 if val15 is not None else 'NaN'
    return m


def resource_usage(d):
    res = []
    for k, v in d.items():
        if isinstance(k, (list, tuple)):
            name, owner = k
        else:
            name, owner = None, k
        res.append({'resourceConsumer': {'owner': owner, 'name': name}, 'usage': measure(v)})
    return res


class Warnings(object):
    def __init__(self):
        super(Warnings, self).__init__()
        self.warnings = []

    def log(self, msg, *args):
        try:
            LOG.warning(msg, *args)
            self.warnings.append(msg % args)
        except TypeError:
            msg = 'Bad message format `%s`: [%s]' % (str(msg), ','.join(map(str, args)))
            LOG.warning(msg)
            self.warnings.append(msg)

    def format_and_pop(self):
        result = '\n'.join(self.warnings) if self.warnings else None
        self.warnings = []
        return result


class PlainModule(object):
    default_config = {'heartbeat': 300.}

    def __init__(self, arch, config=None):
        self.arch = arch
        self.config = deepcopy(self.default_config)
        self.warnings = Warnings()
        if config:
            self.config.update(config)

    def init(self):
        pass

    def start(self):
        pass

    def format_answer(self, name, data):
        # type: (str, Any) -> Dict[str, Any]
        return {'module': name, 'data': data}

    def get_value(self):
        raise NotImplementedError('not implemented')

    def stop(self):
        pass

    def pop_warnings(self):
        return self.warnings.format_and_pop()


class ThreadedModule(PlainModule):
    default_config = {'heartbeat': 600., 'circle_time': 2.}
    running = True
    last_loop = 0

    def __init__(self, *args, **kw):
        super(ThreadedModule, self).__init__(*args, **kw)
        self.thread = threading.Thread(target=self._run)
        self.thread.daemon = True

    def start(self):
        self.thread.start()

    def _run(self):
        while self.running:
            if time.time() - self.last_loop > self.config.get('circle_time'):
                self.last_loop = time.time()
                try:
                    self.loop()
                except:
                    LOG.exception('loop exception')
            time.sleep(0.1)

    def loop(self):
        raise Exception('not implemented')

    def stop(self):
        self.running = False
        self.thread.join()
