Skip to content

Commit

Permalink
Dont wrap negative indexing in scatter reduce (pytorch#131503)
Browse files Browse the repository at this point in the history
Fix for pytorch#131321

Pull Request resolved: pytorch#131503
Approved by: https://github.com/shunting314
  • Loading branch information
eellison authored and pull[bot] committed Aug 4, 2024
1 parent c9cfb3a commit 9718f42
Show file tree
Hide file tree
Showing 11 changed files with 50 additions and 21 deletions.
16 changes: 16 additions & 0 deletions test/inductor/test_cuda_repro.py
Original file line number Diff line number Diff line change
Expand Up @@ -694,6 +694,22 @@ def fn(arg3_1, arg3_2, relu, permute_1):
ref = torch.compile(fn, fullgraph=True)(*args)
assert same(ref, correct)

def test_scatter_index_not_wrapped(self):
src = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], device=self.device)
index = torch.tensor([0, 1, 0, 1, 2, 0], device=self.device)
input = torch.tensor([1.0, 2.0, 3.0, 4.0], device=self.device)
compiled_sr = torch.compile(torch.scatter_reduce)

input_orig = input.clone()
out, code = run_and_get_code(compiled_sr, input, 0, index, src, "sum")
# tmp0 - not wrapping of negative numbers
FileCheck().check("tl.device_assert(((0 <= tmp0) & (tmp0 < 4))").check_next(
"atomic_add"
).run(code[0])
self.assertEqual(
out, torch.scatter_reduce(input_orig.clone(), 0, index, src, "sum")
)

