import gevent.pool
import gevent.queue
import msgpack
import os
import struct
import sys
import threading
import time
import types

import cPickle as pickle

from api.copier import errors

from .. import subprocess_gevent as subproc
from ..component import Component
from ..hasher import rehash
from ..utils import human_size, human_time


class Hasher(Component):
    def __init__(self, hasher_binary, max_parallel=16, parent=None):
        super(Hasher, self).__init__(logname='hasher', parent=parent)
        self.hasher_binary = hasher_binary

        # Pool used for share in parallel in one call
        self.pool = gevent.pool.Pool(size=max_parallel)

    @staticmethod
    def convert_sha1_blocks_from_lt_form(pieces, piece_length):
        assert len(pieces) % 20 == 0

        sha1_blocks = [1]  # 1 is the version of sha1_blocks structure

        for idx in range(0, len(pieces) // 20):
            start = idx * 20
            end = start + 20

            sha1_blocks.append((piece_length, pieces[start:end]))

        return sha1_blocks

    def rehash_direct(self, path):
        def _progress(idx):
            gevent.sleep()  # cooperative yield

        return rehash(path, progress_func=_progress)

    def _rehash_in_thread_helper(self, path, pipe, done_cb, alive, progress=None):
        def _autostop():
            # This is used stop forcibly stop thread if alive() flag is False
            if not alive():
                raise gevent.GreenletExit()

        if progress:
            def _on_progress(idx, bytes):
                _autostop()
                msg = msgpack.dumps(('progress', {'hashed': bytes}))
                os.write(pipe, struct.pack('!I', len(msg)) + msg)
            progress_func = _on_progress
        else:
            def progress_func(*args, **kwargs):
                return _autostop

        try:
            done_cb((rehash(path, progress_func=progress_func), ))
        except BaseException as ex:
            done_cb(('error', pickle.dumps(ex)))
        finally:
            if alive():
                msg = msgpack.dumps(('finished', ))
                os.write(pipe, struct.pack('!I', len(msg)) + msg)

    def _fd_read_bytes(self, pipe, count, timeout):
        data = []
        left = count

        if timeout is not None:
            deadline = time.time() + timeout
        else:
            deadline = None

        while left:
            if deadline:
                timeout = max(0, deadline - time.time())
                if not timeout:
                    raise gevent.socket.timeout('timed out')
            else:
                timeout = None

            gevent.socket.wait_read(pipe, timeout)

            buf = os.read(pipe, left)

            if not buf:
                return ''

            buflen = len(buf)
            left -= buflen
            data.append(buf)

        return ''.join(data)

    def rehash_in_thread(self, path, progress=None, timeout=60):
        rpipe, wpipe = os.pipe()

        alive = [True]
        ret = []

        thr = threading.Thread(
            target=self._rehash_in_thread_helper,
            args=(path, wpipe, ret.append, lambda: alive[0], progress)
        )
        thr.daemon = True
        thr.start()

        try:
            if timeout is not None:
                deadline = time.time() + timeout
            else:
                deadline = None

            while True:
                if deadline:
                    timeout = deadline - time.time()
                else:
                    timeout = None

                try:
                    msglen = struct.unpack('!I', self._fd_read_bytes(rpipe, 4, timeout))[0]
                except gevent.socket.timeout as ex:
                    raise errors.CopierError('Unable to hash in thread: %s' % (str(ex), ))

                if deadline:
                    timeout = deadline - time.time()
                else:
                    timeout = None

                try:
                    msg = self._fd_read_bytes(rpipe, msglen, timeout)
                except gevent.socket.timeout as ex:
                    raise errors.CopierError('Unable to hash in thread: %s' % (str(ex), ))

                msg = msgpack.loads(msg)

                if msg:
                    if msg[0] == 'finished':
                        break
                    elif msg[0] == 'progress':
                        progress(msg[1]['hashed'])
                else:
                    # We got corrupt msg, maybe thread died
                    break

            if not ret:
                raise errors.CopierError('Unable to hash in thread: empty result')

            result = ret[0]
            if len(result) == 2:
                ex = pickle.loads(result[1])
                raise ex

            return result[0]
        except BaseException as ex:
            alive[0] = False
            raise
        finally:
            os.close(rpipe)
            os.close(wpipe)

    def rehash_in_subprocess(self, path, progress=None, timeout=None):
        proc = subproc.Popen(
            [self.hasher_binary, path.strpath, '-omsgpack'] + (['--progress'] if progress else []),
            close_fds=True,
            stdout=subproc.PIPE, stderr=subproc.PIPE
        )

        if progress:
            state = ['', '']

            def _stdout_reader():
                while True:
                    data = self._fd_read_bytes(proc.stdout.fileno(), 4, None)
                    if not data:
                        break

                    llen = struct.unpack('!I', data)[0]
                    msg = msgpack.loads(self._fd_read_bytes(proc.stdout.fileno(), llen, None))

                    if not msg:
                        break

                    if msg[0] == 'progress':
                        progress(msg[2])
                    else:
                        state[0] = struct.pack('!I', llen) + msgpack.dumps(msg)

            def _stderr_reader():
                while True:
                    gevent.socket.wait_read(proc.stderr.fileno())
                    data = proc.stderr.read()
                    if not data:
                        break
                    state[1] += data

            grn_stdout = gevent.spawn(_stdout_reader)
            grn_stderr = gevent.spawn(_stderr_reader)

            try:
                grn_stdout.join()
                grn_stderr.join()
            except BaseException:
                grn_stdout.kill()
                grn_stderr.kill()
                proc.kill()
                raise

            grn_stdout.get()
            grn_stderr.get()

            stdout, stderr = state
        else:
            try:
                stdout, stderr = proc.communicate()
            except BaseException:
                proc.kill()
                raise

        if proc.returncode is None:
            proc.kill()
            proc.wait()

        if proc.returncode == 2:
            raise errors.CopierError('skybone-hasher: msgpack problem (%s)' % (stderr, ))

        if stderr != '' and proc.returncode != 0:
            raise errors.CopierError(
                'skybone-hasher: failed with stderr: %s, return code %d' % (stderr, proc.returncode)
            )

        # Copier hasher output <int>datalen + <msgpack>
        # But we ignore datalen completely

        unpacked = msgpack.loads(stdout[4:])

        if not isinstance(unpacked, (list, tuple)):
            raise errors.CopierError('skybone-hasher: msgpack problem: %r (stderr: %r)' % (stdout, stderr))

        if len(unpacked) == 2:
            # Reconstruct exception here
            (error, maker, args), trace = unpacked
            maker_module, maker_name = maker.rsplit('.', 1)
            maker_meth = getattr(__import__(maker_module, fromlist=[0]), maker_name)

            if isinstance(maker_meth, types.FunctionType) or not issubclass(maker_meth, BaseException):
                args = list(args)
                exception_module, exception_name = args[0].rsplit('.', 1)
                exception_class = getattr(__import__(exception_module, fromlist=[0]), exception_name)
                error = maker_meth(exception_class, *args[1:])
            else:
                error = maker_meth(*args)

            self.log.warning('skybone-hasher: %s', trace)
            raise error

        if proc.returncode == 0:
            assert stderr == ''
        else:
            if stderr:
                raise errors.CopierError('skybone-hasher: stderr: %s' % (stderr, ))

        return unpacked

    def rehash(self, ritem, target_name='data', progress=None):
        # Depending on file size, we have different sharing themes:
        # > 64mb: in separate process (for good concurrency)
        # > 64kb, < 64mb: in separate python thread
        # < 64kb: directly here.
        #
        # Actually it is better to create thread pool for second variant

        if ritem.size > 64 * 1024 * 1024 and self.hasher_binary:
            hash_func = self.rehash_in_subprocess
            hash_func_name = 'subprocess'
            hash_timeout = None
        else:
            hash_func = self.rehash_in_thread
            hash_func_name = 'thread'
            if self.hasher_binary:
                hash_timeout = 120
            else:
                hash_timeout = None

        ts = time.time()
        piece_size, pieces, (inode, size, atime, ctime, mtime), md5sum = hash_func(
            ritem.path, progress=progress, timeout=hash_timeout
        )

        sha1_blocks = [1]
        for idx, piece in enumerate(pieces):
            piece_start = idx * 4 * 1024 * 1024
            piece_end = min(size, piece_start + 4 * 1024 * 1024)
            piece_length = piece_end - piece_start

            sha1_blocks.append((piece_length, piece))

        self.log.debug(
            'Hashed md5:%s [%s in %s using %s, pool size %d] (%s ino:%d mtime:%d)',
            md5sum.encode('hex'),
            human_size(size),
            human_time(time.time() - ts),
            hash_func_name,
            len(self.pool),
            ritem.path,
            ritem.inode,
            ritem.mtime,
        )
        return md5sum, tuple(sha1_blocks), (inode, size, atime, ctime, mtime)

    def rehash_in_bulk(self, ritems, target_name='data', progress=None):
        result_queue = gevent.queue.Queue()

        spawned = []
        done = [0]

        if progress:
            def _on_progress(hashed_bytes):
                progress(done[0], hashed_bytes)
        else:
            _on_progress = None

        def _spawner():
            try:
                for ritem in ritems:
                    def _collector(grn, ritem=ritem):
                        done[0] += 1
                        try:
                            result = (True, grn.get())
                        except BaseException:
                            result = (False, sys.exc_info())
                        result_queue.put((ritem, result))

                    grn = self.pool.spawn(self.rehash, ritem, target_name=target_name, progress=_on_progress)
                    grn.rawlink(_collector)
                    spawned.append(grn)
            except:
                result_queue.put((None, (False, sys.exc_info())))

        try:
            grn = gevent.spawn(_spawner)

            for _ in range(len(ritems)):
                data = result_queue.get()
                ritem, (flag, result) = data
                if not flag:
                    raise result[0], result[1], result[2]

                yield ritem, result

            grn.get()
        except:
            ei = sys.exc_info()

            # Kill spawner
            grn.kill()

            # Kill all spawned hashers
            [grn.kill() for grn in spawned]  # kill all spawned

            raise ei[0], ei[1], ei[2]
