Skip to content

Commit

Permalink
[BigBird Pegasus] Make tests faster (huggingface#11744)
Browse files Browse the repository at this point in the history
* improve tests

* remove bogus file

* make style

Co-authored-by: Patrick von Platen <patrick@huggingface.co>
  • Loading branch information
patrickvonplaten and Patrick von Platen committed May 17, 2021
1 parent a0531c8 commit 73893fc
Showing 1 changed file with 12 additions and 5 deletions.
17 changes: 12 additions & 5 deletions tests/test_modeling_bigbird_pegasus.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,17 +368,24 @@ def test_batched_forward_block_sparse(self):
self._check_batched_forward(attn_type="block_sparse", tolerance=1e-1)

def _check_batched_forward(self, attn_type, tolerance=1e-3):
config = BigBirdPegasusConfig(block_size=16, attention_type=attn_type)
config, _ = self.model_tester.prepare_config_and_inputs()
config.max_position_embeddings = 128
config.block_size = 16
config.attention_type = attn_type
model = BigBirdPegasusForConditionalGeneration(config).to(torch_device)
model.eval()

sample_with_padding = [3, 8, 11] * 128 + [0] * 128
sample_without_padding = [4, 7, 9, 13] * 128
chunk_length = 32

sample_with_padding = [3, 8, 11] * chunk_length + [0] * chunk_length
sample_without_padding = [4, 7, 9, 13] * chunk_length
target_ids_without_padding = [2, 3] * 8
target_ids_with_padding = [7, 8] * 6 + 4 * [-100]

attention_mask = torch.tensor(
[[1] * 3 * 128 + [0] * 128, [1] * 4 * 128], device=torch_device, dtype=torch.long
[[1] * 3 * chunk_length + [0] * chunk_length, [1] * 4 * chunk_length],
device=torch_device,
dtype=torch.long,
)

input_ids = torch.tensor([sample_with_padding, sample_without_padding], device=torch_device, dtype=torch.long)
Expand All @@ -390,7 +397,7 @@ def _check_batched_forward(self, attn_type, tolerance=1e-3):
logits_batched = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels).logits

with torch.no_grad():
logits_single_first = model(input_ids=input_ids[:1, :-128], labels=labels[:1]).logits
logits_single_first = model(input_ids=input_ids[:1, :-chunk_length], labels=labels[:1]).logits

self.assertTrue(torch.allclose(logits_batched[0, -3:], logits_single_first[0, -3:], atol=tolerance))

Expand Down

0 comments on commit 73893fc

Please sign in to comment.