forked from zhanghang1989/PyTorch-Encoding
-
Notifications
You must be signed in to change notification settings - Fork 0
/
prepare_imagenet.py
131 lines (117 loc) · 4.92 KB
/
prepare_imagenet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
"""Prepare the ImageNet dataset"""
import os
import argparse
import tarfile
import pickle
import gzip
import subprocess
from tqdm import tqdm
from encoding.utils import check_sha1, download, mkdir
_TARGET_DIR = os.path.expanduser('~/.encoding/datasets/imagenet')
_TRAIN_TAR = 'ILSVRC2012_img_train.tar'
_TRAIN_TAR_SHA1 = '43eda4fe35c1705d6606a6a7a633bc965d194284'
_VAL_TAR = 'ILSVRC2012_img_val.tar'
_VAL_TAR_SHA1 = '5f3f73da3395154b60528b2b2a2caf2374f5f178'
def parse_args():
parser = argparse.ArgumentParser(
description='Setup the ImageNet dataset.',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--download-dir', required=True,
help="The directory that contains downloaded tar files")
parser.add_argument('--target-dir', default=_TARGET_DIR,
help="The directory to store extracted images")
parser.add_argument('--checksum', action='store_true',
help="If check integrity before extracting.")
parser.add_argument('--with-rec', action='store_true',
help="If build image record files.")
parser.add_argument('--num-thread', type=int, default=1,
help="Number of threads to use when building image record file.")
args = parser.parse_args()
return args
def check_file(filename, checksum, sha1):
if not os.path.exists(filename):
raise ValueError('File not found: '+filename)
if checksum and not check_sha1(filename, sha1):
raise ValueError('Corrupted file: '+filename)
def build_rec_process(img_dir, train=False, num_thread=1):
rec_dir = os.path.abspath(os.path.join(img_dir, '../rec'))
mkdir(rec_dir)
prefix = 'train' if train else 'val'
print('Building ImageRecord file for ' + prefix + ' ...')
to_path = rec_dir
# download lst file and im2rec script
script_path = os.path.join(rec_dir, 'im2rec.py')
script_url = 'https://raw.githubusercontent.com/apache/incubator-encoding/master/tools/im2rec.py'
download(script_url, script_path)
lst_path = os.path.join(rec_dir, prefix + '.lst')
lst_url = 'http://data.encoding.io/models/imagenet/resnet/' + prefix + '.lst'
download(lst_url, lst_path)
# execution
import sys
cmd = [
sys.executable,
script_path,
rec_dir,
img_dir,
'--recursive',
'--pass-through',
'--pack-label',
'--num-thread',
str(num_thread)
]
subprocess.call(cmd)
os.remove(script_path)
os.remove(lst_path)
print('ImageRecord file for ' + prefix + ' has been built!')
def extract_train(tar_fname, target_dir, with_rec=False, num_thread=1):
os.makedirs(target_dir)
with tarfile.open(tar_fname) as tar:
print("Extracting "+tar_fname+"...")
# extract each class one-by-one
pbar = tqdm(total=len(tar.getnames()))
for class_tar in tar:
pbar.set_description('Extract '+class_tar.name)
tar.extract(class_tar, target_dir)
class_fname = os.path.join(target_dir, class_tar.name)
class_dir = os.path.splitext(class_fname)[0]
os.mkdir(class_dir)
with tarfile.open(class_fname) as f:
f.extractall(class_dir)
os.remove(class_fname)
pbar.update(1)
pbar.close()
if with_rec:
build_rec_process(target_dir, True, num_thread)
def extract_val(tar_fname, target_dir, with_rec=False, num_thread=1):
os.makedirs(target_dir)
print('Extracting ' + tar_fname)
with tarfile.open(tar_fname) as tar:
tar.extractall(target_dir)
# build rec file before images are moved into subfolders
if with_rec:
build_rec_process(target_dir, False, num_thread)
# move images to proper subfolders
val_maps_file = os.path.join(os.path.dirname(__file__), 'imagenet_val_maps.pklz')
with gzip.open(val_maps_file, 'rb') as f:
dirs, mappings = pickle.load(f)
for d in dirs:
os.makedirs(os.path.join(target_dir, d))
for m in mappings:
os.rename(os.path.join(target_dir, m[0]), os.path.join(target_dir, m[1], m[0]))
def main():
args = parse_args()
target_dir = os.path.expanduser(args.target_dir)
if os.path.exists(target_dir):
raise ValueError('Target dir ['+target_dir+'] exists. Remove it first')
download_dir = os.path.expanduser(args.download_dir)
train_tar_fname = os.path.join(download_dir, _TRAIN_TAR)
check_file(train_tar_fname, args.checksum, _TRAIN_TAR_SHA1)
val_tar_fname = os.path.join(download_dir, _VAL_TAR)
check_file(val_tar_fname, args.checksum, _VAL_TAR_SHA1)
build_rec = args.with_rec
if build_rec:
os.makedirs(os.path.join(target_dir, 'rec'))
extract_train(train_tar_fname, os.path.join(target_dir, 'train'), build_rec, args.num_thread)
extract_val(val_tar_fname, os.path.join(target_dir, 'val'), build_rec, args.num_thread)
if __name__ == '__main__':
main()