Skip to content

Commit

Permalink
add Bigbird ONNX config (huggingface#16427)
Browse files Browse the repository at this point in the history
* add Bigbird ONNX config
  • Loading branch information
vumichien authored and elusenji committed Jun 12, 2022
1 parent 5f2c641 commit cf131e2
Show file tree
Hide file tree
Showing 6 changed files with 30 additions and 4 deletions.
1 change: 1 addition & 0 deletions docs/source/en/serialization.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ Ready-made configurations include the following architectures:
- BART
- BEiT
- BERT
- BigBird
- Blenderbot
- BlenderbotSmall
- CamemBERT
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/big_bird/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@


_import_structure = {
"configuration_big_bird": ["BIG_BIRD_PRETRAINED_CONFIG_ARCHIVE_MAP", "BigBirdConfig"],
"configuration_big_bird": ["BIG_BIRD_PRETRAINED_CONFIG_ARCHIVE_MAP", "BigBirdConfig", "BigBirdOnnxConfig"],
}

if is_sentencepiece_available():
Expand Down Expand Up @@ -66,7 +66,7 @@
]

if TYPE_CHECKING:
from .configuration_big_bird import BIG_BIRD_PRETRAINED_CONFIG_ARCHIVE_MAP, BigBirdConfig
from .configuration_big_bird import BIG_BIRD_PRETRAINED_CONFIG_ARCHIVE_MAP, BigBirdConfig, BigBirdOnnxConfig

if is_sentencepiece_available():
from .tokenization_big_bird import BigBirdTokenizer
Expand Down
14 changes: 14 additions & 0 deletions src/transformers/models/big_bird/configuration_big_bird.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
""" BigBird model configuration"""
from collections import OrderedDict
from typing import Mapping

from ...configuration_utils import PretrainedConfig
from ...onnx import OnnxConfig
from ...utils import logging


Expand Down Expand Up @@ -160,3 +163,14 @@ def __init__(
self.block_size = block_size
self.num_random_blocks = num_random_blocks
self.classifier_dropout = classifier_dropout


class BigBirdOnnxConfig(OnnxConfig):
@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
return OrderedDict(
[
("input_ids", {0: "batch", 1: "sequence"}),
("attention_mask", {0: "batch", 1: "sequence"}),
]
)
4 changes: 2 additions & 2 deletions src/transformers/models/big_bird/modeling_big_bird.py
Original file line number Diff line number Diff line change
Expand Up @@ -3000,7 +3000,7 @@ def forward(
# setting lengths logits to `-inf`
logits_mask = self.prepare_question_mask(question_lengths, seqlen)
if token_type_ids is None:
token_type_ids = (~logits_mask).long()
token_type_ids = torch.ones(logits_mask.size(), dtype=int) - logits_mask
logits_mask = logits_mask
logits_mask[:, 0] = False
logits_mask.unsqueeze_(2)
Expand Down Expand Up @@ -3063,5 +3063,5 @@ def prepare_question_mask(q_lengths: torch.Tensor, maxlen: int):
# q_lengths -> (bz, 1)
mask = torch.arange(0, maxlen).to(q_lengths.device)
mask.unsqueeze_(0) # -> (1, maxlen)
mask = mask < q_lengths
mask = torch.where(mask < q_lengths, 1, 0)
return mask
10 changes: 10 additions & 0 deletions src/transformers/onnx/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from ..models.bart import BartOnnxConfig
from ..models.beit import BeitOnnxConfig
from ..models.bert import BertOnnxConfig
from ..models.big_bird import BigBirdOnnxConfig
from ..models.blenderbot import BlenderbotOnnxConfig
from ..models.blenderbot_small import BlenderbotSmallOnnxConfig
from ..models.camembert import CamembertOnnxConfig
Expand Down Expand Up @@ -156,6 +157,15 @@ class FeaturesManager:
"question-answering",
onnx_config_cls=BertOnnxConfig,
),
"bigbird": supported_features_mapping(
"default",
"masked-lm",
"causal-lm",
"sequence-classification",
"token-classification",
"question-answering",
onnx_config_cls=BigBirdOnnxConfig,
),
"ibert": supported_features_mapping(
"default",
"masked-lm",
Expand Down
1 change: 1 addition & 0 deletions tests/onnx/test_onnx_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ def test_values_override(self):
PYTORCH_EXPORT_MODELS = {
("albert", "hf-internal-testing/tiny-albert"),
("bert", "bert-base-cased"),
("bigbird", "google/bigbird-roberta-base"),
("ibert", "kssteven/ibert-roberta-base"),
("camembert", "camembert-base"),
("distilbert", "distilbert-base-cased"),
Expand Down

0 comments on commit cf131e2

Please sign in to comment.