From 933e4d3cb6753e569fa8ac81a5e637d9414ce1f0 Mon Sep 17 00:00:00 2001 From: MengzhangLI Date: Thu, 1 Dec 2022 19:03:10 +0800 Subject: [PATCH] [Feature] Support MaskFormer(NeurIPS'2021) in MMSeg 1.x (#2215) * [Feature] Support MaskFormer(NeurIPS'2021) in MMSeg 1.x * add mmdet try except logic * refactor config files * add readme * fix config * update models & logs * add MMDET installation and fix info * fix comments * fix * fix config norm optimizer setting * update models & logs & unittest * add docstring of MaskFormerHead * wait for mmdet 3.0.0rc4 * replace seg_mask with seg_logits & add docstring for batch_input_shape * use mmdet3.0.0rc4 * fix readme and modify config comments * add mmdet installation in pr_stage_test.yml * update mmcv version in pr_stage_test.yml * add mmdet in build_cpu of pr_stage_test.yml * modify mmdet& mmcv installation in merge_stage_test.yml * fix typo * update test.yml * update test.yml --- .circleci/test.yml | 9 +- .github/workflows/merge_stage_test.yml | 12 +- .github/workflows/pr_stage_test.yml | 9 +- README.md | 1 + README_zh-CN.md | 1 + configs/maskformer/README.md | 60 ++++++ configs/maskformer/maskformer.yml | 101 ++++++++++ ...ormer_r101-d32_8xb2-160k_ade20k-512x512.py | 7 + ...former_r50-d32_8xb2-160k_ade20k-512x512.py | 143 +++++++++++++++ ...swin-s_upernet_8xb2-160k_ade20k-512x512.py | 79 ++++++++ ...swin-t_upernet_8xb2-160k_ade20k-512x512.py | 81 ++++++++ mmseg/models/decode_heads/__init__.py | 3 +- mmseg/models/decode_heads/maskformer_head.py | 173 ++++++++++++++++++ model-index.yml | 1 + requirements/mminstall.txt | 1 + .../test_heads/test_maskformer_head.py | 54 ++++++ 16 files changed, 724 insertions(+), 11 deletions(-) create mode 100644 configs/maskformer/README.md create mode 100644 configs/maskformer/maskformer.yml create mode 100644 configs/maskformer/maskformer_r101-d32_8xb2-160k_ade20k-512x512.py create mode 100644 configs/maskformer/maskformer_r50-d32_8xb2-160k_ade20k-512x512.py create mode 100644 configs/maskformer/maskformer_swin-s_upernet_8xb2-160k_ade20k-512x512.py create mode 100644 configs/maskformer/maskformer_swin-t_upernet_8xb2-160k_ade20k-512x512.py create mode 100644 mmseg/models/decode_heads/maskformer_head.py create mode 100644 tests/test_models/test_heads/test_maskformer_head.py diff --git a/.circleci/test.yml b/.circleci/test.yml index 76f9f70d8f..c8acb48293 100644 --- a/.circleci/test.yml +++ b/.circleci/test.yml @@ -61,8 +61,9 @@ jobs: command: | pip install git+https://github.com/open-mmlab/mmengine.git@main pip install -U openmim - mim install 'mmcv>=2.0.0rc1' + mim install 'mmcv>=2.0.0rc3' pip install git+https://github.com/open-mmlab/mmclassification@dev-1.x + pip install git+https://github.com/open-mmlab/mmdetection.git@dev-3.x pip install -r requirements/tests.txt -r requirements/optional.txt - run: name: Build and install @@ -96,18 +97,20 @@ jobs: command: | git clone -b main --depth 1 https://github.com/open-mmlab/mmengine.git /home/circleci/mmengine git clone -b dev-1.x --depth 1 https://github.com/open-mmlab/mmclassification.git /home/circleci/mmclassification + git clone -b dev-3.x --depth 1 https://github.com/open-mmlab/mmdetection.git /home/circleci/mmdetection - run: name: Build Docker image command: | docker build .circleci/docker -t mmseg:gpu --build-arg PYTORCH=<< parameters.torch >> --build-arg CUDA=<< parameters.cuda >> --build-arg CUDNN=<< parameters.cudnn >> - docker run --gpus all -t -d -v /home/circleci/project:/mmseg -v /home/circleci/mmengine:/mmengine -v /home/circleci/mmclassification:/mmclassification -w /mmseg --name mmseg mmseg:gpu + docker run --gpus all -t -d -v /home/circleci/project:/mmseg -v /home/circleci/mmengine:/mmengine -v /home/circleci/mmclassification:/mmclassification -v /home/circleci/mmdetection:/mmdetection -w /mmseg --name mmseg mmseg:gpu - run: name: Install mmseg dependencies command: | docker exec mmseg pip install -e /mmengine docker exec mmseg pip install -U openmim - docker exec mmseg mim install 'mmcv>=2.0.0rc1' + docker exec mmseg mim install 'mmcv>=2.0.0rc3' docker exec mmseg pip install -e /mmclassification + docker exec mmseg pip install -e /mmdetection docker exec mmseg pip install -r requirements/tests.txt -r requirements/optional.txt - run: name: Build and install diff --git a/.github/workflows/merge_stage_test.yml b/.github/workflows/merge_stage_test.yml index 42a9dc0c46..b4a4a4424d 100644 --- a/.github/workflows/merge_stage_test.yml +++ b/.github/workflows/merge_stage_test.yml @@ -44,8 +44,9 @@ jobs: python -V pip install -U openmim pip install git+https://github.com/open-mmlab/mmengine.git - mim install 'mmcv>=2.0.0rc1' + mim install 'mmcv>=2.0.0rc3' pip install git+https://github.com/open-mmlab/mmclassification.git@dev-1.x + pip install git+https://github.com/open-mmlab/mmdetection.git@dev-3.x - name: Install unittest dependencies run: pip install -r requirements/tests.txt -r requirements/optional.txt - name: Build and install @@ -92,8 +93,9 @@ jobs: python -V pip install -U openmim pip install git+https://github.com/open-mmlab/mmengine.git - mim install 'mmcv>=2.0.0rc1' + mim install 'mmcv>=2.0.0rc3' pip install git+https://github.com/open-mmlab/mmclassification.git@dev-1.x + pip install git+https://github.com/open-mmlab/mmdetection.git@dev-3.x - name: Install unittest dependencies run: pip install -r requirements/tests.txt -r requirements/optional.txt - name: Build and install @@ -155,8 +157,9 @@ jobs: python -V pip install -U openmim pip install git+https://github.com/open-mmlab/mmengine.git - mim install 'mmcv>=2.0.0rc1' + mim install 'mmcv>=2.0.0rc3' pip install git+https://github.com/open-mmlab/mmclassification.git@dev-1.x + pip install git+https://github.com/open-mmlab/mmdetection.git@dev-3.x - name: Install unittest dependencies run: pip install -r requirements/tests.txt -r requirements/optional.txt - name: Build and install @@ -187,8 +190,9 @@ jobs: python -V pip install -U openmim pip install git+https://github.com/open-mmlab/mmengine.git - mim install 'mmcv>=2.0.0rc1' + mim install 'mmcv>=2.0.0rc3' pip install git+https://github.com/open-mmlab/mmclassification.git@dev-1.x + pip install git+https://github.com/open-mmlab/mmdetection.git@dev-3.x - name: Install unittest dependencies run: pip install -r requirements/tests.txt -r requirements/optional.txt - name: Build and install diff --git a/.github/workflows/pr_stage_test.yml b/.github/workflows/pr_stage_test.yml index 30e50a962d..302c4689f9 100644 --- a/.github/workflows/pr_stage_test.yml +++ b/.github/workflows/pr_stage_test.yml @@ -40,8 +40,9 @@ jobs: run: | pip install -U openmim pip install git+https://github.com/open-mmlab/mmengine.git - mim install 'mmcv>=2.0.0rc1' + mim install 'mmcv>=2.0.0rc3' pip install git+https://github.com/open-mmlab/mmclassification.git@dev-1.x + pip install git+https://github.com/open-mmlab/mmdetection.git@dev-3.x - name: Install unittest dependencies run: pip install -r requirements/tests.txt -r requirements/optional.txt - name: Build and install @@ -92,8 +93,9 @@ jobs: python -V pip install -U openmim pip install git+https://github.com/open-mmlab/mmengine.git - mim install 'mmcv>=2.0.0rc1' + mim install 'mmcv>=2.0.0rc3' pip install git+https://github.com/open-mmlab/mmclassification.git@dev-1.x + pip install git+https://github.com/open-mmlab/mmdetection.git@dev-3.x - name: Install unittest dependencies run: pip install -r requirements/tests.txt -r requirements/optional.txt - name: Build and install @@ -124,8 +126,9 @@ jobs: python -V pip install -U openmim pip install git+https://github.com/open-mmlab/mmengine.git - mim install 'mmcv>=2.0.0rc1' + mim install 'mmcv>=2.0.0rc3' pip install git+https://github.com/open-mmlab/mmclassification.git@dev-1.x + pip install git+https://github.com/open-mmlab/mmdetection.git@dev-3.x - name: Install unittest dependencies run: pip install -r requirements/tests.txt -r requirements/optional.txt - name: Build and install diff --git a/README.md b/README.md index 0930f5c27a..378b2ab08c 100644 --- a/README.md +++ b/README.md @@ -139,6 +139,7 @@ Supported methods: - [x] [Segmenter (ICCV'2021)](configs/segmenter) - [x] [SegFormer (NeurIPS'2021)](configs/segformer) - [x] [K-Net (NeurIPS'2021)](configs/knet) +- [x] [MaskFormer (NeurIPS'2021)](configs/maskformer) Supported datasets: diff --git a/README_zh-CN.md b/README_zh-CN.md index 9d9f5716f2..b66977035f 100644 --- a/README_zh-CN.md +++ b/README_zh-CN.md @@ -134,6 +134,7 @@ MMSegmentation 是一个基于 PyTorch 的语义分割开源工具箱。它是 O - [x] [Segmenter (ICCV'2021)](configs/segmenter) - [x] [SegFormer (NeurIPS'2021)](configs/segformer) - [x] [K-Net (NeurIPS'2021)](configs/knet) +- [x] [MaskFormer (NeurIPS'2021)](configs/maskformer) 已支持的数据集: diff --git a/configs/maskformer/README.md b/configs/maskformer/README.md new file mode 100644 index 0000000000..5e33d17afb --- /dev/null +++ b/configs/maskformer/README.md @@ -0,0 +1,60 @@ +# MaskFormer + +[MaskFormer: Per-Pixel Classification is Not All You Need for Semantic Segmentation](https://arxiv.org/abs/2107.06278) + +## Introduction + + + +Official Repo + +Code Snippet + +## Abstract + + + +Modern approaches typically formulate semantic segmentation as a per-pixel classification task, while instance-level segmentation is handled with an alternative mask classification. Our key insight: mask classification is sufficiently general to solve both semantic- and instance-level segmentation tasks in a unified manner using the exact same model, loss, and training procedure. Following this observation, we propose MaskFormer, a simple mask classification model which predicts a set of binary masks, each associated with a single global class label prediction. Overall, the proposed mask classification-based method simplifies the landscape of effective approaches to semantic and panoptic segmentation tasks and shows excellent empirical results. In particular, we observe that MaskFormer outperforms per-pixel classification baselines when the number of classes is large. Our mask classification-based method outperforms both current state-of-the-art semantic (55.6 mIoU on ADE20K) and panoptic segmentation (52.7 PQ on COCO) models. + + + +
+ +
+ +```bibtex +@article{cheng2021per, + title={Per-pixel classification is not all you need for semantic segmentation}, + author={Cheng, Bowen and Schwing, Alex and Kirillov, Alexander}, + journal={Advances in Neural Information Processing Systems}, + volume={34}, + pages={17864--17875}, + year={2021} +} +``` + +### Usage + +- MaskFormer model needs to install [MMDetection](https://github.com/open-mmlab/mmdetection) first. + +```shell +pip install "mmdet>=3.0.0rc4" +``` + +## Results and models + +### ADE20K + +| Method | Backbone | Crop Size | Lr schd | Mem (GB) | Inf time (fps) | mIoU | mIoU(ms+flip) | config | download | +| ---------- | --------- | --------- | ------- | -------- | -------------- | ----- | ------------- | -------------------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | +| MaskFormer | R-50-D32 | 512x512 | 160000 | 3.29 | 42.20 | 44.29 | - | [config](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/configs/maskformer/maskformer_r50-d32_8xb2-160k_ade20k-512x512.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/maskformer/maskformer_r50-d32_8xb2-160k_ade20k-512x512/maskformer_r50-d32_8xb2-160k_ade20k-512x512_20221030_182724-cbd39cc1.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/maskformer/maskformer_r50-d32_8xb2-160k_ade20k-512x512/maskformer_r50-d32_8xb2-160k_ade20k-512x512_20221030_182724.json) | +| MaskFormer | R-101-D32 | 512x512 | 160000 | 4.12 | 34.90 | 45.11 | - | [config](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/configs/maskformer/maskformer_r101-d32_8xb2-160k_ade20k-512x512.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/maskformer/maskformer_r101-d32_8xb2-160k_ade20k-512x512/maskformer_r101-d32_8xb2-160k_ade20k-512x512_20221031_223053-c8e0931d.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/maskformer/maskformer_r101-d32_8xb2-160k_ade20k-512x512/maskformer_r101-d32_8xb2-160k_ade20k-512x512_20221031_223053.json) | +| MaskFormer | Swin-T | 512x512 | 160000 | 3.73 | 40.53 | 46.69 | - | [config](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/configs/maskformer/maskformer_swin-t_upernet_8xb2-160k_ade20k-512x512.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/maskformer/maskformer_swin-t_upernet_8xb2-160k_ade20k-512x512/maskformer_swin-t_upernet_8xb2-160k_ade20k-512x512_20221114_232813-03550716.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/maskformer/maskformer_swin-t_upernet_8xb2-160k_ade20k-512x512/maskformer_swin-t_upernet_8xb2-160k_ade20k-512x512_20221114_232813.json) | +| MaskFormer | Swin-S | 512x512 | 160000 | 5.33 | 26.98 | 49.36 | - | [config](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/configs/maskformer/maskformer_swin-s_upernet_8xb2-160k_ade20k-512x512.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/maskformer/maskformer_swin-s_upernet_8xb2-160k_ade20k-512x512/maskformer_swin-s_upernet_8xb2-160k_ade20k-512x512_20221115_114710-5ab67e58.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/maskformer/maskformer_swin-s_upernet_8xb2-160k_ade20k-512x512/maskformer_swin-s_upernet_8xb2-160k_ade20k-512x512_20221115_114710.json) | + +Note: + +- All experiments of MaskFormer are implemented with 8 V100 (32G) GPUs with 2 samplers per GPU. +- The results of MaskFormer are relatively not stable. The accuracy (mIoU) of model with `R-101-D32` is from 44.7 to 46.0, and with `Swin-S` is from 49.0 to 49.8. +- The ResNet backbones utilized in MaskFormer models are standard `ResNet` rather than `ResNetV1c`. +- Test time augmentation is not supported in MMSegmentation 1.x version yet, we would add "ms+flip" results as soon as possible. diff --git a/configs/maskformer/maskformer.yml b/configs/maskformer/maskformer.yml new file mode 100644 index 0000000000..1b3d398e34 --- /dev/null +++ b/configs/maskformer/maskformer.yml @@ -0,0 +1,101 @@ +Collections: +- Name: MaskFormer + Metadata: + Training Data: + - Usage + - ADE20K + Paper: + URL: https://arxiv.org/abs/2107.06278 + Title: 'MaskFormer: Per-Pixel Classification is Not All You Need for Semantic + Segmentation' + README: configs/maskformer/README.md + Code: + URL: https://github.com/open-mmlab/mmdetection/blob/dev-3.x/mmdet/models/dense_heads/maskformer_head.py#L21 + Version: dev-3.x + Converted From: + Code: https://github.com/facebookresearch/MaskFormer/ +Models: +- Name: maskformer_r50-d32_8xb2-160k_ade20k-512x512 + In Collection: MaskFormer + Metadata: + backbone: R-50-D32 + crop size: (512,512) + lr schd: 160000 + inference time (ms/im): + - value: 23.7 + hardware: V100 + backend: PyTorch + batch size: 1 + mode: FP32 + resolution: (512,512) + Training Memory (GB): 3.29 + Results: + - Task: Semantic Segmentation + Dataset: ADE20K + Metrics: + mIoU: 44.29 + Config: configs/maskformer/maskformer_r50-d32_8xb2-160k_ade20k-512x512.py + Weights: https://download.openmmlab.com/mmsegmentation/v0.5/maskformer/maskformer_r50-d32_8xb2-160k_ade20k-512x512/maskformer_r50-d32_8xb2-160k_ade20k-512x512_20221030_182724-cbd39cc1.pth +- Name: maskformer_r101-d32_8xb2-160k_ade20k-512x512 + In Collection: MaskFormer + Metadata: + backbone: R-101-D32 + crop size: (512,512) + lr schd: 160000 + inference time (ms/im): + - value: 28.65 + hardware: V100 + backend: PyTorch + batch size: 1 + mode: FP32 + resolution: (512,512) + Training Memory (GB): 4.12 + Results: + - Task: Semantic Segmentation + Dataset: ADE20K + Metrics: + mIoU: 45.11 + Config: configs/maskformer/maskformer_r101-d32_8xb2-160k_ade20k-512x512.py + Weights: https://download.openmmlab.com/mmsegmentation/v0.5/maskformer/maskformer_r101-d32_8xb2-160k_ade20k-512x512/maskformer_r101-d32_8xb2-160k_ade20k-512x512_20221031_223053-c8e0931d.pth +- Name: maskformer_swin-t_upernet_8xb2-160k_ade20k-512x512 + In Collection: MaskFormer + Metadata: + backbone: Swin-T + crop size: (512,512) + lr schd: 160000 + inference time (ms/im): + - value: 24.67 + hardware: V100 + backend: PyTorch + batch size: 1 + mode: FP32 + resolution: (512,512) + Training Memory (GB): 3.73 + Results: + - Task: Semantic Segmentation + Dataset: ADE20K + Metrics: + mIoU: 46.69 + Config: configs/maskformer/maskformer_swin-t_upernet_8xb2-160k_ade20k-512x512.py + Weights: https://download.openmmlab.com/mmsegmentation/v0.5/maskformer/maskformer_swin-t_upernet_8xb2-160k_ade20k-512x512/maskformer_swin-t_upernet_8xb2-160k_ade20k-512x512_20221114_232813-03550716.pth +- Name: maskformer_swin-s_upernet_8xb2-160k_ade20k-512x512 + In Collection: MaskFormer + Metadata: + backbone: Swin-S + crop size: (512,512) + lr schd: 160000 + inference time (ms/im): + - value: 37.06 + hardware: V100 + backend: PyTorch + batch size: 1 + mode: FP32 + resolution: (512,512) + Training Memory (GB): 5.33 + Results: + - Task: Semantic Segmentation + Dataset: ADE20K + Metrics: + mIoU: 49.36 + Config: configs/maskformer/maskformer_swin-s_upernet_8xb2-160k_ade20k-512x512.py + Weights: https://download.openmmlab.com/mmsegmentation/v0.5/maskformer/maskformer_swin-s_upernet_8xb2-160k_ade20k-512x512/maskformer_swin-s_upernet_8xb2-160k_ade20k-512x512_20221115_114710-5ab67e58.pth diff --git a/configs/maskformer/maskformer_r101-d32_8xb2-160k_ade20k-512x512.py b/configs/maskformer/maskformer_r101-d32_8xb2-160k_ade20k-512x512.py new file mode 100644 index 0000000000..04bd37546a --- /dev/null +++ b/configs/maskformer/maskformer_r101-d32_8xb2-160k_ade20k-512x512.py @@ -0,0 +1,7 @@ +_base_ = './maskformer_r50-d32_8xb2-160k_ade20k-512x512.py' + +model = dict( + backbone=dict( + depth=101, + init_cfg=dict(type='Pretrained', + checkpoint='torchvision://resnet101'))) diff --git a/configs/maskformer/maskformer_r50-d32_8xb2-160k_ade20k-512x512.py b/configs/maskformer/maskformer_r50-d32_8xb2-160k_ade20k-512x512.py new file mode 100644 index 0000000000..7d8f657221 --- /dev/null +++ b/configs/maskformer/maskformer_r50-d32_8xb2-160k_ade20k-512x512.py @@ -0,0 +1,143 @@ +_base_ = [ + '../_base_/datasets/ade20k.py', '../_base_/default_runtime.py', + '../_base_/schedules/schedule_160k.py' +] +norm_cfg = dict(type='SyncBN', requires_grad=True) +crop_size = (512, 512) +data_preprocessor = dict( + type='SegDataPreProcessor', + size=crop_size, + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True, + pad_val=0, + seg_pad_val=255) +# model_cfg +num_classes = 150 +model = dict( + type='EncoderDecoder', + data_preprocessor=data_preprocessor, + backbone=dict( + type='ResNet', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + dilations=(1, 1, 1, 1), + strides=(1, 2, 2, 2), + norm_cfg=norm_cfg, + norm_eval=True, + style='pytorch', + contract_dilation=True, + init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')), + decode_head=dict( + type='MaskFormerHead', + in_channels=[256, 512, 1024, + 2048], # input channels of pixel_decoder modules + feat_channels=256, + in_index=[0, 1, 2, 3], + num_classes=150, + out_channels=256, + num_queries=100, + pixel_decoder=dict( + type='mmdet.PixelDecoder', + norm_cfg=dict(type='GN', num_groups=32), + act_cfg=dict(type='ReLU')), + enforce_decoder_input_project=False, + positional_encoding=dict( + type='mmdet.SinePositionalEncoding', num_feats=128, + normalize=True), + transformer_decoder=dict( + type='mmdet.DetrTransformerDecoder', + return_intermediate=True, + num_layers=6, + transformerlayers=dict( + type='mmdet.DetrTransformerDecoderLayer', + attn_cfgs=dict( + type='mmdet.MultiheadAttention', + embed_dims=256, + num_heads=8, + attn_drop=0.1, + proj_drop=0.1, + dropout_layer=None, + batch_first=False), + ffn_cfgs=dict( + embed_dims=256, + feedforward_channels=2048, + num_fcs=2, + act_cfg=dict(type='ReLU', inplace=True), + ffn_drop=0.1, + dropout_layer=None, + add_identity=True), + # the following parameter was not used, + # just make current api happy + feedforward_channels=2048, + operation_order=('self_attn', 'norm', 'cross_attn', 'norm', + 'ffn', 'norm')), + init_cfg=None), + loss_cls=dict( + type='mmdet.CrossEntropyLoss', + use_sigmoid=False, + loss_weight=1.0, + reduction='mean', + class_weight=[1.0] * num_classes + [0.1]), + loss_mask=dict( + type='mmdet.FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + reduction='mean', + loss_weight=20.0), + loss_dice=dict( + type='mmdet.DiceLoss', + use_sigmoid=True, + activate=True, + reduction='mean', + naive_dice=True, + eps=1.0, + loss_weight=1.0), + train_cfg=dict( + assigner=dict( + type='mmdet.HungarianAssigner', + match_costs=[ + dict(type='mmdet.ClassificationCost', weight=1.0), + dict( + type='mmdet.FocalLossCost', + weight=20.0, + binary_input=True), + dict( + type='mmdet.DiceCost', + weight=1.0, + pred_act=True, + eps=1.0) + ]), + sampler=dict(type='mmdet.MaskPseudoSampler'))), + # training and testing settings + train_cfg=dict(), + test_cfg=dict(mode='whole'), +) +# optimizer +optimizer = dict( + type='AdamW', lr=0.0001, betas=(0.9, 0.999), weight_decay=0.0001) +optim_wrapper = dict( + _delete_=True, + type='OptimWrapper', + optimizer=optimizer, + clip_grad=dict(max_norm=0.01, norm_type=2), + paramwise_cfg=dict(custom_keys={ + 'backbone': dict(lr_mult=0.1), + })) +# learning policy +param_scheduler = [ + dict( + type='PolyLR', + eta_min=0, + power=0.9, + begin=0, + end=160000, + by_epoch=False) +] + +# In MaskFormer implementation we use batch size 2 per GPU as default +train_dataloader = dict(batch_size=2, num_workers=2) +val_dataloader = dict(batch_size=1, num_workers=4) +test_dataloader = val_dataloader diff --git a/configs/maskformer/maskformer_swin-s_upernet_8xb2-160k_ade20k-512x512.py b/configs/maskformer/maskformer_swin-s_upernet_8xb2-160k_ade20k-512x512.py new file mode 100644 index 0000000000..2cbc038ac2 --- /dev/null +++ b/configs/maskformer/maskformer_swin-s_upernet_8xb2-160k_ade20k-512x512.py @@ -0,0 +1,79 @@ +checkpoint_file = 'https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/swin/swin_small_patch4_window7_224_20220317-7ba6d6dd.pth' # noqa +_base_ = './maskformer_r50-d32_8xb2-160k_ade20k-512x512.py' +backbone_norm_cfg = dict(type='LN', requires_grad=True) +depths = [2, 2, 18, 2] +model = dict( + backbone=dict( + _delete_=True, + type='SwinTransformer', + pretrain_img_size=224, + embed_dims=96, + patch_size=4, + window_size=7, + mlp_ratio=4, + depths=depths, + num_heads=[3, 6, 12, 24], + strides=(4, 2, 2, 2), + out_indices=(0, 1, 2, 3), + qkv_bias=True, + qk_scale=None, + patch_norm=True, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.3, + use_abs_pos_embed=False, + act_cfg=dict(type='GELU'), + norm_cfg=backbone_norm_cfg, + init_cfg=dict(type='Pretrained', checkpoint=checkpoint_file)), + decode_head=dict( + type='MaskFormerHead', + in_channels=[96, 192, 384, + 768], # input channels of pixel_decoder modules + )) + +# optimizer +optimizer = dict( + type='AdamW', lr=0.00006, betas=(0.9, 0.999), weight_decay=0.01) +# set all layers in backbone to lr_mult=1.0 +# set all norm layers, position_embeding, +# query_embeding to decay_multi=0.0 +backbone_norm_multi = dict(lr_mult=1.0, decay_mult=0.0) +backbone_embed_multi = dict(lr_mult=1.0, decay_mult=0.0) +embed_multi = dict(decay_mult=0.0) +custom_keys = { + 'backbone': dict(lr_mult=1.0), + 'backbone.patch_embed.norm': backbone_norm_multi, + 'backbone.norm': backbone_norm_multi, + 'relative_position_bias_table': backbone_embed_multi, + 'query_embed': embed_multi, +} +custom_keys.update({ + f'backbone.stages.{stage_id}.blocks.{block_id}.norm': backbone_norm_multi + for stage_id, num_blocks in enumerate(depths) + for block_id in range(num_blocks) +}) +custom_keys.update({ + f'backbone.stages.{stage_id}.downsample.norm': backbone_norm_multi + for stage_id in range(len(depths) - 1) +}) +# optimizer +optim_wrapper = dict( + _delete_=True, + type='OptimWrapper', + optimizer=optimizer, + clip_grad=dict(max_norm=0.01, norm_type=2), + paramwise_cfg=dict(custom_keys=custom_keys)) + +# learning policy +param_scheduler = [ + dict( + type='LinearLR', start_factor=1e-6, by_epoch=False, begin=0, end=1500), + dict( + type='PolyLR', + eta_min=0.0, + power=1.0, + begin=1500, + end=160000, + by_epoch=False, + ) +] diff --git a/configs/maskformer/maskformer_swin-t_upernet_8xb2-160k_ade20k-512x512.py b/configs/maskformer/maskformer_swin-t_upernet_8xb2-160k_ade20k-512x512.py new file mode 100644 index 0000000000..aa242dbe31 --- /dev/null +++ b/configs/maskformer/maskformer_swin-t_upernet_8xb2-160k_ade20k-512x512.py @@ -0,0 +1,81 @@ +_base_ = './maskformer_r50-d32_8xb2-160k_ade20k-512x512.py' + +checkpoint_file = 'https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/swin/swin_tiny_patch4_window7_224_20220317-1cdeb081.pth' # noqa +backbone_norm_cfg = dict(type='LN', requires_grad=True) +depths = [2, 2, 6, 2] +model = dict( + backbone=dict( + _delete_=True, + type='SwinTransformer', + pretrain_img_size=224, + embed_dims=96, + patch_size=4, + window_size=7, + mlp_ratio=4, + depths=depths, + num_heads=[3, 6, 12, 24], + strides=(4, 2, 2, 2), + out_indices=(0, 1, 2, 3), + qkv_bias=True, + qk_scale=None, + patch_norm=True, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.3, + use_abs_pos_embed=False, + act_cfg=dict(type='GELU'), + norm_cfg=backbone_norm_cfg, + init_cfg=dict(type='Pretrained', checkpoint=checkpoint_file)), + decode_head=dict( + type='MaskFormerHead', + in_channels=[96, 192, 384, + 768], # input channels of pixel_decoder modules + )) + +# optimizer +optimizer = dict( + type='AdamW', lr=0.00006, betas=(0.9, 0.999), weight_decay=0.01) + +# set all layers in backbone to lr_mult=1.0 +# set all norm layers, position_embeding, +# query_embeding to decay_multi=0.0 +backbone_norm_multi = dict(lr_mult=1.0, decay_mult=0.0) +backbone_embed_multi = dict(lr_mult=1.0, decay_mult=0.0) +embed_multi = dict(decay_mult=0.0) +custom_keys = { + 'backbone': dict(lr_mult=1.0), + 'backbone.patch_embed.norm': backbone_norm_multi, + 'backbone.norm': backbone_norm_multi, + 'relative_position_bias_table': backbone_embed_multi, + 'query_embed': embed_multi, +} +custom_keys.update({ + f'backbone.stages.{stage_id}.blocks.{block_id}.norm': backbone_norm_multi + for stage_id, num_blocks in enumerate(depths) + for block_id in range(num_blocks) +}) +custom_keys.update({ + f'backbone.stages.{stage_id}.downsample.norm': backbone_norm_multi + for stage_id in range(len(depths) - 1) +}) +# optimizer +optim_wrapper = dict( + _delete_=True, + type='OptimWrapper', + optimizer=optimizer, + clip_grad=dict(max_norm=0.01, norm_type=2), + paramwise_cfg=dict(custom_keys=custom_keys)) + +# learning policy +param_scheduler = [ + dict( + type='LinearLR', start_factor=1e-6, by_epoch=False, begin=0, end=1500), + dict( + type='PolyLR', + eta_min=0.0, + power=1.0, + begin=1500, + end=160000, + by_epoch=False, + ) +] diff --git a/mmseg/models/decode_heads/__init__.py b/mmseg/models/decode_heads/__init__.py index 8add7615c2..c6976652d7 100644 --- a/mmseg/models/decode_heads/__init__.py +++ b/mmseg/models/decode_heads/__init__.py @@ -15,6 +15,7 @@ from .isa_head import ISAHead from .knet_head import IterativeDecodeHead, KernelUpdateHead, KernelUpdator from .lraspp_head import LRASPPHead +from .maskformer_head import MaskFormerHead from .nl_head import NLHead from .ocr_head import OCRHead from .point_head import PointHead @@ -36,5 +37,5 @@ 'PointHead', 'APCHead', 'DMHead', 'LRASPPHead', 'SETRUPHead', 'SETRMLAHead', 'DPTHead', 'SETRMLAHead', 'SegmenterMaskTransformerHead', 'SegformerHead', 'ISAHead', 'STDCHead', 'IterativeDecodeHead', - 'KernelUpdateHead', 'KernelUpdator' + 'KernelUpdateHead', 'KernelUpdator', 'MaskFormerHead' ] diff --git a/mmseg/models/decode_heads/maskformer_head.py b/mmseg/models/decode_heads/maskformer_head.py new file mode 100644 index 0000000000..98ca92b996 --- /dev/null +++ b/mmseg/models/decode_heads/maskformer_head.py @@ -0,0 +1,173 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + +try: + from mmdet.models.dense_heads import MaskFormerHead as MMDET_MaskFormerHead +except ModuleNotFoundError: + MMDET_MaskFormerHead = None + +from mmengine.structures import InstanceData +from torch import Tensor + +from mmseg.registry import MODELS +from mmseg.structures.seg_data_sample import SegDataSample +from mmseg.utils import ConfigType, SampleList + + +@MODELS.register_module() +class MaskFormerHead(MMDET_MaskFormerHead): + """Implements the MaskFormer head. + + See `Per-Pixel Classification is Not All You Need for Semantic Segmentation + `_ for details. + + Args: + num_classes (int): Number of classes. Default: 150. + align_corners (bool): align_corners argument of F.interpolate. + Default: False. + ignore_index (int): The label index to be ignored. Default: 255. + """ + + def __init__(self, + num_classes: int = 150, + align_corners: bool = False, + ignore_index: int = 255, + **kwargs) -> None: + super().__init__(**kwargs) + + self.out_channels = kwargs['out_channels'] + self.align_corners = True + self.num_classes = num_classes + self.align_corners = align_corners + self.out_channels = num_classes + self.ignore_index = ignore_index + + feat_channels = kwargs['feat_channels'] + self.cls_embed = nn.Linear(feat_channels, self.num_classes + 1) + + def _seg_data_to_instance_data(self, batch_data_samples: SampleList): + """Perform forward propagation to convert paradigm from MMSegmentation + to MMDetection to ensure ``MMDET_MaskFormerHead`` could be called + normally. Specifically, ``batch_gt_instances`` would be added. + + Args: + batch_data_samples (List[:obj:`SegDataSample`]): The Data + Samples. It usually includes information such as + `gt_sem_seg`. + + Returns: + tuple[Tensor]: A tuple contains two lists. + + - batch_gt_instances (list[:obj:`InstanceData`]): Batch of + gt_instance. It usually includes ``labels``, each is + unique ground truth label id of images, with + shape (num_gt, ) and ``masks``, each is ground truth + masks of each instances of a image, shape (num_gt, h, w). + - batch_img_metas (list[dict]): List of image meta information. + """ + batch_img_metas = [] + batch_gt_instances = [] + for data_sample in batch_data_samples: + # Add `batch_input_shape` in metainfo of data_sample, which would + # be used in MaskFormerHead of MMDetection. + metainfo = data_sample.metainfo + metainfo['batch_input_shape'] = metainfo['img_shape'] + data_sample.set_metainfo(metainfo) + batch_img_metas.append(data_sample.metainfo) + gt_sem_seg = data_sample.gt_sem_seg.data + classes = torch.unique( + gt_sem_seg, + sorted=False, + return_inverse=False, + return_counts=False) + + # remove ignored region + gt_labels = classes[classes != self.ignore_index] + + masks = [] + for class_id in gt_labels: + masks.append(gt_sem_seg == class_id) + + if len(masks) == 0: + gt_masks = torch.zeros((0, gt_sem_seg.shape[-2], + gt_sem_seg.shape[-1])).to(gt_sem_seg) + else: + gt_masks = torch.stack(masks).squeeze(1) + + instance_data = InstanceData( + labels=gt_labels, masks=gt_masks.long()) + batch_gt_instances.append(instance_data) + return batch_gt_instances, batch_img_metas + + def loss(self, x: Tuple[Tensor], batch_data_samples: SampleList, + train_cfg: ConfigType) -> dict: + """Perform forward propagation and loss calculation of the decoder head + on the features of the upstream network. + + Args: + x (tuple[Tensor]): Multi-level features from the upstream + network, each is a 4D-tensor. + batch_data_samples (List[:obj:`SegDataSample`]): The Data + Samples. It usually includes information such as + `gt_sem_seg`. + train_cfg (ConfigType): Training config. + + Returns: + dict[str, Tensor]: a dictionary of loss components. + """ + # batch SegDataSample to InstanceDataSample + batch_gt_instances, batch_img_metas = self._seg_data_to_instance_data( + batch_data_samples) + + # forward + all_cls_scores, all_mask_preds = self(x, batch_data_samples) + + # loss + losses = self.loss_by_feat(all_cls_scores, all_mask_preds, + batch_gt_instances, batch_img_metas) + + return losses + + def predict(self, x: Tuple[Tensor], batch_img_metas: List[dict], + test_cfg: ConfigType) -> Tuple[Tensor]: + """Test without augmentaton. + + Args: + x (tuple[Tensor]): Multi-level features from the + upstream network, each is a 4D-tensor. + batch_img_metas (List[:obj:`SegDataSample`]): The Data + Samples. It usually includes information such as + `gt_sem_seg`. + test_cfg (ConfigType): Test config. + + Returns: + Tensor: A tensor of segmentation mask. + """ + + batch_data_samples = [] + for metainfo in batch_img_metas: + metainfo['batch_input_shape'] = metainfo['img_shape'] + batch_data_samples.append(SegDataSample(metainfo=metainfo)) + # Forward function of MaskFormerHead from MMDetection needs + # 'batch_data_samples' as inputs, which is image shape actually. + all_cls_scores, all_mask_preds = self(x, batch_data_samples) + mask_cls_results = all_cls_scores[-1] + mask_pred_results = all_mask_preds[-1] + + # upsample masks + img_shape = batch_img_metas[0]['batch_input_shape'] + mask_pred_results = F.interpolate( + mask_pred_results, + size=img_shape, + mode='bilinear', + align_corners=False) + + # semantic inference + cls_score = F.softmax(mask_cls_results, dim=-1)[..., :-1] + mask_pred = mask_pred_results.sigmoid() + seg_logits = torch.einsum('bqc,bqhw->bchw', cls_score, mask_pred) + return seg_logits diff --git a/model-index.yml b/model-index.yml index b087a7294c..6aacf72b0d 100644 --- a/model-index.yml +++ b/model-index.yml @@ -25,6 +25,7 @@ Import: - configs/isanet/isanet.yml - configs/knet/knet.yml - configs/mae/mae.yml +- configs/maskformer/maskformer.yml - configs/mobilenet_v2/mobilenet_v2.yml - configs/mobilenet_v3/mobilenet_v3.yml - configs/nonlocal_net/nonlocal_net.yml diff --git a/requirements/mminstall.txt b/requirements/mminstall.txt index 2d43c0cb42..d27af8dd0f 100644 --- a/requirements/mminstall.txt +++ b/requirements/mminstall.txt @@ -1,3 +1,4 @@ mmcls>=1.0.0rc0 mmcv>=2.0.0rc3,<2.1.0 +mmdet>=3.0.0rc4 mmengine>=0.1.0,<1.0.0 diff --git a/tests/test_models/test_heads/test_maskformer_head.py b/tests/test_models/test_heads/test_maskformer_head.py new file mode 100644 index 0000000000..fe4bf96fea --- /dev/null +++ b/tests/test_models/test_heads/test_maskformer_head.py @@ -0,0 +1,54 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from os.path import dirname, join + +import torch +from mmengine import Config +from mmengine.structures import PixelData + +from mmseg.registry import MODELS +from mmseg.structures import SegDataSample +from mmseg.utils import register_all_modules + + +def test_maskformer_head(): + register_all_modules() + repo_dpath = dirname(dirname(__file__)) + cfg = Config.fromfile( + join( + repo_dpath, + '../../configs/maskformer/maskformer_r50-d32_8xb2-160k_ade20k-512x512.py' # noqa + )) + cfg.model.train_cfg = None + decode_head = MODELS.build(cfg.model.decode_head) + inputs = (torch.randn(1, 256, 32, 32), torch.randn(1, 512, 16, 16), + torch.randn(1, 1024, 8, 8), torch.randn(1, 2048, 4, 4)) + # test inference + batch_img_metas = [ + dict( + scale_factor=(1.0, 1.0), + img_shape=(512, 683), + ori_shape=(512, 683)) + ] + test_cfg = dict(mode='whole') + output = decode_head.predict(inputs, batch_img_metas, test_cfg) + assert output.shape == (1, 150, 512, 683) + + # test training + inputs = (torch.randn(2, 256, 32, 32), torch.randn(2, 512, 16, 16), + torch.randn(2, 1024, 8, 8), torch.randn(2, 2048, 4, 4)) + batch_data_samples = [] + img_meta = { + 'img_shape': (512, 512), + 'ori_shape': (480, 640), + 'pad_shape': (512, 512), + 'scale_factor': (1.425, 1.425), + } + for _ in range(2): + data_sample = SegDataSample( + gt_sem_seg=PixelData(data=torch.ones(512, 512).long())) + data_sample.set_metainfo(img_meta) + batch_data_samples.append(data_sample) + train_cfg = {} + losses = decode_head.loss(inputs, batch_data_samples, train_cfg) + assert (loss in losses.keys() + for loss in ('loss_cls', 'loss_mask', 'loss_dice'))