import os
import re
import sys
import gzip
import time
import random
import threading

from ..kernel_util import logging
from ..kernel_util.logging import renameLevels as rename_logging_levels  # noqa
from .utils import Path
from .greendeblock import Deblock

import gevent
import six
import py


class RotatingHandler(logging.handlers.BaseRotatingHandler):
    """
    Rotating logging handler with mixed logic of TimingRotatingFileHandler
    and base RotatingFileHandler.

    Rotates logfile at :arg tuple when: (e.g. (1, 0, 0) means 1:00AM every
    night). Rotated file name in this case will be <filename>.%Y-%m-%d.

    Additionally rotates forcibly by size. If size of logfile will be bigger than
    :arg int max_bytes: it will forcibly rotate.

    So, this could produce this files:

        filename => filename.1111-11-11  < current log file (symlink)
        filename.1111-11-11        < file rotated by size  (#1 of 11 day)
        filename.1111-11-11.001    < also rotated by size  (#2 of 11 day)
        filename.1111-11-11.002    < rotated at midnight   (#3 of 11 day)
        filename.1111-11-10        < rotated at midnight
        filename.1111-11-9         < rotated at midnight

    This way date in filename suffix always means day logfile belongs to.

    If program will be turned off in 01 day at 23:59, rotate will not be performed.
    But if you start it again at 03 day, it will rotate first logfile at mark it
    with 01 day suffix. This is done by looking into logfile modification time, and
    if it does not match current day -- rollover to that day.

    As a bonus - automatically reopen logfile if it's inode changed. So rotating by
    external tools is supported as well.
    """

    suffix = '%Y-%m-%d'
    suffix2 = '%03d'
    extMatch = re.compile(r'^\d{4}-\d{2}-\d{2}(\.\d{3})?(.gz)?$')

    def __init__(self, filename, max_bytes=10 * 1024 * 1024, backup_count=5, time_diff=True):
        assert isinstance(max_bytes, int)
        assert max_bytes > 0
        assert isinstance(backup_count, int)
        assert backup_count >= 0
        assert isinstance(time_diff, (bool, int))

        self._filename = Path(filename)
        self._currentInode = None
        self._time_diff = time_diff
        self._stat = None
        self._dirpath = self._filename.dirpath()
        self._max_bytes = max_bytes
        self._backup_count = backup_count
        self._compressThread = None
        self._lock = threading.Lock()
        self._emit_lock = threading.Lock()
        dblk_log = logging.getLogger('nonexistent')  # to avoid deadlock with self
        dblk_log.propagate = False
        if not dblk_log.handlers:
            dblk_log.addHandler(logging.StreamHandler(sys.stderr))
        self._deblock = Deblock(logger=dblk_log, name='log ' + self._filename.basename)
        dfn = self._getDfn()
        logging.handlers.BaseRotatingHandler.__init__(self, dfn, mode='a')

        if self._filename.check(exists=1) or self._filename.check(link=1):
            self._filename.remove()

        self._filename.mksymlinkto(self._dirpath.bestrelpath(Path(dfn)))
        try:
            self._slave_handler = logging.handlers.SysLogHandler('/dev/log')
        except Exception:
            self._slave_handler = None

    def shouldRollover(self, record):
        if not os.path.exists(self.baseFilename):
            self._reopen()
            self._cleanup()
            self._compress()
            return False

        self._getStat()
        if self._stat and self._stat.st_size >= self._max_bytes:
            return True

        if self._needReopen():
            self._reopen()
            self._cleanup()
            self._compress()

    def doRollover(self):
        self._prepareFileSlot(self._getDfn())
        self._reopen()
        self._cleanup()
        self._compress()

    def _needReopen(self):
        if self._getDfn() != self.baseFilename:
            return True

        prevInode = self._currentInode
        self._getStat()
        return self._currentInode != prevInode or self._currentInode is None

    def _getDfn(self, base=None):
        timeTuple = time.gmtime(int(time.time()) - self._getTimeDiff())
        dfn = '%s.%s' % (self._filename.strpath, time.strftime(self.suffix, timeTuple))
        return dfn

    def _getStat(self):
        try:
            self._stat = os.stat(self.baseFilename)
            self._currentInode = self._stat.st_ino
        except OSError:
            self._stat = None
            self._currentInode = None
        return self._stat

    def _open(self):
        stream = logging.handlers.BaseRotatingHandler._open(self)
        self._getStat()
        return stream

    def _reopen(self):
        if self.stream:
            self.stream.close()
            self.stream = None
        self.baseFilename = self._getDfn(self._filename.strpath)
        self.stream = logging.handlers.BaseRotatingHandler._open(self)

        if self._filename.check(exists=1) or self._filename.check(link=1):
            self._filename.remove()

        # To make replacing symlink operation atomic, we create
        # new symlink and rename (atomically) new to old.

        filename_tmp = Path(self._filename + '_tmp')
        if filename_tmp.check(exists=1) or filename_tmp.check(link=1):
            filename_tmp.remove()

        filename_tmp.mksymlinkto(self._dirpath.bestrelpath(Path(self.baseFilename)))
        filename_tmp.rename(self._filename)

    def _prepareFileSlot(self, dfn):
        """
        Find next filename we should use for rotation.
        This will fix any "clashes" (e.g. missing by sequence logs)
        """
        with self._lock:
            backups = self._getBackupNames()
            currentBackups = []
            for fn in backups:
                if fn.strpath.startswith(dfn):
                    currentBackups.append(fn)

            randomSuffix = ''.join(random.sample('abcdefghijklmnopqrstuvwxyz', 5))

            newBackups = []
            for idx, fn in enumerate(currentBackups, 1):
                tfn = Path(dfn + '.' + (self.suffix2 % (idx, )) + '_' + randomSuffix)
                fn.rename(tfn)
                newBackups.append((tfn, fn.ext))

            for backup, ext in newBackups:
                fn = backup.strpath[:-(len(randomSuffix) + 1)]
                if ext == '.gz':
                    fn += '.gz'
                backup.rename(fn)

            return dfn

    def _getTimeDiff(self):
        if self._time_diff is True:
            return time.altzone
        elif self._time_diff is False:
            return 0
        else:
            return self._time_diff

    def _sort_backups(self, a, b):
        a, b = a.strpath, b.strpath
        if a.endswith('.gz'):
            a = a[:-3]
        if b.endswith('.gz'):
            b = b[:-3]
        return cmp(a, b)

    def _getBackupNames(self):
        prefix = self._filename.strpath + '.'
        plen = len(prefix)
        result = []

        for filename in self._dirpath.listdir():
            if filename.strpath[:plen] == prefix:
                suffix = filename.strpath[plen:]
                if self.extMatch.match(suffix):
                    result.append(filename)

        result.sort(cmp=self._sort_backups)
        return result

    def _cleanup(self):
        with self._lock:
            result = self._getBackupNames()

            prefix = self._filename.strpath + '.'
            plen = len(prefix)

            while len(result) > self._backup_count:
                days = []
                for filename in result:
                    if filename.strpath[:plen] == prefix:
                        suffix = filename.strpath[plen:]
                        match = self.extMatch.match(suffix)
                        if match:
                            if suffix.endswith('.gz'):
                                suffix = suffix[:-3]
                            if match.groups()[0] is not None:
                                suffix = suffix[:-len(match.groups()[0])]
                            if suffix not in days:
                                days.append(suffix)

                oldestDay = sorted(days)[0]

                left = len(result) - self._backup_count
                for filename in reversed(result):
                    suffix = filename.strpath[plen:]
                    if suffix.startswith(oldestDay):
                        filename.remove()
                        result.remove(filename)
                        left -= 1

                    if left == 0:
                        break

    def _compress(self):
        def _runner():
            try:
                while True:
                    with self._lock:
                        toCompress = [
                            x for x in self._getBackupNames()
                            if x.ext != '.gz' and x.strpath != self.baseFilename
                        ]
                        if not toCompress:
                            break

                        fn = toCompress[0]
                        fngz = fn.dirpath().join(fn.basename + '.gz')
                        try:
                            fpgz = gzip.open(fngz.strpath, mode='wb')
                            with fn.open(mode='rb') as fp:
                                while 1:
                                    data = fp.read(4 * 1024 * 1024)
                                    if not data:
                                        break
                                    fpgz.write(data)
                        except py.error.ENOENT:
                            continue
                        finally:
                            fpgz.close()

                            try:
                                fn.remove()
                            except Exception:
                                pass

            finally:
                self._compressThread = None

        if self._compressThread is None or not self._compressThread.isAlive():
            thr = threading.Thread(target=_runner)
            thr.daemon = True
            self._compressThread = thr
            thr.start()

    def _emit(self, record):
        with self._emit_lock:
            return super(RotatingHandler, self).emit(record)

    def emit(self, record):
        try:
            if self._deblock and isinstance(threading.current_thread(), threading._MainThread):
                return self._deblock.apply(self._emit, record)
            else:
                return self._emit(record)
        except Exception:
            self.handleError(record)

    def createLock(self):
        self.lock = None  # everything is done with deblock

    def handleError(self, record):
        if self._slave_handler is not None:
            try:
                self._slave_handler.emit(record)
            except Exception:
                pass
            else:
                return

        logging.handlers.BaseRotatingHandler.handleError(self, record)

    def stop_async(self):
        self._deblock.stop()
        self._deblock = None


