import numpy as np

from jafar.storages.memmap.storage import MultiprocessMemmapStorage
from jafar.tests import JafarTestCase


class StorageMixin(object):
    def test_int(self):
        self.storage.store('my.int.key', 1)
        self.assertEquals(self.storage.get_object('my.int.key'), 1)

    def test_float(self):
        self.storage.store('my.float.key', 1.23456)
        self.assertAlmostEquals(self.storage.get_object('my.float.key'), 1.23456)

    def test_numpy_int(self):
        self.storage.store('my.int.key', np.int32(1))
        self.assertEquals(self.storage.get_object('my.int.key'), 1)

    def test_numpy_float(self):
        self.storage.store('my.float.key', np.float32(1.23456))
        self.assertAlmostEquals(self.storage.get_object('my.float.key'), 1.23456)

    def test_dict(self):
        dct = {
            'a': 1, 'b': 2, 'c': 3.14, 'd': 'wat', 'f': None
        }
        self.storage.store('my.dict.key', dct)
        proxy = self.storage.get_proxy('my.dict.key')
        self.assertEquals(dct['a'], proxy['a'])
        self.assertEquals(dct['c'], proxy['c'])
        self.assertEquals(dct['d'], proxy['d'])
        self.assertEquals(dct['f'], proxy['f'])
        for k, v in proxy.iteritems():
            self.assertTrue(k in dct)
            self.assertEquals(v, dct[k])
        self.assertEquals(set(dct.keys()), set(proxy.keys()))
        self.assertEquals(set(dct.values()), set(proxy.values()))
        self.assertEquals(set(dct.items()), set(proxy.items()))
        self.assertEquals(len(dct), len(proxy))
        self.assertTrue('a' in proxy)
        self.assertFalse('z' in proxy)
        self.assertEquals(dct.get('z'), proxy.get('z'))
        self.assertEquals(dct.get('c'), proxy.get('c'))
        self.assertTrue(np.array_equal(self.storage.get_dict_values('my.dict.key', ['c', 'f', 'd']),
                                       [3.14, None, 'wat']))
        self.assertEquals(self.storage.get_object('my.dict.key'), dct)

    def test_matrix(self):
        arr = np.random.rand(10, 10)
        self.storage.store('my.matrix.key', arr)
        obj = np.vstack(self.storage.get_matrix_rows('my.matrix.key', range(10)))
        self.assertTrue(isinstance(obj, np.ndarray))
        self.assertTrue(np.allclose(obj, arr))
        self.assertTrue(np.allclose(self.storage.get_matrix_rows('my.matrix.key', [1, 2, 4]), arr[[1, 2, 4]]))


class MemoryStorageTestCase(StorageMixin, JafarTestCase):
    @property
    def storage(self):
        return self.memory_storage


class MemmapStorageTestCase(StorageMixin, JafarTestCase):
    @property
    def storage(self):
        return self.memmap_storage

    def test_object_array(self):
        # memmap storage handles object arrays in a special way
        array = np.array(['string', 'another string'], dtype=np.object)
        self.storage.store('my.object.array.key', array)
        self.assertTrue(np.array_equal(array, self.storage.get_object('my.object.array.key')))
        self.assertTrue(np.array_equal(array, self.storage.get_proxy('my.object.array.key')))


class MultiprocessingStorageTestCase(MemmapStorageTestCase):
    @property
    def storage(self):
        return MultiprocessMemmapStorage()
