# pylint:disable=C0103
#
# pylint:disable=W0703

# B333FC49\tRequestTime\tTotalTime\tIntStat\tLMStat\tExtTime\t ExtStat
# ServerName      Port    Request ReplyCode
"""yabs-server-phantom access log parser"""
import time
import re
import socket
import logging
import logging.handlers
import os
import os.path
import glob

from collections import defaultdict, OrderedDict

HANDLE_RE = re.compile(r'^\w+\s+\/+(\w+).*')
METRIC_GROUP = 'one_sec'


def decode_ext_stat_entry(entry_str):
    """
    :return tuple: ext, shard, err, time
    """
    try:
        fields = entry_str.split(':')
        return fields[0], fields[3], fields[1] != '+', float(fields[2])
    except Exception:
        logging.error("Bad ExtStat entry: %s", entry_str)
        return '__unknown__', 0, False, 0


def update_dict(prev, update):
    if not update:
        return

    for key in update.keys():
        if not prev.get(key):
            prev[key] = update[key]
            continue

        if isinstance(update[key], list):
            prev[key].extend(update[key])
        if isinstance(update[key], dict):
            update_dict(prev[key], update[key])
        if isinstance(update[key], (int, float)):
            prev[key] += update[key]


class LineFormat(object):

    """
    Handles log header and stores field numbers
    """
    # pylint:disable = R0902,R0903

    def __init__(self, format_line):
        fields = format_line.split()
        self.field_count = len(fields)
        self.stamp = fields[0][1:]

        field_ord = {name: field_no for field_no, name in enumerate(fields)}
        self.req_time_fn = field_ord['RequestTime']
        self.tot_time_fn = field_ord['TotalTime']
        self.int_stat_fn = field_ord['IntStat']
        self.ext_time_fn = field_ord['ExtTime']
        self.ext_stat_fn = field_ord['ExtStat']
        self.request_fn = field_ord['Request']
        self.reply_code_fn = field_ord['ReplyCode']
        self.perf_limits_stat_fn = field_ord['PerfLimitsStat']


