Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add option for BatchNorm to handle batches of size one #5530

Merged
merged 12 commits into from
Sep 30, 2022

Conversation

Padarn
Copy link
Contributor

@Padarn Padarn commented Sep 25, 2022

The purpose of this PR is to provide a way for BatchNorm to work even when the batch size is one. This comes up when training heterogeneous graphs with rare node types.

This PR addresses #5529.

@Padarn Padarn self-assigned this Sep 25, 2022
@Padarn
Copy link
Contributor Author

Padarn commented Sep 25, 2022

This is still WIP because I realised my simple test case would fail due to the base behaviour of BatchNorm1D:

        r"""
        Decide whether the mini-batch stats should be used for normalization rather than the buffers.
        Mini-batch stats are used in training mode, and in eval mode when buffers are None.
        """
        if self.training:
            bn_training = True
        else:
            bn_training = (self.running_mean is None) and (self.running_var is None)

I currently think the best way to support this is to update mean and var even when batch is size one. But variance is still zero in this case.

@codecov
Copy link

codecov bot commented Sep 25, 2022

Codecov Report

Merging #5530 (2709908) into master (00c3a5d) will increase coverage by 0.01%.
The diff coverage is 100.00%.

❗ Current head 2709908 differs from pull request most recent head 611edb0. Consider uploading reports for the commit 611edb0 to get more accurate results

@@            Coverage Diff             @@
##           master    #5530      +/-   ##
==========================================
+ Coverage   83.67%   83.69%   +0.01%     
==========================================
  Files         346      346              
  Lines       19017    19013       -4     
==========================================
- Hits        15913    15912       -1     
+ Misses       3104     3101       -3     
Impacted Files Coverage Δ
torch_geometric/nn/norm/batch_norm.py 100.00% <100.00%> (ø)
torch_geometric/utils/scatter.py 66.66% <0.00%> (-33.34%) ⬇️
torch_geometric/sampler/hgt_sampler.py 100.00% <0.00%> (ø)
torch_geometric/loader/link_neighbor_loader.py 100.00% <0.00%> (ø)
torch_geometric/sampler/utils.py 80.59% <0.00%> (+0.07%) ⬆️
torch_geometric/sampler/neighbor_sampler.py 92.25% <0.00%> (+0.38%) ⬆️
torch_geometric/nn/dense/linear.py 83.96% <0.00%> (+0.51%) ⬆️
torch_geometric/sampler/base.py 96.77% <0.00%> (+5.86%) ⬆️

📣 We’re building smart automated test selection to slash your CI/CD build times. Learn more

@EdisonLeeeee
Copy link
Contributor

Would it make sense if we repeat x twice before inputting the batch norm?

x = torch.randn(1, 16)
x = torch.cat([x, x], dim=0)
bn = torch.nn.BatchNorm1d(16)

x = bn(x)[0].unsqueeze(0) # shape [1, 16]

@@ -30,18 +30,26 @@ class BatchNorm(torch.nn.Module):
:obj:`False`, this module does not track such statistics and always
uses batch statistics in both training and eval modes.
(default: :obj:`True`)
skip_no_batch (bool, optional): If set to :obj:`True`, batches with
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this is a good choice since now we apply no normalization at all. Instead, IMO it is better to use training statistics for normalization.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the feedback. My only concern for this is that it may be based on a small subset unless we add to the running mean/variance even when the batch size is one. WDYT?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you clarify? What's the problem with using running mean/variance here for normalization here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've updated baed on your suggestion.

But to clarify: In the BatchNorm layer, we calculate mean and variance using a running_mean and running_var which are exponentially smoothed. But if we switch to eval mode for the cases where batch size is one, we will exclude those examples from the mean/var calculation.

I also suspect that with small batch sizes the variance/mean calculated this way will not approximate well the population.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, these examples will be excluded from the mean/var calculation, but I guess there is no way around this (it would be the case in your previous implementation as well). I think the longer we train, the more stable the running mean/var will become. Ideally, batches with only a single node for a node type should be not too frequent, so excluding them should not cause too much harm.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but I guess there is no way around this (it would be the case in your previous implementation as well).

Yes agreed. The previous implementation was not better in this respect. It was just simpler and I thought I'd raise the question.

@rusty1s rusty1s changed the title [WIP] Add skip for batch norm [WIP] Add skip option for BatchNorm Sep 26, 2022
@Padarn Padarn changed the title [WIP] Add skip option for BatchNorm Add option for BatchNorm to handle batches of size one. Sep 26, 2022
running_mean = self.module.running_mean
running_var = self.module.running_var
if running_mean is None or running_var is None:
self.module.running_var = torch.ones(self.in_channels)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this really be None? I assume they will be initialized by PyTorch.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes: https://github.com/pytorch/pytorch/blob/9c036aa112b0a8fd9afb824d1fda058e2b66ba1d/torch/nn/modules/batchnorm.py#L68

