import numpy as np
import gzip
from tensorflow.contrib.learn.python.learn.datasets import base

# CVDF mirror of http://yann.lecun.com/exdb/mnist/
SOURCE_URL = 'https://storage.googleapis.com/cvdf-datasets/mnist/'

TRAIN_IMAGES = 'train-images-idx3-ubyte.gz'

IMAGE_DATA_FILE_MAGIC = 2051
LABEL_DATA_FILE_MAGIC = 2049

def _read32(bytestream):
    dt = np.dtype(np.uint32).newbyteorder('>')
    return np.frombuffer(bytestream.read(4), dtype=dt)[0]

def _extract_images(file):
    print('Extracting', file.name)
    with gzip.GzipFile(fileobj = file) as bytestream:
        magic = _read32(bytestream)
        if magic == IMAGE_DATA_FILE_MAGIC:
            channels = 1
        elif magic == IMAGE_DATA_FILE_MAGIC + 1:
            channels = 3
        else:
            raise ValueError('Invalid magic number %d in image file: %s' %
                             (magic, file.name))
        num_images = _read32(bytestream)
        rows = _read32(bytestream)
        cols = _read32(bytestream)
        print("images = ({}, {}, {}, {})".format(num_images, rows, cols, channels))
        buf = bytestream.read(num_images * rows * cols * channels)
    data = np.frombuffer(buf, dtype = np.uint8)
    data = data.reshape(num_images, rows, cols, channels)
    return data

def read_mnist( work_dir,
                dtype=np.float32,
                seed=None):

    local_file = base.maybe_download(TRAIN_IMAGES, work_dir,
                                    SOURCE_URL + TRAIN_IMAGES)
    with open(local_file, 'rb') as f:
        train_images = _extract_images(f)
    if (dtype == np.float32):
        train_images = train_images.astype(np.float32) / 255.

    return train_images
