Here is the official PyTorch implementation of OGM-GE proposed in ''Balanced Multimodal Learning via On-the-fly Gradient Modulation'', which is a flexible plug-in module to enhance the optimization process of multimodal learning. Please refer to our CVPR 2022 paper for more details.
Paper Title: "Balanced Multimodal Learning via On-the-fly Gradient Modulation"
Authors: Xiaokang Peng*, Yake Wei*, Andong Deng, Dong Wang and Di Hu
Accepted by: IEEE Conference on Computer Vision and Pattern Recognition(CVPR 2022, Oral Presentation)
[arXiv] [Supplementary Material]
We observe that the potential of multimodal information is not fully exploited even when the multimodal model outperforms its uni-modal counterpart. We conduct linear probing experiments to explore the quality of jointly trained encoders, and find them under-optimized (the yellow line) compared with the uni-modal model (the red line). We proposed the OGM-GE method to improve the optimization process adaptively and achieved consistent improvement (the blue line). We improve both the multimodal performance and uni-model representation as shown in the following figure.
Pipeline of our OGM-GE method, consisting of two submodules:
- On-the-fly Gradient Modulation (OGM), which is designed to adaptively balance the training between modalities;
- Adaptive Gaussian noise Enhancement (GE), which restores the gradient intensity and brings generalization.
- Ubuntu 16.04
- CUDA Version: 11.1
- PyTorch 1.8.1
- torchvision 0.9.1
- python 3.7.6
Download Original Dataset: CREMA-D, AVE, VGGSound, Kinetics-Sounds.
For example, we provide code to pre-process videos into RGB frames and audio wav files in directory pre-process/
. The pre-processed data can be obtained by running:
python pre-processing/obtain_audio_spectrogram.py
and
python pre-processing/obtain_frames.py
.
After downloading and processing data, you should build the data directory following proper structure. Take AVE for example:
AVE
│------ visual
│---------sample1
│------------frame1
│------------frame2
│------ audio
│---------sample1.wav
│---------sample2.wav
Our proposed OGM-GE can work as a simple but useful plugin for some widely used multimodal fusion frameworks. We dispaly the core code part as following:
import torch
---in training step---
# Out_a, out_v are calculated to estimate the performance of 'a' and 'v' modality.
x, y, out = model(spec.unsqueeze(1).float(), image.float(), label, iteration)
out_v = (torch.mm(x,torch.transpose(model.module.fc_.weight[:,:512],0,1)) + model.module.fc_.bias/2)
out_a = (torch.mm(y,torch.transpose(model.module.fc_.weight[:,512:],0,1)) + model.module.fc_.bias/2)
loss = criterion(out, label)
# Calculate original gradient first
loss.backward()
# Calculation of discrepancy ration and k.
k_a,k_v = calculate_coefficient(label, out_a, out_v)
# Gradient Modulation begins before optimization, and with GE applied.
update_model_with_OGM_GE(model, k_a, k_v)
# Optimize the modulated parameters.
optimizer.step()
---continue for next training step---
The difference between main.py
and main_old.py
is as follows:
In main.py
, we consider adaptive imbalance during the whole optimization according to the Eq.10 in our paper.
In main_old.py
, the auditory modality is viewed as dominant by default.
--modulation OGM_GE --modulation_starts 0 --modulation_ends 50 --fusion_method concat --alpha 0.5
You can train your model simply by running:
python main.py --dataset VGGSound --train
.
You can also adapt to your own setting by adding additional arguments, for example, if you want to train our model on CREMA-D dataset, with gated fusion method and only OGM (i.e., without GE), and try to modulate the gradient from epoch 20 to epoch 80, you can run the following command:
train.py --train --dataset CREMAD --fusion_method gated --modulation OGM --modulation_starts 20 --modulation_ends 80 --alpha 0.3
.
You can test the performance of trained model by simply running
python main.py --ckpt_path /PATH-to-trained-ckpt
KineticSound and VGGSound: Coming Soon
As shown in above picture, 'playing guitar' is a class that audio surpasses visual modality for most samples ('shovelling show' is just opposite), and we can tell audio achieves more adequate training and leads the optimization process. Our OGM-GE (as well as OGM) gains improvement in both modalties as well as multimodal performance, and the weak visual gains more porfit. The evaluation metric used in 'audio' and 'visual' is the predicted accuracy with classification scores just from one specific modality.If you find this work useful, please consider citing it.
@inproceedings{Peng2022Balanced,
title = {Balanced Multimodal Learning via On-the-fly Gradient Modulation},
author = {Peng, Xiaokang and Wei, Yake and Deng, Andong and Wang, Dong and Hu, Di},
booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
year = {2022}
}
This research was supported by Public Computing Cloud, Renmin University of China.
This project is released under the GNU General Public License v3.0.
If you have any detailed questions or suggestions, you can email us: yakewei@ruc.edu.cn and andongdeng69@gmail.com