class IntervalMetrics(object):
    # pylint:disable = R0902,R0914

    """Class that aggregates data for one time interval"""

    METRICS_NAMES = ['tot_times', 'wext_times', 'ext_times', 'int_sums', 'ext_errors', 'errors', 'request_errors', 'perf_limits']

    def __init__(self, disable_ext_sharded=None, only_ext_sharded=None):
        """
        disable_ext_sharded: disable per-shard aggregates for these ext requests
        only_ext_sharded: enable per-shard aggregates only for these ext requests
                          (overrides disable_ext_sharded even if set to empty iterable)
        """
        if only_ext_sharded is not None:
            self.is_ext_sharded = lambda ext: ext in only_ext_sharded
        elif disable_ext_sharded is not None:
            self.is_ext_sharded = lambda ext: ext not in disable_ext_sharded
        else:
            self.is_ext_sharded = lambda _: True

        self.n_requests = 0
        self.tot_times = defaultdict(list)
        self.wext_times = defaultdict(list)
        self.ext_times = defaultdict(list)
        self.int_sums = defaultdict(lambda: 0)
        self.ext_errors = defaultdict(lambda: 0)
        self.errors = defaultdict(lambda: 0)
        self.request_errors = defaultdict(lambda: 0)
        self.perf_limits = defaultdict(lambda: 0)

    def as_dict(self):
        return {metric: self.__dict__[metric] for metric in self.METRICS_NAMES}

    def add_line(self, fields, fmt):
        """Add line data"""

        handle_match = HANDLE_RE.match(fields[fmt.request_fn])
        handle = handle_match.group(1) if handle_match is not None else 'unknown'

        tot_time = float(fields[fmt.tot_time_fn])
        self.tot_times[handle].append(tot_time)

        ext_time_str = fields[fmt.ext_time_fn]
        if ext_time_str:
            self.wext_times[handle].append(tot_time - float(ext_time_str))

        int_stat_str = fields[fmt.int_stat_fn]
        if int_stat_str:
            for val_no, ival in enumerate(int_stat_str.split(':')):
                self.int_sums[val_no] += float(ival)

        ext_stat_str = fields[fmt.ext_stat_fn]
        if ext_stat_str:
            for entry_str in ext_stat_str.split(' '):
                ext, shard, error, ext_time = decode_ext_stat_entry(entry_str)
                self.ext_times[ext].append(ext_time)
                if error:
                    self.ext_errors[ext] += 1

                if self.is_ext_sharded(ext):
                    esh = ext + shard
                    self.ext_times[esh].append(ext_time)
                    if error:
                        self.ext_errors[esh] += 1

        code = fields[fmt.reply_code_fn]
        if code[0] == '5':
            self.errors[handle] += 1
        elif code[0] == '4':
            self.request_errors[handle] += 1

        perf_limits_stat = fields[fmt.perf_limits_stat_fn]
        for limit in perf_limits_stat.partition(':')[0].split(','):
            if limit:
                self.perf_limits[limit] += 1

    def update_from_dict(self, dict):
        for metric in self.METRICS_NAMES:
            update_dict(self.__dict__[metric], dict.get(metric, None))

    def prepare_metrics(self, quantiles):
        """ Calculate aggregates """
        handle_count = dict()
        total_count = 0

        metrics = {}

        for handle, tot_times in self.tot_times.iteritems():
            handle_cnt = len(tot_times)
            handle_count[handle] = handle_cnt
            total_count += handle_cnt

            metrics['requests.' + handle] = handle_cnt
            for q_key, q_val in list_qs_destructive(tot_times, quantiles):
                metrics['timings.%s.t%s' % (handle, q_key)] = q_val

        for handle, wext_times in self.wext_times.iteritems():
            for q_key, q_val in list_qs_destructive(wext_times, quantiles):
                metrics['ext_diff_timings.%s.t%s' % (handle, q_key)] = q_val

        for val_no, isum in self.int_sums.iteritems():
            metrics['int_stat.%s' % val_no] = 1.0 * isum / total_count

        for ext, ext_times in self.ext_times.iteritems():
            metrics['ext_requests.%s' % ext] = len(ext_times)
            for q_key, q_val in list_qs_destructive(ext_times, quantiles):
                metrics['ext_timings.%s.t%s' % (ext, q_key)] = q_val

        for ext, error_count in self.ext_errors.iteritems():
            metrics['ext_errors.%s' % ext] = error_count

        for handle, error_count in self.errors.iteritems():
            metrics['errors.%s' % handle] = error_count

        for handle, error_count in self.request_errors.iteritems():
            metrics['request_errors.%s' % handle] = error_count

        for limit, limit_count in self.perf_limits.iteritems():
            metrics['perf_limits.%s' % limit] = limit_count

        return metrics

    def prepare_ext_statistics(self, quantiles):
        """ Calculate betta aggregates for ext requests """
        ext_metrics = {}
        for ext, ext_times in self.ext_times.iteritems():
            ext_metrics[ext] = {
                'total': len(ext_times),
                'cumtime': round(sum(ext_times), 4),
                'errors': self.ext_errors.get(ext, 0),
            }
            for q_key, q_val in list_qs(ext_times, quantiles):
                ext_metrics[ext][str(q_key)] = round(q_val, 4)

        return ext_metrics


class Table(object):
    def __init__(self, headers, header_line_sep='-', cell_sep='|'):
        self.__headers = OrderedDict(
            (header, len(header))
            for header in headers
        )
        self.__data = []

        self.header_line_sep = header_line_sep
        self.cell_sep = cell_sep

    def add(self, row):
        for k, v in row.items():
            if k in self.__headers:
                self.__headers[k] = max(self.__headers[k], len(str(v)))
        self.__data.append(row)

    def __str__(self):
        header_line = self.cell_sep.join([''] + [
            ' {header:^{cell_length}} '.format(header=header, cell_length=cell_length)
            for header, cell_length in self.__headers.items()
        ] + [''])
        lines = ['-' * len(header_line), header_line, '-' * len(header_line)]
        for row in self.__data:
            lines.append(self.cell_sep.join([''] + [
                ' {value:<{cell_length}} '.format(value=row.get(header, ''), cell_length=cell_length)
                for header, cell_length in self.__headers.items()
            ] + ['']))
        lines.append('-' * len(header_line))
        return '\n'.join(lines)


def dump_ext_metrics(ext_metrics):
    headers = ['ext tag', 'total', 'errors', 'cumtime', ]
    for key in ext_metrics.values()[0]:
        if key not in headers:
            headers.append(key)
    table = Table(headers=headers)

    for ext, metrics in sorted(ext_metrics.items(), key=lambda k_v: -k_v[1]['cumtime']):
        entry = {
            'ext tag': ext,
        }
        entry.update(metrics)
        table.add(entry)

    return str(table)


def dump_metrics(role, timestamp, metrics):
    """Aggregate metrics and return string to be sent"""
    prefix = METRIC_GROUP + '.' + \
        socket.getfqdn().replace('.', '_') + '.' + role
    logging.debug(
        "Dumping %s metrics for %s with timestamp %s",
        len(metrics), role, timestamp
    )

    return '\n'.join(
        ('%s.%s %s %s' % (prefix, key, metrics[key], timestamp))
        for key in sorted(metrics.keys())
    )


