-
-
Notifications
You must be signed in to change notification settings - Fork 16.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
A new activation function ACON that is very simple and effective !! #2891
Comments
👋 Hello @nmaac, thank you for your interest in 🚀 YOLOv5! Please visit our ⭐️ Tutorials to get started, where you can find quickstart guides for simple tasks like Custom Data Training all the way to advanced concepts like Hyperparameter Evolution. If this is a 🐛 Bug Report, please provide screenshots and minimum viable code to reproduce your issue, otherwise we can not help you. If this is a custom training ❓ Question, please provide as much information as possible, including dataset images, training logs, screenshots, and a public link to online W&B logging if available. For business inquiries or professional support requests please visit https://www.ultralytics.com or email Glenn Jocher at glenn.jocher@ultralytics.com. RequirementsPython 3.8 or later with all requirements.txt dependencies installed, including $ pip install -r requirements.txt EnvironmentsYOLOv5 may be run in any of the following up-to-date verified environments (with all dependencies including CUDA/CUDNN, Python and PyTorch preinstalled):
StatusIf this badge is green, all YOLOv5 GitHub Actions Continuous Integration (CI) tests are currently passing. CI tests verify correct operation of YOLOv5 training (train.py), testing (test.py), inference (detect.py) and export (export.py) on MacOS, Windows, and Ubuntu every 24 hours and on every commit. |
@nmaac thanks for the idea, looks promising! Any object detection results so far? |
@nmaac ah great, thank you! Yes this is quite a significant improvement in your Table 9. Which ACON version would you recommend we try, and what values for p1, p2, Beta?
The right place to include a new activation would be utils/activations, and then the place to swap out nn.SiLU() for a new activation is here on L39 of models/common.py Lines 33 to 43 in d48a34d
|
I would like to suggest ACON-C, which improves accuracy without a negligible overhead. You can use the code directly: https://github.com/nmaac/acon/blob/8782b65f5d7b3523f656beceb586b54d04019705/acon.py#L4-19 |
@nmaac @ilem777 I've added AconC to our activations study here: I just started runs with AconC(), MetaAconC() and FReLU(), you can track their progress live at the link above. Training time will be about 3 days. I tried MetaAconC but ran into issues. The nn.batchnorm2d(16) layers produced errors on inputs of size (1,16,1,1), perhaps I implemented the function incorrectly. |
@AyushExel I spotted something concerning I was hoping you could look at. When runs are public, like the activation study above, the 'stop run' button appears to work even when the visitor is incognito / no signin. |
@glenn-jocher thanks for reporting this. I'll check if the button for non-authorized users actually stops the runs or not. If it does then it's a very bad bug otherwise it's just a minor frontend bug. I'll file a ticket for this to get fixed |
it because nn.batchnorm2d need batch size > 1 when training. m.stride = torch.tensor([s / x.shape[-2] for x in self.forward(torch.zeros(2, ch, s, s))]) # forward |
@glenn-jocher you can simply remove the two bn |
@nmaac oh, I think I misunderstood before. I think you mean to remove self.bn1 and self.bn2 completely from the MetaAconC() module for all batch-sizes? |
@WongKinYiu yes, this is a good solution too, though will make model creation a bit slower for all other models. The nn.batchnorm2d() layers are ok for batch-size 1 inference? |
@WongKinYiu @nmaac I'm curious, looking at the ACON implementation have you guys tried simply training with SiLU with Beta? I've never done this before. nn.SiLU() does not allow this but I think I might try testing this using a custom SiLU to see how this affects the results. |
nn.batchnorm2d() layers can do batch-size 1 inference. |
@glenn-jocher SiLU with beta does not show benefits, in the paper Swish-1 and Swish show comparable results when set beta=1, specifically, Therefore meta-ACON uses an explicitly way to learn beta which show the improvements. |
@nmaac understood, thanks for explanation. I had to completely remove the BN layers from MetaAconC otherwise instabilities appeared in the training (two 'STOPPED' runs below). Results should be done in about a day, but based on the current trends it doesn't initially seem like I was able to produce better results with either AconC or MetaAconC. The best performing activation in the study by far was FReLU, though this should be taken with a grain of salt as FReLU is really blurring the lines between an activation and a convolution layer. Due to the added parameters and FLOPS I would also assume FReLU would disproportionately improve smaller models like YOLOv5s, with unclear correlation to improving larger models like YOLOv5x6, which may necessitate a second study in the future. |
@glenn-jocher Yes the curves show comparable results. Which activations did you change? Did you pre-train the backbone? In my experiments I usually change the activations in the backbone and pre-train the backbone on ImageNet first. I can help with the pre-training if needed :) |
i think the main reason is that @glenn-jocher forget to add |
@nmaac well that's a good question, should the activation function parameters be exempt from weight decay? We use the following parameter groups to exempt .bias parameters and BatchNorm layers from weight decay, so at the moment only the fc1, fc2 biases are exempt from decay. Lines 115 to 123 in 1849916
The activation function implementations are all in utils/activations: Lines 58 to 98 in 1849916
The activations_study branch used in this study replaces all activations in the YOLOv5 model by redefining Lines 34 to 55 in c9c95fb
All YOLOv5 models are trained from scratch to 300 epochs using all default settings. Training commands are shown in the W&B link to reproduce (COCO dataset autodownloads):
|
@glenn-jocher |
@developer0hye well I'm not sure. The ACON authors @nmaac didn't answer my question of whether we should exempt some of the ACON parameters from weight decay. The current results are here for all the activations on YOLOv5s: https://wandb.ai/glenn-jocher/activations |
@glenn-jocher @developer0hye In my experiments, the weight decay setting does not affect the results very much. But I suggest try another initialization approach:
|
in my experiments:
and old initial with decay drops 0.2% AP |
@nmaac @WongKinYiu got it, thanks guys! |
👋 Hello, this issue has been automatically marked as stale because it has not had recent activity. Please note it will be closed if no further activity occurs. Access additional YOLOv5 🚀 resources:
Access additional Ultralytics ⚡ resources:
Feel free to inform us of any other issues you discover or feature requests that come to mind in the future. Pull Requests (PRs) are also always welcomed! Thank you for your contributions to YOLOv5 🚀 and Vision AI ⭐! |
Any updates on this? How's ACON |
🚀 Feature
There is a new activation function ACON (CVPR 2021) that unifies ReLU and Swish.
ACON is simple but very effective, code is here: https://github.com/nmaac/acon/blob/main/acon.py#L19
The improvements are very significant:
Motivation
Pitch
I would like to suggest replacing SiLU with ACON directly because SiLU (Swish) is used in your project, its general and effective form ACON may also show improvements.
Alternatives
It also has an enhanced version meta-ACON that uses a small network to learn beta explicitly, which may influence the speed a bit.
Additional context
Code and paper.
The text was updated successfully, but these errors were encountered: