#!/skynet/python/bin/python

import os
import sys
import fcntl
import msgpack
import subprocess
import argparse
from collections import OrderedDict

if sys.platform != 'cygwin':
    print 'This tool works on cygwin only!'
    sys.exit(1)


try:
    from setproctitle import PROCTITLE_FILENAME, setproctitle
except ImportError:
    print 'Cygwin`s setproctitle is not installed'
    sys.exit(1)


def set_proc_title(pid, title):
    mode = 'r+b' if os.path.exists(PROCTITLE_FILENAME) else 'w+b'

    with open(PROCTITLE_FILENAME, mode) as f:
        try:
            fcntl.flock(f, fcntl.LOCK_EX)
            data = msgpack.load(f) or {}
            data[pid] = title
            f.seek(0)
            msgpack.dump(data, f)
        finally:
            fcntl.flock(f, fcntl.LOCK_UN)


def create_test_data():
    values = [(7956, 'skynet.copier'), (1201, 'old_dead_process'), (9004, 'heartbeat-client')]
    with open(PROCTITLE_FILENAME, 'wb'):
        pass
    for v in values:
        set_proc_title(v[0], v[1])


class Column(object):
    def __init__(self, width, value):
        self._width = width
        self._value = value

    @property
    def format(self):
        return ':>' + str(self._width)

    @property
    def value(self):
        return self._value

    @value.setter
    def value(self, v):
        self._value = v


class EmptyColumn(Column):
    def __init__(self):
        super(EmptyColumn, self).__init__(0, '')


class UidStrColumn(Column):
    NAME = 'UID'

    def __init__(self, value):
        width = 8 if ord(value[0]) < 128 else 8 + len(value) / 2
        super(UidStrColumn, self).__init__(width, value)


class UidIntColumn(Column):
    NAME = 'UID'

    def __init__(self, value):
        super(UidIntColumn, self).__init__(11, value)


class PidColumn(Column):
    NAME = 'PID'

    def __init__(self, value):
        super(PidColumn, self).__init__(7, value)


class PpidColumn(Column):
    NAME = 'PPID'

    def __init__(self, value):
        super(PpidColumn, self).__init__(7, value)


class TtyColumn(Column):
    NAME = 'TTY'

    def __init__(self, value):
        super(TtyColumn, self).__init__(4, value)

    @property
    def format(self):
        return ':<' + str(self._width)


class LongTtyColumn(Column):
    NAME = 'TTY'

    def __init__(self, value):
        super(LongTtyColumn, self).__init__(5, value)


class StimeColumn(Column):
    NAME = 'STIME'

    def __init__(self, value):
        super(StimeColumn, self).__init__(12, value)


class LongStimeColumn(Column):
    NAME = 'STIME'

    def __init__(self, value):
        super(LongStimeColumn, self).__init__(8, value)


class CommandColumn(Column):
    NAME = 'COMMAND'

    def __init__(self, value):
        super(CommandColumn, self).__init__(0, value)

    @property
    def format(self):
        return ''


class PgidColumn(Column):
    NAME = 'PGID'

    def __init__(self, value):
        super(PgidColumn, self).__init__(7, value)


class WinpidColumn(Column):
    NAME = 'WINPID'

    def __init__(self, value):
        super(WinpidColumn, self).__init__(10, value)


class FlagColumn(Column):
    NAME = 'FLAG'

    def __init__(self, value):
        super(FlagColumn, self).__init__(1, value)


class ColumnFactory(object):
    def __init__(self, known_columns):
        self._known_columns = known_columns

    def create_columns(self, line):
        columns = OrderedDict()
        column_values = line.split()

        if len(column_values) < len(self._known_columns) - 1:
            # incorrect line, skip it
            return columns

        # pre-process column_values
        column_values = self._preprocess_columns(column_values)

        if len(column_values) != len(self._known_columns):
            # incorrect line, skip it
            return columns

        for column_class, column_value in zip(self._known_columns, column_values):
            columns[column_class.NAME] = column_class(column_value)

        return columns

    def _preprocess_columns(self, column_values):
        return column_values


class StandardColumnFactory(ColumnFactory):
    def __init__(self):
        super(StandardColumnFactory, self).__init__(
            [
                UidStrColumn,
                PidColumn,
                PpidColumn,
                TtyColumn,
                StimeColumn,
                CommandColumn
            ]
        )

    def _preprocess_columns(self, column_values):
        # handle defunct, we have to join it with command field
        if column_values[-1] == '<defunct>':
            column_values[-2] = column_values[-2] + ' <defunct>'
            del column_values[-1]

        # handle date field, it can be split, we have to join it again
        if column_values[4].isalpha():
            column_values[4] = column_values[4] + column_values[5]
            del column_values[5]

        return column_values


