Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add DETR #11653

Merged
merged 31 commits into from
Jun 9, 2021
Merged

Add DETR #11653

Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
d21c87b
Squash all commits of modeling_detr_v7 branch into one
NielsRogge May 10, 2021
8e06e0b
Improve docs
NielsRogge May 10, 2021
f2592fc
Fix tests
LysandreJik May 10, 2021
1c76900
Style
LysandreJik May 10, 2021
645801d
Improve docs some more and fix most tests
NielsRogge May 10, 2021
0a02d46
Fix slow tests of ViT, DeiT and DETR
NielsRogge May 10, 2021
82507da
Improve replacement of batch norm
NielsRogge May 14, 2021
eb4ee5c
Restructure timm backbone forward
NielsRogge May 14, 2021
9f23454
Make DetrForSegmentation support any timm backbone
NielsRogge May 14, 2021
cfc4b4b
Fix name of output
NielsRogge May 14, 2021
7ddb65a
Address most comments by @LysandreJik
NielsRogge May 21, 2021
a4dec31
Give better names for variables
NielsRogge May 21, 2021
844a280
Conditional imports + timm in setup.py
LysandreJik May 27, 2021
0c5d9ea
Address additional comments by @sgugger
NielsRogge May 28, 2021
f9d379d
Make style, add require_timm and require_vision to testsé
NielsRogge May 28, 2021
0d3eb3c
Remove train_backbone attribute of DetrConfig, add methods to freeze/…
NielsRogge May 28, 2021
9acf820
Add png files to fixtures
NielsRogge May 28, 2021
5d05e45
Fix type hint
NielsRogge May 28, 2021
db9c555
Add timm to workflows
LysandreJik Jun 1, 2021
7bef619
Add `BatchNorm2d` to the weight initialization
LysandreJik Jun 1, 2021
681d714
Fix retain_grad test
NielsRogge Jun 1, 2021
c5ef475
Replace model checkpoints by Facebook namespace
NielsRogge Jun 1, 2021
fdc7edb
Fix name of checkpoint in test
NielsRogge Jun 1, 2021
457b26d
Add user-friendly message when scipy is not available
NielsRogge Jun 1, 2021
734adf9
Address most comments by @patrickvonplaten
NielsRogge Jun 2, 2021
5237b13
Remove return_intermediate_layers attribute of DetrConfig and simplif…
NielsRogge Jun 3, 2021
674297d
Better initialization
LysandreJik Jun 7, 2021
fb30f72
Scipy is necessary to get sklearn metrics
LysandreJik Jun 7, 2021
731297d
Rename TimmBackbone to DetrTimmConvEncoder and rename DetrJoiner to D…
NielsRogge Jun 8, 2021
dcbd458
Make style
NielsRogge Jun 8, 2021
2479544
Improve docs and add 2 community notebooks
NielsRogge Jun 9, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Fix tests
  • Loading branch information
