from unittest import TestCase
import operator
import math


NUMERIC = (int, float, long)


class NotMatchDim(Exception):
    def __init__(self, size1=None, size2=None):
        self.size1 = size1
        self.size2 = size2

    def __str__(self):
        return 'bad sizes: {} and {}'.format(self.size1, self.size2)


class Array:
    def __init__(self, lst):
        self.dim = self.__get_dim(lst)
        if len(self.dim) > 1:
            self.data = [Array(lst[i]) for i in xrange(len(lst))]
        else:
            if isinstance(lst, Array):
                self.data = lst.data
            else:
                self.data = [float(a) for a in lst]

    @classmethod
    def __get_dim(cls, data):
        if isinstance(data, Array):
            return data.dim
        elif isinstance(data, list):
            assert len(data) > 0, 'empty list'
            return [len(data)] + cls.__get_dim(data[0])
        else:
            return []

    def copy(self):
        new_arr = Array([0])
        new_arr.dim = self.dim[:]
        if len(self.dim) == 1:
            new_arr.data = self.data[:]
        else:
            new_arr.data = [arr.copy() for arr in self.data]
        return new_arr

    def tolist(self):
        if len(self.dim) == 1:
            new_lst = self.data[:]
        else:
            new_lst = [arr.data for arr in self.data]
        return new_lst

    @classmethod
    def zeros(cls, dim):
        new_arr = Array([1])
        new_arr.dim = dim[:]
        if len(dim) == 1:
            new_arr.data = [0] * dim[0]
        else:
            new_arr.data = [cls.zeros(dim[1:]) for _ in range(dim[0])]
        return new_arr

    def __str(self, depth):
        if depth > 4:
            return ['...']
        shift = '  ' * depth
        if len(self.dim) > 1:
            vals = sum([val.__str(depth + 1) for val in self.data[:4]], [])
            vals += ['...'] if len(self) > 4 else []
            res = [shift + '['] + vals + [shift + ']']
        else:
            vals = [str(val) for val in self.data[:4]]
            vals += ['...'] if len(self) > 4 else []
            vals = ' '.join(vals)
            res = [''.join([shift, '[',
                           vals, ']'])]
        return res

    def __str__(self):
        return '\n'.join(self.__str(0))

    def __repr__(self):
        return str(self.data)

    def __getitem__(self, item):
        if not isinstance(item, tuple):
            item = (item,)
        first = item[0]
        if isinstance(first, int):
            if len(item) > 1:
                return self.data[first][item[1:]]
            else:
                return self.data[first]
        elif isinstance(first, slice):
            start = first.start or 0
            stop = self.dim[0] if (self.dim[0] <= first.stop or not first.stop) else first.stop
            step = first.step or 1
            res = []
            if len(item) > 1:
                for i in xrange(start, stop, step):
                    res.append(self.data[i][item[1:]])
            else:
                for i in xrange(start, stop, step):
                    res.append(self.data[i])
            return Array(res)
        else:
            raise IndexError('only int and slices can be indexes')

    def __setitem__(self, key, value):
        if not isinstance(key, tuple):
            key = (key,)
        if isinstance(value, list):
            value = Array(value)
        first = key[0]
        if len(key) > 1:
            self.data[first][key[1:]] = value
        else:
            self.data[first] = value
        return value

    @staticmethod
    def inc_pos(pos, dim):
        shift = 1
        i = len(dim) - 1
        while shift and i >= 0:
            pos[i] = (pos[i] + shift) % (dim[i])
            if pos[i] != 0:
                shift = 0
            i -= 1

    def keys(self):
        pos = [0] * len(self.dim)
        end = [self.dim[i]-1 for i in xrange(len(self.dim))]
        while pos != end:
            yield tuple(pos)
            self.inc_pos(pos, self.dim)
        else:
            yield tuple(pos)

    def __iter__(self):
        for i in range(self.dim[0]):
            yield self.data[i]

    def __eq__(self, other):
        if self.dim != other.dim:
            return False
        for k1, k2 in zip(self.keys(), other.keys()):
            if self[k1] != other[k2]:
                return False
        return True

    def __operation__(self, other, oper):
        new_arr = self.copy()
        if isinstance(other, NUMERIC):
            for key in new_arr.keys():
                new_arr[key] = oper(new_arr[key], other)
        elif isinstance(other, Array):
            if self.dim == other.dim:
                for key in new_arr.keys():
                    new_arr[key] = oper(new_arr[key], other[key])
            elif self.dim[1:] == other.dim:
                for i in xrange(new_arr.dim[0]):
                    new_arr.data[i] = new_arr.data[i].__operation__(other, oper)
            else:
                raise NotMatchDim(self.dim, other.dim)
        else:
            raise ValueError
        return new_arr

    def __add__(self, other):
        return self.__operation__(other, operator.add)

    def __sub__(self, other):
        return self.__operation__(other, operator.sub)

    def __mul__(self, other):
        return self.__operation__(other, operator.mul)

    def __div__(self, other):
        return self.__operation__(other, operator.div)

    def __pow__(self, power, modulo=None):
        return self.__operation__(power, math.pow)

    def __len__(self):
        return self.dim[0]

    def mean(self):
        if len(self.dim) == 1:
            temp = 0
        else:
            temp = self[0].copy()
        for i in range(1, len(self)):
            temp += self[i]
        temp /= len(self)
        return temp

    def std(self):
        mean = self.mean()
        if len(self.dim) == 1:
            temp = 0
        else:
            temp = self.zeros(self.dim[1:])
        for i in range(len(self)):
            sub = self[i] - mean
            temp += sub * sub
        temp /= len(self)
        return temp ** 0.5

    def scale(self):
        x = self.copy()
        mean = x.mean()
        x -= mean
        x /= x.std()
        return x


