Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds ATen fallback for scaled_dot_product_attention #21107

Merged
merged 35 commits into from
Jul 22, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
f22b8dc
attn aten fallback
Jun 19, 2024
612e425
use correct operator names
Jun 20, 2024
bdcfebb
formatting
Jun 20, 2024
80c3107
add unit test
Jun 20, 2024
2b29b4c
formatting
Jun 20, 2024
d2b8566
use pytorch sdpa kernel
Jun 26, 2024
0ca8fa0
bug fix
Jun 26, 2024
8999ff2
lint
Jun 26, 2024
6bf3018
use different kernel
Jun 27, 2024
35bd07a
formatting
Jun 27, 2024
dd1849a
include Peng's & Vincent's editS
Jul 2, 2024
8219ec9
adjust test and comments
prathikr Jul 2, 2024
65c2cb7
move import inside test
prathikr Jul 2, 2024
b5f5863
merge with master
prathikr Jul 2, 2024
18648ad
feature flag
prathikr Jul 2, 2024
be9ce0a
add documentation
prathikr Jul 2, 2024
e269e89
minor fixes
prathikr Jul 2, 2024
d3cc487
doc update
prathikr Jul 2, 2024
f500528
peng fix, xavier suggestion
prathikr Jul 8, 2024
c4cdab6
bug fix
prathikr Jul 8, 2024
c05a5ee
bug fix
prathikr Jul 8, 2024
f82bd48
bug fix
prathikr Jul 8, 2024
668409b
adjust unit test
prathikr Jul 8, 2024
b5f1169
adjust checks
prathikr Jul 9, 2024
31becab
grad input fix
prathikr Jul 9, 2024
5aa147d
handle both with and without bias
prathikr Jul 9, 2024
37eb6bc
full mask
prathikr Jul 9, 2024
ae3b5e7
merge with main
prathikr Jul 12, 2024
3484926
lint
prathikr Jul 12, 2024
8d0e879
add version check for tesT
prathikr Jul 15, 2024
b72a042
grad output adjustment
prathikr Jul 16, 2024
6b4dd10
add more docs
prathikr Jul 16, 2024
999b04b
remove support for masked attention
prathikr Jul 17, 2024
b1fe489
adjust docs
prathikr Jul 18, 2024
4ab54e6
Merge remote-tracking branch 'origin' into prathikrao/attn-aten-fallback
prathikr Jul 22, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
handle both with and without bias
  • Loading branch information
prathikr committed Jul 9, 2024
commit 5aa147d90c47025dc652a3266f34ed5f72854539
4 changes: 3 additions & 1 deletion docs/ORTModule_Training_Guidelines.md
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,9 @@ A classical usage of disabling the deep copy: when the deep copy before module e
- **Description**: By default, this is disabled. This env var can be used for enabling pre-export attention fall back to PyTorch's efficient_attention ATen kernel for execution.

```bash
export ORTMODULE_ATEN_SDPA_FALLBACK=1 # ENABLE
export ORTMODULE_ATEN_SDPA_FALLBACK=1 # ENABLE **WITHOUT** ATTN_MASK INPUT
export ORTMODULE_ATEN_SDPA_FALLBACK=MASKED # ENABLE **WITH** ATTN_MASK INPUT
unset ORTMODULE_ATEN_SDPA_FALLBACK # DISABLE
```

