# encoding: utf-8

from jafar.tests import JafarTestCase
from jafar.utils.static_dict import StaticDict, StaticMapping

import numpy as np
import string


class MyPicklableClass(object):
    def __init__(self, a, b):
        self.a = a
        self.b = b


class StaticDictTestCase(JafarTestCase):
    def test_access(self):
        static_dict = StaticDict.create(['ab', 'bc', 'cd'], [1, 2, 3])
        self.assertEquals(static_dict['ab'], 1)
        self.assertEquals(static_dict['bc'], 2)
        self.assertEquals(static_dict['cd'], 3)

    def test_in(self):
        keys = ['abc', 'bcd', 'cde']
        static_dict = StaticDict.create(keys, [1, 2, 3])
        for key in keys:
            self.assertIn(key, static_dict)

    def test_items(self):
        keys = [4, 8, 123]
        values = [7, 14, 21]
        static_dict = StaticDict.create(keys, values)
        self.assertEquals(
            sorted(static_dict.items()),
            sorted(zip(keys, values))
        )

    def test_already_indexed(self):
        # if keys are already sorted integers in 0..len(keys) - 1, it's a special case with empty hash function
        static_dict = StaticDict.create([0, 1, 2], [1, 2, 3])
        self.assertEquals(static_dict[0], 1)
        self.assertEquals(static_dict[1], 2)
        self.assertEquals(static_dict[2], 3)

    def test_compare_with_dict(self):
        np.random.seed(1234)
        n_keys = 1000
        keys = [''.join(row) for row in np.random.choice(list(string.letters), size=(n_keys, 10))]
        values = range(n_keys)
        static_dict = StaticDict.create(keys, values)
        simple_dict = {keys[i]: values[i] for i in xrange(len(keys))}
        for key in keys:
            self.assertEquals(static_dict[key], simple_dict[key], 4)

    def test_len(self):
        static_dict = StaticDict.create(['x', 'y', 'z'], ['a', 'b', 'c'])
        self.assertEquals(len(static_dict), 3)

    def test_keys(self):
        keys = ['never', 'gonna', 'give', 'you', 'up']
        static_dict = StaticDict.create(keys, [1, 2, 3, 4, 5])
        self.assertEquals(sorted(static_dict.keys()), sorted(keys))

    def test_values(self):
        values = ['never', 'gonna', 'let', 'you', 'down']
        static_dict = StaticDict.create([6, 7, 8, 9, 10], values)
        self.assertEquals(sorted(static_dict.values()), sorted(values))

    def test_iteritems(self):
        keys = [1, 2, 3, 4, 5]
        values = [6, 7, 8, 9, 10]
        items = zip(keys, values)
        static_dict = StaticDict.create(keys, values)
        for key, value in static_dict.iteritems():
            self.assertIn((key, value), items)

    def test_inconsistent_types(self):
        consistent_keys = ['london', 'bridge', 'is', 'falling', 'down']
        inconsistent_keys = [1, 2, 3, None, 'wat']
        consistent_values = ['every', 'dream', 'that', 'I', 'dream']
        inconsistent_values = ['seems', 'to', 'float', 'on', 404]

        StaticDict.create(consistent_keys, inconsistent_values)
        with self.assertRaises(AssertionError):
            StaticDict.create(inconsistent_keys, consistent_values)
        with self.assertRaises(AssertionError):
            StaticDict.create(inconsistent_keys, inconsistent_values)

    def test_json_serializable(self):
        keys = ['app1', 'app2', 'app3']
        values = [['category1', 'category2'], ['category1'], ['category3', 'category1']]
        static_dict = StaticDict.create(keys, values)
        self.assertEquals(static_dict['app1'], ['category1', 'category2'])
        self.assertEquals(static_dict['app2'], ['category1'])
        self.assertEquals(static_dict['app3'], ['category3', 'category1'])

    def test_pickle_serializable(self):
        keys = ['app1', 'app2', 'app3']
        values = [MyPicklableClass(1, 2), MyPicklableClass(3, 4), MyPicklableClass(5, 6)]
        static_dict = StaticDict.create(keys, values)
        self.assertEquals(static_dict['app1'].a, 1)
        self.assertEquals(static_dict['app1'].b, 2)
        self.assertEquals(static_dict['app2'].a, 3)
        self.assertEquals(static_dict['app2'].b, 4)
        self.assertEquals(static_dict['app3'].a, 5)
        self.assertEquals(static_dict['app3'].b, 6)

    def test_float_values(self):
        keys = ['show', 'must', 'go', 'on']
        values = [0.123, 0.456, 0.789, 0.101]
        static_dict = StaticDict.create(keys, values)
        for key, value in zip(keys, values):
            self.assertAlmostEqual(static_dict[key], value)

    def test_string_values(self):
        keys = ['hello', 'darkness', 'my', 'old', 'friend']
        values = ["I've come", 'to talk', 'with', 'you', 'again']
        static_dict = StaticDict.create(keys, values)
        for key, value in zip(keys, values):
            self.assertEquals(static_dict[key], value)

    def test_unicode_values(self):
        keys = [1, 2, 3, 4, 5]
        values = [u'один', u'два', u'три', u'четыре', u'пять']
        static_dict = StaticDict.create(keys, values)
        for key, value in zip(keys, values):
            self.assertEquals(static_dict[key], value)

    def test_convert_to_dict(self):
        keys = ['show', 'must', 'go', 'on']
        values = [0.123, 0.456, 0.789, 0.101]
        static_dict = StaticDict.create(keys, values)
        dict_from_static_dict = dict(static_dict)
        self.assertItemsEqual(static_dict.items(), dict_from_static_dict.items())


