-
Notifications
You must be signed in to change notification settings - Fork 2
/
mnist.py
98 lines (88 loc) · 2.89 KB
/
mnist.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
# -*- coding: utf-8 -*-
import numpy as np
import struct
import os
import matplotlib.pyplot as plt
import _pickle as cPickle
import gzip
"""
The first method
"""
_tag = '>' #使用大端读取
_twoBytes = 'II' #读取数据格式是两个整数
_fourBytes = 'IIII' #读取的数据格式是四个整数
_pictureBytes = '784B' #读取的图片的数据格式是784个字节,28*28
_lableByte = '1B' #标签是1个字节
_msb_twoBytes = _tag + _twoBytes
_msb_fourBytes = _tag + _fourBytes
_msb_pictureBytes = _tag + _pictureBytes
_msb_lableByte = _tag + _lableByte
def getImage(filename = None):
binfile = open(filename, 'rb') #以二进制读取的方式打开文件
buf = binfile.read() #获取文件内容缓存区
binfile.close()
index = 0 #偏移量
numMagic, numImgs, numRows, numCols = struct.unpack_from(_msb_fourBytes, buf, index)
index += struct.calcsize(_fourBytes)
images = []
for i in range(numImgs):
imgVal = struct.unpack_from(_msb_pictureBytes, buf, index)
index += struct.calcsize(_pictureBytes)
imgVal = list(imgVal)
#for j in range(len(imgVal)):
# if imgVal[j] > 1:
# imgVal[j] = 1
images.append(imgVal)
return np.array(images)
def getlable(filename=None) :
binfile = open(filename, 'rb')
buf = binfile.read() #获取文件内容缓存区
binfile.close()
index = 0 #偏移量
numMagic, numItems = struct.unpack_from(_msb_twoBytes,buf, index)
index += struct.calcsize(_twoBytes)
labels = []
for i in range(numItems):
value = struct.unpack_from(_msb_lableByte, buf, index)
index += struct.calcsize(_lableByte)
labels.append(value[0]) #获取值的内容
return np.array(labels)
def outImg(arrX, arrY, order):
"""
根据指定的order来获取集合中对应的图片和标签
"""
test1 = np.array([1,2,3])
print(test1.shape)
image = np.array(arrX[order])
print(image.shape)
image = image.reshape(28,28)
label = arrY[order]
print(label)
outfile = str(order) + '_'+str(label) + '.png'
plt.figure()
plt.imshow(image, cmap="gray_r") # 在MNIST官网中有说道 “Pixel values are 0 to 255. 0 means background (white), 255 means foreground (black).”
plt.show()
#plt.savefig("./dataset/" + outfile) #保存图片
"""
The second method
"""
def load_data(filename = None):
f = gzip.open(filename, 'rb')
training_data, validation_data, test_data = cPickle.load(f)
return (training_data, validation_data, test_data)
def test_cPickle():
filename = './dataset/MNIST/mnist.pkl.gz'
training_data, validation_data, test_data = load_data(filename)
print(len(test_data))
outImg(training_data[0],training_data[1], 1000)
#print len(training_data[1])
def test():
trainfile_X = './dataset/MNIST/train-images.idx3-ubyte'
trainfile_y = './dataset/MNIST/train-labels.idx1-ubyte'
arrX = getImage(trainfile_X)
arrY = getlable(trainfile_y)
#for i in range(100):
outImg(arrX, arrY, 1000)
if __name__ == '__main__':
#test_cPickle() #test the second method
test() #test the first method