from StringIO import StringIO
from operator import itemgetter
import abc
import tarfile
import time
import json
import itertools
import random
import yt.wrapper as yt

class FileContainer(object):
    __metaclass__ = abc.ABCMeta

    @abc.abstractmethod
    def add_file(self, relpath, content):
        pass

    @abc.abstractmethod
    def close(self):
        pass

class NullContainer(FileContainer):
    def add_file(self, relpath, content):
        pass

    def close(self):
        pass

class TarballContainer(FileContainer):
    def __init__(self, path, gzip=True):
        mode = 'w:gz' if gzip else 'w'
        self.tar = tarfile.open(path, mode)
        self.timestamp = time.time()

    def add_file(self, relpath, content):
        stream = StringIO()
        stream.write(content)
        stream.seek(0)
        tarinfo = tarfile.TarInfo(name=relpath)
        tarinfo.size = len(stream.buf)
        tarinfo.mtime = self.timestamp
        self.tar.addfile(tarinfo=tarinfo, fileobj=stream)

    def close(self):
        self.tar.close()

def generic_read_images_table(src, max_images, key=None, filterfunc=None):
    def mapper(rec):
        if filterfunc is not None and not filterfunc(rec):
            return
        rec['__sorting_tag'] = key(rec)
        yield rec

    with yt.Transaction():
        if max_images is not None and yt.get(src + '/@row_count') > max_images:
            with yt.TempTable() as tmp:
                yt.run_map(mapper, src, tmp)
                yt.run_sort(tmp, sort_by=['__sorting_tag'])
                for rec in itertools.islice(yt.read_table(tmp, raw=False), 0, max_images):
                    del rec['__sorting_tag']
                    yield rec
        else:
            for rec in yt.read_table(src, raw=False):
                if filterfunc is None or filterfunc(rec):
                    yield rec

def generic_store_results(image_iter, result_container, debug_container, resultfunc, indexfunc=None, image_format='gif'):
    metadata = {}

    for image_index, image_data in enumerate(image_iter):
        if indexfunc is not None:
            image_index = indexfunc(image_data)

        image_name = format(image_index, "08")
        image_path = image_name[0:3] + "/" + image_name[3:6] + "/" + image_name + '.e' + image_format

        image, answer = resultfunc(image_data)
        image_data = dict(image_data)
        del image_data['image']
        metadata[image_path] = image_data

        output = StringIO()
        image.save(output, format=image_format.upper())
        image_bin = output.getvalue()

        content = ''.join([answer, '\x00', image_bin])
        result_container.add_file(image_path, content)
        debug_container.add_file('debug/%s.%s'%(image_name, image_format), image_bin)

    result_container.close()

    debug_container.add_file('debug/metadata.json',
            json.dumps(metadata, indent=1, ensure_ascii=False, sort_keys=True))
    debug_container.close()


def irandmerge(*iters_lengths):
    iters = []
    remaining = []
    total = 0
    for it, length in iters_lengths:
        iters.append(iter(it))
        remaining.append(length)
        total += length

    def weighted_random_index():
        result = None
        n = random.randint(0, total-1)
        for i, grouplen in enumerate(remaining):
            if n < grouplen:
                result = i
                break
            n -= grouplen
        assert result is not None
        return result

    while total > 0:
        idx = weighted_random_index()
        try:
            yield next(iters[idx])
        except StopIteration:
            raise ValueError('Unexpected stop of iterator with index %d'%idx)
        remaining[idx] -= 1
        total -= 1
        assert remaining[idx] >= 0

    assert all(n == 0 for n in remaining)
