Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds ATen fallback for scaled_dot_product_attention #21107

Merged
merged 35 commits into from
Jul 22, 2024
Merged
Changes from 1 commit
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
f22b8dc
attn aten fallback
Jun 19, 2024
612e425
use correct operator names
Jun 20, 2024
bdcfebb
formatting
Jun 20, 2024
80c3107
add unit test
Jun 20, 2024
2b29b4c
formatting
Jun 20, 2024
d2b8566
use pytorch sdpa kernel
Jun 26, 2024
0ca8fa0
bug fix
Jun 26, 2024
8999ff2
lint
Jun 26, 2024
6bf3018
use different kernel
Jun 27, 2024
35bd07a
formatting
Jun 27, 2024
dd1849a
include Peng's & Vincent's editS
Jul 2, 2024
8219ec9
adjust test and comments
prathikr Jul 2, 2024
65c2cb7
move import inside test
prathikr Jul 2, 2024
b5f5863
merge with master
prathikr Jul 2, 2024
18648ad
feature flag
prathikr Jul 2, 2024
be9ce0a
add documentation
prathikr Jul 2, 2024
e269e89
minor fixes
prathikr Jul 2, 2024
d3cc487
doc update
prathikr Jul 2, 2024
f500528
peng fix, xavier suggestion
prathikr Jul 8, 2024
c4cdab6
bug fix
prathikr Jul 8, 2024
c05a5ee
bug fix
prathikr Jul 8, 2024
f82bd48
bug fix
prathikr Jul 8, 2024
668409b
adjust unit test
prathikr Jul 8, 2024
b5f1169
adjust checks
prathikr Jul 9, 2024
31becab
grad input fix
prathikr Jul 9, 2024
5aa147d
handle both with and without bias
prathikr Jul 9, 2024
37eb6bc
full mask
prathikr Jul 9, 2024
ae3b5e7
merge with main
prathikr Jul 12, 2024
3484926
lint
prathikr Jul 12, 2024
8d0e879
add version check for tesT
prathikr Jul 15, 2024
b72a042
grad output adjustment
prathikr Jul 16, 2024
6b4dd10
add more docs
prathikr Jul 16, 2024
999b04b
remove support for masked attention
prathikr Jul 17, 2024
b1fe489
adjust docs
prathikr Jul 18, 2024
4ab54e6
Merge remote-tracking branch 'origin' into prathikrao/attn-aten-fallback
prathikr Jul 22, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
full mask
  • Loading branch information
prathikr committed Jul 9, 2024
commit 37eb6bc2dda88c7b65ee67daa67327bd178a702e
Original file line number Diff line number Diff line change
Expand Up @@ -6999,7 +6999,7 @@
# reset manual seed to reset the generator
torch.manual_seed(2333)
pt_input = gen_inputs(device=device, dtype=torch.float32)
attn_mask = torch.randint(2, (32, 8, 128, 128), dtype=torch.float32, device=device, requires_grad=True)
attn_mask = torch.ones(32, 8, 128, 128, dtype=torch.float32, device=device, requires_grad=True)
ort_input = copy.deepcopy(pt_input)
pt_prediction = run_step(pt_model, pt_input, attn_mask)
ort_prediction = run_step(ort_model, ort_input, attn_mask)
Expand All @@ -7020,7 +7020,7 @@
)

onnx_model = onnx.load(path)
onnx_nodes = onnx_model.graph.node
Dismissed Show dismissed Hide dismissed

mem_eff_attn_nodes = 0
for node in onnx_nodes:
Expand Down
Loading