Table of Contents
- The proposed WeakTr fully explores the potential of plain ViT in the WSSS domain. State-of-the-art results are achieved on both challenging WSSS benchmarks, with 74.0% mIoU on PASCAL VOC 2012 and 46.9% on COCO 2014 validation sets respectively, significantly surpassing previous methods.
- The proposed WeakTr
based on the improved ViT pretrained on ImageNet-21k and fine-tuned on ImageNet-1k performs better with 78.4% mIoU on PASCAL VOC 2012 and 50.3% on COCO 2014 validation sets respectively.
This paper explores the properties of the plain Vision Transformer (ViT) for Weakly-supervised Semantic Segmentation (WSSS). The class activation map (CAM) is of critical importance for understanding a classification network and launching WSSS. We observe that different attention heads of ViT focus on different image areas. Thus a novel weight-based method is proposed to end-to-end estimate the importance of attention heads, while the self-attention maps are adaptively fused for high-quality CAM results that tend to have more complete objects. Besides, we propose a ViT-based gradient clipping decoder for online retraining with the CAM results to complete the WSSS task. We name this plain Transformer-based Weakly-supervised learning framework WeakTr. It achieves the state-of-the-art WSSS performance on standard benchmarks, i.e., 78.4% mIoU on the val set of PASCAL VOC 2012 and 50.3% mIoU on the val set of COCO 2014.
Step1: End-to-End CAM Generation
Step2: Online Retraining with Gradient Clipping Decoder
conda create --name weaktr python=3.7
conda activate weaktr
pip install torch==1.8.0+cu111 torchvision==0.9.0+cu111 --extra-index-url https://download.pytorch.org/whl/cu111
pip install -r requirements.txt
Then, install mmcv==1.4.0 and mmsegmentation following the official instruction.
pip install -U openmim
mim install mmcv-full==1.4.0
pip install mmsegmentation==0.30.0
And install pydensecrf
from source.
pip install git+https://github.com/lucasb-eyer/pydensecrf.git
Pascal VOC 2012
- First download the Pascal VOC 2012 datasets use the scripts in the
data
dir.
cd data
sh download_and_convert_voc12.sh
- Then download SBD annotations from here.
The folder structure is assumed to be:
- data
- download_and_convert_voc12.sh
+ voc12
+ VOCdevkit
+ VOC2012
+ JPEGImages
+ SegmentationClass
+ SegmentationClassAug
- voc12
- cls_labels.npy
- train_aug_id.txt
- train_id.txt
- val_id.txt
COCO 2014
- First download the COCO 2014 datasets use the scripts in the
data
dir.
cd data
sh download_and_convert_coco.sh
The folder structure is assumed to be:
- data
- download_and_convert_coco.sh
- voc12
+ coco
+ images
+ voc_format
+ class_labels
+ train.txt
+ val.txt
- coco
- cls_labels.npy
- train_id.txt
- val_id.txt
# Training
python main.py --model deit_small_WeakTr_patch16_224 \
--batch-size 64 \
--data-set VOC12 \
--img-list voc12 \
--img-ms-list voc12/train_id.txt \
--gt-dir SegmentationClass \
--scales 1.0 \
--cam-npy-dir $your_cam_dir \
--visualize-cls-attn \
--patch-attn-refine \
--data-path data/voc12 \
--output_dir $your_output_dir \
--finetune https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth \
--if_eval_miou \
--lr 4e-4 \
--seed 504 \
--extra-token
# Generate CAM
python main.py --model deit_small_WeakTr_patch16_224 \
--data-set VOC12MS \
--scales 1.0 0.8 1.2 \
--img-list voc12 \
--data-path data/voc12 \
--img-ms-list voc12/train_aug_id.txt \
--gt-dir SegmentationClassAug \
--output_dir $your_model_dir \
--resume $your_checkpoint_path \
--gen_attention_maps \
--attention-type fused \
--visualize-cls-attn \
--patch-attn-refine \
--cam-npy-dir $your_CAM_npy_dir \
# CRF post-processing
python evaluation.py --list voc12/train_aug_id.txt \
--data-path data/voc12 \
--gt_dir SegmentationClassAug \
--img_dir JPEGImages \
--type npy \
--t 42 \
--predict_dir $your_CAM_npy_dir \
--out-crf \
--out-dir $your_CAM_label_dir \
We store the best checkpoint of CAM generation and the CAM label for Online Retraining in Google Drive , the mIoU of the CAM label is 69% in the trainset.
cd OnlineRetraining
DATASET=$your_dataset_path WORK=$your_project_path python -m segm.train \
--log-dir $your_log_dir \
--dataset pascal_voc --backbone $your_model_name --decoder mask_transformer \
--batch-size 4 --epochs 100 -lr 1e-4 \
--num-workers 2 --eval-freq 1 \
--ann-dir $your_CAM_label_dir \
--start-value 1.2 --patch-size 120 \
cd OnlineRetraining
- Multi-scale Evaluation
MASTER_PORT=10201 DATASET=$your_dataset_path PYTHONPATH=. WORK=$your_project_path python segm/eval/miou.py --window-batch-size 1 --multiscale \
$your_checkpoint_path \
--predict-dir $your_pred_npy_dir \
pascal_voc
- CRF post-processing
python -m segm.eval.make_crf \
--list ImageSets/Segmentation/val.txt \
--data-path ../data/voc12 \
--predict-dir $your_pred_npy_dir \
--predict-png-dir $your_pred_png_dir \
--img-path JPEGImages \
--gt-folder SegmentationClassAug \
--num-cls 21 \
--dataset voc12
- Evaluation
python -m segm.eval.make_crf \
--list ImageSets/Segmentation/val.txt \
--data-path ../data/voc12 \
--predict-dir $your_pred_crf_dir \
--type png \
--img-path JPEGImages \
--gt-folder SegmentationClassAug \
--num-cls 21 \
--dataset voc12
Dataset | Checkpoint | CAM_Label | Train mIoU |
---|---|---|---|
Pascal VOC 2012 | Google Drive | Google Drive | 69.0% |
COCO 2014 | Google Drive | Google Drive | 41.9% |
Dataset | Method | Checkpoint | Val mIoU | Pseudo-mask | Train mIoU |
---|---|---|---|---|---|
Pascal VOC 2012 | WeakTr | Google Drive | 74.0% | Google Drive | 76.5% |
Pascal VOC 2012 | WeakTr |
Google Drive | 78.4% | Google Drive | 80.3% |
COCO 2014 | WeakTr | Google Drive | 46.9% | Google Drive | 48.9% |
COCO 2014 | WeakTr |
Google Drive | 50.3% | Google Drive | 51.3% |
If you find this repository/work helpful in your research, welcome to cite the paper and give a ⭐.
@article{zhu2023weaktr,
title={WeakTr: Exploring Plain Vision Transformer for Weakly-supervised Semantic Segmentation},
author={Lianghui Zhu and Yingyue Li and Jieming Fang and Yan Liu and Hao Xin and Wenyu Liu and Xinggang Wang},
year={2023},
journal={arxiv:2304.01184},
}