# coding=utf-8
"""
Collect HAProxy Stats
"""
import re
import glob
import socket
import csv
import errno
import diamond.collector
from diamond.collector import str_to_bool

try:
    from cStringIO import StringIO
except ImportError:
    from io import StringIO


class HAProxySocketCollector(diamond.collector.Collector):

    SOCKET_BUFFER_SIZE = 1024

    def get_default_config_help(self):
        config_help = super(HAProxySocketCollector, self).get_default_config_help()
        config_help.update({
            'aggregate': 'Send aggregate stats instead of per process stats',
            'search_path': 'There is a difference between knowing the path and walking it, my son',
            'path': 'haproxy',
            'name_exlude': 'Exclude "show stat" CSV row if first column matches',
            'type_exlude': 'Exclude "show stat" CSV row if second column matches'
        })
        return config_help

    def get_default_config(self):
        """
        Returns the default collector settings
        """
        config = super(HAProxySocketCollector, self).get_default_config()
        config.update({
            'aggregate': False,
            'search_path': '/var/run/haproxy-stats.sock.*',
            'path': 'haproxy',
            'name_exclude': '^stats-.*$',
            'type_exclude': '^(?!frontend|backend).*$'
        })
        return config

    def get_socket_data(self, socket_path, command):
        """
        Request given command from HAProxy Sockets
        """
        s = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)

        try:
            s.connect(socket_path)
        except IOError, e:
            if e.errno == errno.ECONNREFUSED:
                self.log.error("insufficient socketry ..%s ", e)
                return list()

        s.sendall(command + '\n')

        buff = StringIO()
        data = ''

        while True:
            try:
                chunk = s.recv(self.SOCKET_BUFFER_SIZE)
                if chunk:
                    buff.write(chunk)
                else:
                    break
            except IOError as e:
                if e.errno not in (errno.EAGAIN, errno.EINTR):
                    raise

        s.close()
        data = buff.getvalue()
        buff.close()

        return data

    def parse_show_info(self, socket_path):
        """
        Returns dict of 'show info' data from HAProxy socket
        """
        raw_show_info_data = self.get_socket_data(socket_path, 'show info')

        # each line in data following "key: value" format
        # only grab the metrics with numeric values
        stats = {}
        for line in raw_show_info_data.splitlines():
            parsed_line = line.split(': ')
            if len(parsed_line) == 2:
                key = parsed_line[0]
                value = parsed_line[1]
                if value.isdigit():
                    stats[key] = value

        return stats

    def collect_show_info(self, socket_path, instance, process):
        """
        Further parses 'show info' data from HAProxy socket
        into dictionary of metric "key: value" pairs

        ex: {'video_edge_ssl.process.2.maxsock': 4183}
        """
        stats = self.parse_show_info(socket_path)

        collected_stats = {}
        for raw_metric_name, raw_metric_value in stats.iteritems():
            base_metric_name = self._sanitize(raw_metric_name.lower())
            metric_value = float(raw_metric_value)

            metric_name = '%s.process.%s.%s' % (instance, process, base_metric_name)
            collected_stats[metric_name] = metric_value

        return collected_stats

    def collect_show_stat(self, socket_path, instance=None, process=None):
        """
        Parses 'show stat' data from HAProxy socket,
        into dictionary of metric "key: value" pairs

        ex: {'video_edge_ssl.frontend.stats-video_edge_ssl-2012_2012.10.slim': 2000}

        'show stat' data saved in CSV style format, current implementation constructs
        the stat name from the first and second columns + the column header and
        walks the row, collecting stats
        """
        name_exclude_pattern = re.compile(self.config['name_exclude'])
        type_exclude_pattern = re.compile(self.config['type_exclude'])

        raw_show_stat_data = self.get_socket_data(socket_path, 'show stat')

        full_csv_data = list(csv.reader(raw_show_stat_data.split('\n')))
        csv_headings = [self._sanitize(heading) for heading in full_csv_data[0]]
        # trimming off headings, parsed above
        csv_data = full_csv_data[1:]

        results = {}
        for row in csv_data:
            valid_row = row and len(row) > 2
            if not valid_row:
                continue

            # ex: https-keepalive
            first_column = self._sanitize(row[0].lower())
            # ex: FRONTEND, BACKEND
            second_column = self._sanitize(row[1].lower())

            invalid_row_pattern = name_exclude_pattern.match(first_column) or type_exclude_pattern.match(second_column)
            if invalid_row_pattern:
                continue

            base_metric_name = '%s.%s.%s.%s' % (instance, second_column, first_column, process)

            for index, metric_string in enumerate(row):
                try:
                    metric_value = int(metric_string)
                except ValueError:
                    continue

                metric_name = '%s.%s' % (base_metric_name, csv_headings[index])
                results[metric_name] = metric_value

        return results

    def collect_for_socket(self, socket_path):
        process_info = re.match(r".*\/haproxy-([\w-]+)-stats\.(\d+)\.sock", socket_path)
        instance_name = process_info.group(1)
        process_number = process_info.group(2)
        return {'show stat': self.collect_show_stat(socket_path, instance_name, process_number),
                'show info': self.collect_show_info(socket_path, instance_name, process_number)}

    def collect_combined_socket(self, stats_per_socket):
        combined_show_info = {}
        combined_show_stat = {}
        combined_stats = {'show info': combined_show_info,
                          'show stat': combined_show_stat}

        for info_and_show_stat in stats_per_socket.itervalues():
            for metric_name, metric_value in info_and_show_stat['show info'].iteritems():
                # follows format 'live.process.<SOCKET_NUM>.<STAT_TYPE>'
                split_metric_name = metric_name.split('.')
                # TODO: use re.match() similar to above to capture these groups
                combined_metric_name = "{}.{}.all.{}".format(split_metric_name[0],
                                                             split_metric_name[1],
                                                             split_metric_name[3])

                if combined_metric_name not in combined_show_info:
                    combined_show_info[combined_metric_name] = 0
                combined_show_info[combined_metric_name] = combined_show_info[combined_metric_name] + metric_value

            for metric_name, metric_value in info_and_show_stat['show stat'].iteritems():
                # follows format 'live.<backend|frontend>.live-port_<NUM>_<NUM>.<SOCKET_NUM>.<STAT_TYPE>'
                split_metric_name = metric_name.split('.')
                # TODO: use re.match() similar to above to capture these groups
                combined_metric_name = "{}.{}.{}.all.{}".format(split_metric_name[0],
                                                                split_metric_name[1],
                                                                split_metric_name[2],
                                                                split_metric_name[4])

                if combined_metric_name not in combined_show_stat:
                    combined_show_stat[combined_metric_name] = 0
                combined_show_stat[combined_metric_name] = combined_show_stat[combined_metric_name] + metric_value

        return combined_stats

    def collect(self):
        search_path = self.config['search_path']
        collect_aggregate = self.config['aggregate']

        stats_per_socket = {socket_path: self.collect_for_socket(socket_path)
                            for socket_path in glob.glob(search_path)}

        if not str_to_bool(collect_aggregate):
            self.log.debug('Running per process stats')
            for info_and_show_stat in stats_per_socket.itervalues():
                self.log.debug('show stat:')
                for metric_name, metric_value in info_and_show_stat['show stat'].iteritems():
                    self.publish(metric_name, metric_value, metric_type='GAUGE')
                    self.log.debug("%s: %s", metric_name, metric_value)
                self.log.debug('show info:')
                for metric_name, metric_value in info_and_show_stat['show info'].iteritems():
                    self.publish(metric_name, metric_value, metric_type='GAUGE')
                    self.log.debug("%s: %s", metric_name, metric_value)
        else:
            # publishing combined stats across all sockets if aggregate setting is True
            self.log.debug("Running aggregated stats")
            combined_stats = self.collect_combined_socket(stats_per_socket)
            self.log.debug('show stat:')
            for metric_name, metric_value in combined_stats['show stat'].iteritems():
                self.publish(metric_name, metric_value, metric_type='GAUGE')
                self.log.debug("%s: %s", metric_name, metric_value)
            self.log.debug('show info:')
            for metric_name, metric_value in combined_stats['show info'].iteritems():
                self.publish(metric_name, metric_value, metric_type='GAUGE')
                self.log.debug("%s: %s", metric_name, metric_value)

    def _sanitize(self, s):
        """Sanitize the name of a metric to remove unwanted chars
        """
        sanitized_s = re.sub('[^\w-]', '_', s)
        sanitized_s = re.sub('^.*haproxy-', '', sanitized_s)
        return sanitized_s
