from __future__ import unicode_literals
import os
import hashlib
import tarfile
import filecmp

from nanny_sox_audit import fileutils


class ResourceChecker(object):

    def __init__(self, root, resource, checksum):
        self.root = root
        self.resource = resource
        self.checksum = checksum

        self._chunk_size = 1024 * 1024

    def check_resource_checksum(self):
        h = hashlib.md5()
        with open(self.resource, 'rb') as f:
            while True:
                chunk = f.read(self._chunk_size)
                if not chunk:
                    return self.checksum == h.hexdigest()
                h.update(chunk)

    def check(self):
        try:
            ok_resource = self.check_resource_checksum()
            if not ok_resource:
                return 'CRIT', '{}: resource modified'.format(self.resource)
            return 'OK', '{}: ok'.format(self.root)
        except (IOError, OSError) as e:
            return 'CRIT', str(e)


class ArchiveChecker(ResourceChecker):

    SRC_TMP_DIR = 'src_checker_sources'

    def __init__(self, root, arc, arc_checksum):
        ResourceChecker.__init__(self, root, arc, arc_checksum)

        self._sources_checked = False
        self._files = []
        self._max_mtime = None

    def check_sources(self):
        common = []
        src_tmp_path = os.path.join(os.getcwd(), self.SRC_TMP_DIR)
        with fileutils.force_create_tmp_dir(path=src_tmp_path) as tmp:
            with tarfile.open(self.resource, 'r') as tar:
                tar.extractall(tmp)

            for f1 in fileutils.yield_source_files(tmp):
                n = os.path.basename(f1)
                d = os.path.relpath(os.path.dirname(f1), tmp)
                p = os.path.normpath(os.path.join(d, n))
                common.append(p)
                f2 = os.path.abspath(os.path.join(self.root, p))
                self._files.append(f2)

            _, mis, err = filecmp.cmpfiles(self.root, tmp, common)
            if mis or err:
                return False

        self._max_mtime = fileutils.get_max_mtime(self._files)
        self._sources_checked = True
        return True

    def check_last_mtime(self):
        max_mtime = fileutils.get_max_mtime(self._files)
        return abs(max_mtime - self._max_mtime) <= 0.1

    def check(self):
        try:
            if self._sources_checked:
                ok_mtime = self.check_last_mtime()
                if not ok_mtime:
                    return 'CRIT', '{}: last mtime modified'.format(self.root)
            else:
                ok_archive = self.check_resource_checksum()
                if not ok_archive:
                    return 'CRIT', '{}: archive modified'.format(self.resource)
                ok_sources = self.check_sources()
                if not ok_sources:
                    return 'CRIT', '{}: sources modified'.format(self.root)
            return 'OK', '{}: ok'.format(self.root)
        except (IOError, OSError) as e:
            return 'CRIT', str(e)