class StaticMappingTestCase(JafarTestCase):
    def setUp(self):
        super(StaticMappingTestCase, self).setUp()
        self.array = np.array(['anton', 'artem', 'roma', 'sasha', 'stepa', 'islam', 'misha'], dtype=np.object)
        self.mapping = StaticMapping.create(self.array)

    def test_return_index(self):
        mapping, index = StaticMapping.create(self.array, return_index=True)
        self.assertTrue((mapping.array == self.array[index]).all())

    def test_nothing_is_lost(self):
        self.assertItemsEqual(self.array, self.mapping.array)

    def test_maps_to_range(self):
        self.assertItemsEqual(self.mapping.map(self.array, -1), range(len(self.array)))

    def test_keys(self):
        self.assertItemsEqual(self.mapping.keys(), self.array)

    def test_reverse_keys(self):
        self.assertItemsEqual(self.mapping.reverse.keys(), range(len(self.array)))

    def test_values(self):
        self.assertItemsEqual(self.mapping.values(), range(len(self.array)))

    def test_reverse_values(self):
        self.assertItemsEqual(self.mapping.reverse.values(), self.array)

    def test_items(self):
        self.assertItemsEqual(self.mapping.items(), zip(self.mapping.keys(), self.mapping.values()))

    def test_reverse_items(self):
        self.assertItemsEqual(self.mapping.reverse.items(),
                              zip(self.mapping.reverse.keys(), self.mapping.reverse.values()))

    def test_map_and_getitem(self):
        self.assertSequenceEqual([self.mapping[x] for x in self.array], list(self.mapping.map(self.array, -1)))

    def test_reverse_keys_values(self):
        self.assertSequenceEqual(list(self.mapping.values()), list(self.mapping.reverse.keys()))
        self.assertSequenceEqual(list(self.mapping.keys()), list(self.mapping.reverse.values()))

    def test_get_unknown_item(self):
        with self.assertRaises(KeyError):
            self.mapping['timur']

    def test_get_unknown_item_reverse(self):
        with self.assertRaises(KeyError):
            self.mapping.reverse[-1]
        with self.assertRaises(KeyError):
            self.mapping.reverse[1000]

    def test_convert_to_dict(self):
        mapping_dict = dict(self.mapping)
        self.assertItemsEqual(self.mapping.items(), mapping_dict.items())

    def test_map_with_invalid_keys(self):
        res = self.mapping.map(np.array(['anton', 'artem', 'timur', 'roma'], dtype=np.object), -1)
        self.assertNotEqual(res[0], -1)
        self.assertNotEqual(res[1], -1)
        self.assertEqual(res[2], -1)
        self.assertNotEqual(res[3], -1)
