Skip to content

Commit

Permalink
[JIT] Allow unpacking tuple and assign their values to SELECT-type ex…
Browse files Browse the repository at this point in the history
…pressions (pytorch#55268)

Summary:
Fixes pytorch#51176

Pull Request resolved: pytorch#55268

Reviewed By: pbelevich, izdeby

Differential Revision: D27551950

Pulled By: gmagogsfm

fbshipit-source-id: 35324b728649bb1e6c5410a1004d2f6964f98304
  • Loading branch information
gmagogsfm authored and facebook-github-bot committed Apr 11, 2021
1 parent b80c6f8 commit fa29a64
Show file tree
Hide file tree
Showing 2 changed files with 245 additions and 1 deletion.
230 changes: 230 additions & 0 deletions test/test_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -11361,6 +11361,236 @@ def test_sizes(x):

self.checkScript(test_sizes, (torch.rand(5, 4, 3, 2, 1),))

# to avoid defining sum_list in multiple tests
def get_sum_list_fn(self):
def sum_list(a):
# type: (List[int]) -> int
sum = 0
for i in a:
sum += i

return sum

return sum_list

def test_sum_list_diff_elms(self):
self.checkScript(self.get_sum_list_fn(), ([1, 2, 3, 4, 5],))

def test_sum_list_empty(self):
self.checkScript(self.get_sum_list_fn(), ([],))

def test_sum_list_one(self):
self.checkScript(self.get_sum_list_fn(), ([1],))

def test_sum_list_literal(self):

def sum_list():
# type: () -> int
sum = 0
for i in [1, 2, 3, 4, 5]:
sum += i

return sum

self.checkScript(sum_list, ())

def test_sum_list_wrong_type(self):

with self.assertRaisesRegex(RuntimeError, "'int' object is not iterable"):
@torch.jit.script
def sum_list(a):
# type: (int) -> int
sum = 0
for i in a: # noqa: T484
sum += i

return sum

sum_list(1)

def test_list_iterables(self):
with self.assertRaisesRegex(RuntimeError, 'List of iterables is not supported currently'):
cu = torch.jit.CompilationUnit('''
def list_iterables(x):
for i, j in [2, 3, 4], [5, 6, 7]:
x += i
x += j
return x
''')

def test_for_in_string(self):
def test_strings(x):
# type: (str) -> str
reverse = ""
for c in x:
reverse = c + reverse
return reverse

self.checkScript(test_strings, ("hello",))
self.checkScript(test_strings, ("",))

def test_list_strings(x):
# type: (List[str]) -> str
result = ""
for sub_str in x:
result += sub_str
return result

self.checkScript(test_list_strings, (["hello", "world"],))
self.checkScript(test_list_strings, (["hello", " ", "world", ""],))

def test_for_in_dict(self):
def test_dicts(x):
# type: (Dict[str, int]) -> int
sum = 0
for key in x:
sum += x[key]
return sum

self.checkScript(test_dicts, ({"a": 1, "b": 2, "c": 3},))

def test_dict_keys_values(x):
# type: (Dict[str, int]) -> Tuple[str, int]
key_str = ""
sum = 0
for key in x.keys():
key_str += key
for val in x.values():
sum += val
return key_str, sum

self.checkScript(test_dicts, ({"a": 1, "b": 2, "c": 3},))

def test_for_tuple_unpack(self):
def for_tuple_unpack(x, y):
for i, j in [[3, 4], [5, 6], [7, 8]]:
x += i
y += j
return x, y

self.checkScript(for_tuple_unpack, (torch.tensor(3), torch.tensor(5)))

def nested_tuple_unpack(x, y):
# type: (List[int], List[int]) -> int
sum = 0
for i, (j, k), v in zip(x, enumerate(x), y):
sum += i + j + k + v
return sum

self.checkScript(nested_tuple_unpack, ([1, 3, 5], [2, 4, 6]))

def test_for_tuple_assign(self):
def test_simple_assign(x):
# type: (Tuple[int, float]) -> float
sum = 0.0
for a in x:
sum += float(a)
return sum

self.checkScript(test_simple_assign, ((1, 2.5),))

def test_tuple_assign(x):
# type: (Tuple[Tuple[int, int], Tuple[int, int]]) -> int
sum = 0
for a in x:
sum += a[0]
sum += a[1]
return sum

self.checkScript(test_tuple_assign, (((1, 2), (4, 7)), ))

def test_single_starred_lhs(self):
with self.assertRaisesRegex(RuntimeError, 'A Starred expression may only appear on the lhs within the presence'
' of another non-starred expression'):
cu = torch.jit.CompilationUnit('''
def single_starred_lhs(x):
a = (x, x, x)
*b, = a
return b
''')

def test_singleton_tuple_unpack(self):
def foo(a):
b, = (a,)
return b + 1
self.checkScript(foo, (torch.rand(3),))

def test_tuple_assignments(self):
def var_tuple_assign(x, y):
# type: (Tuple[Tensor, Tensor], Tensor) -> Tensor
(a, b), c = x, y
return a + b + c

tuple_inputs = (torch.randn(1, 4), torch.randn(3, 4))
self.checkScript(var_tuple_assign, (tuple_inputs, torch.randn(3, 4)))

def nested_tuple_assign(x, y, z):
# type: (int, Tuple[int, Tuple[int, int]], Tuple[int, int]) -> int
a, (b, (c, d)), (e, f) = x, y, z
return a + b + c + d + e + f

self.checkScript(nested_tuple_assign, ((1, (2, (3, 4)), (5, 6))))

def subscript_tuple_assign(a, x, i):
# type: (List[int], Tensor, int) -> Tuple[int, Tensor, int]
a[i], (x[i], b) = 1, (2, 3)
return a[i] + 1, x + 5, b

self.checkScript(subscript_tuple_assign, ([12, 7, 9, 11], torch.tensor((3, 13, 17)), 0))

def star_tuple_assign():
# type: () -> Tuple[int, int, Tuple[int, int], Tuple[int, int]]
a, (b, *c), *d = 1, (2, 3, 4), 5, 6
return a, b, c, d

self.checkScript(star_tuple_assign, ())

def subscript_tuple_augmented_assign(a):
# type: (Tuple[int, int]) -> Tuple[int, int]
a[0] += 1
return a

with self.assertRaisesRegex(RuntimeError, 'does not support augmented assign'):
scripted_aug_assign = torch.jit.script(subscript_tuple_augmented_assign)

class AttrTupleAssignmentTestClass:
def __init__(self, a: int, b: int):
self.a = a
self.b = b

def set_ab(self, a: int, b: int):
self.a, self.b = (a, b)

def get(self) -> Tuple[int, int]:
return (self.a, self.b)

make_global(AttrTupleAssignmentTestClass)

@torch.jit.script
def attr_tuple_assignment(o: AttrTupleAssignmentTestClass, a: int, b: int):
o.set_ab(a, b)
return o

o = AttrTupleAssignmentTestClass(1, 2)
self.assertEqual(attr_tuple_assignment(o, 3, 4).get(), (3, 4))

def test_multiple_assign(self):
def test():
a = b, c = d, f = (1, 1)

# side effect
ten = torch.tensor(1)
ten1 = ten2 = ten.add_(1)

# ordering
x = 1
y = 3
x, y = y, x + y

return a, b, c, d, f, ten, ten1, ten2, x, y

self.checkScript(test, ())

def test_multi_reduction(self):
with self.assertRaisesRegex(
RuntimeError,
Expand Down
16 changes: 15 additions & 1 deletion torch/csrc/jit/frontend/ir_emitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2030,7 +2030,7 @@ struct to_ir {
size_t num_starred = 0;
for (const auto& assignee : lhs) {
if (assignee.kind() == TK_VAR || assignee.kind() == TK_SUBSCRIPT ||
assignee.kind() == TK_TUPLE_LITERAL) {
assignee.kind() == TK_TUPLE_LITERAL || assignee.kind() == '.') {
num_normal_assign++;
} else if (assignee.kind() == TK_STARRED) {
num_starred++;
Expand Down Expand Up @@ -2482,6 +2482,10 @@ struct to_ir {
sub_starred_unpack);
i++;
} break;
case '.': {
emitSelectAssign(assignee, outputs.at(i), rhs_loc);
i++;
} break;
default:
throw ErrorReport(assignee)
<< "unexpected expression on the left-hand side";
Expand Down Expand Up @@ -2596,6 +2600,16 @@ struct to_ir {
lhsObject->setAttr(stmt.range(), method, lhs.selector().name(), rhsValue);
}

void emitSelectAssign(
const Expr& lhs,
SugaredValuePtr rhs,
const SourceRange& loc) {
const auto lhs_select = Select(lhs);
auto lhs_sv = emitSugaredExpr(lhs_select.value(), 1);
const auto rhs_value = rhs->asValue(loc, method);
lhs_sv->setAttr(loc, method, lhs_select.selector().name(), rhs_value);
}

NodeKind getNodeKind(int kind, int ninputs) {
switch (kind) {
case '+':
Expand Down

0 comments on commit fa29a64

Please sign in to comment.