Skip to content

Commit

Permalink
Clean up DictLiteral and DictComprehension emission logic (pytorc…
Browse files Browse the repository at this point in the history
…h#64953)

Summary: Pull Request resolved: pytorch#64953

Test Plan: Imported from OSS

Reviewed By: jamesr66a

Differential Revision: D30914687

Pulled By: ansley

fbshipit-source-id: ab9b9192a29f05b90c113c678e7c795bc087dc99
  • Loading branch information
Ansley Ussery authored and facebook-github-bot committed Oct 15, 2021
1 parent a7b7903 commit a108440
Show file tree
Hide file tree
Showing 4 changed files with 498 additions and 380 deletions.
58 changes: 56 additions & 2 deletions test/jit/test_list_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,8 +297,8 @@ def fn():
self.checkScript(fn, ())

def test_dict_keyword_with_mismatched_annotations(self):
err_msg = r"is annotated with type Dict\[int, str\] but is " \
r"being assigned to a value of type Dict\[str, int\]"
err_msg = r"Dict type annotation `Dict\[int, str\]` did not " \
"match the type of an actual key type `str`"
with self.assertRaisesRegex(RuntimeError, err_msg):
@torch.jit.script
def fn():
Expand Down Expand Up @@ -1441,6 +1441,60 @@ def fn(x: Dict[str, int]) -> Dict[str, int]:
with self.assertRaisesRegexWithHighlight(RuntimeError, "KeyError", "x['hi']"):
self.checkScript(fn, [{}])

def test_dict_variance(self):
"""
`Dict[T1, _]` is not a subtype of `Dict[T2, _]`, even if `T1` is
a subtype of `T2`; similarly `Dict[_, T1]` would not be a
subtype of `Dict[_, T2]`.
However, if we have a temporary dict object (that is, a dict
comprehension or a dict literal) on the rhs of an assignment
statement, we want to ignore the inferred type of the rhs if we
can prove that: 1) both the lhs and the rhs are dicts with the
same key types (TorchScript has a restricted set of allowed key
types, so we don't need to worry about subtyping relationships
here), and 2) the value type of the dict is a subtype of the
value type of the rhs dict.
"""
def test_dictliteral_is_typed_from_annotation():
x: Dict[str, Optional[int]] = {"foo": None, "bar": None, "baz": None}
return x

self.checkScript(test_dictliteral_is_typed_from_annotation, ())

def test_dictcomprehension_is_typed_from_annotation():
metasyntactics = ["foo", "bar", "baz"]
x: Dict[str, Optional[int]] = {word: None for word in metasyntactics}
return x

self.checkScript(test_dictcomprehension_is_typed_from_annotation, ())

def test_dicts_with_different_value_types_are_invariant(self):
x: Dict[str, int] = {"foo": 1, "bar": 2, "baz": 3}
y: Dict[str, Optional[int]] = x
return x

with self.assertRaisesRegex(RuntimeError, "Variable 'y' is "
"annotated with type "
r"Dict\[str, Optional\[int\]\] but "
"is being assigned to a value of "
r"type Dict\[str, int\]"):
torch.jit.script(test_dicts_with_different_value_types_are_invariant)

def test_dicts_with_different_value_types_are_invariant_recursive(self):
x: Dict[str, int] = {"foo": 1, "bar": 2, "baz": 3}
y: Dict[str, Dict[str, int]] = {"foo": x, "bar": x, "baz": x}
z: Dict[str, Dict[str, Optional[int]]] = y
return x

with self.assertRaisesRegex(RuntimeError, "Variable 'z' is "
"annotated with type "
r"Dict\[str, Dict\[str, Optional"
r"\[int\]\]\] but is being assigned"
r" to a value of type Dict\[str, "
r"Dict\[str, int\]\]"):
torch.jit.script(test_dicts_with_different_value_types_are_invariant_recursive)

def test_keys(self):
@torch.jit.script
def keys(x: Dict[str, Tensor]) -> List[str]:
Expand Down
43 changes: 4 additions & 39 deletions test/jit/test_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import sys

import torch
from torch.testing import FileCheck
from torch.testing._internal.jit_utils import JitTestCase
from torch.testing._internal.common_utils import IS_WINDOWS
from collections import namedtuple
Expand Down Expand Up @@ -84,40 +83,6 @@ def fn():
"types of the given list elements"):
torch.jit.script(fn)

