Skip to content

Commit

Permalink
Merge pull request #851 from zlheui/distributed-cnn-mnist-dataset
Browse files Browse the repository at this point in the history
data downloading for mnist
  • Loading branch information
lzjpaul committed May 21, 2021
2 parents 4dab111 + ee48750 commit 9e96ed9
Show file tree
Hide file tree
Showing 2 changed files with 140 additions and 0 deletions.
49 changes: 49 additions & 0 deletions examples/cifar_distributed_cnn/data/download_mnist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#

import os
import urllib.request


def check_exist_or_download(url):

download_dir = '/tmp/'
name = url.rsplit('/', 1)[-1]
filename = os.path.join(download_dir, name)

if not os.path.isfile(filename):
print("Downloading %s" % url)
urllib.request.urlretrieve(url, filename)
else:
print("Already Downloaded: %s" % url)


if __name__ == '__main__':

#url of the mnist dataset
train_x_url = 'http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz'
train_y_url = 'http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz'
valid_x_url = 'http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz'
valid_y_url = 'http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz'

#download the mnist dataset
check_exist_or_download(train_x_url)
check_exist_or_download(train_y_url)
check_exist_or_download(valid_x_url)
check_exist_or_download(valid_y_url)
91 changes: 91 additions & 0 deletions examples/cifar_distributed_cnn/data/mnist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#

import numpy as np
import os
import sys
import gzip
import codecs


def check_dataset_exist(dirpath):
if not os.path.exists(dirpath):
print(
'The MNIST dataset does not exist. Please download the mnist dataset using python data/download_mnist.py'
)
sys.exit(0)
return dirpath


def load_dataset():
train_x_path = '/tmp/train-images-idx3-ubyte.gz'
train_y_path = '/tmp/train-labels-idx1-ubyte.gz'
valid_x_path = '/tmp/t10k-images-idx3-ubyte.gz'
valid_y_path = '/tmp/t10k-labels-idx1-ubyte.gz'

train_x = read_image_file(check_dataset_exist(train_x_path)).astype(
np.float32)
train_y = read_label_file(check_dataset_exist(train_y_path)).astype(
np.float32)
valid_x = read_image_file(check_dataset_exist(valid_x_path)).astype(
np.float32)
valid_y = read_label_file(check_dataset_exist(valid_y_path)).astype(
np.float32)
return train_x, train_y, valid_x, valid_y


def read_label_file(path):
with gzip.open(path, 'rb') as f:
data = f.read()
assert get_int(data[:4]) == 2049
length = get_int(data[4:8])
parsed = np.frombuffer(data, dtype=np.uint8, offset=8).reshape((length))
return parsed


def get_int(b):
return int(codecs.encode(b, 'hex'), 16)


def read_image_file(path):
with gzip.open(path, 'rb') as f:
data = f.read()
assert get_int(data[:4]) == 2051
length = get_int(data[4:8])
num_rows = get_int(data[8:12])
num_cols = get_int(data[12:16])
parsed = np.frombuffer(data, dtype=np.uint8, offset=16).reshape(
(length, 1, num_rows, num_cols))
return parsed


def normalize(train_x, val_x):
train_x /= 255
val_x /= 255
return train_x, val_x


def load():
train_x, train_y, val_x, val_y = load_dataset()
train_x, val_x = normalize(train_x, val_x)
train_x = train_x.astype(np.float32)
val_x = val_x.astype(np.float32)
train_y = train_y.astype(np.int32)
val_y = val_y.astype(np.int32)
return train_x, train_y, val_x, val_y

0 comments on commit 9e96ed9

Please sign in to comment.