def test_embedding_var_mean(self):
def forward(arg0_1):
full = torch.ops.aten.full.default(
Expand Down
18 changes: 12 additions & 6 deletions torch/_inductor/codegen/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1834,19 +1834,25 @@ def arg_to_bound(x):

@staticmethod
def indirect_indexing(
var: CSEVariable, size: Union[sympy.Expr, int], check: bool = True
var: CSEVariable,
size: Union[sympy.Expr, int],
check: bool = True,
wrap_neg=True,
):
if isinstance(size, int):
size = sympy.Integer(size)
assert isinstance(size, sympy.Expr), size
# Skip CSE since this doesn't return an expression

if var.bounds.lower < 0: # type: ignore[operator]
stm = ops.add(var, ops.index_expr(size, torch.long))
# Mixed negative and non-negative
if var.bounds.upper >= 0: # type: ignore[operator]
lt = ops.lt(var, 0)
stm = ops.where(lt, stm, var)
if wrap_neg:
stm = ops.add(var, ops.index_expr(size, torch.long))
# Mixed negative and non-negative
if var.bounds.upper >= 0: # type: ignore[operator]
lt = ops.lt(var, 0)
stm = ops.where(lt, stm, var)
else:
stm = var

# Propagate bounds as we know how to compute them properly
new_bounds = ValueRanges.unknown()
Expand Down
2 changes: 1 addition & 1 deletion torch/_inductor/codegen/cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -3058,7 +3058,7 @@ def can_use_int32():
return tmp_var

@staticmethod
def indirect_indexing(index_var, size, check=True):
def indirect_indexing(index_var, size, check=True, wrap_neg=True):
return sympy_index_symbol(str(index_var))

@staticmethod
Expand Down
2 changes: 1 addition & 1 deletion torch/_inductor/codegen/halide.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,7 +518,7 @@ def index_expr(cls, expr, dtype):
return var

@classmethod
def indirect_indexing(cls, index_var, size, check=True):
def indirect_indexing(cls, index_var, size, check=True, wrap_neg=True):
# TODO(jansel): Halide only supports 32-bit indexing, we should error on overflow
index_var = ops.to_dtype(index_var, torch.int32)
index_var = ops.halide_clamp(index_var, size, check)
Expand Down
4 changes: 3 additions & 1 deletion torch/_inductor/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -596,7 +596,9 @@ def inner(*args, **kwargs):

return inner

def indirect_indexing(self, index_var, size, check=True) -> sympy.Symbol:
def indirect_indexing(
self, index_var, size, check=True, wrap_neg=True
) -> sympy.Symbol:
assert not isinstance(index_var, (sympy.Expr, sympy.logic.boolalg.Boolean))
self.symbols |= free_unbacked_symbols(size)
return sympy_index_symbol(f"({str(index_var)})")
Expand Down
11 changes: 8 additions & 3 deletions torch/_inductor/index_propagation.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,11 @@ def statically_true(self, e):
return bool(evaluated)

def indirect_indexing(
self, index: Union[Any, IndexPropVar], size: Any, check: bool = True
self,
index: Union[Any, IndexPropVar],
size: Any,
check: bool = True,
wrap_neg=True,
) -> Any:
if isinstance(index, IndexPropVar) and index.is_symbolic:
# If we find something we can convert into a direct indexing we do so
Expand All @@ -354,7 +358,8 @@ def wrap_expr(expr):
-size <= expr
)
can_prove_upper = self.statically_true(expr < size)
expr = wrap_expr(expr)
if wrap_neg:
expr = wrap_expr(expr)
if generate_assert(check):
self.fallback(
"check_bounds",
Expand All @@ -364,6 +369,6 @@ def wrap_expr(expr):
return expr

indirect_var = self.fallback(
"indirect_indexing", (index, size, check), {}
"indirect_indexing", (index, size, check, wrap_neg), {}
).value
return indirect_var
4 changes: 2 additions & 2 deletions torch/_inductor/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -6810,7 +6810,7 @@ def frexp(self, value_proxy):
return (result[0], result[1])

@staticmethod
def indirect_indexing(index_proxy, size, check=True):
def indirect_indexing(index_proxy, size, check=True, wrap_neg=True):
"""
Flow data from tensors into indexing formulas.
Introduce a call_module to update the indexing.
Expand All @@ -6820,7 +6820,7 @@ def indirect_indexing(index_proxy, size, check=True):

def set_indirect(new_var):
self.body.replace_indirect(
var, V.ops.indirect_indexing(new_var, size, check)
var, V.ops.indirect_indexing(new_var, size, check, wrap_neg)
)

tracer.create_proxy(
Expand Down
2 changes: 1 addition & 1 deletion torch/_inductor/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -3368,7 +3368,7 @@ def output_indexer(idx):
ndim = len(shape)
indirect_idx = list(idx)
indirect_idx[dim] = ops.indirect_indexing(
index_loader(idx), 1 if ndim == 0 else shape[dim]
index_loader(idx), 1 if ndim == 0 else shape[dim], wrap_neg=False
)
return indirect_idx

Expand Down
6 changes: 3 additions & 3 deletions torch/_inductor/ops_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def identity(self, x: T) -> T:
# in scope, which are typically used by sympy.Expr indexing.

def indirect_indexing(
self, x: T, size: sympy.Expr, check: bool = True
self, x: T, size: sympy.Expr, check: bool = True, wrap_neg=True
) -> sympy.Expr:
"""
Convert an integral x into a sympy.Expr that can be subsequently used in
Expand Down Expand Up @@ -764,7 +764,7 @@ def sort(dtypes, values, stable, descending) -> Tuple[None, ...]:
return (None,) * len(values)

@staticmethod
def indirect_indexing(index_var, size, check=True) -> sympy.Symbol:
def indirect_indexing(index_var, size, check=True, wrap_neg=True) -> sympy.Symbol:
return sympy.Integer(0)


Expand Down Expand Up @@ -808,7 +808,7 @@ def sort(dtypes, values, stable, descending):
)

@staticmethod
def indirect_indexing(index_var, size, check=True) -> sympy.Symbol:
def indirect_indexing(index_var, size, check=True, wrap_neg=True) -> sympy.Symbol:
return sympy_index_symbol(str(index_var))

@classmethod
Expand Down
2 changes: 1 addition & 1 deletion torch/_inductor/select_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,7 @@ def load(self, name: str, index: sympy.Expr):

return f"({fixed_inputs[name]})"

def indirect_indexing(self, index_var, size, check):
def indirect_indexing(self, index_var, size, check, wrap_neg=True):
return sympy_index_symbol(str(index_var))

with V.set_ops_handler(PlaceholderSubstitution(V.ops)):
Expand Down
4 changes: 2 additions & 2 deletions torch/_inductor/virtualized.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,10 +283,10 @@ def _wrap(x):
return OpsValue(x)

@staticmethod
def indirect_indexing(index, size, check=True):
def indirect_indexing(index, size, check=True, wrap_neg=True):
# Returns a sympy value, not IR value
index = OpsWrapper._unwrap(index)
return _ops.indirect_indexing(index, size, check)
return _ops.indirect_indexing(index, size, check, wrap_neg)


ops = OpsWrapper()
Expand Down

0 comments on commit 9718f42

Please sign in to comment.