import os
import six
import sys
import imp
import zipfile
import marshal

from itertools import product
from threading import RLock
from ya.skynet.util.errors import saveTraceback, getTraceback, setTraceback


class InMemoryImporter(object):
    PATH = '<IN-MEMORY>'

    def __init__(self, data):
        self.import_lock = RLock()
        self.data = data
        if six.PY2:
            io = six.moves.cStringIO(self.data)
        else:
            io = six.BytesIO(self.data)
        self.storage = zipfile.ZipFile(io, "r")
        self.filenames = set(self.storage.namelist())
        self.modnames = set()
        for fn in self.filenames:
            if fn.endswith('.pyc'):
                self.modnames.add(self._modname(fn[:-4]))
            elif fn.endswith('.py'):
                self.modnames.add(self._modname(fn[:-3]))

    def _modname(self, filename):
        parts = filename.split(os.path.sep)
        if parts[-1] == '__init__':
            parts = parts[:-1]
        return '.'.join(parts)

    def _module_basename(self, name):
        return os.path.sep.join(name.split('.'))

    def _package_basename(self, name):
        return os.path.sep.join(name.split('.') + ['__init__'])

    def _basenames(self, name):
        yield self._package_basename(name)
        yield self._module_basename(name)

    def find_module(self, fullname, path=None):
        if fullname in self.modnames:
            return self

    def load_module(self, fullname):
        if fullname in sys.modules:
            return sys.modules[fullname]

        err = ImportError("no such module: {}".format(fullname))
        e = None
        for basename, ext in product(self._basenames(fullname), ('.pyc', '.py')):
            name = basename + ext
            if name in self.filenames:
                with self.import_lock:
                    try:
                        mod = sys.modules.setdefault(fullname, imp.new_module(fullname))
                        mod.__name__ = fullname
                        mod.__file__ = self.PATH + '/' + name
                        if basename.endswith(os.path.sep + '__init__'):
                            mod.__path__ = []
                            mod.__package__ = fullname
                        else:
                            mod.__package__ = '.'.join(fullname.split('.')[:-1])
                        mod.__loader__ = self

                        data = self.storage.open(name, 'r').read()
                        if ext == '.pyc':
                            data = marshal.loads(data[8:])
                        six.exec_(data, mod.__dict__)

                        # because some of the python developers are faggots and will end
                        # their days in hell (see 'werkzeug' authors), they replace module
                        # itself during init. So we must set all its properties again
                        # (we cannot do it once because init may fail otherwise)
                        mod = sys.modules[fullname]
                        mod.__name__ = fullname
                        mod.__file__ = self.PATH + '/' + name
                        if basename.endswith(os.path.sep + '__init__'):
                            mod.__path__ = []
                            mod.__package__ = fullname
                        else:
                            mod.__package__ = '.'.join(fullname.split('.')[:-1])
                        mod.__loader__ = self
                    except Exception as ex:
                        saveTraceback(ex)
                        ex.filename = name
                        sys.modules.pop(fullname, None)
                        e = ex
                        continue

                    return mod
        if e is not None:
            err = ImportError("Cannot import module {} from {}: {}".format(fullname, e.filename, e))
            setTraceback(err, getTraceback(e))
        raise err

    def get_source(self, fullname):
        if fullname in self.modnames:
            for basename in self._basenames(fullname):
                name = basename + '.py'
                if name in self.filenames:
                    with self.import_lock:
                        data = self.storage.open(name, 'r').read()
                        if six.PY3:
                            return data.decode('utf-8')
                        return data

        raise ImportError("No sources for module {}".format(fullname))

    def get_code(self, fullname):
        if fullname in self.modnames:
            for basename in self._basenames(fullname):
                name = basename + '.pyc'
                if name in self.filenames:
                    with self.import_lock:
                        data = self.storage.open(name, 'r').read()[8:]
                        try:
                            return marshal.loads(data)
                        except (TypeError, ValueError):  # wrong pyc version
                            pass

    def is_package(self, fullname):
        if fullname not in self.modnames:
            return False

        name = self._package_basename(fullname)
        for ext in ('.py', '.pyc'):
            if (name + ext) in self.filenames:
                return True

        return False

    def package_submodules(self, fullname):
        if fullname not in self.modnames:
            return []

        return [x for x in self.modnames if x.startswith(fullname + '.')]
