Skip to content

Commit

Permalink
Add huggingface gpt2 fake tensor unit test for torch.onnx.dynamo_expo…
Browse files Browse the repository at this point in the history
…rt (pytorch#115380)

open llama, dolly v2 and falcon are still broken regardless of `ExportedProgram`, so they were not moved from `test_fx_to_onnx.py` to `fx_to_onnx_onnxruntime.py`.

Dolly and falcon already have tracking issues, but a tracking issue was created for open llama: pytorch#115552

A tracking issue was created for `xfail_if_model_type_is_exportedprogram` and `xfail_if_model_type_is_not_exportedprogram` issues with unexpected success runs: pytorch#115747
Pull Request resolved: pytorch#115380
Approved by: https://github.com/titaiwangms
  • Loading branch information
Thiago Crepaldi authored and dmenig committed Dec 21, 2023
1 parent 778cf1a commit fac26c2
Showing 1 changed file with 44 additions and 1 deletion.
45 changes: 44 additions & 1 deletion test/onnx/test_fx_to_onnx_with_onnxruntime.py
Original file line number Diff line number Diff line change
Expand Up @@ -1125,7 +1125,7 @@ def create_kwargs():
"AssertionError: Dynamic shape check failed for graph inputs",
skip_model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM,
)
def test_large_scale_exporter_with_tiny_gpt2(self):
def test_fake_tensor_mode_huggingface_tiny_gpt2(self):
model_name = "sshleifer/tiny-gpt2"
device = "cpu"

Expand Down Expand Up @@ -1345,6 +1345,49 @@ def create_model():
model_type=self.model_type,
)

@pytorch_test_common.xfail_if_model_type_is_not_exportedprogram(
"AssertionError: Expected 5 inputs, got 3"
"Github issue: https://github.com/pytorch/pytorch/issues/115745"
)
@pytorch_test_common.skip_dynamic_fx_test(
"AssertionError: Dynamic shape check failed for graph inputs",
skip_model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM,
)
def test_fake_tensor_mode_huggingface_gpt2(self):
config = transformers.GPT2Config(
vocab_size=8096, n_positions=256, n_embd=256, n_layer=2, n_head=2
)

def create_model():
return transformers.GPT2Model(config).eval()

def create_args():
return tuple()

def create_kwargs():
batch, seq = 4, 256

input_ids = torch.randint(0, config.vocab_size, (batch, seq))
attention_mask = torch.ones(batch, seq, dtype=torch.bool)
position_ids = torch.arange(0, seq, dtype=torch.long)
position_ids = position_ids.unsqueeze(0).view(-1, seq)

return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"position_ids": position_ids,
}

self._test_fake_tensor_mode_exporter(
"huggingface_gpt2",
create_model,
create_args,
create_kwargs,
load_checkpoint_during_init=self.load_checkpoint_during_init,
export_within_fake_mode=self.export_within_fake_mode,
model_type=self.model_type,
)


if __name__ == "__main__":
common_utils.run_tests()

0 comments on commit fac26c2

Please sign in to comment.