Its a bit of an odd case, but this combined with: https://github.com/pytorch/pytorch/blob/9c036aa112b0a8fd9afb824d1fda058e2b66ba1d/torch/nn/modules/batchnorm.py#L175 will cause errors

As a side note I suspect this line should be:

if not self.training or not self.track_running_stats

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.module.running_var = torch.ones(self.in_channels)

This seems would cause a device mismatch if the module is already on CUDA.

Copy link
Contributor Author

@Padarn Padarn Sep 26, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. I guess the important thing is probably to match the device of x. But to simplify I just added them to init

def __init__(self, in_channels: int, eps: float = 1e-5,
momentum: float = 0.1, affine: bool = True,
track_running_stats: bool = True,
allow_no_batch: bool = False):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
allow_no_batch: bool = False):
allow_no_batch: bool = True):

Let's make this True by default?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Personally feel making it False by default might be better, as its a bit of a corner case and doesn't match the pytorch behaviour.

training = self.module.training
running_mean = self.module.running_mean
running_var = self.module.running_var
if running_mean is None or running_var is None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am still not sure about this while looking at https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/batchnorm.py#L53. I think running_mean is only None in case of track_running_stats=False, in which case we should also error out.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh hmm, yes, you're right, the problem only happens if track_running_stats is False.

In our current implementation, if track_running_stats=False we do not raise an exception in training mode. What would be the logic for us to do so in this specific case?

Copy link
Contributor Author

@Padarn Padarn Sep 29, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The tests in test/nn/norm/test_batch_norm.py all pass with this implementation, removing this will throw an error in the single combination

  • track_running_stats=False
  • batch_size=1

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it is impossible to support track_running_stats=False with batch_size=1. The current implementation does simply not do any normalization which is probably not desired. I would simply error out in this case TBH.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm yeah I guess that makes the most sense to me. Its a bit ambiguous but I haven't got a concrete use case for supporting this either so have updated based on your suggestion.

@@ -30,18 +30,41 @@ class BatchNorm(torch.nn.Module):
:obj:`False`, this module does not track such statistics and always
uses batch statistics in both training and eval modes.
(default: :obj:`True`)
allow_no_batch (bool, optional): If set to :obj:`True`, batches with
only a single element will work as though in training mode. That is
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
only a single element will work as though in training mode. That is
only a single element will work as in evaluation. That is

allow_no_batch (bool, optional): If set to :obj:`True`, batches with
only a single element will work as though in training mode. That is
the running mean and variance will be used.
(default: :obj:`False`)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
(default: :obj:`False`)
Requires :obj:`track_running_stats=True`. (default: :obj:`False`)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm just saw this. Sorry I'm not clear on why this is the desired behaviour?

torch_geometric/nn/norm/batch_norm.py Outdated Show resolved Hide resolved
@Padarn
Copy link
Contributor Author

Padarn commented Sep 29, 2022

Would it make sense if we repeat x twice before inputting the batch norm?

x = torch.randn(1, 16)
x = torch.cat([x, x], dim=0)
bn = torch.nn.BatchNorm1d(16)

x = bn(x)[0].unsqueeze(0) # shape [1, 16]

I missed this @EdisonLeeeee. What would be the reason to do this?

@EdisonLeeeee
Copy link
Contributor

@Padarn I thought this would avoid the error when batch size 1 while keeping running_mean updated in such a case. Just a proposal.

@Padarn
Copy link
Contributor Author

Padarn commented Sep 29, 2022

Oh I see. I'm not sure how well this would work because the variance would still be zero for the batch and so you would have a division by zero.

@EdisonLeeeee
Copy link
Contributor

Interestingly, it worked without any errors. It seems that PyTorch is able to handle such cases properly except that batch size =1.

@Padarn
Copy link
Contributor Author

Padarn commented Sep 29, 2022 via email

@rusty1s rusty1s changed the title Add option for BatchNorm to handle batches of size one. Add option for BatchNorm to handle batches of size one Sep 29, 2022
torch_geometric/nn/norm/batch_norm.py Outdated Show resolved Hide resolved
@rusty1s rusty1s enabled auto-merge (squash) September 30, 2022 06:02
@rusty1s rusty1s merged commit a49fd34 into master Sep 30, 2022
@rusty1s rusty1s deleted the padarn/batch_norm_singleton branch September 30, 2022 06:05
JakubPietrakIntel pushed a commit to JakubPietrakIntel/pytorch_geometric that referenced this pull request Nov 25, 2022
* add skip for batch norm

* add changelog

* device of default mean/var

* device of default mean/var

* device of default mean/var

* device of default mean/var

* update naming for new arugment

* require track running

* Update torch_geometric/nn/norm/batch_norm.py

* Update torch_geometric/nn/norm/batch_norm.py

* update

Co-authored-by: Matthias Fey <matthias.fey@tu-dortmund.de>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants