from __future__ import division, print_function

import contextlib
import ctypes
import errno
import os
import stat
import time
import threading
import mmap

import gevent
import gevent.event
try:
    import gevent.coros as coros
except ImportError:
    import gevent.lock as coros

from ..component import Component
from ..greenish.deblock import Deblock
from ..fadvise_random import fadv_random


if os.uname()[0].lower() == 'linux':
    try:
        import directio
    except ImportError:
        # We will unable to import directio on old linuxes (with old glibc)
        directio = None

    SYNC_FILE_RANGE_WAIT_BEFORE = 1
    SYNC_FILE_RANGE_WRITE = 2
    SYNC_FILE_RANGE_WAIT_AFTER = 4

    sync_file_range = ctypes.CDLL(None).sync_file_range
    sync_file_range.argtypes = [ctypes.c_int, ctypes.c_int64, ctypes.c_int64, ctypes.c_uint]
    sync_file_range.use_errno = True
else:
    directio = None
    sync_file_range = None


@contextlib.contextmanager
def dummy_user_privileges(*args, **kwargs):
    yield


class DummySem(object):
    def __enter__(self):
        pass

    def __exit__(self, *args, **kwargs):
        pass


class IOFile(object):
    def __init__(self, fd, io):
        self.fd = fd
        self.io = io

    def read(self, size):
        return os.read(self.fd, size)

    def write(self, data, sha1=None):
        data = memoryview(data)

        left = len(data)
        pos = 0
        while left:
            written = os.write(self.fd, data[pos:])
            assert written > 0, 'SKYDEV-1004: weird os.write result %r' % (written, )
            left -= written
            pos += written

    def seek(self, pos):
        os.lseek(self.fd, pos, os.SEEK_SET)

    def truncate(self, size):
        os.ftruncate(self.fd, size)

    def stat(self):
        return os.fstat(self.fd)

    def close(self):
        os.close(self.fd)

    def fsync(self):
        os.fsync(self.fd)