class SmartLoggerAdapter(logging.LoggerAdapter):
    def __init__(self, logger, extra):
        super(SmartLoggerAdapter, self).__init__(logger, extra)
        self.name = logger.name

    def get_child(self, *args, **kwargs):
        return type(self)(self.logger.getChild(*args, **kwargs), self.extra)

    def get_parent(self):
        get_parent = getattr(self.logger, 'get_parent', None)
        parent = get_parent() if get_parent else self.logger.parent
        if parent is None:
            return None
        return type(self)(parent, self.extra)

    @staticmethod
    def get_local(key):
        grn = gevent.getcurrent()
        local = getattr(grn, 'slocal', getattr(grn, 'tlocal', {}))
        return local.get(key)

    getChild = get_child

    warn = logging.LoggerAdapter.warning
    normal = logging.LoggerAdapter.info


class SkynetFormatter(logging.Formatter):
    def __init__(self):
        logging.Formatter.__init__(self)

    def formatTime(self, record, datefmt=None):
        """
        copy paste from standard logging, except we use
        '.' to delimit milliseconds.
        :param logging.LogRecord record:
        :param str datefmt:
        :return str:
        """
        ct = self.converter(record.created)
        if datefmt:
            s = time.strftime(datefmt, ct)
        else:
            t = time.strftime("%Y-%m-%d %H:%M:%S", ct)
            s = "%s.%03d" % (t, record.msecs)
        return s

    def format(self, record):
        levelno = record.levelno
        if levelno > 5:
            level = '[%-1s]' % logging.getLevelName(levelno)
        else:
            level = '(%s)' % (str(levelno) if levelno < 0 else ' %d' % levelno)

        date = self.formatTime(record)
        message = record.getMessage()
        header = '{0} {1} [{2}]  '.format(date, level, record.name)

        if record.exc_info:
            # Cache the traceback text to avoid converting it multiple times
            # (it's constant anyway)
            if not record.exc_text:
                record.exc_text = self.formatException(record.exc_info)

        if record.exc_text:
            if message[-1:] != "\n":
                message += "\n"
            try:
                message += record.exc_text
            except UnicodeError:
                # Sometimes filenames have non-ASCII chars, which can lead
                # to errors when s is Unicode and record.exc_text is str
                # See issue 8924.
                # We also use replace for when there are multiple
                # encodings, e.g. UTF-8 for the filesystem and latin-1
                # for a script. See issue 13232.
                message += record.exc_text.decode(sys.getfilesystemencoding(),
                                                  'replace')

        if '\n' in message:
            # special case for multi-line log messages
            messageLines = message.strip().split('\n')
            line = [header + messageLines[0]]
            prepend = '%s%s' % (' ' * (len(header) - 2), ': ')
            line.extend(['%s%s' % (prepend, l) for l in messageLines[1:]])
            line = '\n'.join(line)
        else:
            line = '{header}{message}'.format(header=header, message=message)

        return line