class LongColumnFactory(ColumnFactory):
    def __init__(self):
        super(LongColumnFactory, self).__init__(
            [
                FlagColumn,
                PidColumn,
                PpidColumn,
                PgidColumn,
                WinpidColumn,
                LongTtyColumn,
                UidIntColumn,
                LongStimeColumn,
                CommandColumn
            ]
        )

    def _preprocess_columns(self, column_values):
        # handle defunct, we have to join it with command field
        if column_values[-1] == '<defunct>':
            column_values[-2] = column_values[-2] + ' <defunct>'
            del column_values[-1]

        # flag can be omitted, so we have to add column manually
        if not column_values[0].isalpha():
            column_values.insert(0, ' ')

        # skip
        # handle date field, it can be split, we have to join it again
        if column_values[7].isalpha():
            column_values[7] = column_values[7] + column_values[8]
            del column_values[8]

        return column_values


class ProcessInfo(object):
    def __init__(self, columns):
        self._columns = columns
        self._children = []

    @property
    def pid(self):
        value = self._columns.get(PidColumn.NAME, EmptyColumn()).value
        return int(value) if value else 0

    @property
    def ppid(self):
        value = self._columns.get(PpidColumn.NAME, EmptyColumn()).value
        return int(value) if value else 0

    @property
    def proctitle(self):
        return self._columns.get(CommandColumn.NAME, EmptyColumn()).value

    @proctitle.setter
    def proctitle(self, title):
        if CommandColumn.NAME in self._columns:
            self._columns[CommandColumn.NAME].value = title

    def append_child(self, child):
        self._children.append(child)

    def print_info(self, level=-1):
        if level >= 0:
            prefix = ''
            if level - 1 > 0:
                prefix += '|   '
            if level - 2 > 0:
                prefix += '    '*(level-2)
            if level > 0:
                prefix += '\_ '

            if prefix:
                self.proctitle = prefix + self.proctitle

            format = ''
            for i, column in enumerate(self._columns.values()):
                format += '{column[' + str(i) + '].value' + column.format + '} '

            #print format
            print format.format(column=self._columns.values())

        for child in self._children:
            child.print_info(level+1)


def main():
    setproctitle('cygps')
    args = parse_args()
    factory = LongColumnFactory() if args.long else StandardColumnFactory()

    # if args.debug:
    #     create_test_data()

    # dummy root
    root_line = '1 0 0 0 ? 0 000 ---' if args.long else 'none 1 0 ? 000 ---'
    root = ProcessInfo(factory.create_columns(root_line))
    pids = dict()
    pids[root.pid] = root
    title = ''

    # run ps
    ps_args = '-al' if args.long else '-af'
    process = subprocess.Popen(['ps', ps_args],
                               close_fds=True,
                               stdout=subprocess.PIPE,
                               stderr=subprocess.PIPE)
    stdout, stderr = process.communicate()

    if stderr:
        print 'Something goes wrong:', stderr
        return 1

    for line in stdout.split('\n'):
        if not title:
            title = line
            continue

        info = ProcessInfo(factory.create_columns(line))
        # ignore invalid lines and hide real ps process
        if info.pid and info.pid != process.pid:
            pids[info.pid] = info

    # build a tree
    for pinfo in pids.values():
        if pinfo.ppid and pinfo.ppid != '0':
            pids[pinfo.ppid].append_child(pinfo)

    # read real names from file and update ProcessInfo
    mode = 'r+b' if os.path.exists(PROCTITLE_FILENAME) else 'w+b'
    with open(PROCTITLE_FILENAME, mode) as f:
        try:
            fcntl.flock(f, fcntl.LOCK_EX)
            data = msgpack.load(f) or {}
            need_update = False

            #print 'loaded data:', data
            for k, v in data.items():
                if k not in pids:
                    del data[k]
                    need_update = True
                else:
                    pids[k].proctitle = v

            if need_update:
                #print 'stored data:', data
                f.seek(0)
                msgpack.dump(data, f)
        finally:
            f.truncate()
            fcntl.flock(f, fcntl.LOCK_UN)

    # show new output
    print title
    root.print_info()
    return 0


def parse_args():
    parser = argparse.ArgumentParser()
    #parser.add_argument('--debug', action='store_true', default=False, help='Pre-create test file')
    parser.add_argument('-l', '--long', action='store_true', default=False, help='show process uids, ppids, pgids, winpids')

    return parser.parse_args()


if __name__ == '__main__':
    sys.exit(main())
