Skip to content

Commit

Permalink
Change data shape in data_feed.py (open-mmlab#3026)
Browse files Browse the repository at this point in the history
  • Loading branch information
qingqing01 committed Aug 6, 2019
1 parent b2b359c commit 0520640
Showing 1 changed file with 14 additions and 13 deletions.
27 changes: 14 additions & 13 deletions ppdet/data/data_feed.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,9 @@
DecodeImage, MixupImage, NormalizeBox, NormalizeImage, RandomDistort,
RandomFlipImage, RandomInterpImage, ResizeImage, ExpandImage, CropImage,
Permute)
from ppdet.data.transform.arrange_sample import (ArrangeRCNN, ArrangeTestRCNN,
ArrangeSSD, ArrangeTestSSD,
ArrangeYOLO, ArrangeEvalYOLO,
ArrangeTestYOLO)
from ppdet.data.transform.arrange_sample import (
ArrangeRCNN, ArrangeTestRCNN, ArrangeSSD, ArrangeTestSSD, ArrangeYOLO,
ArrangeEvalYOLO, ArrangeTestYOLO)

__all__ = [
'PadBatch', 'MultiScale', 'RandomShape', 'DataSet', 'CocoDataSet',
Expand Down Expand Up @@ -138,8 +137,8 @@ def create_reader(feed, max_iter=0, args_path=None, my_source=None):
ops.append(op_dict)
transform_config['OPS'] = ops

return Reader.create(feed.mode, data_config,
transform_config, max_iter, my_source)
return Reader.create(feed.mode, data_config, transform_config, max_iter,
my_source)


# XXX batch transforms are only stubs for now, actually handled by `post_map`
Expand Down Expand Up @@ -412,6 +411,7 @@ def __init__(self,
num_workers=num_workers)


# yapf: disable
@register
class FasterRCNNTrainFeed(DataFeed):
__doc__ = DataFeed.__doc__
Expand All @@ -422,7 +422,7 @@ def __init__(self,
'image', 'im_info', 'im_id', 'gt_box', 'gt_label',
'is_crowd'
],
image_shape=[3, 1333, 800],
image_shape=[3, 800, 1333],
sample_transforms=[
DecodeImage(to_rgb=True),
RandomFlipImage(prob=0.5),
Expand Down Expand Up @@ -467,7 +467,7 @@ def __init__(self,
dataset=CocoDataSet(COCO_VAL_ANNOTATION,
COCO_VAL_IMAGE_DIR).__dict__,
fields=['image', 'im_info', 'im_id', 'im_shape'],
image_shape=[3, 1333, 800],
image_shape=[3, 800, 1333],
sample_transforms=[
DecodeImage(to_rgb=True),
NormalizeImage(mean=[0.485, 0.456, 0.406],
Expand Down Expand Up @@ -508,7 +508,7 @@ def __init__(self,
dataset=SimpleDataSet(COCO_VAL_ANNOTATION,
COCO_VAL_IMAGE_DIR).__dict__,
fields=['image', 'im_info', 'im_id', 'im_shape'],
image_shape=[3, 1333, 800],
image_shape=[3, 800, 1333],
sample_transforms=[
DecodeImage(to_rgb=True),
NormalizeImage(mean=[0.485, 0.456, 0.406],
Expand Down Expand Up @@ -555,7 +555,7 @@ def __init__(self,
'image', 'im_info', 'im_id', 'gt_box', 'gt_label',
'is_crowd', 'gt_mask'
],
image_shape=[3, 1333, 800],
image_shape=[3, 800, 1333],
sample_transforms=[
DecodeImage(to_rgb=True),
RandomFlipImage(prob=0.5, is_mask_flip=True),
Expand Down Expand Up @@ -601,7 +601,7 @@ def __init__(self,
dataset=CocoDataSet(COCO_VAL_ANNOTATION,
COCO_VAL_IMAGE_DIR).__dict__,
fields=['image', 'im_info', 'im_id', 'im_shape'],
image_shape=[3, 1333, 800],
image_shape=[3, 800, 1333],
sample_transforms=[
DecodeImage(to_rgb=True),
NormalizeImage(mean=[0.485, 0.456, 0.406],
Expand Down Expand Up @@ -647,7 +647,7 @@ def __init__(self,
dataset=SimpleDataSet(COCO_VAL_ANNOTATION,
COCO_VAL_IMAGE_DIR).__dict__,
fields=['image', 'im_info', 'im_id', 'im_shape'],
image_shape=[3, 1333, 800],
image_shape=[3, 800, 1333],
sample_transforms=[
DecodeImage(to_rgb=True),
NormalizeImage(
Expand Down Expand Up @@ -900,7 +900,7 @@ class YoloEvalFeed(DataFeed):
def __init__(self,
dataset=CocoDataSet(COCO_VAL_ANNOTATION,
COCO_VAL_IMAGE_DIR).__dict__,
fields=['image', 'im_size', 'im_id', 'gt_box',
fields=['image', 'im_size', 'im_id', 'gt_box',
'gt_label', 'is_difficult'],
image_shape=[3, 608, 608],
sample_transforms=[
Expand Down Expand Up @@ -985,3 +985,4 @@ def __init__(self,
use_process=use_process)
self.mode = 'TEST'
self.bufsize = 128
# yapf: enable

0 comments on commit 0520640

Please sign in to comment.