def adapt_logger(logger):
    return SmartLoggerAdapter(logger, {})


_rotating_handlers = set()


def initialize_log(logdir, basename, max_bytes, backup_count, root_log=logging.getLogger(''), level=logging.DEBUG):
    rename_logging_levels(1)

    if isinstance(root_log, six.string_types):
        root_log = logging.getLogger(root_log)

    root_log.propagate = False
    root_log.setLevel(level)

    if os.getenv('SKYNET_LOG_STDOUT', False) == '1':
        os.environ.pop('SKYNET_LOG_STDOUT', None)

        console_handler = logging.StreamHandler(sys.stderr)
        console_handler.setLevel(logging.DEBUG)
        formatter = logging.ColoredFormatter(
            fmt='%(asctime)s %(levelname)s %(name)s  %(message)s',
            fmtName='[%(name)-30s]',
            fmtLevelname='[%(levelname)-1s]',
            fmtAsctime='%(datetime)s.%(msecs)003d',
            hungryLevels={
                'info': ['message'],
                'debug': ['message'],
                'warning': ['message'],
                'error': ['message', 'name']
            }
        )
        console_handler.setFormatter(formatter)
        root_log.addHandler(console_handler)

    handler = RotatingHandler(
        logdir.join(basename + '.log').strpath,
        max_bytes=max_bytes,
        backup_count=backup_count
    )
    _rotating_handlers.add(handler)

    handler.setFormatter(SkynetFormatter())
    handler.setLevel(level)
    root_log.addHandler(handler)

    return root_log


class HierarchicalLoggerFactory(object):
    def __init__(self, logdir, logger):
        if isinstance(logdir, six.string_types):
            logdir = Path(logdir)

        if isinstance(logger, six.string_types):
            logger = logging.getLogger(logger)

        self.logdir = logdir
        self.logger = logger

    def child(self, name):
        return type(self)(self.logdir.join(name), self.logger.getChild(name))

    def make_logger(self, name, suffix=None):
        logdir = self.logdir.join(name)
        logdir.ensure(dir=True)

        name = name if suffix is None else "%s-%s" % (name, suffix)
        logger = self.logger.getChild(name)
        if not logger.handlers:
            logger.propagate = False
            handler = RotatingHandler(
                logdir.join(name + '.log').strpath,
                max_bytes=50 * 1024 * 1024,
                backup_count=14,
            )
            _rotating_handlers.add(handler)
            handler.setFormatter(SkynetFormatter())
            handler.setLevel(logging.DEBUG)
            logger.addHandler(handler)

        return logger

    def shutdown(self):
        for handler in _rotating_handlers:
            handler.stop_async()
