#!/usr/bin/env python3

import os
import re
import sys

from collections import defaultdict, namedtuple
from datetime import datetime, timedelta
from urllib.parse import urlparse, parse_qs

datapoint_fields = ['time', 'calltype', 'req', 'host']
consul_query_fields = ['method', 'path_params', 'count']

# python < 3.7 has different method to set namedtuples defaults
if sys.version_info.minor < 7:
    DataPoint = namedtuple('DataPoint', datapoint_fields)
    DataPoint.__new__.__defaults__ = (None,) * len(datapoint_fields)

    ConsulQuery = namedtuple('ConsulQuery', consul_query_fields)
    ConsulQuery.__new__.__defaults__ = (None,) * len(consul_query_fields)
else:
    DataPoint = namedtuple('DataPoint', datapoint_fields, defaults=(None,) * len(datapoint_fields))
    ConsulQuery = namedtuple('ConsulQuery', consul_query_fields,
                             defaults=(None,) * len(consul_query_fields))

DEFAULT_INTERVAL = 60
DEFAULT_TIMEZONE_OFFSET = 8  # hours

# Telemetry data comes in increment of 10
TELEMETRY_INTERVAL = 10


def safe_div(x, y):
    if y == 0:
        return 0
    else:
        return x / y


def perinterval(start: datetime, end: datetime, interval: timedelta):
    current = start
    while current < end:
        yield current
        current += interval


