from __future__ import absolute_import

import logging
import tarfile
import hashlib
import sys
import os


# TODO support xz compression
# NOTE: TarFile.open, tarfile._Stream.__init__ and tarfile._StreamProxy.getcomptype are needed to be patched
# NOTE: or just migrate to python 3.3+ ahaha

class TarFile(tarfile.TarFile):
    """Safe tarfile: do not operate on files with unsafe names, fsync files after extraction"""

    def __init__(self, *args, **kwargs):
        self.log = kwargs.pop('log', None)
        self.fsyncqueue = kwargs.pop('fsyncqueue', None)
        self._md5 = None
        super(TarFile, self).__init__(*args, **kwargs)

    def _dbg(self, level, msg):
        log_level = {
            1: logging.INFO,
            2: logging.DEBUG,
        }.get(level, logging.WARNING)
        if self.log is not None:
            self.log.log(log_level, '%s', msg)
        elif level <= self.debug:
            print >>sys.stderr, msg

    def __iter__(self):
        if self._loaded:
            return iter(self.members)
        else:
            return TarIter(self)

    def next(self):
        result = super(TarFile, self).next()
        while result is not None and not is_valid(result):
            self._dbg(1, 'skipped not secure path: %r' % (result.name,))
            self.members.pop()
            result = super(TarFile, self).next()

        return result

    MD5_CHUNK_SIZE = 1 << 14  # 16 KB

    def _calculate_md5(self):
        md5 = hashlib.md5()

        for member in sorted(self, key=lambda tarinfo: tarinfo.name):
            name = os.path.relpath(member.name, '.')
            if name == '.':
                continue
            md5.update(name)
            if member.issym():
                md5.update(member.linkname)
            elif not member.isdir():
                tgt = self._find_link_target(member) if member.islnk() else member
                source = self.extractfile(tgt)
                try:
                    while True:
                        data = source.read(self.MD5_CHUNK_SIZE)
                        if not data:
                            break
                        md5.update(data)
                finally:
                    source.close()

            if not member.issym():
                mode = str(member.mode & 0o7777)
                mtime = str(member.mtime)
                md5.update(mode)
                md5.update(mtime)

        result = md5.hexdigest()
        return result

    def md5(self):
        if not self._md5:
            self._md5 = self._calculate_md5()

        return self._md5

    def makefile(self, tarinfo, targetpath):
        source = self.extractfile(tarinfo)
        try:
            with tarfile.bltn_open(targetpath, 'wb') as target:
                tarfile.copyfileobj(source, target)
                if self.fsyncqueue is None:
                    os.fsync(target.fileno())
            if self.fsyncqueue is not None:
                self.fsyncqueue.put(targetpath)
        finally:
            source.close()


class TarIter(tarfile.TarIter):
    """
    Safe iterator over tarfile.
    Skips absolute paths and references to parent.
    """
    def next(self):
        result = tarfile.TarIter.next(self)
        while not is_valid(result):
            self.tarfile._dbg(1, 'skipped not secure path: %r' % (result.name,))
            result = tarfile.TarIter.next(self)

        return result


def is_valid(tarinfo):
    # forbidden filenames
    if '..' in tarinfo.name or tarinfo.name.startswith('/'):
        return False

    if tarinfo.ischr() or tarinfo.isblk() or tarinfo.isfifo() or tarinfo.isdev():
        return False

    if tarinfo.type not in tarfile.SUPPORTED_TYPES:
        return False

    return True


open = TarFile.open