def list_qs(lst, quantiles):
    """Extracts quantiles from list"""
    sorted_lst = sorted(lst)
    n = len(lst)
    for quan in quantiles:
        yield ('%g' % quan).replace('.', '_'), sorted_lst[min(int(n * quan / 100), n)]


def list_qs_destructive(lst, quantiles):
    """Extracts quantiles from list, dusrupts item order in the list"""
    lst.sort()
    n = len(lst)
    for quan in quantiles:
        yield ('%g' % quan).replace('.', '_'),\
            lst[min(int(n * quan / 100), n)] * 1000


class NoData(Exception):

    """
    Raised when log exists, but there is no more data in log
    """
    pass


class NoLog(Exception):

    """
    Raised when there is no log file
    """
    pass


class ParseError(RuntimeError):
    """Raised on bad lines"""
    pass


class LogHandler(object):

    """
    Aggregates log lines into counters
    Dumps counters as metrics
    """
    # pylint: disable=R0902
    INTERVAL_LEN = 1
    PARSING_SPEED_INTERVAL = 5

    def __init__(self,
                 role,
                 filepath,
                 max_lag,
                 quantiles,
                 disable_ext_sharded,
                 metrics_writers=()
                 ):
        self.metrics_writers = metrics_writers
        self.role = role
        self.filepath = filepath
        self.quantiles = quantiles
        self.disable_ext_sharded = disable_ext_sharded
        self.log_file = None
        self.inode = None
        self.fmt = None
        self.cur_interval_id = 0
        self.intervals = dict()
        self.time_counter = time.time()
        self.line_counter = 0
        self.skip_counter = 0
        self.last_line_ts = 0
        self.parse_errors_counter = 0
        self.next_parse_errors_send_ts = 0
        self.max_lag = max_lag
        self._open_log()

        self.log = logging.getLogger(role)
        self.log.info("LogHandler created.")

    def parse_lines(self):
        """
        Read log line-by-line
        Raise either NoLog or NoData
        """
        while True:
            for _ in range(100):
                line = self.log_file.readline()
                if line:
                    try:
                        self._parse_line(line)
                    except ParseError as exc:
                        self.log.error("%s: %s", exc, line)
                        self.parse_errors_counter += 1
                else:
                    self._handle_no_data()
            self._send_parse_errors()

    def _send_parse_errors(self):
        now = time.time()
        if now > self.next_parse_errors_send_ts:
            self._write_metrics(now, {'parse_errors': self.parse_errors_counter})
            self.parse_errors_counter = 0
            self.next_parse_errors_send_ts = now + 1.0

    def _parse_line(self, line):
        if line[0] == '#':
            try:
                self.fmt = LineFormat(line)
            except Exception:
                raise ParseError("Bad format line")
            else:
                return

        if self.fmt is None:
            return

        fields = line.split('\t')
        if fields[0][1:] != self.fmt.stamp:
            raise ParseError("Bad stamp: {}".format(fields[0][1:]))
        if len(fields) != self.fmt.field_count:
            raise ParseError('Line should have {} fields'.format(self.fmt.field_count))

        try:
            self.last_line_ts = float(fields[self.fmt.req_time_fn])
        except ValueError:
            raise ParseError("Bad timestamp")
        if self.get_lag() > self.max_lag:
            self.skip_counter += 1
            if self.skip_counter % 100000 == 0:
                self.log.info("Total %s lines skipped", self.skip_counter)
            return

        self.line_counter += 1
        interval_id = int(self.last_line_ts / self.INTERVAL_LEN)

        if interval_id > self.cur_interval_id:
            self._new_interval(interval_id)

        if interval_id in self.intervals:
            try:
                self.intervals[interval_id].add_line(fields, self.fmt)
            except Exception as exc:
                raise ParseError(str(exc))
        else:
            self.log.warning("Dropping line that is %s seconds old", (
                self.cur_interval_id - interval_id) * self.INTERVAL_LEN)

    def get_lag(self):
        """
        :return: delay between last line timestamp and now
        """
        return time.time() - self.last_line_ts

    def _new_interval(self, new_interval_id):
        """
        Leaves last two intervals untouched
        For all other intervals:
            calculates metrics
            sends them everywhere
            deletes interval
        Also calculates parsing speed
        """
        self.cur_interval_id = new_interval_id

        for iid in sorted(self.intervals.keys())[:-3]:
            metrics = self.intervals[iid].prepare_metrics(self.quantiles)
            del self.intervals[iid]

            timestamp = iid * self.INTERVAL_LEN
            self._write_metrics(timestamp, metrics)

        self.intervals[self.cur_interval_id] = IntervalMetrics(
            disable_ext_sharded=self.disable_ext_sharded
        )

        new_time_counter = time.time()
        if (new_time_counter - self.time_counter) > self.PARSING_SPEED_INTERVAL:
            self.log.info(
                "Parsing speed: %g RPS",
                self.line_counter / (new_time_counter - self.time_counter)
            )
            self.time_counter = new_time_counter
            self.line_counter = 0

    def _write_metrics(self, timestamp, metrics):
        for wrt in self.metrics_writers:
            wrt.write(self.role, timestamp, metrics)

    def _handle_no_data(self):
        """
        Check log file inode and reopen it if the log was rotated
        :raises NoLog: if log file disappeared
        :raises NoData: if log file is the same
        """
        try:
            same_inode = (os.stat(self.filepath).st_ino == self.inode)
        except OSError:
            logging.info("Failed to stat %s", self.filepath)
            raise NoLog()
        else:
            if same_inode:
                raise NoData()
        self.log_file.close()
        self._open_log()

    def _open_log(self):
        """
        Open log file and remember its inode
        """
        try:
            self.log_file = open(self.filepath)
        except OSError as exc:
            self.log.error("Failed to open %s: %s", self.filepath, exc)
            raise NoLog()
        self.inode = os.fstat(self.log_file.fileno()).st_ino


