# coding: utf-8

import heapq
import logging
import io
import os
import os.path
import socket
import struct
import time
import xml.etree.ElementTree
import simplejson as json

from collections import namedtuple
from .log import log_lock

log = logging.getLogger('unistat')


class BaseFileReader(object):
    def __init__(self, path):
        self._path = os.path.realpath(os.path.abspath(path))
        self._fd = None
        self.__offset = None
        self.__inode = None
        self.__prev_reopen = 0

    def __enter__(self):
        self.__open()
        return self

    def __exit__(self, *args):
        self.__close()

    def _alignment(self):
        return 0

    def _reopen(self):
        self.__close()
        self.__open()
        self.__sleep()

    def __open(self):
        self._fd = io.open(self._path, 'rb')
        fstat = os.fstat(self._fd.fileno())
        if self.__inode == fstat.st_ino:
            if self.__offset is not None and self.__offset <= fstat.st_size:
                self._fd.seek(self.__offset)
                with log_lock:
                    log.info('open %s same inode %s, read from %s', self._path, self.__inode, self.__offset)
            else:
                self.__offset = self._fd.tell()
                with log_lock:
                    log.info('open %s same inode %s, read from %s, assume file was truncated',
                             self._path, self.__inode, self.__offset)
        else:
            self.__inode = fstat.st_ino
            self._fd.seek(0, 2)
            if self._alignment() != 0:
                offset = self._fd.tell() % self._alignment()
                self._fd.seek(-offset, 2)
            self.__offset = self._fd.tell()
            with log_lock:
                log.info('open %s new inode %s', self._path, self.__inode)

    def __close(self):
        self.__offset = self._fd.tell()
        self._fd.close()
        with log_lock:
            log.info('close %s with inode %s', self._path, self.__inode)

    def __sleep(self):
        now = time.time()
        if now - self.__prev_reopen < 1:
            time.sleep(1)
            self.__prev_reopen = now + 1
        else:
            self.__prev_reopen = now


class LineReader(object):
    def __init__(self):
        self.__line = ''
        self.__eof = False

    def eof(self):
        return self.__eof

    def __call__(self, fd):
        line = fd.readline()
        if not line:
            self.__eof = True
            return
        self.__eof = False
        self.__line += line
        if self.__line[-1] == '\n':
            yield self.__line[:-1]
            self.__line = ''


class SlowReadDetector(object):
    def __init__(self, max_possible_read_size):
        self.__max_possible_read_size = max_possible_read_size
        self.__times_read_max_possible = 0

    def update(self, read_size):
        if read_size == self.__max_possible_read_size:
            self.__times_read_max_possible += 1
        else:
            self.__times_read_max_possible = 0

    def check(self, threshold):
        return self.__times_read_max_possible >= threshold


