Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add option for BatchNorm to handle batches of size one #5530

Merged
merged 12 commits into from
Sep 30, 2022
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `BaseStorage.get()` functionality ([#5240](https://github.com/pyg-team/pytorch_geometric/pull/5240))
- Added a test to confirm that `to_hetero` works with `SparseTensor` ([#5222](https://github.com/pyg-team/pytorch_geometric/pull/5222))
### Changed
- Changed `BatchNorm` to allow for batches of size one during training ([#5530](https://github.com/pyg-team/pytorch_geometric/pull/5530))
- Fixed a bug when applying several scalers with `PNAConv` ([#5514](https://github.com/pyg-team/pytorch_geometric/issues/5514))
- Allow `.` in `ParameterDict` key names ([#5494](https://github.com/pyg-team/pytorch_geometric/pull/5494))
- Renamed `drop_unconnected_nodes` to `drop_unconnected_node_types` and `drop_orig_edges` to `drop_orig_edge_types` in `AddMetapaths` ([#5490](https://github.com/pyg-team/pytorch_geometric/pull/5490))
Expand Down
9 changes: 9 additions & 0 deletions test/nn/norm/test_batch_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,12 @@ def test_batch_norm(conf):

out = norm(x)
assert out.size() == (100, 16)

x = torch.randn(1, 16)
with pytest.raises(ValueError):
_ = norm(x)

norm = BatchNorm(16, affine=conf, track_running_stats=conf,
allow_no_batch=True)
out = norm(x)
assert torch.allclose(out, x)
25 changes: 23 additions & 2 deletions torch_geometric/nn/norm/batch_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,18 +30,39 @@ class BatchNorm(torch.nn.Module):
:obj:`False`, this module does not track such statistics and always
uses batch statistics in both training and eval modes.
(default: :obj:`True`)
allow_no_batch (bool, optional): If set to :obj:`True`, batches with
Padarn marked this conversation as resolved.
Show resolved Hide resolved
only a single element will work as though in training mode. That is
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
only a single element will work as though in training mode. That is
only a single element will work as in evaluation. That is

the running mean and variance will be used.
(default: :obj:`False`)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
(default: :obj:`False`)
Requires :obj:`track_running_stats=True`. (default: :obj:`False`)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm just saw this. Sorry I'm not clear on why this is the desired behaviour?

"""
def __init__(self, in_channels, eps=1e-5, momentum=0.1, affine=True,
track_running_stats=True):
def __init__(self, in_channels: int, eps: float = 1e-5,
momentum: float = 0.1, affine: bool = True,
track_running_stats: bool = True,
allow_no_batch: bool = False):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
allow_no_batch: bool = False):
allow_no_batch: bool = True):

Let's make this True by default?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Personally feel making it False by default might be better, as its a bit of a corner case and doesn't match the pytorch behaviour.

super().__init__()
self.module = torch.nn.BatchNorm1d(in_channels, eps, momentum, affine,
track_running_stats)
self.allow_no_batch = allow_no_batch
self.in_channels = in_channels

def reset_parameters(self):
self.module.reset_parameters()

def forward(self, x: Tensor) -> Tensor:
""""""
if self.allow_no_batch and x.size(0) <= 1:
training = self.module.training
running_mean = self.module.running_mean
running_var = self.module.running_var
if running_mean is None or running_var is None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am still not sure about this while looking at https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/batchnorm.py#L53. I think running_mean is only None in case of track_running_stats=False, in which case we should also error out.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh hmm, yes, you're right, the problem only happens if track_running_stats is False.

In our current implementation, if track_running_stats=False we do not raise an exception in training mode. What would be the logic for us to do so in this specific case?

Copy link
Contributor Author

@Padarn Padarn Sep 29, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The tests in test/nn/norm/test_batch_norm.py all pass with this implementation, removing this will throw an error in the single combination

  • track_running_stats=False
  • batch_size=1

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it is impossible to support track_running_stats=False with batch_size=1. The current implementation does simply not do any normalization which is probably not desired. I would simply error out in this case TBH.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm yeah I guess that makes the most sense to me. Its a bit ambiguous but I haven't got a concrete use case for supporting this either so have updated based on your suggestion.

self.module.running_var = torch.ones(self.in_channels)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this really be None? I assume they will be initialized by PyTorch.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes: https://github.com/pytorch/pytorch/blob/9c036aa112b0a8fd9afb824d1fda058e2b66ba1d/torch/nn/modules/batchnorm.py#L68

Its a bit of an odd case, but this combined with: https://github.com/pytorch/pytorch/blob/9c036aa112b0a8fd9afb824d1fda058e2b66ba1d/torch/nn/modules/batchnorm.py#L175 will cause errors

As a side note I suspect this line should be:

if not self.training or not self.track_running_stats

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.module.running_var = torch.ones(self.in_channels)

This seems would cause a device mismatch if the module is already on CUDA.

Copy link
Contributor Author

@Padarn Padarn Sep 26, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. I guess the important thing is probably to match the device of x. But to simplify I just added them to init

self.module.running_mean = torch.zeros(self.in_channels)
self.module.eval()
out = self.module(x)
self.module.training = training
self.module.running_mean = running_mean
self.module.running_var = running_var
return out
return self.module(x)

def __repr__(self):
Expand Down