#!/skynet/python/bin/python -u

TIMEOUT=120

class SkyPortodGrep(object):
    @property
    def marshaledModules(self):
        return [__name__]

    @property
    def osUser(self):
        return "max7255"

    @property
    def asUser(self):
        return "max7255"

    class LogCtx(object):
        def __init__(self, line=None, taint=False, stk=False):
            self.body = []
            self.taints = 0
            self.stk = False
            if line:
                self.append(line, taint, stk)

        def append(self, line, taint=False, stk=False):
            self.body += [line]
            self.taints += 1 if taint else 0
            self.stk = self.stk or stk

    class ShortLogCtx(LogCtx):
        def __init__(self, line=None, taint=False, stk=False, depth=5):
            self.body = []
            self.taints = 0
            self.stk = False
            self.depth = depth
            if line:
                self.append(line, taint, stk)

        def append(self, line, taint=False, stk=False):
            self.stk = self.stk or stk
            self.taints += 1 if taint else 0
            self.body += [line]
            if not self.stk and len(self.body) > self.depth:
                self.body = self.body[1:self.depth + 1]

    def grep_log_lines(self, filename, backlog=[], gzipped=False):

        from re import compile as re_compile
        from gzip import open as gzopen

        try:
            if gzipped:
                f = gzopen(filename, "r")
            else:
                f = open(filename, "r")
        except IOError as e:
            return (None, "Error opening file: {} : {}".format(filename, e))


        date_re="[0-9]{4}-[0-9]{2}-[0-9]{2} [0-9]{2}:[0-9]{2}:[0-9]{2}"
        re_worker="(?P<worker>portod-worker[0-9]+\[[0-9]+\])"
        re_aux="(?P<aux>(portod|portod-slave|portod-spawn-p|portod-spawn-c|portod-event0)\[[0-9]+\])"
        re_process="({}|{}):".format(re_worker, re_aux)

        re_msg = "(?P<msg>(WRN|ERR|EVT|REQ|RSP|ACT|SYS|STK|   ))"
        re_exp = "{} {} {}".format(date_re, re_process, re_msg)

        re_obj = re_compile(re_exp)
        re_sys = re_compile("{} {} SYS (Started)".format(date_re, re_process))

        run_hdr = None
        last = "aux"
        sessions = {"aux" : self.ShortLogCtx()}
        result = []
        prev = []

        from itertools import chain

        for line in chain(f, backlog):
            m = re_obj.search(line)
            if m:
                md = m.groupdict()
                msg = md["msg"]

                if msg == "SYS":
                    m2 = re_sys.search(line)
                    if m2:
                        last = "aux"
                        sessions = {"aux" : self.ShortLogCtx()}
                        result = []
                        prev = []
                        run_hdr = line
                        continue

                if md["worker"] is not None:
                    name = md["worker"]
                    last = name
                    s = sessions.get(name)

                    if msg != "REQ" and s is None:
                        prev += [line]
                        last = "aux"
                        continue

                    if msg == "REQ":
                        sessions[name] = self.LogCtx(line)

                    elif msg == "RSP":
                        if s.taints or s.stk > 0:
                            result += [s]

                        sessions[name] = None
                        last = "aux"

                    else:
                        sessions[name].append(line, taint=(msg == "WRN" or msg == "ERR"),
                                                          stk=(msg == "STK"))
                else:
                    s = sessions.get("aux")
                    last = "aux"

                    if msg == "STK":
                        if s:
                            s.append(line, stk=True)
                        else:
                            s = self.ShortLogCtx(line, stk=True)

                        result += [s]
                        sessions["aux"] = self.ShortLogCtx()

                    if msg == "ERR" or msg == "WRN":
                        s.append(line, True)
                        result += [s]
                        sessions["aux"] = self.ShortLogCtx()
                    else:
                        s.append(line)
            else:
                #append to last session
                sessions[last].append(line)

        for s in sessions.itervalues():
            if s and (s.taints > 0 or s.stk):
                result += [s]

        return (run_hdr, result, prev)



    def process_log(self, num):

        output = ""
        (hdr, result, backlog) = self.grep_log_lines("/var/log/portod.log", gzipped=False)
        found = reduce(lambda x, y: x + 0 if y is None else y.taints, result, 0)

        for i in range(1, 8):
            if found >= num or hdr:
                break

            (hdr, r2, backlog) = self.grep_log_lines("/var/log/portod.log.{}.gz".format(i),
                                                backlog, gzipped=True)

            result = r2 + result
            found += reduce(lambda x, y: x + 0 if y is None else y.taints, result, 0)


        if found < num and not hdr:
            (hdr, r2, _) = self.grep_log_lines("/var/log/portoloop.log".format(i))
            result = r2 + result

        for snippet in result:
            output += "".join(snippet.body) + "--\n"

        if hdr:
            output = hdr + "--\n" + output

        return output

    def __call__(self):
        try:
            import porto
            c = porto.Connection()
            try:
                if (c.Version()[0][0:1] != "3"):
                    #Skipping old porto for now
                    return (0, 0, [])
            except:
                return (0, 0, [])

            err = int(c.GetProperty("/", "porto_stat[errors]"))
            wrn = int(c.GetProperty("/", "porto_stat[warnings]"))

            msgs = None
            if err + wrn > 0:
                msgs = self.process_log(err + wrn)

            return (err, wrn, msgs)

        except BaseException as e:
            import traceback
            return (-1, -1, traceback.format_exc())

def main():

    from api.cqueue import Client
    from library.sky.hosts import resolveHosts
    from library.sky.hostresolver import Resolver

    from sys import argv
    from time import time, sleep

    if len(argv) != 2:
        return 1
    try:
        hosts_str = open(argv[1], "r").read()
    except:
        hosts_str = argv[1]

    try:
        hosts = list(resolveHosts(hosts_str.split())[0])
    except:
        hosts = list(Resolver().resolveHosts(hosts_str))

    print "Host total: {}".format(len(hosts))

    idx = 0

    while len(hosts) > 0:
        current = hosts[0:500]
        hosts = hosts[500:]

        if len(current) == 0:
            break

        expected = len(current)
        deadline = time() + TIMEOUT

        print "Starting porto checking on hosts indices: {}-{}, timeout {}...\n".format(idx, len(current) + idx , TIMEOUT)

        with Client(implementation="cqudp", netlibus=True) as client:
            with client.run(current, SkyPortodGrep()) as session:
                actual = 0

                while actual < expected and session.running:
                    sleep(0.05)

                    if time() > deadline:
                        print "Timeout for {} hosts".format(expected - actual)
                        break

                    for host, result, error in session.poll(0):

                        actual += 1
                        output = ""
                        output +=  "==== REPORT START: {} ====\n".format(host)

                        if error is None:
                            try:
                                (err, wrn, msgs) = result
                            except:
                                print "{}\n".format(result)
                                continue

                            if err == 0 and wrn == 0:
                                continue

                            output += "ERR: {} WRN: {}\n".format(err, wrn)

                            if msgs is not None:
                                output += msgs

                        else:
                            output += "FAIL: {} failed: {}\n".format(host, error)

                        output +=  "==== REPORT END: {} ====\n\n".format(host)

                        print output

        idx += len(current)

if __name__ == '__main__':
    main()