class BufferedLineReader(object):
    DEFAULT_SLEEP_TIME = 0.1
    MIN_SLEEP_TIME = 0.01
    MAX_SLEEP_TIME = 1
    DEFAULT_BUFFER_SIZE = 4096
    MIN_BUFFER_SIZE = 64
    MAX_BUFFER_SIZE = 1024 * 1024
    SLOW_READ_THRESHOLD = 3

    def __init__(self):
        self.__sleep_time = self.DEFAULT_SLEEP_TIME
        self.__buffer_size = self.DEFAULT_BUFFER_SIZE
        self.__stored = bytearray()
        self.__buffer = bytearray(self.__buffer_size)
        self.__stored_pos = 0
        self.__buffer_pos = 0
        self.__eof = False
        self.__slow_read_detector = SlowReadDetector(max_possible_read_size=self.MAX_BUFFER_SIZE)

    def eof(self):
        return self.__eof

    def __call__(self, fd):
        read_size = fd.readinto(self.__buffer)
        if not read_size:
            self.__slow_read_detector.update(read_size)
            self.__eof = True
            return
        self.__slow_read_detector.update(read_size)
        self.__eof = False
        new_line_pos = self.__buffer.find('\n')
        if new_line_pos == -1:
            if not self.__stored:
                self.__stored = self.__buffer
                self.__buffer = bytearray(self.__buffer_size)
            else:
                self.__stored += self.__buffer
        else:
            yield (self.__stored[self.__stored_pos:] + self.__buffer[:new_line_pos]).decode('utf-8')
            self.__stored = bytearray()
            self.__stored_pos = 0
            self.__buffer_pos = new_line_pos + 1
            while self.__buffer_pos < read_size:
                new_line_pos = self.__buffer.find('\n', self.__buffer_pos)
                if new_line_pos == -1:
                    self.__stored = self.__buffer
                    self.__stored_pos = self.__buffer_pos
                    self.__buffer = bytearray(self.__buffer_size)
                    break
                else:
                    yield self.__buffer[self.__buffer_pos:new_line_pos].decode('utf-8')
                    self.__buffer_pos = new_line_pos + 1
        self.__sleep(read_size)

    def __sleep(self, read_size):
        if read_size < self.__buffer_size * 0.9:
            if self.__sleep_time < self.MAX_SLEEP_TIME:
                self.__sleep_time = min(self.MAX_SLEEP_TIME, self.__sleep_time * 1.1)
            else:
                self.__sleep_time = self.DEFAULT_SLEEP_TIME
                self.__buffer_size = max(self.MIN_BUFFER_SIZE, self.__buffer_size / 2)
                self.__buffer = bytearray(self.__buffer_size)
            time.sleep(self.__sleep_time)
        elif read_size < self.__buffer_size * 0.95:
            time.sleep(self.__sleep_time)
        else:
            if self.__sleep_time > self.MIN_SLEEP_TIME:
                self.__sleep_time = max(self.MIN_SLEEP_TIME, self.__sleep_time * 0.9)
            else:
                self.__sleep_time = self.DEFAULT_SLEEP_TIME
                self.__buffer_size = min(self.MAX_BUFFER_SIZE, self.__buffer_size * 2)
                self.__buffer = bytearray(self.__buffer_size)
            if self.__slow_read_detector.check(threshold=self.SLOW_READ_THRESHOLD):
                with log_lock:
                    log.warning('read full buffer with max size %s times in row,'
                                + ' probably unistat is too slow', self.SLOW_READ_THRESHOLD)


class TextFileReader(BaseFileReader):
    def __init__(self, path, line_reader=BufferedLineReader):
        super(TextFileReader, self).__init__(path)
        self.__line_reader_builder = line_reader
        self.__line_reader = None

    def __enter__(self):
        self.__line_reader = self.__line_reader_builder()
        super(TextFileReader, self).__enter__()
        return self

    def __exit__(self, *args):
        super(TextFileReader, self).__exit__(*args)
        self.__line_reader = None

    def __call__(self, stop):
        for value in read_until_eof_or_stop(reader=self.__line_reader, fd=self._fd, stop=stop):
            yield value
        if self.__line_reader.eof() and not stop.is_set():
            self._reopen()
            self.__line_reader = self.__line_reader_builder()

    def __str__(self):
        return 'text file %s' % self._path


class UnbufferedTextFileReader(TextFileReader):
    def __init__(self, path):
        super(UnbufferedTextFileReader, self).__init__(path, line_reader=LineReader)


PaRecord = namedtuple('PaRecord', ('type', 'host', 'req', 'suid', 'spent_ms', 'timestamp'))


class PaReader(object):
    FIELDS = (
        ('type', 4, lambda v: struct.unpack('@I', v)[0]),
        ('host', 16, lambda v: ''.join(chr(w) for w in v if w)),
        ('req', 24, lambda v: ''.join(chr(w) for w in v if w)),
        ('suid', 16, lambda v: ''.join(chr(w) for w in v if w)),
        ('spent_ms', 4, lambda v: struct.unpack('@I', v)[0]),
        ('timestamp', 4, lambda v: struct.unpack('@I', v)[0]),
    )
    RECORD_SIZE = sum(v[1] for v in FIELDS)
    assert RECORD_SIZE == 68

    def __init__(self):
        self.__binary_line = bytearray()
        self.__eof = False

    def eof(self):
        return self.__eof

    def __call__(self, fd):
        part = fd.read(self.RECORD_SIZE - len(self.__binary_line))
        if not part:
            self.__eof = True
            return
        self.__eof = False
        self.__binary_line += part
        if len(self.__binary_line) == self.RECORD_SIZE:
            yield self.__parse_binary_record(self.__binary_line)
            self.__binary_line = bytearray()

    def __parse_binary_record(self, binary_line):
        fields = dict()
        position = 0
        for name, size, parse in self.FIELDS:
            fields[name] = parse(binary_line[position:position + size])
            position += size
        return PaRecord(**fields)