class ConsulMonitorAnalyzer(object):

    def __init__(self, interval: int = DEFAULT_INTERVAL, timezone_offset: int = DEFAULT_TIMEZONE_OFFSET,
                 detailed: bool = False):
        super(ConsulMonitorAnalyzer, self).__init__()

        self.interval = interval
        self.detailed = detailed
        self.timezone_offset = timezone_offset

        self.dc_datapoints = defaultdict(lambda: [])

        # format: yyyy/mm/dd hh:MM:ss [DEBUG|INFO|...] http|agent|consul.rpc extras...
        self.log_regex = re.compile(r'^([\d/]{10} [\d:]{8}) (\[\w+\]) ([\w.]+): (.*)$')

        # format: [yyyy-mm-dd hh:MM:ss +0000 UTC][S] 'consul.http.<VERB>.<query>': ... Count: <value>
        # Groups: time, type, count
        self.rpc_telemetry_regex = re.compile(r'^\[([\d\-\s\:\+\w]+)\].*\'consul.rpc.([\w\-]+).*\': Count: ([\d\.]+) .+$') # noqa

    def parse_files(self, folder: str = '.') -> dict:
        for filename in os.listdir(folder):
            host = filename.split('.')[0]
            dc = filename.split('.')[1]

            with open(os.path.join(folder, filename)) as f:
                content = f.readlines()
                content = [x.strip() for x in content]

                for line in content:
                    self.dc_datapoints[dc].append(self.parse_logline(line=line, host=host))

    def _get_interval_from_datapoints(self, datapoints: list) -> int:
        if len(datapoints):
            start = datapoints[0].time
            end = datapoints[-1].time
            interval = (end - start).seconds

            # Hack for UTC. If interval is greater than timezone offset, subtract that offset delta.
            # This assumes we do not collect data greater than timezone_offset.
            # TODO: A better way is to take an average of the aggregate of dcs, Another day.
            if interval >= self.timezone_offset * 3600:
                interval -= self.timezone_offset * 3600

            return interval
        else:
            return 0

    def pipe_input(self, calltype: str = 'http'):
        future = None
        point_in_time = None
        datapoints = []

        try:
            for line in sys.stdin:
                dp = self.parse_logline(line=line.strip())

                if dp is None:
                    dp = self.parse_rpc_telemetry(line=line.strip())

                if dp is None:
                    continue

                current = dp.time

                # Initial bootstrap
                if future is None:
                    future = current + timedelta(seconds=self.interval)

                if point_in_time is None:
                    point_in_time = current

                if current >= future:
                    future = current + timedelta(seconds=self.interval)
                    point_in_time = current

                    print("{} >> ".format(datetime.strftime(current, '%H:%M:%S')), end='')
                    self.http_rpc_analyzer(datapoints, interval=self.interval, calltype=calltype)

                    datapoints = []
                else:
                    # Make sure that we add datapoints that are current or later - to solve duplicate telemetry
                    # data.
                    # E.g. When requesting telemetry data from consul by issuing SIGUSR1, consul ouputs the last 4
                    # 10s interval of telemetry data. If interval is set < 40s, consul ouputs the same data again
                    # and we should ignore data that has already been processed.
                    if current >= point_in_time:
                        datapoints.append(dp)
        except (KeyboardInterrupt, SystemExit):
            pass
        finally:
            print("{} >> ".format(datetime.strftime(current, '%H:%M:%S')), end='')
            self.http_rpc_analyzer(datapoints, interval=self.interval, calltype=calltype)

    def http_rpc_analyzer(self, datapoints: list = [], interval: int = None, calltype: str = 'http'):
        rpc_http_datapoints = [dp for dp in datapoints if dp.calltype == calltype]

        if interval is None:
            interval = self._get_interval_from_datapoints(rpc_http_datapoints)

        if calltype == 'http':
            self._http_analyzer(datapoints=rpc_http_datapoints, interval=interval)
        else:
            self._rpc_analyzer(datapoints=rpc_http_datapoints, interval=interval)

    def _http_analyzer(self, interval: int, datapoints: list = []):
        total = sum([dp.req.count for dp in datapoints])
        wait_completed = 0
        stale = 0

        gets = sum([dp.req.count for dp in datapoints if dp.req.method == 'GET'])
        non_gets = sum([dp.req.count for dp in datapoints if dp.req.method != 'GET'])

        for dp in datapoints:
            url = urlparse(dp.req.path_params)
            queries = parse_qs(url.query, keep_blank_values=True)

            if queries.get('index', None):
                wait_completed += 1

            if queries.get('stale', None):
                stale += 1

        print("GET: {:4} {:>4.2f}qps ({:>6.2f}%) "
              "non-GET: {:4} {:4.2f}qps ({:>6,.2f}%) "
              "waits: {:4} {:>4.2f}qps ({:>6.2f}%) "
              "stale: {:4} {:>4.2f}qps ({:>6.2f}%) "
              "Total: {}".format(gets, safe_div(gets, interval), safe_div(gets, total) * 100,
                                 non_gets, safe_div(non_gets, interval), safe_div(non_gets, total) * 100,
                                 wait_completed, safe_div(wait_completed, interval), safe_div(wait_completed, total) * 100, # noqa
                                 stale, safe_div(stale, interval), safe_div(stale, total) * 100,
                                 total))

    def _rpc_analyzer(self, interval: int, datapoints: list = []):
        total = sum([dp.req.count for dp in datapoints])
        request = sum([dp.req.count for dp in datapoints if dp.req.method == 'request'])
        query = sum([dp.req.count for dp in datapoints if dp.req.method == 'query'])
        cross_dc = sum([dp.req.count for dp in datapoints if dp.req.method == 'cross-dc'])

        print("request: {:5} {:>7.2f}qps ({:>6.2f}%) "
              "query: {:5} {:>7.2f}qps ({:>6,.2f}%) "
              "cross-dc: {:5} {:>7.2f}qps ({:>6.2f}%) "
              "Total: {}".format(request, safe_div(request, interval), safe_div(request, total) * 100,
                                 query, safe_div(query, interval), safe_div(query, total) * 100,
                                 cross_dc, safe_div(cross_dc, interval), safe_div(cross_dc, total) * 100, # noqa
                                 total))

    def parse_logline(self, line: str, host: str = 'stdin') -> DataPoint:
        try:
            g = self.log_regex.search(line)

            if g is not None:
                time, loglevel, calltype, query = g.groups()
                time = datetime.strptime(time, '%Y/%m/%d %H:%M:%S')

                if calltype == "http":
                    # Sample:
                    # Request GET /v1/health/service/twitch_ping_service?dc=fra06 (1m2.726677928s) from=127.0.0.1:34994
                    _, method, path_params, *_ = query.split(" ")

                    return DataPoint(time=time, calltype=calltype,
                                     req=ConsulQuery(method=method, path_params=path_params, count=1), host=host)
                else:
                    return DataPoint(time=time, calltype=calltype, host=host)
            else:
                return None
        except Exception as ex:
            print("Offending line: {}, {}".format(line, ex))
            return None

    def parse_rpc_telemetry(self, line: str) -> DataPoint:
        try:
            g = self.rpc_telemetry_regex.search(line)

            if g is not None:
                time, action, count = g.groups()
                time = datetime.strptime(time, '%Y-%m-%d %H:%M:%S %z %Z')

                return DataPoint(time=time, calltype='rpc',
                                 req=ConsulQuery(method=action, count=int(count)),
                                 host='stdin')
            else:
                return None
        except Exception:
            return None

    def time_trend_by_dc(self):
        """
        detailed: break down by hosts
        """

        dcs = self.dc_datapoints.keys()
        for dc in dcs:
            print("{}".format('-' * 8))
            print('Datacenter: {}'.format(dc))
            print("{}".format('-' * 8))

            all_time = sorted(set([dp.time for dp in self.dc_datapoints[dc]]))
            start = all_time[0]
            end = all_time[-1]

            if self.detailed:
                for host in set([dp.host for dp in self.dc_datapoints[dc]]):
                    print("Host: {}".format(host))

                    for time in perinterval(start, end, timedelta(seconds=self.interval)):
                        results = [dp for dp in self.dc_datapoints[dc] if dp.time >= time and
                                   dp.time < (time + timedelta(seconds=self.interval))]
                        results = [dp for dp in results if dp.host == host]

                        num_calls = len(results)
                        print("{}: {} {}q".format(time, '|' * round(safe_div(num_calls, self.interval)), num_calls))
            else:
                for time in perinterval(start, end, timedelta(seconds=self.interval)):
                    results = [dp for dp in self.dc_datapoints[dc] if dp.time >= time and
                               dp.time < (time + timedelta(seconds=self.interval))]

                    num_calls = len(results)
                    print("{}: {} {}q".format(time, '|' * round(safe_div(num_calls, self.interval)), num_calls))

    def summary(self):
        """
        detailed: break by DC
        """

        data = {'all': [i for sublist in self.dc_datapoints.values() for i in sublist]}
        calltypes = sorted(set([ct.calltype for ct in data.get('all')]))

        if self.detailed:
            data = self.dc_datapoints

        for dc, datapoints in data.items():
            print("{}".format('-' * 8))
            print('Datacenter: {}'.format(dc))
            print("{}".format('-' * 8))

            interval = self._get_interval_from_datapoints(datapoints)
            total = len(datapoints)

            for calltype in calltypes:
                results = [dp for dp in datapoints if dp.calltype == calltype]
                results_total = len(results)
                percentage = safe_div(results_total, total) * 100

                print("dc: {:<5} calltype: {:<15} total: {:>9} percentage: {:>6.2f}% qps: {:>6.2f}qps {}".format(
                      dc, calltype, results_total, percentage, safe_div(results_total, interval),
                      '|' * round(safe_div(results_total, interval))))

            print("dc: {:<5} calltype: {:<15} total: {:>9} percentage: {:>6.2f}% qps: {:>6.2f}qps {}".format(
                  dc, 'all', total, 100, safe_div(total, interval),
                  '|' * round(safe_div(total, interval))))

            print("{}".format('-' * 8))
            print("Breakdown of HTTP calls for dc: {}".format(dc))
            self.http_rpc_analyzer(datapoints)