class IOWorker(object):
    def __init__(self, priv, sync_writes_every, log):
        self.stats = {
            'open_cnt': 0,
            'open_cnt_fail': 0,
            'open_sec': 0,
            'read_cnt': 0,
            'read_bytes': 0,
            'read_sec': 0,
            'write_cnt': 0,
            'write_bytes': 0,
            'write_sec': 0,
            'seek_cnt': 0,
            'seek_sec': 0,
        }
        self.priv = priv
        self.log = log

        self._force_fsync = False
        self._max_write_chunk = None
        self._sync_writes_every = sync_writes_every if sync_file_range else 0

        if self._sync_writes_every:
            self._unsynced_blocks = [] # list of tuples (filename, range start, range length)
            self._sync_writes_lock = threading.Lock() # IOWorker methods are called from other threads

        super(IOWorker, self).__init__()

    def set_force_fsync(self, flag):
        self._force_fsync = flag

    def set_max_write_chunk(self, size):
        self._max_write_chunk = size

    def sleep(self, seconds):
        # This one used primarily in tests
        time.sleep(seconds)
        return True

    def _add_unsynced_block(self, fn, start, length):
        '''
        Records a written block of data, periodically syncs all previously recorded
        blocks to disk. Returns None on success, returns a warning message or throws otherwise.
        '''
        with self._sync_writes_lock:
            self._unsynced_blocks.append((fn, start, length))
            if len(self._unsynced_blocks) < self._sync_writes_every:
                return

            warn_msgs = []
            try:
                filenames = set(fn for fn, _, _ in self._unsynced_blocks)
                files = {}
                for fn in filenames:
                    files[fn] = self._fopen(fn, write=True)

                for fn, start, length in self._unsynced_blocks:
                    flags = SYNC_FILE_RANGE_WAIT_BEFORE | SYNC_FILE_RANGE_WRITE | SYNC_FILE_RANGE_WAIT_AFTER
                    ret = sync_file_range(files[fn].fd, start, length, flags)
                    # This function is called from a deblock thread, logging here will break gevent.
                    # We need to return warning messages to the main thread instead.
                    if ret < 0:
                        err = ctypes.get_errno()
                        warn_msgs.append(
                            'sync_file_range({}, {}, {}, {}) failed: [Errno {}] {}'.format(
                                fn, start, length, flags, err, os.strerror(err)
                        ))

            finally:
                for f in files.itervalues():
                    f.close()
                self._unsynced_blocks = []

        return '\n'.join(warn_msgs) if warn_msgs else None

    def _fopen(self, fn, read=False, write=False, direct=False):
        assert not (read and write)
        assert read or write

        self.stats['open_cnt'] += 1
        ts = time.time()
        try:
            if direct:
                flags = os.O_DIRECT
            else:
                flags = 0

            wflag = os.O_RDONLY if read else os.O_WRONLY
            flags |= wflag

            try:
                fd = os.open(fn, flags)
            except OSError as ex:
                if ex.errno == errno.EINVAL and direct:
                    flags = wflag
                    fd = os.open(fn, flags)
                else:
                    raise

            try:
                fadv_random(fd)
                fp = IOFile(fd, self)
            except:
                os.close(fd)
                raise

            try:
                fstat = fp.stat()

                if not stat.S_ISREG(fstat.st_mode):
                    raise IOError(666, 'Not a regular file')

                self.stats['open_sec'] += time.time() - ts
                return fp
            except:
                fp.close()  # this will close fd
                raise

        except BaseException:
            self.stats['open_cnt_fail'] += 1
            raise

    def read(self, uid, fn, start, length, direct):
        with self.priv(user=uid, store=True, limit=False):
            fp = self._fopen(fn, read=True, direct=direct)

        self.stats['read_cnt'] += 1

        try:
            self.stats['seek_cnt'] += 1
            ts = time.time()
            fp.seek(start)
            self.stats['seek_sec'] += time.time() - ts

            if not direct:
                ts = time.time()
                data = fp.read(length)
                self.stats['read_bytes'] += len(data)
                self.stats['read_sec'] += time.time() - ts
                return data
            else:
                ts = time.time()
                if directio:
                    return directio.read(fp.fd, 4 * 1024 * 1024)
                else:
                    buf = mmap.mmap(-1, 4 * 1024 * 1024)
                    try:
                        os.fdopen(os.dup(fp.fd)).readinto(buf)
                        buf.seek(0)
                        data = buf.read(length)
                        self.stats['read_bytes'] += len(data)
                        self.stats['read_sec'] += time.time() - ts
                        return data
                    finally:
                        buf.close()
        finally:
            fp.close()

    def write(self, uid, fn, start, data, direct, truncate=False, sha1=None):
        length = len(data)
        warn_msg = None

        with self.priv(user=uid, store=True, limit=False):
            fp = self._fopen(fn, write=True, direct=direct)

        try:
            fp.seek(start)
            if not direct:
                fp.write(data, sha1=sha1)
            else:
                if directio:
                    align_by = 4096
                    if len(data) % align_by != 0:
                        data_alligned = data + ' ' * (align_by - len(data) % align_by)
                    else:
                        data_alligned = data

                    if self._max_write_chunk and self._max_write_chunk < len(data_alligned):
                        chunks = len(data_alligned) // self._max_write_chunk
                        if len(data_alligned) % self._max_write_chunk > 0:
                            chunks += 1

                        for chunk_idx in range(chunks):
                            chunk_start = chunk_idx * self._max_write_chunk
                            chunk_end = chunk_start + self._max_write_chunk
                            directio.write(fp.fd, data_alligned[chunk_start:chunk_end])
                    else:
                        directio.write(fp.fd, data_alligned)
                else:
                    buf = mmap.mmap(-1, 4 * 1024 * 1024)
                    try:
                        buf.write(data)
                        os.write(fp.fd, buf)
                    finally:
                        buf.close()

                if len(data) != 4 * 1024 * 1024:
                    truncate = True

            if truncate:
                # Truncate should be done only for last not full page, blocks < 4Mb should never come
                # in the middle of file.
                stat = fp.stat()
                truncate_bytes = stat.st_size - (start + length)

                # If somehow we want to truncate more than 1 page (more than 4095 bytes) -- assert here
                # for further investigation
                assert truncate_bytes < 4096, '%s: attempt to truncate %d bytes (>=4096)' % (fn, truncate_bytes)
                assert truncate_bytes >= 0, '%s: attempt to grow file by %d bytes' % (fn, -truncate_bytes)

                if truncate_bytes == 0:
                    # This could happen if our file is already aligned (size multiplies 4k pages)
                    # Just do not need truncate in this case at all
                    pass
                else:
                    # !!! This do not yet possible -- logging from different thread will break gevent !!!
                    # self.log.debug(
                    #     'Got eof block for %s -- will truncate file '
                    #     '(cut %d bytes, inode: %d, size: %d, blocks: %d, mtime: %d)',
                    #     fn, truncate_bytes, stat.st_ino, stat.st_size, stat.st_blocks, stat.st_mtime
                    # )
                    fp.truncate(start + length)

            if self._sync_writes_every:
                warn_msg = self._add_unsynced_block(fn, start, length)

            if self._force_fsync:
                fp.fsync()

            return (fp.stat(), warn_msg)
        finally:
            fp.close()

    def truncate(self, fn, size):
        fp = self._fopen(fn, write=True)
        try:
            fp.truncate(size)
        finally:
            fp.close()

    def fsync(self, fn):
        fp = self._fopen(fn, write=True)
        try:
            fp.fsync()
        finally:
            fp.close()


