Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BigBird #10183

Merged
merged 88 commits into from
Mar 30, 2021
Merged

BigBird #10183

Changes from 1 commit
Commits
Show all changes
88 commits
Select commit Hold shift + click to select a range
d0aa9ea
init bigbird
thevasudevgupta Feb 14, 2021
0e37183
model.__init__ working, conversion script ready, config updated
thevasudevgupta Feb 15, 2021
faa4ab3
add conversion script
thevasudevgupta Feb 16, 2021
facd93e
BigBirdEmbeddings working :)
thevasudevgupta Feb 16, 2021
28c8d13
slightly update conversion script
patrickvonplaten Feb 16, 2021
124b99f
BigBirdAttention working :) ; some bug in layer.output.dense
thevasudevgupta Feb 17, 2021
2c14788
add debugger-notebook
thevasudevgupta Feb 17, 2021
12a523b
forward() working for BigBirdModel :) ; replaced gelu with gelu_fast
thevasudevgupta Feb 17, 2021
aebd36b
tf code adapted to torch till rand_attn in bigbird_block_sparse_atten…
thevasudevgupta Feb 19, 2021
9df3127
BigBirdModel working in block-sparse attention mode :)
thevasudevgupta Feb 23, 2021
ad84acf
add BigBirdForPreTraining
thevasudevgupta Feb 24, 2021
4076c9b
small fix
thevasudevgupta Feb 24, 2021
78a205a
add tokenizer for BigBirdModel
thevasudevgupta Feb 25, 2021
644f65d
fix config & hence modeling
thevasudevgupta Feb 25, 2021
ce66bac
fix base prefix
thevasudevgupta Feb 25, 2021
f672205
init testing
thevasudevgupta Feb 25, 2021
372ff99
init tokenizer test
thevasudevgupta Feb 26, 2021
ed6dc49
pos_embed must be absolute, attn_type=original_full when add_cross_at…
thevasudevgupta Feb 26, 2021
7e05539
remove position_embedding_type arg
thevasudevgupta Feb 26, 2021
d257079
complete normal tests
thevasudevgupta Feb 26, 2021
07ec9a1
add comments to block sparse attention
thevasudevgupta Feb 27, 2021
01dd2e8
add attn_probs for sliding & global tokens
thevasudevgupta Feb 27, 2021
49d62e5
create fn for block sparse attn mask creation
thevasudevgupta Feb 28, 2021
5912716
add special tests
thevasudevgupta Feb 28, 2021
89de3c5
restore pos embed arg
thevasudevgupta Feb 28, 2021
b132905
minor fix
thevasudevgupta Feb 28, 2021
6ab2921
attn probs update
thevasudevgupta Mar 1, 2021
72e2532
make big bird fully gpu friendly
patrickvonplaten Mar 2, 2021
7401768
fix tests
patrickvonplaten Mar 2, 2021
da2824f
remove pruning
patrickvonplaten Mar 2, 2021
3a866e2
correct tokenzier & minor fixes
patrickvonplaten Mar 2, 2021
753ba75
update conversion script , remove norm_type
thevasudevgupta Mar 2, 2021
1e186d0
tokenizer-inference test add
thevasudevgupta Mar 2, 2021
72a150e
remove extra comments
thevasudevgupta Mar 2, 2021
24c74a9
add docs
thevasudevgupta Mar 3, 2021
79955e4
save intermediate
patrickvonplaten Mar 3, 2021
018b8fd
finish trivia_qa conversion
patrickvonplaten Mar 4, 2021
1716dea
small update to forward
thevasudevgupta Mar 4, 2021
15b7cfa
correct qa and layer
patrickvonplaten Mar 4, 2021
c300f3f
merge into master
patrickvonplaten Mar 4, 2021
2782295
better error message
patrickvonplaten Mar 4, 2021
56bd1d8
BigBird QA ready
thevasudevgupta Mar 5, 2021
ecfe137
fix rebased
thevasudevgupta Mar 5, 2021
eebd92a
add triva-qa debugger notebook
thevasudevgupta Mar 5, 2021
f6b6f43
qa setup
thevasudevgupta Mar 6, 2021
a50a10c
fixed till embeddings
thevasudevgupta Mar 7, 2021
edf5f2a
some issue in q/k/v_layer
thevasudevgupta Mar 8, 2021
a94d006
fix bug in conversion-script
thevasudevgupta Mar 9, 2021
3b489a3
fixed till self-attn
thevasudevgupta Mar 9, 2021
1e3aa50
qa fixed except layer norm
thevasudevgupta Mar 11, 2021
2f59e51
add qa end2end test
thevasudevgupta Mar 12, 2021
ef72bcd
fix gradient ckpting ; other qa test
thevasudevgupta Mar 12, 2021
8b94584
speed-up big bird a bit
patrickvonplaten Mar 15, 2021
468de78
hub_id=google
thevasudevgupta Mar 12, 2021
58ee280
clean up
thevasudevgupta Mar 15, 2021
e873658
make quality
thevasudevgupta Mar 15, 2021
4e13753
speed up einsum with bmm
patrickvonplaten Mar 16, 2021
e88110a
finish perf improvements for big bird
patrickvonplaten Mar 16, 2021
5f2d6a0
Merge branch 'master' into add_big_bird
patrickvonplaten Mar 16, 2021
cada132
Merge branch 'master' into add_big_bird
patrickvonplaten Mar 22, 2021
b8f41c0
remove wav2vec2 tok
patrickvonplaten Mar 22, 2021
22a71cc
fix tokenizer
patrickvonplaten Mar 22, 2021
5730a98
include docs
patrickvonplaten Mar 22, 2021
ab65872
correct docs
patrickvonplaten Mar 22, 2021
ff32248
add helper to auto pad block size
thevasudevgupta Mar 25, 2021
de2f812
make style
thevasudevgupta Mar 25, 2021
1b0e5f1
remove fast tokenizer for now
patrickvonplaten Mar 25, 2021
1ff2ff0
fix some
thevasudevgupta Mar 25, 2021
87a4e8c
add pad test
thevasudevgupta Mar 25, 2021
b20906c
finish
patrickvonplaten Mar 28, 2021
00cd6fb
fix some bugs
patrickvonplaten Mar 28, 2021
a719f1f
fix another bug
patrickvonplaten Mar 28, 2021
66fbec6
fix buffer tokens
thevasudevgupta Mar 29, 2021
184d361
:wqalMerge branch 'master' of https://github.com/huggingface/transfor…
patrickvonplaten Mar 29, 2021
1af7c98
Merge branch 'add_big_bird' of https://github.com/vasudevgupta7/trans…
patrickvonplaten Mar 29, 2021
aca2b4b
fix comment and merge from master
patrickvonplaten Mar 29, 2021
ef673bb
add comments
thevasudevgupta Mar 29, 2021
a6018bf
make style
patrickvonplaten Mar 29, 2021
58ef450
Merge branch 'master' of https://github.com/huggingface/transformers …
patrickvonplaten Mar 29, 2021
8a47841
commit some suggestions
thevasudevgupta Mar 29, 2021
dbc6e39
Fix typos
sgugger Mar 29, 2021
25164b9
fix some more suggestions
thevasudevgupta Mar 29, 2021
7bbbd6b
add another patch
thevasudevgupta Mar 29, 2021
ab6755e
fix copies
thevasudevgupta Mar 29, 2021
a9779b2
another path
thevasudevgupta Mar 29, 2021
df70258
update
thevasudevgupta Mar 29, 2021
0f110c5
update nit suggestions
thevasudevgupta Mar 29, 2021
8332604
make style
patrickvonplaten Mar 30, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fix base prefix
  • Loading branch information
