forked from zhanghang1989/PyTorch-Encoding
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- ADE20K training model - Amazon legal approval fixes zhanghang1989#69
- Loading branch information
1 parent
9bc7053
commit 32e382b
Showing
26 changed files
with
610 additions
and
114 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,126 @@ | ||
import os | ||
from tqdm import trange | ||
from PIL import Image, ImageOps, ImageFilter | ||
import numpy as np | ||
import torch | ||
|
||
from .base import BaseDataset | ||
|
||
""" | ||
NUM_CHANNEL = 91 | ||
[] background | ||
[5] airplane | ||
[2] bicycle | ||
[16] bird | ||
[9] boat | ||
[44] bottle | ||
[6] bus | ||
[3] car | ||
[17] cat | ||
[62] chair | ||
[21] cow | ||
[67] dining table | ||
[18] dog | ||
[19] horse | ||
[4] motorcycle | ||
[1] person | ||
[64] potted plant | ||
[20] sheep | ||
[63] couch | ||
[7] train | ||
[72] tv | ||
""" | ||
CAT_LIST = [0, 5, 2, 16, 9, 44, 6, 3, 17, 62, 21, 67, 18, 19, 4, | ||
1, 64, 20, 63, 7, 72] | ||
|
||
|
||
class COCOSegmentation(BaseDataset): | ||
def __init__(self, root=os.path.expanduser('~/.encoding/data'), split='train', | ||
mode=None, transform=None, target_transform=None): | ||
super(COCOSegmentation, self).__init__( | ||
root, split, mode, transform, target_transform) | ||
from pycocotools.coco import COCO | ||
from pycocotools import mask | ||
if mode == 'train': | ||
print('train set') | ||
ann_file = os.path.join(root, 'coco/annotations/instances_train2014.json') | ||
ids_file = os.path.join(root, 'coco/annotations/train_ids.pth') | ||
root = os.path.join(root, 'coco/train2014') | ||
else: | ||
print('val set') | ||
ann_file = os.path.join(root, 'coco/annotations/instances_val2014.json') | ||
ids_file = os.path.join(root, 'coco/annotations/val_ids.pth') | ||
root = os.path.join(root, 'coco/val2014') | ||
self.coco = COCO(ann_file) | ||
self.coco_mask = mask | ||
if os.path.exists(ids_file): | ||
self.ids = torch.load(ids_file) | ||
else: | ||
ids = list(self.coco.imgs.keys()) | ||
self.ids = self._preprocess(ids, ids_file) | ||
self.transform = transform | ||
self.target_transform = target_transform | ||
|
||
def __getitem__(self, index): | ||
coco = self.coco | ||
img_id = self.ids[index] | ||
img_metadata = coco.loadImgs(img_id)[0] | ||
path = img_metadata['file_name'] | ||
img = Image.open(os.path.join(self.root, path)).convert('RGB') | ||
cocotarget = coco.loadAnns(coco.getAnnIds(imgIds=img_id)) | ||
mask = Image.fromarray(self._gen_seg_mask(cocotarget, img_metadata['height'], | ||
img_metadata['width'])) | ||
# synchrosized transform | ||
if self.mode == 'train': | ||
img, mask = self._sync_transform(img, mask) | ||
elif self.mode == 'val': | ||
img, mask = self._val_sync_transform(img, mask) | ||
else: | ||
assert self.mode == 'testval' | ||
mask = self._mask_transform(mask) | ||
# general resize, normalize and toTensor | ||
if self.transform is not None: | ||
img = self.transform(img) | ||
if self.target_transform is not None: | ||
mask = self.target_transform(mask) | ||
return img, mask | ||
|
||
def __len__(self): | ||
return len(self.ids) | ||
|
||
def _gen_seg_mask(self, target, h, w): | ||
mask = np.zeros((h, w), dtype=np.uint8) | ||
coco_mask = self.coco_mask | ||
for instance in target: | ||
rle = coco_mask.frPyObjects(instance['segmentation'], h, w) | ||
m = coco_mask.decode(rle) | ||
cat = instance['category_id'] | ||
if cat in CAT_LIST: | ||
c = CAT_LIST.index(cat) | ||
else: | ||
continue | ||
if len(m.shape) < 3: | ||
mask[:, :] += (mask == 0) * (m * c) | ||
else: | ||
mask[:, :] += (mask == 0) * (((np.sum(m, axis=2)) > 0) * c).astype(np.uint8) | ||
return mask | ||
|
||
def _preprocess(self, ids, ids_file): | ||
print("Preprocessing mask, this will take a while." + \ | ||
"But don't worry, it only run once for each split.") | ||
tbar = trange(len(ids)) | ||
new_ids = [] | ||
for i in tbar: | ||
img_id = ids[i] | ||
cocotarget = self.coco.loadAnns(self.coco.getAnnIds(imgIds=img_id)) | ||
img_metadata = self.coco.loadImgs(img_id)[0] | ||
mask = self._gen_seg_mask(cocotarget, img_metadata['height'], | ||
img_metadata['width']) | ||
# more than 1k pixels | ||
if (mask > 0).sum() > 1000: | ||
new_ids.append(img_id) | ||
tbar.set_description('Doing: {}/{}, got {} qualified images'.\ | ||
format(i, len(ids), len(new_ids))) | ||
print('Found number of qualified images: ', len(new_ids)) | ||
torch.save(new_ids, ids_file) | ||
return new_ids |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.