class Deblock(Deblock):
    def __init__(self, *args, **kwargs):
        super(Deblock, self).__init__(*args, **kwargs)
        self.apply_count = 0

    def apply(self, *args, **kwargs):
        self.apply_count += 1
        return super(Deblock, self).apply(*args, **kwargs)


class IO(Component):
    def __init__(self, threads, sem=None, priv=None, sync_writes_every=0, parent=None):
        super(IO, self).__init__(logname='io', parent=parent)

        if not priv:
            priv = dummy_user_privileges

        self.worker = IOWorker(priv, sync_writes_every, log=self.log.getChild('wrkr'))

        self.threads = [Deblock() for i in range(threads)]
        self.threads_free = set(range(threads))
        self.threads_sem = coros.Semaphore(threads)

        self.busy = 0

        if sem:
            self.sem = sem
        else:
            self.sem = DummySem()

        self._written_files = set()

    def set_force_fsync(self, flag):
        self.worker.set_force_fsync(flag)
        if flag:
            assert len(self.threads) == 1

    def set_max_write_chunk(self, size):
        assert isinstance(size, int)
        assert size > 0 and size < 1 * 1024 * 1024 * 1024
        self.worker.set_max_write_chunk(size)

    def read_block(self, uid, fn, start, length, direct=False):
        return self._iocall(self.worker.read, uid, fn, start, length, direct)

    def write_block(self, uid, fn, start, data, direct=False, truncate=False, sha1=None):
        self._written_files.add(fn)
        return self._iocall(self.worker.write, uid, fn, start, data, direct, truncate=truncate, sha1=sha1)

    def truncate(self, fn, length):
        return self._iocall(self.worker.truncate, fn, length)

    def noop(self, sleep):
        return self._iocall(self.worker.sleep, sleep)

    def fsync(self, fn):
        return self._iocall(self.worker.fsync, fn)

    def iter_written_files(self):
        for fn in self._written_files:
            yield fn

    def _iocall(self, func, *args, **kwargs):
        with self.threads_sem:
            thr_idx = min(self.threads_free)
            thr = self.threads[thr_idx]

            self.threads_free.discard(thr_idx)
            try:
                with self.sem:
                    return thr.apply(func, *args, **kwargs)
            except BaseException as ex:
                nolog = False
                if isinstance(ex, OSError):
                    if ex.errno in (
                        errno.EPERM,
                        errno.ENOENT,
                        errno.EACCES,
                    ):
                        nolog = True

                if isinstance(ex, IOError):
                    if ex.errno in (
                        666,
                        errno.EPERM,
                        errno.ENOENT,
                        errno.EACCES,
                    ):
                        nolog = True

                if not nolog:
                    import traceback
                    self.log.warning(traceback.format_exc())
                raise
            finally:
                self.threads_free.add(thr_idx)

    def stats(self):
        stats = {
            'thread_counts': [],
            'io': {}
        }

        for thr in self.threads:
            stats['thread_counts'].append(thr.apply_count)

        stats['io'].update(self.worker.stats)

        return stats