thevasudevgupta committed Feb 25, 2021
commit ce66bac693c6a6182308ebdf1b30e414f3b5092e
35 changes: 15 additions & 20 deletions src/transformers/models/big_bird/modeling_big_bird.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def load_tf_weights_in_big_bird(model, tf_checkpoint_path):
logger.info("Weights not initialized in PyTorch model: {}".format(", ".join(pt_names)))
thevasudevgupta marked this conversation as resolved.
Show resolved Hide resolved
return model


# TODO: enable `relative_position_embedding`incase of `block_sparse`
class BigBirdEmbeddings(nn.Module):
"""Construct the embeddings from word, position and token_type embeddings."""

Expand Down Expand Up @@ -366,7 +366,7 @@ def forward(
context_layer = context_layer.view(*new_context_layer_shape)
# TODO
# print(context_layer.shape)
self.attn_o = context_layer.view(2,128,12,64)
# self.attn_o = context_layer.view(2,128,12,64)
#
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)

Expand Down Expand Up @@ -1398,7 +1398,7 @@ class BigBirdPreTrainedModel(PreTrainedModel):

config_class = BigBirdConfig
load_tf_weights = load_tf_weights_in_big_bird
base_model_prefix = "big_bird"
base_model_prefix = "bert"
_keys_to_ignore_on_load_missing = [r"position_ids"]