class PaLogReader(BaseFileReader):
    def __init__(self, path, pa_reader=PaReader):
        super(PaLogReader, self).__init__(path)
        self.__pa_reader_builder = pa_reader
        self.__pa_reader = None

    def __enter__(self):
        self.__pa_reader = self.__pa_reader_builder()
        super(PaLogReader, self).__enter__()
        return self

    def __exit__(self, *args):
        super(PaLogReader, self).__exit__(*args)
        self.__pa_reader = None

    def __call__(self, stop):
        for value in read_until_eof_or_stop(reader=self.__pa_reader, fd=self._fd, stop=stop):
            yield value
        if self.__pa_reader.eof() and not stop.is_set():
            self._reopen()
            self.__pa_reader = self.__pa_reader_builder()

    def _alignment(self):
        return self.__pa_reader.RECORD_SIZE

    def __str__(self):
        return 'pa log %s' % self._path


def read_until_eof_or_stop(reader, fd, stop):
    count = 0
    sleeped = False
    while True:
        count += 1
        if count % 1000 == 0 and stop.is_set():
            break
        read_count = 0
        for value in reader(fd):
            yield value
            read_count += 1
        if read_count:
            sleeped = False
        else:
            if stop.is_set():
                break
            if reader.eof():
                if sleeped:
                    break
                else:
                    time.sleep(1)
                    sleeped = True


class StatServerReader(object):
    def __init__(self, path, interval=1):
        self.__host, self.__port = path.rsplit(':', 1)
        self.__port = int(self.__port)
        self.__interval = interval
        self.__previous = 0

    def __enter__(self):
        with log_lock:
            log.info('watch stat server at %s:%s', self.__host, self.__port)
        self.__previous = 0
        return self

    def __exit__(self, *args):
        pass

    def __call__(self, stop):
        now = time.time()
        elapsed = now - self.__previous
        if elapsed < self.__interval:
            time.sleep(self.__interval - elapsed)
        with TcpConnection(host=self.__host, port=self.__port) as connection:
            result = connection.read(stop)
        self.__previous = time.time()
        yield result

    def __str__(self):
        return 'stat server at %s:%s' % (self.__host, self.__port)


class StatServerXmlReader(StatServerReader):
    def __call__(self, stop):
        for data in super(StatServerXmlReader, self).__call__(stop):
            if data:
                yield xml.etree.ElementTree.fromstring(data)


class StatServerJsonReader(StatServerReader):
    def __call__(self, stop):
        for data in super(StatServerJsonReader, self).__call__(stop):
            if data:
                yield json.loads(data)


class TcpConnection(object):
    def __init__(self, host, port):
        self.__host = host
        self.__port = port
        self.__socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)

    def __enter__(self):
        self.__socket.connect((self.__host, self.__port))
        return self

    def __exit__(self, _exc_type, _exc_val, _exc_tb):
        self.__socket.close()

    def read(self, stop):
        result = str()
        while True:
            if stop.is_set():
                return
            part = self.__socket.recv(4096)
            if not part:
                break
            result += part
        return result


class CoreDumpChecker(object):
    __previous = None

    def __init__(self, path, interval=1):
        self.__path = os.path.abspath(path)
        self.__interval = interval
        self.__seen = set()
        self.__queue = list()
        self.__path_exists = None

    def __enter__(self):
        with log_lock:
            log.info('watch core dumps in %s', self.__path)
        self.__previous = 0
        return self

    def __exit__(self, *args):
        pass

    def __call__(self, stop):
        if not self.__queue:
            now = time.time()
            elapsed = now - self.__previous
            if elapsed < self.__interval:
                time.sleep(self.__interval - elapsed)
            if os.path.exists(self.__path):
                self.__path_exists = True
                for name in os.listdir(self.__path):
                    if name.count('.') != 3:
                        continue
                    stat = os.stat(os.path.join(self.__path, name))
                    if (stat.st_ino, stat.st_mtime) in self.__seen:
                        continue
                    self.__seen.add((stat.st_ino, stat.st_mtime))
                    heapq.heappush(self.__queue, CoreDumpFile(name=name, mtime=stat.st_mtime))
            elif self.__path_exists or self.__path_exists is None:
                log.warning("cores path doesn't exist: %s", self.__path)
                self.__path_exists = False
            self.__previous = time.time()
        if self.__queue:
            yield heapq.heappop(self.__queue)

    def __str__(self):
        return 'core dumps in %s' % self.__path


CoreDumpFile = namedtuple('CoreDumpFile', ('mtime', 'name'))
