-
Notifications
You must be signed in to change notification settings - Fork 3.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add option for BatchNorm
to handle batches of size one
#5530
Conversation
This is still WIP because I realised my simple test case would fail due to the base behaviour of r"""
Decide whether the mini-batch stats should be used for normalization rather than the buffers.
Mini-batch stats are used in training mode, and in eval mode when buffers are None.
"""
if self.training:
bn_training = True
else:
bn_training = (self.running_mean is None) and (self.running_var is None) I currently think the best way to support this is to update mean and var even when batch is size one. But variance is still zero in this case. |
Codecov Report
@@ Coverage Diff @@
## master #5530 +/- ##
==========================================
+ Coverage 83.67% 83.69% +0.01%
==========================================
Files 346 346
Lines 19017 19013 -4
==========================================
- Hits 15913 15912 -1
+ Misses 3104 3101 -3
📣 We’re building smart automated test selection to slash your CI/CD build times. Learn more |
Would it make sense if we repeat x = torch.randn(1, 16)
x = torch.cat([x, x], dim=0)
bn = torch.nn.BatchNorm1d(16)
x = bn(x)[0].unsqueeze(0) # shape [1, 16] |
@@ -30,18 +30,26 @@ class BatchNorm(torch.nn.Module): | |||
:obj:`False`, this module does not track such statistics and always | |||
uses batch statistics in both training and eval modes. | |||
(default: :obj:`True`) | |||
skip_no_batch (bool, optional): If set to :obj:`True`, batches with |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think this is a good choice since now we apply no normalization at all. Instead, IMO it is better to use training statistics for normalization.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the feedback. My only concern for this is that it may be based on a small subset unless we add to the running mean/variance even when the batch size is one. WDYT?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you clarify? What's the problem with using running mean/variance here for normalization here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've updated baed on your suggestion.
But to clarify: In the BatchNorm
layer, we calculate mean and variance using a running_mean
and running_var
which are exponentially smoothed. But if we switch to eval
mode for the cases where batch size is one, we will exclude those examples from the mean/var calculation.
I also suspect that with small batch sizes the variance/mean calculated this way will not approximate well the population.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, these examples will be excluded from the mean/var calculation, but I guess there is no way around this (it would be the case in your previous implementation as well). I think the longer we train, the more stable the running mean/var will become. Ideally, batches with only a single node for a node type should be not too frequent, so excluding them should not cause too much harm.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
but I guess there is no way around this (it would be the case in your previous implementation as well).
Yes agreed. The previous implementation was not better in this respect. It was just simpler and I thought I'd raise the question.
BatchNorm
BatchNorm
BatchNorm
to handle batches of size one.
running_mean = self.module.running_mean | ||
running_var = self.module.running_var | ||
if running_mean is None or running_var is None: | ||
self.module.running_var = torch.ones(self.in_channels) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can this really be None
? I assume they will be initialized by PyTorch.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Its a bit of an odd case, but this combined with: https://github.com/pytorch/pytorch/blob/9c036aa112b0a8fd9afb824d1fda058e2b66ba1d/torch/nn/modules/batchnorm.py#L175 will cause errors
As a side note I suspect this line should be:
if not self.training or not self.track_running_stats
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self.module.running_var = torch.ones(self.in_channels)
This seems would cause a device mismatch if the module is already on CUDA.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point. I guess the important thing is probably to match the device of x
. But to simplify I just added them to init
def __init__(self, in_channels: int, eps: float = 1e-5, | ||
momentum: float = 0.1, affine: bool = True, | ||
track_running_stats: bool = True, | ||
allow_no_batch: bool = False): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
allow_no_batch: bool = False): | |
allow_no_batch: bool = True): |
Let's make this True
by default?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Personally feel making it False by default might be better, as its a bit of a corner case and doesn't match the pytorch behaviour.
training = self.module.training | ||
running_mean = self.module.running_mean | ||
running_var = self.module.running_var | ||
if running_mean is None or running_var is None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am still not sure about this while looking at https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/batchnorm.py#L53. I think running_mean
is only None
in case of track_running_stats=False
, in which case we should also error out.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh hmm, yes, you're right, the problem only happens if track_running_stats
is False
.
In our current implementation, if track_running_stats=False
we do not raise an exception in training mode. What would be the logic for us to do so in this specific case?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The tests in test/nn/norm/test_batch_norm.py
all pass with this implementation, removing this will throw an error in the single combination
track_running_stats=False
- batch_size=1
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, it is impossible to support track_running_stats=False
with batch_size=1
. The current implementation does simply not do any normalization which is probably not desired. I would simply error out in this case TBH.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm yeah I guess that makes the most sense to me. Its a bit ambiguous but I haven't got a concrete use case for supporting this either so have updated based on your suggestion.
@@ -30,18 +30,41 @@ class BatchNorm(torch.nn.Module): | |||
:obj:`False`, this module does not track such statistics and always | |||
uses batch statistics in both training and eval modes. | |||
(default: :obj:`True`) | |||
allow_no_batch (bool, optional): If set to :obj:`True`, batches with | |||
only a single element will work as though in training mode. That is |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
only a single element will work as though in training mode. That is | |
only a single element will work as in evaluation. That is |
allow_no_batch (bool, optional): If set to :obj:`True`, batches with | ||
only a single element will work as though in training mode. That is | ||
the running mean and variance will be used. | ||
(default: :obj:`False`) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(default: :obj:`False`) | |
Requires :obj:`track_running_stats=True`. (default: :obj:`False`) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm just saw this. Sorry I'm not clear on why this is the desired behaviour?
I missed this @EdisonLeeeee. What would be the reason to do this? |
0fa22ad
to
2709908
Compare
@Padarn I thought this would avoid the error when batch size 1 while keeping |
Oh I see. I'm not sure how well this would work because the variance would still be zero for the batch and so you would have a division by zero. |
Interestingly, it worked without any errors. It seems that PyTorch is able to handle such cases properly except that batch size =1. |
Oh interesting. Well yes we could do this instead. I'm not really sure if
either is better
…On Thu, 29 Sep 2022, 10:10 pm Jintang Li, ***@***.***> wrote:
Interestingly, it worked without any errors. It seems that PyTorch is able
to handle such cases properly except that batch size =1.
—
Reply to this email directly, view it on GitHub
<#5530 (comment)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AAGRPN3APYTXNACMPCJ3NU3WAWPN3ANCNFSM6AAAAAAQVCWQNM>
.
You are receiving this because you were mentioned.Message ID:
***@***.***>
--
By communicating with Grab Holdings Limited and/or its subsidiaries,
associate companies and jointly controlled entities (collectively, “Grab”),
you are deemed to have consented to the processing of your personal data as
set out in the Privacy Notice which can be viewed at
https://grab.com/privacy/ <https://grab.com/privacy/>
This email
contains confidential information that may be privileged and is only for
the intended recipient(s). If you are not the intended recipient(s), please
do not disseminate, distribute or copy this email. Please notify Grab
immediately if you have received this by mistake and delete this email from
your system. Email transmission may not be secure or error-free as any
information could be intercepted, corrupted, lost, destroyed, delayed or
incomplete, or contain viruses. Grab does not accept liability for any
errors or omissions in this email that arise as a result of email
transmission. All intellectual property rights in this email and any
attachments shall remain vested in Grab, unless otherwise provided by law
|
BatchNorm
to handle batches of size one.BatchNorm
to handle batches of size one
* add skip for batch norm * add changelog * device of default mean/var * device of default mean/var * device of default mean/var * device of default mean/var * update naming for new arugment * require track running * Update torch_geometric/nn/norm/batch_norm.py * Update torch_geometric/nn/norm/batch_norm.py * update Co-authored-by: Matthias Fey <matthias.fey@tu-dortmund.de>
The purpose of this PR is to provide a way for
BatchNorm
to work even when the batch size is one. This comes up when training heterogeneous graphs with rare node types.This PR addresses #5529.