Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add option for BatchNorm to handle batches of size one #5530

Merged
merged 12 commits into from
Sep 30, 2022
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `BaseStorage.get()` functionality ([#5240](https://github.com/pyg-team/pytorch_geometric/pull/5240))
- Added a test to confirm that `to_hetero` works with `SparseTensor` ([#5222](https://github.com/pyg-team/pytorch_geometric/pull/5222))
### Changed
- Changed `BatchNorm` to allow for batches of size one during training ([#5530](https://github.com/pyg-team/pytorch_geometric/pull/5530))
- Fixed a bug when applying several scalers with `PNAConv` ([#5514](https://github.com/pyg-team/pytorch_geometric/issues/5514))
- Allow `.` in `ParameterDict` key names ([#5494](https://github.com/pyg-team/pytorch_geometric/pull/5494))
- Renamed `drop_unconnected_nodes` to `drop_unconnected_node_types` and `drop_orig_edges` to `drop_orig_edge_types` in `AddMetapaths` ([#5490](https://github.com/pyg-team/pytorch_geometric/pull/5490))
Expand Down
9 changes: 9 additions & 0 deletions test/nn/norm/test_batch_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,12 @@ def test_batch_norm(conf):

out = norm(x)
assert out.size() == (100, 16)

x = torch.randn(1, 16)
with pytest.raises(ValueError):
_ = norm(x)

norm = BatchNorm(16, affine=conf, track_running_stats=conf,
allow_single_element=True)
out = norm(x)
assert torch.allclose(out, x)
23 changes: 21 additions & 2 deletions torch_geometric/nn/norm/batch_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,18 +30,37 @@ class BatchNorm(torch.nn.Module):
:obj:`False`, this module does not track such statistics and always
uses batch statistics in both training and eval modes.
(default: :obj:`True`)
allow_single_element (bool, optional): If set to :obj:`True`, batches
with only a single element will work as though in evaluation.
rusty1s marked this conversation as resolved.
Show resolved Hide resolved
That is the running mean and variance will be used.
Requires :obj:`track_running_stats=True`. (default: :obj:`False`)
"""
def __init__(self, in_channels, eps=1e-5, momentum=0.1, affine=True,
track_running_stats=True):
def __init__(self, in_channels: int, eps: float = 1e-5,
momentum: float = 0.1, affine: bool = True,
track_running_stats: bool = True,
allow_single_element: bool = False):
super().__init__()

if allow_single_element and not track_running_stats:
raise ValueError("'allow_single_element' requires "
"'track_running_stats' to be set to `True`")

self.module = torch.nn.BatchNorm1d(in_channels, eps, momentum, affine,
track_running_stats)
self.in_channels = in_channels
self.allow_single_element = allow_single_element

def reset_parameters(self):
self.module.reset_parameters()

def forward(self, x: Tensor) -> Tensor:
""""""
if self.allow_single_element and x.size(0) <= 1:
training = self.module.training
self.module.eval()
out = self.module(x)
self.module.training = training
return out
return self.module(x)

def __repr__(self):
Expand Down