Skip to content

Commit

Permalink
transforms (zhanghang1989#272)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhanghang1989 committed May 4, 2020
1 parent f70fa97 commit b8d83b0
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 15 deletions.
2 changes: 1 addition & 1 deletion encoding/lib/cpu/operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ py::array_t<float> apply_transform(int H, int W, int C, py::array_t<float> img,
auto ctm_buf = ctm.request();

// printf("H: %d, W: %d, C: %d\n", H, W, C);
py::array_t<float> result{img_buf.size};
py::array_t<float> result{(unsigned long)img_buf.size};
auto res_buf = result.request();

float *img_ptr = (float *)img_buf.ptr;
Expand Down
20 changes: 10 additions & 10 deletions encoding/transforms/get_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def get_transform(dataset, base_size=None, crop_size=224, rand_aug=False, etrans
CenterCrop(crop_size),
])
train_transforms.extend([
RandomHorizontalFlip(),
RandomHorizontalFlip(),
ColorJitter(0.4, 0.4, 0.4),
ToTensor(),
Lighting(0.1, _imagenet_pca['eigval'], _imagenet_pca['eigvec']),
Expand Down Expand Up @@ -65,16 +65,16 @@ def get_transform(dataset, base_size=None, crop_size=224, rand_aug=False, etrans
normalize,
])
elif dataset == 'cifar10':
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465),
(0.2023, 0.1994, 0.2010)),
transform_train = Compose([
RandomCrop(32, padding=4),
RandomHorizontalFlip(),
ToTensor(),
Normalize((0.4914, 0.4822, 0.4465),
(0.2023, 0.1994, 0.2010)),
])
transform_val = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465),
transform_val = Compose([
ToTensor(),
Normalize((0.4914, 0.4822, 0.4465),
(0.2023, 0.1994, 0.2010)),
])
return transform_train, transform_val
Expand Down
11 changes: 7 additions & 4 deletions encoding/utils/lr_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,11 @@ class LR_Scheduler(object):
iters_per_epoch: number of iterations per epoch
"""
def __init__(self, mode, base_lr, num_epochs, iters_per_epoch=0,
lr_step=0, warmup_epochs=0):
lr_step=0, warmup_epochs=0, quiet=False):
self.mode = mode
print('Using {} LR scheduler with warm-up epochs of {}!'.format(self.mode, warmup_epochs))
self.quiet = quiet
if not quiet:
print('Using {} LR scheduler with warm-up epochs of {}!'.format(self.mode, warmup_epochs))
if mode == 'step':
assert lr_step
self.base_lr = base_lr
Expand All @@ -57,8 +59,9 @@ def __call__(self, optimizer, i, epoch, best_pred):
else:
raise NotImplemented
if epoch > self.epoch and (epoch == 0 or best_pred > 0.0):
print('\n=>Epoch %i, learning rate = %.4f, \
previous best = %.4f' % (epoch, lr, best_pred))
if not self.quiet:
print('\n=>Epoch %i, learning rate = %.4f, \
previous best = %.4f' % (epoch, lr, best_pred))
self.epoch = epoch
assert lr >= 0
self._adjust_learning_rate(optimizer, lr)
Expand Down

0 comments on commit b8d83b0

Please sign in to comment.