def test_dict_type_refinement_defaults_to_Any_dict_creation(self):
def fn(x):
d = dict(foo=torch.tensor(2),
bar={"23": torch.tensor(3)})
d["baz"] = x
t = d["foo"]
if isinstance(t, torch.Tensor):
d["bar"] = torch.add(t, t)
return d

self.checkScript(fn, (torch.arange(5),))

graph = torch.jit.script(fn).graph

FileCheck().check("Dict(str, Union[Tensor, Dict(str, Tensor)])"
" = prim::DictConstruct").run(graph)

def test_dict_type_refinement_defaults_to_Any_dict_comprehension(self):
def fn(x):
d = {"foo": torch.tensor(2),
"bar": {"23": torch.tensor(3)}}
d["baz"] = x
t = d["foo"]
if isinstance(t, torch.Tensor):
d["bar"] = torch.add(t, t)
return d

self.checkScript(fn, (torch.arange(5),))

graph = torch.jit.script(fn).graph

FileCheck().check("Dict(str, Union[Tensor, Dict(str, Tensor)])"
" = prim::DictConstruct").run(graph)

def test_dict_type_refinement_annotation_key_mismatch(self):
def fn():
l1 = [1, 2, "foo", 3]
Expand All @@ -138,10 +103,10 @@ def fn():
d: Dict[str, int] = {k : v for k, v in zip(l1, l2)}
return d

with self.assertRaisesRegex(RuntimeError, "annotated with type "
r"Dict\[str, int\] but is being "
"assigned to a value of type "
r"Dict\[str, Union\[int, str\]\]"):
with self.assertRaisesRegex(RuntimeError, "Dict type annotation"
r" `Dict\[str, int\]` did not match"
" the type of an actual value type"
r" `Union\[int, str\]`"):
torch.jit.script(fn)

def test_dict_invalid_annotations(self):
Expand Down
97 changes: 81 additions & 16 deletions test/jit/test_union.py
Original file line number Diff line number Diff line change
Expand Up @@ -783,7 +783,7 @@ def fn():
"not match the types of the given list "
"elements")

