#!/usr/bin/env pyhon

#from pytz import timezone
from datetime import datetime, timedelta
from sys import stdout,stderr
import threading
import logging
import time
import socket
import re
import json

#=== PG STATS ===
#host=pgload03h.mail.yandex.net port=6432 dbname=maildb user=mxfront: 
#	free_connections: 20
#	busy_connections: 0
#	pending_connections: 0
#	max_connections: 20
#	queue_size: 0
#	average_request_roundtrip: 43 ms
#	average_request_db_latency: 8 ms
#	average_wait_time: 53 ms
#	dropped_connections_timed_out: 0
#	dropped_connections_failed: 0
#	dropped_connections_busy: 0
#	dropped_connections_with_result: 0
#=== ===

field_map = {'free_connections': 'free-conn', 'busy_connections': 'busy-conn', 'pending_connections': 'pending-conn', 'queue_size': 'queue-size', 'average_request_roundtrip': 'avg-rtt', 'average_request_db_latency': 'avg-db-lat', 'average_wait_time': 'avg-wait', 'dropped': 'dropped'}

def split_field(f):
    name, raw_val = f.split(': ')
    if 'ms' in raw_val:
        val = float(raw_val.rstrip(' ms'))
    else:
        val = float(raw_val)
    return (name, val)

def parse_status(status, regex):
    result = {}
    try:
        perhost_stats = regex.findall(status)[0].split('\r\nhost=')[1:] #each entry begins with host=, so 0 element is always blank
        for hs in perhost_stats:
            raw = hs.split('\r\n\t')
            host = raw[0].split(' ')[0].split('.')[0] #host is first part of the list and we only use short hostname
            host_stats = dict([split_field(f) for f in raw[1:]])
            host_stats['dropped'] = sum([host_stats[k] for k in host_stats if k.startswith('dropped')])
            for k,v in host_stats.iteritems():
                if k in field_map:
                    result[(host, field_map[k])] = [v]
        return result
    except Exception, ex:
        log.error("Parsing failed with %s for %s", ex, status.replace('\n', '\\n'), exc_info=True)
        return None

def get_status(host, port):
    start = time.time()
    s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    s.settimeout(0.2)
    s.connect((host, port))
    s.sendall('db\n')
    total_data = []
    finished = False
    while not finished:
       time.sleep(0.001)
       data = s.recv(65535)
       if data:
           total_data.append(data)
           if '\r\ndone\r\n' in data:
               finished = True
       else:
           break
    response = ''.join(total_data)
    end = time.time()
    duration = end - start
    return (response, end, duration)

class DumperTread(object):

    def __init__(self, stats, offset=0.2):
        self.stats = stats
        self.offset = offset
        thread = threading.Thread(target=self.run, args=(), name='dumper')
        thread.daemon = True                            # Daemonize thread
        thread.start()                                  # Start the execution

    def run(self):
        """ Method that runs forever """
        while True:
            now = time.time()
            try:
                stat = gen_dump(stats, now)
                log.info("Got stat for dump: %s", stat)
                if stat:
                    stdout.write(stat)
                    stdout.flush()
            except Exception, ex:
                log.error("Error dumping data: %s", ex, exc_info=True)
            next_ts = int(now) + 1. + self.offset
            delta = next_ts - now
            log.info("dumper sleep for: %s", delta)
            time.sleep(delta)


class CollectorTread(object):

    def __init__(self, stats, host, port, regex, frequency=1, offset=0.1 ):
        self.stats = stats
        self.offset = offset
        self.host = host
        self.port = port
        self.regex = regex
        self.delay = 1.0/frequency
        self.last_run = None
        thread = threading.Thread(target=self.run, args=(), name='colector')
        thread.daemon = True                            # Daemonize thread
        thread.start()                                  # Start the execution

    def adjusted_sleep(self):
        now = time.time()
        delay = self.last_run + self.delay - now
        if delay > 0:
            log.info("%s sleep for: %ss", 'collector' , delay)
            time.sleep(delay)
        else:
            log.info("%s behind schedule for %ss", 'collector', delay)
            
    def collect(self):
        timeformat = "%Y-%m-%dT%H:%M:%S"
        try:
            status, end, duration = get_status(self.host, self.port)
            end = int(end)
            status_dict = parse_status(status, self.regex)
            #status_dict["duration"] = [duration]
            log.debug("Got data after parsing %s", status_dict)

            pg_stats = {}

            for key, val in status_dict.iteritems():
                host, f = key
                new_key = ('db-stat', f)
                if new_key not in pg_stats:
                    pg_stats[new_key] = []
                pg_stats[new_key] += val

            if end not in self.stats:
                self.stats[end] = {}
                for key in pg_stats:
                    self.stats[end][key] = pg_stats[key]
            else:
                for key,val in pg_stats.iteritems():
                    self.stats[end][key] += val #note: val is list, so we join to previuos list here
        except Exception, ex: 
            log.warning("Got error while getting data: %s", ex, exc_info=True)
        
    def run(self):
        """ Method that runs forever """
        while True:
            if self.last_run:
                self.adjusted_sleep()
            self.last_run = time.time()
            self.collect()

def gen_dump(stats, now):
    result = ''
    t = int(now - 1)
    if t in stats:
        stat_vals = stats[t]
        log.debug("Goin to dump some stats: %s", stat_vals)
        names = [k for k in stat_vals.keys() if k not in ['duration', 'count']]
        groups = set([i[0] for i in names])
        for g in groups:
            fields = {}
            for n in [i[1] for i in names if i[0]==g]:
                subvalues = stat_vals[(g,n)]
                val_count = len(subvalues)
                if val_count > 0:
                    avg = float(sum(subvalues)) / val_count
                    max_val = max(subvalues)
                else:
                    avg = 0
                    max_val = 0
                fields["%s_%s" % (n, 'avg')] = avg
                fields["%s_%s" % (n, 'max')] = max_val
            grp_result = {'fields': fields, 'name':g, 'timestamp':t}
            result += json.dumps(grp_result) + '\n'
        return result
    else:
        log.info("No stats for %s", t)
        return None


if __name__ == '__main__':
    log = logging.getLogger('fastsrv_db_stat')
    log.setLevel(logging.INFO)
    ch = logging.StreamHandler()
    formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    ch.setFormatter(formatter)
    log.addHandler(ch)

    regex = re.compile('^=== PG STATS ===(.*)^=== ===', flags = (re.MULTILINE| re.DOTALL))

    stats = {}
    host = 'localhost'
    port = 4321
    frequency = 5
    collector = CollectorTread(stats, host, port, regex, frequency=frequency)
    dumper = DumperTread(stats)
    while True:
        time.sleep(1)
