Skip to content

Commit

Permalink
Bug fix in the MLP layer: torch.squeeze was removing the batch size w…
Browse files Browse the repository at this point in the history
…hen equal to 1 (facebookresearch#179)

Summary:
The fix comes with unit tests so that error will be caught up if it comes back.

Pull Request resolved: facebookresearch#179

Reviewed By: prigoyal

Differential Revision: D26338689

Pulled By: QuentinDuval

fbshipit-source-id: be9aaf668faf046060df321b62a462d79c7b0d8b
  • Loading branch information
QuentinDuval authored and facebook-github-bot committed Feb 16, 2021
1 parent 471d73f commit 75d472b
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 5 deletions.
61 changes: 61 additions & 0 deletions tests/test_mlp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import unittest

import torch
from vissl.models.heads import LinearEvalMLP, MLP
from vissl.utils.hydra_config import AttrDict


class TestMLP(unittest.TestCase):
"""
Unit test to verify that correct construction of MLP layers
and linear evaluation MLP layers
"""

MODEL_CONFIG = AttrDict({
"HEAD": {
"BATCHNORM_EPS": 1e-6,
"BATCHNORM_MOMENTUM": 0.99,
"PARAMS_MULTIPLIER": 1.0,
}
})

def test_mlp(self):
mlp = MLP(self.MODEL_CONFIG, dims=[2048, 100])

x = torch.randn(size=(4, 2048))
out = mlp(x)
assert out.shape == torch.Size([4, 100])

x = torch.randn(size=(1, 2048))
out = mlp(x)
assert out.shape == torch.Size([1, 100])

def test_mlp_reshaping(self):
mlp = MLP(self.MODEL_CONFIG, dims=[2048, 100])

x = torch.randn(size=(1, 2048, 1, 1))
out = mlp(x)
assert out.shape == torch.Size([1, 100])

def test_mlp_catch_bad_shapes(self):
mlp = MLP(self.MODEL_CONFIG, dims=[2048, 100])

x = torch.randn(size=(1, 2048, 2, 1))
with self.assertRaises(AssertionError) as context:
mlp(x)
assert context.exception is not None

def test_eval_mlp_shape(self):
eval_mlp = LinearEvalMLP(
self.MODEL_CONFIG,
in_channels=2048,
dims=[2048 * 2 * 2, 1000],
)

resnet_feature_map = torch.randn(size=(4, 2048, 2, 2))
out = eval_mlp(resnet_feature_map)
assert out.shape == torch.Size([4, 1000])

resnet_feature_map = torch.randn(size=(1, 2048, 2, 2))
out = eval_mlp(resnet_feature_map)
assert out.shape == torch.Size([1, 1000])
9 changes: 5 additions & 4 deletions vissl/models/heads/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,10 @@ def forward(self, batch: torch.Tensor):
len(batch) == 1
), "MLP input should be either a tensor (2D, 4D) or list containing 1 tensor."
batch = batch[0]
batch = torch.squeeze(batch)
assert (
len(batch.shape) <= 2
), f"MLP expected 2D input tensor or 4D tensor of shape NxCx1x1. got: {batch.shape}"
if batch.ndim > 2:
assert (
all(d == 1 for d in batch.shape[2:])
), f"MLP expected 2D input tensor or 4D tensor of shape NxCx1x1. got: {batch.shape}"
batch = batch.reshape((batch.size(0), batch.size(1)))
out = self.clf(batch)
return out
2 changes: 1 addition & 1 deletion vissl/models/model_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def is_feature_extractor_model(model_config):
If the model is a feature extractor model:
- evaluation model is on
- trunk is frozen
- number of features specified for features extratction > 0
- number of features specified for features extraction > 0
"""
if (
model_config.FEATURE_EVAL_SETTINGS.EVAL_MODE_ON
Expand Down

0 comments on commit 75d472b

Please sign in to comment.