if __name__ == '__main__':
    import argparse

    parser = argparse.ArgumentParser()

    exclusive = parser.add_mutually_exclusive_group(required=True)
    exclusive.add_argument('-d', '--dir', help='dir of files')
    exclusive.add_argument('-p', '--pipe', help='Analyze consul http calls from stdin', action='store_true')

    parser.add_argument('--time-trend-by-dc', action='store_true',
                        help='time series trend by DC. Pass --detailed to break down by individual host')

    parser.add_argument('--summary', action='store_true',
                        help='Summary of all call types. Pass --detailed to break it down by dc')

    parser.add_argument('-i', '--interval', type=int, default=DEFAULT_INTERVAL,
                        help="Interval in seconds. Defaults to {}".format(DEFAULT_INTERVAL))

    parser.add_argument('-t', '--telemetry', action='store_true',
                        help='Can only be used with --pipe. Expects telemetry data')

    parser.add_argument('--detailed', action='store_true', default=False)
    parser.add_argument('--timezone-offset', type=int, default=DEFAULT_TIMEZONE_OFFSET,
                        help="Timezone offset in hours. This is to account for some host that are configured "
                             "in UTC instead of PST. Defaults to {}".format(DEFAULT_TIMEZONE_OFFSET))

    args = parser.parse_args()

    if args.telemetry is True:
        interval = TELEMETRY_INTERVAL
        calltype = 'rpc'
    else:
        interval = args.interval
        calltype = 'http'

    cma = ConsulMonitorAnalyzer(interval=interval, detailed=args.detailed)

    if args.dir is not None:
        cma.parse_files(args.dir)

        kargs = vars(args)
        for arg, result in kargs.items():
            if result is True:
                if hasattr(cma, arg) and callable(getattr(cma, arg)):
                    getattr(cma, arg)()
    else:
        cma.pipe_input(calltype=calltype)