def _init_weights(self, module):
Expand Down Expand Up @@ -1760,9 +1760,6 @@ def __init__(self, config):

self.init_weights()

def get_input_embeddings(self):
return self.bert.get_input_embeddings()

def get_output_embeddings(self):
return self.cls.predictions.decoder

Expand Down Expand Up @@ -1874,9 +1871,6 @@ def __init__(self, config):

self.init_weights()

def get_input_embeddings(self):
return self.bert.get_input_embeddings()

def get_output_embeddings(self):
return self.cls.predictions.decoder

Expand Down Expand Up @@ -1961,7 +1955,7 @@ def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_

return {"input_ids": input_ids, "attention_mask": attention_mask}


# TODO: check BigBird as decoder
@add_start_docstrings(
"""BigBird Model with a `language modeling` head on top for CLM fine-tuning. """, BIG_BIRD_START_DOCSTRING
)
Expand All @@ -1975,7 +1969,7 @@ def __init__(self, config):
if not config.is_decoder:
logger.warning("If you want to use `BigBirdForCausalLM` as a standalone, add `is_decoder=True.`")

self.big_bird = BigBirdModel(config)
self.bert = BigBirdModel(config)
self.cls = BigBirdOnlyMLMHead(config)

self.init_weights()
Expand Down Expand Up @@ -2047,7 +2041,7 @@ def forward(
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

outputs = self.big_bird(
outputs = self.bert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
Expand Down Expand Up @@ -2136,7 +2130,7 @@ class BigBirdForSequenceClassification(BigBirdPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.big_bird = BigBirdModel(config)
self.bert = BigBirdModel(config)
self.classifier = BigBirdClassificationHead(config)

self.init_weights()
Expand Down Expand Up @@ -2170,7 +2164,7 @@ def forward(
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

outputs = self.big_bird(
outputs = self.bert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
Expand Down Expand Up @@ -2206,6 +2200,7 @@ def forward(
attentions=outputs.attentions,
)


@add_start_docstrings(
"""BigBird Model with a multiple choice classification head on top (a linear layer on top of
the pooled output and a softmax) e.g. for RocStories/SWAG tasks. """,
Expand All @@ -2215,7 +2210,7 @@ class BigBirdForMultipleChoice(BigBirdPreTrainedModel):
def __init__(self, config):
super().__init__(config)

self.big_bird = BigBirdModel(config)
self.bert = BigBirdModel(config)
self.sequence_summary = SequenceSummary(config)
self.classifier = nn.Linear(config.hidden_size, 1)

Expand Down Expand Up @@ -2260,7 +2255,7 @@ def forward(
else None
)

outputs = self.big_bird(
outputs = self.bert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
Expand Down Expand Up @@ -2305,7 +2300,7 @@ def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels

self.big_bird = BigBirdModel(config)
self.bert = BigBirdModel(config)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)

Expand Down Expand Up @@ -2338,7 +2333,7 @@ def forward(
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

outputs = self.big_bird(
outputs = self.bert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
Expand Down Expand Up @@ -2393,7 +2388,7 @@ def __init__(self, config):
config.num_labels = 2
self.num_labels = config.num_labels

self.big_bird = BigBirdModel(config)
self.bert = BigBirdModel(config)
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)

self.init_weights()
Expand Down Expand Up @@ -2431,7 +2426,7 @@ def forward(
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

outputs = self.big_bird(
outputs = self.bert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
Expand Down