### 2.2 Memory Optimization
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -285,12 +285,13 @@ def upsample_bicubic2d_gradient():
# https://github.com/pytorch/pytorch/blob/c12a4f2e65ad41b739aab1a261e2336b4a79fcfb/aten/src/ATen/native/native_functions.yaml#L14784
@register_gradient("org.pytorch.aten", "ATen", "_scaled_dot_product_efficient_attention", "")
def scaled_dot_product_attention_gradient():
grad_input_mask = [1, 1, 1, 1] if ATEN_SDPA_FALLBACK.upper() == "MASKED" else [1, 1, 1, 0]
return [
(
"Constant",
[],
["grad_input_mask"],
{"value": {"value": [1, 1, 1, 0], "dtype": "int", "is_tensor": True}},
{"value": {"value": grad_input_mask, "dtype": "int", "is_tensor": True}},
),
(
("ATen", "org.pytorch.aten"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6930,15 +6930,13 @@
def test_aten_attention():
from torch.nn.attention import SDPBackend, sdpa_kernel

os.environ["ORTMODULE_ATEN_SDPA_FALLBACK"] = "1"

class _NeuralNetAttention(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, q, k, v):
def forward(self, q, k, v, attn_mask=None):
with sdpa_kernel(SDPBackend.EFFICIENT_ATTENTION):
return torch.nn.functional.scaled_dot_product_attention(q, k, v)
return torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask)

def gen_inputs(device, dtype):
return [
Expand All @@ -6947,15 +6945,18 @@
torch.randn(32, 8, 128, 64, dtype=dtype, device=device, requires_grad=True),
]

def run_step(model, inputs, attn_mask=None):
prediction = model(*inputs, attn_mask)
prediction.sum().backward()
return prediction

device = "cuda"

os.environ["ORTMODULE_ATEN_SDPA_FALLBACK"] = "1" # TESTING WITHOUT ATTN_MASK

pt_model = _NeuralNetAttention().to(device)
ort_model = ORTModule(copy.deepcopy(pt_model), DebugOptions(save_onnx=True, onnx_prefix="mem_eff_attn"))

def run_step(model, inputs):
prediction = model(*inputs)
prediction.sum().backward()
return prediction

# reset manual seed to reset the generator
torch.manual_seed(2333)
pt_input = gen_inputs(device=device, dtype=torch.float32)
Expand Down Expand Up @@ -6990,4 +6991,44 @@

assert mem_eff_attn_nodes > 0, "No mem_eff_attn nodes are found"

os.environ["ORTMODULE_ATEN_SDPA_FALLBACK"] = "MASKED" # TESTING WITH ATTN_MASK

pt_model = _NeuralNetAttention().to(device)
ort_model = ORTModule(copy.deepcopy(pt_model), DebugOptions(save_onnx=True, onnx_prefix="mem_eff_attn_masked"))

# reset manual seed to reset the generator
torch.manual_seed(2333)
pt_input = gen_inputs(device=device, dtype=torch.float32)
attn_mask = torch.randint(2, (32, 8, 128, 128), dtype=torch.float32, device=device, requires_grad=True)
ort_input = copy.deepcopy(pt_input)
pt_prediction = run_step(pt_model, pt_input, attn_mask)
ort_prediction = run_step(ort_model, ort_input, attn_mask)

_test_helpers.assert_values_are_close(ort_prediction, pt_prediction)
_test_helpers.assert_values_are_close(ort_input[0].grad, pt_input[0].grad)
_test_helpers.assert_values_are_close(ort_input[1].grad, pt_input[1].grad)
_test_helpers.assert_values_are_close(ort_input[2].grad, pt_input[2].grad)

execution_mgr = ort_model._torch_module._execution_manager._training_manager
from onnxruntime.training.ortmodule._onnx_models import _get_onnx_file_name

path = os.path.join(
execution_mgr._debug_options.save_onnx_models.path,
_get_onnx_file_name(
execution_mgr._debug_options.save_onnx_models.name_prefix, "execution_model", execution_mgr._export_mode
),
)

onnx_model = onnx.load(path)
onnx_nodes = onnx_model.graph.node
Dismissed Show dismissed Hide dismissed

mem_eff_attn_nodes = 0
for node in onnx_nodes:
if "ATen" in node.name:
for attr in node.attribute:
if b"_scaled_dot_product_efficient_attention" in attr.s:
mem_eff_attn_nodes += 1

assert mem_eff_attn_nodes > 0, "No mem_eff_attn nodes are found"

del os.environ["ORTMODULE_ATEN_SDPA_FALLBACK"]
Loading