class TestArray(TestCase):
    def test_get_dim(self):
        data_set = [
            ([1], [1]),
            ([1, 2, 3], [3]),
            ([[1, 2], [3, 4]], [2, 2])
        ]
        for arr, size in data_set:
            a = Array(arr)
            self.assertEqual(a.dim, size, 'wrong dim')
            self.assertEqual(a.dim[0], len(a.data))

    def test_keys(self):
        a = Array([[1, 2], [3, 4]])
        for key, value in zip(a.keys(), [1, 2, 3, 4]):
            self.assertEqual(a[key], value)

    def test_copy(self):
        a = Array([[1, 2], [3, 4]])
        b = a.copy()
        self.assertEqual(a, b)
        self.assertIsNot(a, b)
        self.assertIsNot(a[0], b[0])
        self.assertIsNot(a[1], b[1])

    def test_operator(self):
        a = Array([[1, 2], [3, 4]])
        reserve = a.copy()
        assert a + 1 == Array([[2, 3], [4, 5]])
        assert a * 2 == Array([[2, 4], [6, 8]])
        assert a / 2.0 == Array([[0.5, 1], [1.5, 2]])
        assert a - 1 == Array([[0, 1], [2, 3]])
        assert reserve == a
        self.assertIsNot(a, reserve)

    def test_get_set_slices(self):
        a = Array([[1, 2], [3, 4]])
        assert a[0] == Array([1, 2])
        assert a[0, 0] == 1
        assert a[0, 1] == 2
        assert a[:] == a
        assert a[:, 0] == Array([1, 3])
        assert a[:, 1] == Array([2, 4])

    def test_zeros(self):
        dim = [2, 2]
        a = Array.zeros(dim)
        assert a == Array([[0, 0], [0, 0]])

    def test_mean(self):
        a = Array([[1, 2], [3, 4]])
        mean = Array([2, 3])
        assert mean == a.mean()

    def test_std(self):
        a = Array.zeros([2, 2])
        assert a.std() == Array([0, 0])
        a = Array.zeros([2])
        assert a.std() == 0
        a = Array([0, 2])
        assert a.std() == 1
        a = Array([[0, 1], [0, 1]])
        assert a.std() == Array([0, 0])
        a = Array([[-1, 0], [1, 0]])
        assert a.std() == Array([1, 0])

    def test_scale(self):
        a = Array([[1, 12], [0, -1]])
        b = a.scale()
        assert b.std() == Array([1, 1])
        assert b.mean() == Array([0, 0])