LysandreJik authored and NielsRogge committed Jun 8, 2021
commit f2592fc7b4ff152d43ef0c684250ebb758aaa95e
4 changes: 4 additions & 0 deletions src/transformers/models/detr/configuration_detr.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ class DetrConfig(PretrainedConfig):
just in case (e.g., 512 or 1024 or 2048).
init_std (:obj:`float`, `optional`, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
init_xavier_std (:obj:`float`, `optional`, defaults to 1.):
NielsRogge marked this conversation as resolved.
Show resolved Hide resolved
The scaling factor used for the Xavier initialization gain in the HM Attention map module.
encoder_layerdrop: (:obj:`float`, `optional`, defaults to 0.0):
The LayerDrop probability for the encoder. See the `LayerDrop paper <see
https://arxiv.org/abs/1909.11556>`__ for more details.
Expand Down Expand Up @@ -142,6 +144,7 @@ def __init__(
attention_dropout=0.0,
activation_dropout=0.0,
init_std=0.02,
init_xavier_std=1.,
classifier_dropout=0.0,
scale_embedding=False,
auxiliary_loss=False,
Expand Down Expand Up @@ -176,6 +179,7 @@ def __init__(
self.activation_dropout = activation_dropout
self.activation_function = activation_function
self.init_std = init_std
self.init_xavier_std = init_xavier_std
self.encoder_layerdrop = encoder_layerdrop
self.decoder_layerdrop = decoder_layerdrop
self.classifier_dropout = classifier_dropout
Expand Down
85 changes: 78 additions & 7 deletions src/transformers/models/detr/modeling_detr.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,11 @@ class DetrObjectDetectionOutput(ModelOutput):
Optional, only returned when auxilary losses are activated (i.e. :obj:`config.auxiliary_loss` is set to
`True`) and labels are provided. It is a list of dictionnaries containing the two above keys (:obj:`logits`
and :obj:`pred_boxes`) for each decoder layer.
last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
Sequence of hidden-states at the output of the last layer of the decoder of the model.

If :obj:`past_key_values` is used only the last hidden-state of the sequences of shape :obj:`(batch_size,
1, hidden_size)` is output.
decoder_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`. Hidden-states of the decoder at the output of
Expand Down Expand Up @@ -129,6 +134,7 @@ class DetrObjectDetectionOutput(ModelOutput):
logits: torch.FloatTensor = None
pred_boxes: torch.FloatTensor = None
auxiliary_outputs: Optional[List[Dict]] = None
last_hidden_state: Optional[torch.FloatTensor] = None
decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
cross_attentions: Optional[Tuple[torch.FloatTensor]] = None
Expand All @@ -138,16 +144,72 @@ class DetrObjectDetectionOutput(ModelOutput):


@dataclass
class DetrSegmentationOutput(DetrObjectDetectionOutput):
class DetrForSegmentationOutput(ModelOutput):
"""
This class adds one attribute to DetrObjectDetectionOutput, namely predicted masks.
Output type of :class:`~transformers.DetrForSegmentation`.

Args:
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` are provided)):
Total loss as a linear combination of a negative log-likehood (cross-entropy) for class prediction and a
bounding box loss. The latter is defined as a linear combination of the L1 loss and the generalized
scale-invariant IoU loss.
loss_dict (:obj:`Dict`, `optional`):
A dictionary containing the individual losses.
logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_queries, num_classes + 1)`):
Classification logits (including no-object) for all queries.
pred_boxes (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_queries, 4)`):
Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These
values are normalized in [0, 1], relative to the size of each individual image in the batch (disregarding
possible padding). You can use :class:`~transformers.DetrForObjectDetection.post_process` to retrieve the
unnormalized bounding boxes.
pred_masks (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_queries, width, height)`):
...
auxiliary_outputs (:obj:`list[Dict]`, `optional`):
Optional, only returned when auxilary losses are activated (i.e. :obj:`config.auxiliary_loss` is set to
`True`) and labels are provided. It is a list of dictionnaries containing the two above keys (:obj:`logits`
and :obj:`pred_boxes`) for each decoder layer.
last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
Sequence of hidden-states at the output of the last layer of the decoder of the model.

If :obj:`past_key_values` is used only the last hidden-state of the sequences of shape :obj:`(batch_size,
1, hidden_size)` is output.
decoder_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`. Hidden-states of the decoder at the output of
each layer plus the initial embedding outputs.
decoder_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, sequence_length)`. Attentions weights of the decoder, after the attention softmax, used to
compute the weighted average in the self-attention heads.
cross_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, sequence_length)`. Attentions weights of the decoder's cross-attention layer, after the
attention softmax, used to compute the weighted average in the cross-attention heads.
encoder_last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
Sequence of hidden-states at the output of the last layer of the encoder of the model.
encoder_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`. Hidden-states of the encoder at the output of
each layer plus the initial embedding outputs.
encoder_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, sequence_length)`. Attentions weights of the encoder, after the attention softmax, used to
compute the weighted average in the self-attention heads.
"""

loss: Optional[torch.FloatTensor] = None
loss_dict: Optional[Dict] = None
logits: torch.FloatTensor = None
pred_boxes: torch.FloatTensor = None
pred_masks: torch.FloatTensor = None
auxiliary_outputs: Optional[List[Dict]] = None
last_hidden_state: Optional[torch.FloatTensor] = None
decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
cross_attentions: Optional[Tuple[torch.FloatTensor]] = None
encoder_last_hidden_state: Optional[torch.FloatTensor] = None
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None


# BELOW: utilities copied from
Expand Down Expand Up @@ -676,7 +738,9 @@ class DetrPreTrainedModel(PreTrainedModel):

def _init_weights(self, module):
std = self.config.init_std
if isinstance(module, nn.Linear):
if isinstance(module, (nn.Linear, nn.Conv2d)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
Expand Down Expand Up @@ -1332,6 +1396,7 @@ class labels themselves should be a :obj:`torch.LongTensor` of len :obj:`(number
logits=logits,
pred_boxes=pred_boxes,
auxiliary_outputs=auxiliary_outputs,
last_hidden_state=outputs.last_hidden_state,
decoder_hidden_states=outputs.decoder_hidden_states,
decoder_attentions=outputs.decoder_attentions,
cross_attentions=outputs.cross_attentions,
Expand Down Expand Up @@ -1359,9 +1424,14 @@ def __init__(self, config: DetrConfig):

# segmentation head
hidden_size, number_of_heads = config.d_model, config.encoder_attention_heads
self.bbox_attention = DetrMHAttentionMap(hidden_size, hidden_size, number_of_heads, dropout=0.0)
self.mask_head = DetrMaskHeadSmallConv(hidden_size + number_of_heads, [1024, 512, 256], hidden_size)

self.init_weights()

# The DetrMHAttentionMap has a custom layer initialization scheme which must not get overwritten by the
# self.init_weights()
self.bbox_attention = DetrMHAttentionMap(hidden_size, hidden_size, number_of_heads, dropout=0.0, std=config.init_xavier_std)

@add_start_docstrings_to_model_forward(DETR_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=DetrSegmentationOutput, config_class=_CONFIG_FOR_DOC)
def forward(
Expand Down Expand Up @@ -1542,6 +1612,7 @@ def forward(
pred_boxes=pred_boxes,
pred_masks=pred_masks,
auxiliary_outputs=auxiliary_outputs,
last_hidden_state=decoder_outputs.last_hidden_state,
decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions,
Expand Down Expand Up @@ -1637,7 +1708,7 @@ def forward(self, x: Tensor, bbox_mask: Tensor, fpns: List[Tensor]):
class DetrMHAttentionMap(nn.Module):
"""This is a 2D attention module, which only returns the attention softmax (no multiplication by value)"""

def __init__(self, query_dim, hidden_dim, num_heads, dropout=0.0, bias=True):
def __init__(self, query_dim, hidden_dim, num_heads, dropout=0.0, bias=True, std=None):
super().__init__()
self.num_heads = num_heads
self.hidden_dim = hidden_dim
Expand All @@ -1648,8 +1719,8 @@ def __init__(self, query_dim, hidden_dim, num_heads, dropout=0.0, bias=True):

nn.init.zeros_(self.k_linear.bias)
NielsRogge marked this conversation as resolved.
Show resolved Hide resolved
nn.init.zeros_(self.q_linear.bias)
nn.init.xavier_uniform_(self.k_linear.weight)
nn.init.xavier_uniform_(self.q_linear.weight)
nn.init.xavier_uniform_(self.k_linear.weight, gain=std)
nn.init.xavier_uniform_(self.q_linear.weight, gain=std)
self.normalize_fact = float(hidden_dim / self.num_heads) ** -0.5

def forward(self, q, k, mask: Optional[Tensor] = None):
Expand Down
5 changes: 4 additions & 1 deletion tests/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -982,7 +982,6 @@ def test_retain_grad_hidden_states_attentions(self):

outputs = model(**inputs)

print(outputs)
output = outputs[0]

if config.is_encoder_decoder:
Expand Down Expand Up @@ -1236,6 +1235,9 @@ def recursive_check(tuple_object, dict_object):
if isinstance(tuple_object, (List, Tuple)):
for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object):
recursive_check(tuple_iterable_value, dict_iterable_value)
elif isinstance(tuple_object, Dict):
for tuple_iterable_value, dict_iterable_value in zip(tuple_object.values(), dict_object.values()):
recursive_check(tuple_iterable_value, dict_iterable_value)
elif tuple_object is None:
return
else:
Expand All @@ -1249,6 +1251,7 @@ def recursive_check(tuple_object, dict_object):
recursive_check(tuple_output, dict_output)

for model_class in self.all_model_classes:
print(model_class)
model = model_class(config)
model.to(torch_device)
model.eval()
Expand Down
31 changes: 16 additions & 15 deletions tests/test_modeling_detr.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,15 @@
import inspect
import math
import unittest
from typing import List, Tuple, Dict

from transformers import is_timm_available, is_vision_available
from transformers.file_utils import cached_property
from transformers.testing_utils import require_torch, slow, torch_device

from .test_configuration_common import ConfigTester
from .test_generation_utils import GenerationTesterMixin
from .test_modeling_common import ModelTesterMixin, floats_tensor

from .test_modeling_common import ModelTesterMixin, floats_tensor, _config_zero_init

if is_timm_available():
import torch
Expand Down Expand Up @@ -87,17 +87,17 @@ def __init__(
def prepare_config_and_inputs(self):
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.min_size, self.max_size])

pixel_mask = torch.ones([self.batch_size, self.min_size, self.max_size])
pixel_mask = torch.ones([self.batch_size, self.min_size, self.max_size], device=torch_device)

labels = None
if self.use_labels:
# labels is a list of Dict (each Dict being the labels for a given example in the batch)
labels = []
for i in range(self.batch_size):
target = {}
target["class_labels"] = torch.randint(high=self.num_labels, size=(self.n_targets,))
target["boxes"] = torch.rand(self.n_targets, 4)
target["masks"] = torch.rand(self.n_targets, self.min_size, self.max_size)
target["class_labels"] = torch.randint(high=self.num_labels, size=(self.n_targets,), device=torch_device)
target["boxes"] = torch.rand(self.n_targets, 4, device=torch_device)
target["masks"] = torch.rand(self.n_targets, self.min_size, self.max_size, device=torch_device)
labels.append(target)

config = DetrConfig(
Expand Down Expand Up @@ -176,12 +176,12 @@ def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
labels = []
for i in range(self.model_tester.batch_size):
target = {}
target["class_labels"] = torch.randint(
high=self.model_tester.num_labels, size=(self.model_tester.n_targets,)
target["class_labels"] = torch.ones(
size=(self.model_tester.n_targets,), device=torch_device, dtype=torch.long
)
target["boxes"] = torch.rand(self.model_tester.n_targets, 4)
target["masks"] = torch.rand(
self.model_tester.n_targets, self.model_tester.min_size, self.model_tester.max_size
target["boxes"] = torch.ones(self.model_tester.n_targets, 4, device=torch_device, dtype=torch.float)
target["masks"] = torch.ones(
self.model_tester.n_targets, self.model_tester.min_size, self.model_tester.max_size, device=torch_device, dtype=torch.float
)
labels.append(target)
inputs_dict["labels"] = labels
Expand Down Expand Up @@ -238,6 +238,7 @@ def test_attention_outputs(self):
encoder_seq_length = encoder_seq_length * self.model_tester.num_hashes

for model_class in self.all_model_classes:
print(model_class)
inputs_dict["output_attentions"] = True
inputs_dict["output_hidden_states"] = False
config.return_dict = True
Expand Down Expand Up @@ -278,12 +279,12 @@ def test_attention_outputs(self):
# loss is at first position
if "labels" in inputs_dict:
correct_outlen += 1 # loss is added to beginning
# Object Detection model returns pred_logits and pred_boxes instead of last_hidden_state
# Object Detection model returns pred_logits and pred_boxes
if model_class.__name__ == "DetrForObjectDetection":
correct_outlen += 1
correct_outlen += 2
# Panoptic Segmentation model returns pred_logits, pred_boxes, pred_masks
if model_class.__name__ == "DetrForSegmentation":
correct_outlen += 2
correct_outlen += 3
if "past_key_values" in outputs:
correct_outlen += 1 # past_key_values have been returned

Expand Down Expand Up @@ -389,7 +390,6 @@ def test_different_timm_backbone(self):

self.assertTrue(outputs)


TOLERANCE = 1e-4


Expand Down Expand Up @@ -488,3 +488,4 @@ def test_inference_panoptic_segmentation_head(self):
[[-7.7558, -10.8788, -11.9797], [-11.8881, -16.4329, -17.7451], [-14.7316, -19.7383, -20.3004]]
).to(torch_device)
self.assertTrue(torch.allclose(outputs.pred_masks[0, 0, :3, :3], expected_slice_masks, atol=1e-4))
NielsRogge marked this conversation as resolved.
Show resolved Hide resolved