class MultiLogHandler(object):

    """
    Processes multiple access logs
    """

    def __init__(
            self,
            metrics_writers,
            log_glob="/var/log/yabs/access-*.log",
            quantiles=(90, 95, 98, 99, 99.5, 99.8),
            disable_ext_sharded=('bsbts_all', 'market', 'yacofast', 'count'),
            max_lag=60
    ):
        """
        :param str log_glob: logs to handle
        :param bool
        :param set quantiles: quantiles to gather for timings and ext_timings
        :param set disable_ext_sharded: external calls that
            are not recorded per-shard
        :param int max_lag: log lines that are older than
            max_lag seconds are dropped
        """
        self.log_glob = log_glob
        self.quantiles = set(quantiles)
        self.disable_ext_sharded = set(disable_ext_sharded)
        self.max_lag = max_lag
        self.metrics_writers = metrics_writers
        self.known_logs = {}

    def run(self):
        """Refresh log handlers and handle logs"""
        while True:
            self._refresh_log_handlers()
            if self.known_logs:
                self._handle_known_logs()
            else:
                logging.info("No logs, sleeping 1 s")
                time.sleep(1)

    def _refresh_log_handlers(self):
        """
        Create LogHandlers for new logs.
        Delete LogHandlers for logs that are long dead.
        """

        existing_roles = set()
        for log_path in glob.glob(self.log_glob):
            match = re.match(r'access-(\w+).log', os.path.basename(log_path))
            if match is None:
                continue
            role = match.group(1)
            if role in self.known_logs:
                existing_roles.add(role)
            else:
                try:
                    log = LogHandler(
                        role=role,
                        filepath=log_path,
                        max_lag=self.max_lag,
                        quantiles=self.quantiles,
                        disable_ext_sharded=self.disable_ext_sharded,
                        metrics_writers=self.metrics_writers
                    )
                except NoLog:
                    logging.info("%s log suddenly disappeared", role)
                else:
                    self.known_logs[role] = log
                    existing_roles.add(role)

        known_roles = set(self.known_logs.keys())
        logging.info("Known roles: %s, existing roles: %s",
                     ', '.join(known_roles),
                     ', '.join(existing_roles)
                     )
        for dead_role in known_roles - existing_roles:
            if self.known_logs[dead_role].get_lag() > 120:
                logging.info("Stopping to handle %s log", dead_role)
                del self.known_logs[dead_role]

    def _handle_known_logs(self):
        """
        We run for 3 seconds or until some log disappears
        """
        rotated_logs = set()
        deadline = time.time() + 3
        while not rotated_logs:
            next_wakeup = time.time() + 0.5

            for role, log in self.known_logs.iteritems():
                try:
                    log.parse_lines()
                except NoData:
                    pass
                except NoLog:
                    logging.info("%s log disappeared", role)
                    rotated_logs.add(role)
            now = time.time()
            if now > deadline:
                break
            if now < next_wakeup:
                time.sleep(next_wakeup - now)


def send_metrics(metrics):
    """
    Send metrics to graphite
    """
    try:
        sender_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        sender_socket.connect(('localhost', 42000))
        sender_socket.sendall(metrics)
        sender_socket.close()
    except Exception:
        logging.exception("Failed to send data")