# TODO: Support mixed list comprehensions
# TODO(@ansley): Support mixed list comprehensions
self._assert_raises(template,
"Union[List[torch.Tensor], int]",
lhs["list_comprehension_of_mixed"],
Expand Down Expand Up @@ -822,7 +822,24 @@ def fn():
zip([\"foo\", \"bar\"], [torch.arange(3), 2])}",

"dict_keyword" :
"dict(foo=torch.arange(3), baz=torch.arange(5))"}
"dict(foo=torch.arange(3), baz=torch.arange(5))",

"dict_keyword_with_iterable" :
"dict([(\"foo\", torch.arange(3)), (\"bar\", torch.arange(5))])",

"dict_keyword_with_empty_iterable" :
"dict([])",

"dict_keyword_with_internal_aggregate_function" :
"dict(zip([\"foo\", \"bar\"], [torch.arange(3), torch.arange(5)])",

"dict_keyword_with_mapping" :
"dict({\"foo\" : torch.arange(3), \"bar\" : torch.arange(5)})",

"dict_keyword_with_mapping_and_kwargs" :
"dict({\"foo\" : torch.arange(3), \"bar\" : torch.arange(5)}, baz=torch.arange(7))",

}

"""
Union[Dict[str, torch.Tensor], Dict[str, int]]
Expand All @@ -843,8 +860,8 @@ def fn():

self._assert_raises(template, "Union[Dict[str, torch.Tensor], Dict[str, int]]",
lhs["dict_literal_of_mixed"],
"none of those types can hold the types "
"of the given dict elements")
"none of those dict types can hold the "
"types of the given keys and values")

# TODO: String frontend does not support tuple unpacking
# https://github.com/pytorch/pytorch/issues/64096
Expand All @@ -858,9 +875,37 @@ def fn():
# lhs["dict_comprehension_of_mixed"],
# "foobar")

self._assert_passes(template,
# self._assert_passes(template,
# "Union[Dict[str, torch.Tensor], Dict[str, int]]",
# lhs["dict_keyword_with_internal_aggregate_function"])

# TODO(@ansley): Follow-up project needed for full type
# inference with dict keyword (supported for dict comprehension
# and dict literal already; should not be a blocker for anyone)
self._assert_raises(template,
"Union[Dict[str, torch.Tensor], Dict[str, int]]",
lhs["dict_keyword"])
lhs["dict_keyword"],
"full type inference is not yet supported")

self._assert_raises(template,
"Union[Dict[str, torch.Tensor], Dict[str, int]]",
lhs["dict_keyword_with_iterable"],
"full type inference is not yet supported")

self._assert_raises(template,
"Union[Dict[str, torch.Tensor], Dict[str, int]]",
lhs["dict_keyword_with_empty_iterable"],
"full type inference is not yet supported")

self._assert_raises(template,
"Union[Dict[str, torch.Tensor], Dict[str, int]]",
lhs["dict_keyword_with_mapping"],
"full type inference is not yet supported")

self._assert_raises(template,
"Union[Dict[str, torch.Tensor], Dict[str, int]]",
lhs["dict_keyword_with_mapping_and_kwargs"],
"full type inference is not yet supported")

"""
Union[int, torch.Tensor]
Expand Down Expand Up @@ -896,20 +941,44 @@ def fn():
self._assert_raises(template,
"Union[Dict[str, torch.Tensor], int]",
lhs["dict_literal_of_str_int"],
r"Type hint for dict was Dict\[str, Tensor\]"
", but the value at index 0 has type int, "
"which is not a valid subtype of Tensor")
"Type annotation was inferred to be "
r"`Dict\[str, Tensor\]`, but the type of "
"values given by the dict literal is")

self._assert_raises(template,
"Union[Dict[str, torch.Tensor], int]",
lhs["dict_literal_of_mixed"],
r"Type hint for dict was Dict\[str, Tensor\]"
", but the value at index 1 has type int, "
"which is not a valid subtype of Tensor")
"Type annotation was inferred to be "
r"`Dict\[str, Tensor\]`, but the type of "
"values given by the dict literal is")

self._assert_passes(template,
"Union[Dict[str, torch.Tensor], int]",
lhs["dict_keyword"])

self._assert_passes(template,
"Union[Dict[str, torch.Tensor], int]",
lhs["dict_keyword_with_iterable"])

self._assert_passes(template,
"Union[Dict[str, torch.Tensor], int]",
lhs["dict_keyword_with_empty_iterable"])

self._assert_passes(template,
"Union[Dict[str, torch.Tensor], int]",
lhs["dict_keyword_with_mapping"])

self._assert_passes(template,
"Union[Dict[str, torch.Tensor], int]",
lhs["dict_keyword_with_mapping_and_kwargs"])

# See above--string frontend does not support tuple unpacking
# self._assert_passes(template,
# "Union[Dict[str, torch.Tensor], int]",
# lhs["dict_keyword_with_internal_aggregate_function"])
#
# self._assert_passes(template,
# "Union[Dict[str, torch.Tensor], int]",
# lhs["dict_comprehension_of_str_tensor"])

# self._assert_raises(template,
Expand All @@ -921,7 +990,3 @@ def fn():
# "Union[Dict[str, torch.Tensor], int]",
# lhs["dict_comprehension_of_mixed"],
# "foobar")

self._assert_passes(template,
"Union[Dict[str, torch.Tensor], int]",
lhs["dict_keyword"])
Loading

0 comments on commit a108440

Please sign in to comment.