#!/usr/bin/env python

# -*- coding: UTF-8 -*-

import sys
import re
from collections import defaultdict, Counter

YCRIDS = {'-', 'andr', 'dav', 'ios', 'lnx', 'mac', 'mpfs', 'public', 'rest', 'sdk', 'web', 'win', 'wp'}

results_errors = Counter()
results_timings = defaultdict(Counter)
results_count_lines = Counter()
results_count_queries = Counter()

# tskv  tskv_format=ydisk-mpfs-requests-log host=mpfs5j.disk.yandex.net name=mpfs.requests.postgres appname=disk
#    unixtime=1510558856 timestamp=2017-11-13 10:40:56,496   timezone=+0300
#  ycrid=web-a1a9a21b964d05987d89e067e21395bc-ufo08e   request_id=758250_101065
#    pid=758250  module=logging
# message=dbname=diskdb user=disk_mpfs password=xxxxxxxxxxxxxxxxxxxx host=mpfsdb01f.disk.yandex.net
# port=6432 connect_timeout=10 "SELECT CAST('test unicode returns' AS VARCHAR(60)) AS anon_1" 1 0.002000


for line in sys.stdin:
    if not line.strip():
        continue

    if 'module=logging' not in line:
        continue

    ycrid = line[line.find('\tycrid=') + 7:].split('\t')[0]
    if '-' not in ycrid or ycrid == '-':
        continue

    ycrid = ycrid.split('-')[0]
    if ycrid.startswith('rest_'):
        ycrid = ycrid[5:]

    if ycrid not in YCRIDS:
        ycrid = 'malformed'

    message_index = line.find('\tmessage=')
    message = line[message_index + 9:].split('\t')[0]

    query_message = message.split()
    dbname = host = None
    try:
        for part in query_message:
            if 'dbname=' in part:
                dbname = part.split('=', 2)[1]
            elif 'host=' in part:
                host = part.split('=', 2)[1].replace('.', '_')
            if host and dbname:
                break
        else:
            raise ValueError()
        q_lines = int(query_message[-2])
        q_took_f = float(query_message[-1])
        q_took = "%.3f" % q_took_f
    except Exception:
        results_errors['parse_line'] += 1
        continue

    shard = host.split('_', 2)[0][:-1]
    if host.endswith('_db_yandex_net'):
        shard = dbname

    for last_part in (host, 'total', 'total_%s' % ycrid):
        key = '_'.join((dbname, shard, last_part))
        results_timings[key][q_took] += 1
        results_count_lines[key] += q_lines
        results_count_queries[key] += 1

    q_type = 'read'
    if any(q in message for q in ('INSERT', 'UPDATE', 'DELETE', 'insert', 'update', 'delete')):
        q_type = 'write'

    results_count_queries['_'.join((dbname, shard, q_type))] += 1
    if q_took_f > 0.3:
        results_count_queries['_'.join(('slow', dbname, shard))] += 1


def print_timings(name, data):
    for key, timings in sorted(data.items()):
        if timings:
            packed_timings = map(lambda t: "%s@%s" % t, sorted(timings.items()))
            print("@{}_{} {}".format(name, key, ' '.join(packed_timings)))


def print_counter(name, data):
    for shard, val in sorted(data.items()):
        print ("{}_{} {}".format(name, shard, val))


print_timings('query_timings', results_timings)
print_counter('query_count_lines', results_count_lines)
print_counter('query_count', results_count_queries)
print_counter('error', results_errors)
