from __future__ import absolute_import

import contextlib

try:
    import zstandard
except ImportError:
    zstandard = None

from ..greenish.deblock import Deblock
from ..utils import timer, dummy_timer


SUPPORTED_CODECS = set()
MAX_LEVEL = {}  # codec => max allowed compression level

CAP_PREFIX = 'comp:'

def compression_codecs_to_capability(codecs):
    return CAP_PREFIX+ ','.join(codecs)

def capabilities_to_compression_codecs(caps):
    for cap in caps:
        if cap.startswith(CAP_PREFIX):
            return cap[len(CAP_PREFIX):].split(',')
    return None


if zstandard:
    SUPPORTED_CODECS.add('zstd')
    MAX_LEVEL['zstd'] = 3

    class ZstdCompressor(object):
        def __init__(self, log, level=None, measure_time=True):
            self.log = log
            if level is None:
                level = 1
            self._ctx = zstandard.ZstdCompressor(level)
            self._measure_time = measure_time
            self._timer = timer if measure_time else dummy_timer

        def compress(self, data):
            with self._timer() as timer:
                result = self._ctx.compress(data)
            if self._measure_time:
                self.log.debug('zstd compression took %.3fms', timer.spent * 1000)

            return result


    class ZstdDecompressor(object):
        def __init__(self, log, measure_time=True):
            self.log = log
            self._ctx = zstandard.ZstdDecompressor()
            self._measure_time = measure_time
            self._timer = timer if measure_time else dummy_timer

            self._out_stream = None
            self._writer = None
            self._elapsed = 0

        def start(self, out_stream):
            assert self._out_stream is None
            assert self._writer is None

            self._out_stream = out_stream
            self._writer = self._ctx.stream_writer(out_stream)
            self._elapsed = 0

        def write(self, data):
            with self._timer() as timer:
                self._writer.write(data)
            self._elapsed += timer.spent

        def finish(self):
            with self._timer() as timer:
                self._writer.flush()
                self._writer.close()
            self._elapsed += timer.spent

            if self._measure_time:
                self.log.debug('zstd decompression took %.3fms', self._elapsed * 1000)

            self._out_stream = None
            self._writer = None
            self._elapsed = 0
else:
    ZstdCompressor = None
    ZstdDecompressor = None


assert SUPPORTED_CODECS == set(MAX_LEVEL.keys())

def parse_compression_mode(mode):
    # Converts 'codec/level' into (codec, level).
    # example: parse_compression_mode('zstd/7') == ('zstd', 7)
    #          parse_compression_mode('zstd')   == ('zstd', None)
    if '/' in mode:
        codec, level = mode.rsplit('/', 1)
        try:
            level = int(level)
        except ValueError:
            raise ValueError('invalid compression level {}'.format(level))
    else:
        codec, level = (mode, None)

    if codec not in SUPPORTED_CODECS:
        raise ValueError('unsupported compression codec: {}'.format(codec))
    if level is not None and level > MAX_LEVEL[codec]:
        raise ValueError(
            'compression level {} is too high: max allowed for {} is {}'.format(
                level, codec, MAX_LEVEL[codec]
            )
        )
    return (codec, level)

# Run all compression operations in a separate thread asynchronously.
# This prevents possibly expensive compression slowing down skybone-skybit.
compressor_deblock = Deblock()

def make_compressor(mode, log):
    codec, level = parse_compression_mode(mode)
    if codec == 'zstd' and ZstdCompressor is not None:
        return compressor_deblock.make_proxy(ZstdCompressor(log, level))
    else:
        raise RuntimeError, 'No compressor defined for codec {}'.format(codec)

def make_decompressor(codec, log):
    if codec == 'zstd' and ZstdDecompressor is not None:
        return ZstdDecompressor(log)
    else:
        raise RuntimeError, 'No decompressor defined for codec {}'.format(codec)
