diff --git a/test/test_jit.py b/test/test_jit.py index 4fb9d2ca618f4..cc7da8deb88cd 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -11361,6 +11361,236 @@ def test_sizes(x): self.checkScript(test_sizes, (torch.rand(5, 4, 3, 2, 1),)) + # to avoid defining sum_list in multiple tests + def get_sum_list_fn(self): + def sum_list(a): + # type: (List[int]) -> int + sum = 0 + for i in a: + sum += i + + return sum + + return sum_list + + def test_sum_list_diff_elms(self): + self.checkScript(self.get_sum_list_fn(), ([1, 2, 3, 4, 5],)) + + def test_sum_list_empty(self): + self.checkScript(self.get_sum_list_fn(), ([],)) + + def test_sum_list_one(self): + self.checkScript(self.get_sum_list_fn(), ([1],)) + + def test_sum_list_literal(self): + + def sum_list(): + # type: () -> int + sum = 0 + for i in [1, 2, 3, 4, 5]: + sum += i + + return sum + + self.checkScript(sum_list, ()) + + def test_sum_list_wrong_type(self): + + with self.assertRaisesRegex(RuntimeError, "'int' object is not iterable"): + @torch.jit.script + def sum_list(a): + # type: (int) -> int + sum = 0 + for i in a: # noqa: T484 + sum += i + + return sum + + sum_list(1) + + def test_list_iterables(self): + with self.assertRaisesRegex(RuntimeError, 'List of iterables is not supported currently'): + cu = torch.jit.CompilationUnit(''' + def list_iterables(x): + for i, j in [2, 3, 4], [5, 6, 7]: + x += i + x += j + return x + ''') + + def test_for_in_string(self): + def test_strings(x): + # type: (str) -> str + reverse = "" + for c in x: + reverse = c + reverse + return reverse + + self.checkScript(test_strings, ("hello",)) + self.checkScript(test_strings, ("",)) + + def test_list_strings(x): + # type: (List[str]) -> str + result = "" + for sub_str in x: + result += sub_str + return result + + self.checkScript(test_list_strings, (["hello", "world"],)) + self.checkScript(test_list_strings, (["hello", " ", "world", ""],)) + + def test_for_in_dict(self): + def test_dicts(x): + # type: (Dict[str, int]) -> int + sum = 0 + for key in x: + sum += x[key] + return sum + + self.checkScript(test_dicts, ({"a": 1, "b": 2, "c": 3},)) + + def test_dict_keys_values(x): + # type: (Dict[str, int]) -> Tuple[str, int] + key_str = "" + sum = 0 + for key in x.keys(): + key_str += key + for val in x.values(): + sum += val + return key_str, sum + + self.checkScript(test_dicts, ({"a": 1, "b": 2, "c": 3},)) + + def test_for_tuple_unpack(self): + def for_tuple_unpack(x, y): + for i, j in [[3, 4], [5, 6], [7, 8]]: + x += i + y += j + return x, y + + self.checkScript(for_tuple_unpack, (torch.tensor(3), torch.tensor(5))) + + def nested_tuple_unpack(x, y): + # type: (List[int], List[int]) -> int + sum = 0 + for i, (j, k), v in zip(x, enumerate(x), y): + sum += i + j + k + v + return sum + + self.checkScript(nested_tuple_unpack, ([1, 3, 5], [2, 4, 6])) + + def test_for_tuple_assign(self): + def test_simple_assign(x): + # type: (Tuple[int, float]) -> float + sum = 0.0 + for a in x: + sum += float(a) + return sum + + self.checkScript(test_simple_assign, ((1, 2.5),)) + + def test_tuple_assign(x): + # type: (Tuple[Tuple[int, int], Tuple[int, int]]) -> int + sum = 0 + for a in x: + sum += a[0] + sum += a[1] + return sum + + self.checkScript(test_tuple_assign, (((1, 2), (4, 7)), )) + + def test_single_starred_lhs(self): + with self.assertRaisesRegex(RuntimeError, 'A Starred expression may only appear on the lhs within the presence' + ' of another non-starred expression'): + cu = torch.jit.CompilationUnit(''' + def single_starred_lhs(x): + a = (x, x, x) + *b, = a + return b + ''') + + def test_singleton_tuple_unpack(self): + def foo(a): + b, = (a,) + return b + 1 + self.checkScript(foo, (torch.rand(3),)) + + def test_tuple_assignments(self): + def var_tuple_assign(x, y): + # type: (Tuple[Tensor, Tensor], Tensor) -> Tensor + (a, b), c = x, y + return a + b + c + + tuple_inputs = (torch.randn(1, 4), torch.randn(3, 4)) + self.checkScript(var_tuple_assign, (tuple_inputs, torch.randn(3, 4))) + + def nested_tuple_assign(x, y, z): + # type: (int, Tuple[int, Tuple[int, int]], Tuple[int, int]) -> int + a, (b, (c, d)), (e, f) = x, y, z + return a + b + c + d + e + f + + self.checkScript(nested_tuple_assign, ((1, (2, (3, 4)), (5, 6)))) + + def subscript_tuple_assign(a, x, i): + # type: (List[int], Tensor, int) -> Tuple[int, Tensor, int] + a[i], (x[i], b) = 1, (2, 3) + return a[i] + 1, x + 5, b + + self.checkScript(subscript_tuple_assign, ([12, 7, 9, 11], torch.tensor((3, 13, 17)), 0)) + + def star_tuple_assign(): + # type: () -> Tuple[int, int, Tuple[int, int], Tuple[int, int]] + a, (b, *c), *d = 1, (2, 3, 4), 5, 6 + return a, b, c, d + + self.checkScript(star_tuple_assign, ()) + + def subscript_tuple_augmented_assign(a): + # type: (Tuple[int, int]) -> Tuple[int, int] + a[0] += 1 + return a + + with self.assertRaisesRegex(RuntimeError, 'does not support augmented assign'): + scripted_aug_assign = torch.jit.script(subscript_tuple_augmented_assign) + + class AttrTupleAssignmentTestClass: + def __init__(self, a: int, b: int): + self.a = a + self.b = b + + def set_ab(self, a: int, b: int): + self.a, self.b = (a, b) + + def get(self) -> Tuple[int, int]: + return (self.a, self.b) + + make_global(AttrTupleAssignmentTestClass) + + @torch.jit.script + def attr_tuple_assignment(o: AttrTupleAssignmentTestClass, a: int, b: int): + o.set_ab(a, b) + return o + + o = AttrTupleAssignmentTestClass(1, 2) + self.assertEqual(attr_tuple_assignment(o, 3, 4).get(), (3, 4)) + + def test_multiple_assign(self): + def test(): + a = b, c = d, f = (1, 1) + + # side effect + ten = torch.tensor(1) + ten1 = ten2 = ten.add_(1) + + # ordering + x = 1 + y = 3 + x, y = y, x + y + + return a, b, c, d, f, ten, ten1, ten2, x, y + + self.checkScript(test, ()) + def test_multi_reduction(self): with self.assertRaisesRegex( RuntimeError, diff --git a/torch/csrc/jit/frontend/ir_emitter.cpp b/torch/csrc/jit/frontend/ir_emitter.cpp index 9ed068e10455b..30d9351a7a1e3 100644 --- a/torch/csrc/jit/frontend/ir_emitter.cpp +++ b/torch/csrc/jit/frontend/ir_emitter.cpp @@ -2030,7 +2030,7 @@ struct to_ir { size_t num_starred = 0; for (const auto& assignee : lhs) { if (assignee.kind() == TK_VAR || assignee.kind() == TK_SUBSCRIPT || - assignee.kind() == TK_TUPLE_LITERAL) { + assignee.kind() == TK_TUPLE_LITERAL || assignee.kind() == '.') { num_normal_assign++; } else if (assignee.kind() == TK_STARRED) { num_starred++; @@ -2482,6 +2482,10 @@ struct to_ir { sub_starred_unpack); i++; } break; + case '.': { + emitSelectAssign(assignee, outputs.at(i), rhs_loc); + i++; + } break; default: throw ErrorReport(assignee) << "unexpected expression on the left-hand side"; @@ -2596,6 +2600,16 @@ struct to_ir { lhsObject->setAttr(stmt.range(), method, lhs.selector().name(), rhsValue); } + void emitSelectAssign( + const Expr& lhs, + SugaredValuePtr rhs, + const SourceRange& loc) { + const auto lhs_select = Select(lhs); + auto lhs_sv = emitSugaredExpr(lhs_select.value(), 1); + const auto rhs_value = rhs->asValue(loc, method); + lhs_sv->setAttr(loc, method, lhs_select.selector().name(), rhs_value); + } + NodeKind getNodeKind(int kind, int ninputs) { switch (kind) { case '+':