Skip to content

Commit

Permalink
Make DetrForSegmentation support any timm backbone
Browse files Browse the repository at this point in the history
  • Loading branch information
NielsRogge committed Jun 8, 2021
1 parent eb4ee5c commit 9f23454
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 20 deletions.
12 changes: 5 additions & 7 deletions docs/source/model_doc/detr.rst
Original file line number Diff line number Diff line change
Expand Up @@ -98,13 +98,11 @@ Tips:
`num_boxes` variable in the `SetCriterion` class of `modeling_detr.py`. When training on multiple nodes, this should
be set to the average number of target boxes across all nodes, as can be seen in the original implementation `here
<https://github.com/facebookresearch/detr/blob/a54b77800eb8e64e3ad0d8237789fcbf2f8350c5/models/detr.py#L227-L232>`__.
- :class:`~transformers.DetrForObjectDetection` can be initialized with any convolutional backbone available in the
`timm library <https://github.com/rwightman/pytorch-image-models>`__. Initializing with a MobileNet backbone for
example can be done by setting the :obj:`backbone` attribute of :class:`~transformers.DetrConfig` to
:obj:`"tf_mobilenetv3_small_075"`, and then initializing :class:`~transformers.DetrForObjectDetection` with that
config. Note that :class:`~transformers.DetrForSegmentation` does not support any timm backbone, as the mask head
depends on the intermediate feature maps from a ResNet. One would need to update the mask head in order to work with
a different backbone.
- :class:`~transformers.DetrForObjectDetection` and :class:`~transformers.DetrForSegmentation` can be initialized
with any convolutional backbone available in the `timm library
<https://github.com/rwightman/pytorch-image-models>`__. Initializing with a MobileNet backbone for example can be
done by setting the :obj:`backbone` attribute of :class:`~transformers.DetrConfig` to
:obj:`"tf_mobilenetv3_small_075"`, and then initializing the model with that config.
- DETR resizes the input images such that the shortest side is at least a certain amount of pixels while the longest is
at most 1333 pixels. At training time, scale augmentation is used such that the shortest side is randomly set to at
least 480 and at most 800 pixels. At inference time, the shortest side is set to 800. One can use
Expand Down
25 changes: 15 additions & 10 deletions src/transformers/models/detr/modeling_detr.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import math
import random
from dataclasses import dataclass
from collections import OrderedDict
from typing import Dict, List, Optional, Tuple

import torch
Expand Down Expand Up @@ -257,7 +256,7 @@ def forward(self, x):
return x * scale + bias


def replace_batch_normalization(m, name=""):
def replace_batch_norm(m, name=""):
for attr_str in dir(m):
target_attr = getattr(m, attr_str)
if isinstance(target_attr, torch.nn.BatchNorm2d):
Expand All @@ -269,7 +268,7 @@ def replace_batch_normalization(m, name=""):
frozen.running_var.data.copy_(bn.running_var)
setattr(m, attr_str, frozen)
for n, ch in m.named_children():
replace_batch_normalization(ch, n)
replace_batch_norm(ch, n)


class TimmBackbone(nn.Module):
Expand All @@ -286,13 +285,13 @@ def __init__(self, name: str, train_backbone: bool, dilation: bool, return_inter
kwargs = {}
if dilation:
kwargs["output_stride"] = 16
backbone = create_model(name, pretrained=True, features_only=True, **kwargs)
backbone = create_model(name, pretrained=True, features_only=True, out_indices=(1, 2, 3, 4), **kwargs)
# replace batch norm by frozen batch norm
with torch.no_grad():
replace_batch_normalization(backbone)
replace_batch_norm(backbone)
self.body = backbone
self.return_intermediate_layers = return_intermediate_layers
self.num_channels = self.body.feature_info.channels(-1)
self.intermediate_channel_sizes = self.body.feature_info.channels()

for name, parameter in self.body.named_parameters():
if not train_backbone or "layer2" not in name and "layer3" not in name and "layer4" not in name:
Expand Down Expand Up @@ -1133,12 +1132,14 @@ def __init__(self, config: DetrConfig):
super().__init__(config)

# Create backbone + positional encoding
backbone = TimmBackbone(config.backbone, config.train_backbone, config.dilation, config.return_intermediate_layers)
backbone = TimmBackbone(
config.backbone, config.train_backbone, config.dilation, config.return_intermediate_layers
)
position_embeddings = build_position_encoding(config)
self.backbone = Joiner(backbone, position_embeddings)

# Create projection layer
self.input_projection = nn.Conv2d(backbone.num_channels, config.d_model, kernel_size=1)
self.input_projection = nn.Conv2d(backbone.intermediate_channel_sizes[-1], config.d_model, kernel_size=1)

self.query_position_embeddings = nn.Embedding(config.num_queries, config.d_model)

Expand Down Expand Up @@ -1442,7 +1443,11 @@ def __init__(self, config: DetrConfig):

# segmentation head
hidden_size, number_of_heads = config.d_model, config.encoder_attention_heads
self.mask_head = DetrMaskHeadSmallConv(hidden_size + number_of_heads, [1024, 512, 256], hidden_size)
intermediate_channel_sizes = self.detr.model.backbone[0].intermediate_channel_sizes

self.mask_head = DetrMaskHeadSmallConv(
hidden_size + number_of_heads, intermediate_channel_sizes[::-1][-3:], hidden_size
)

self.init_weights()

Expand Down Expand Up @@ -1574,7 +1579,7 @@ def forward(
# bbox_mask is of shape (batch_size, num_queries, number_of_attention_heads in bbox_attention, height/32, width/32)
bbox_mask = self.bbox_attention(sequence_output, memory, mask=~mask)

seg_masks = self.mask_head(projected_feature_map, bbox_mask, [features[3][0], features[2][0], features[1][0]])
seg_masks = self.mask_head(projected_feature_map, bbox_mask, [features[2][0], features[1][0], features[0][0]])

pred_masks = seg_masks.view(batch_size, self.detr.config.num_queries, seg_masks.shape[-2], seg_masks.shape[-1])

Expand Down
3 changes: 0 additions & 3 deletions tests/test_modeling_detr.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,9 +378,6 @@ def test_different_timm_backbone(self):
config.backbone = "tf_mobilenetv3_small_075"

for model_class in self.all_model_classes:
# the mask head of DetrForSegmentation does not support any timm backbone
if model_class.__name__ == "DetrForSegmentation":
continue
model = model_class(config)
model.to(torch_device)
model.eval()
Expand Down

0 comments on commit 9f23454

Please sign in to comment.