Skip to content

Commit

Permalink
add numpy typing plugin to mypy config (pytorch#92930)
Browse files Browse the repository at this point in the history
This added the numpy typing plugin to mypy config so that we could
use it for DeviceMesh typing annotations

Please see pytorch#92931 about why we need this. For example, we are currently saving the DeviceMesh's mesh field as torch.Tensor, where when we do sth like:
```python
with FakeTensorMode():
    device_mesh = DeviceMesh("cuda", torch.arange(4))
```
It would throw error because FakeTensorMode or any TorchDispatchMode tracks every tensor creation and interactions. While DeviceMesh just want to save a nd-array to record the mesh topology, and would like to avoid the interaction with subsystems like FakeTensor, so we want to support saving `mesh` as numpy array instead.

Pull Request resolved: pytorch#92930
Approved by: https://github.com/ezyang, https://github.com/malfet
  • Loading branch information
wanchaol authored and pytorchmergebot committed Jan 31, 2023
1 parent 2a6e085 commit 5f1ac18
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 3 deletions.
2 changes: 1 addition & 1 deletion mypy-nofollow.ini
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[mypy]
plugins = mypy_plugins/check_mypy_version.py
plugins = mypy_plugins/check_mypy_version.py, numpy.typing.mypy_plugin

cache_dir = .mypy_cache/nofollow
warn_unused_configs = True
Expand Down
2 changes: 1 addition & 1 deletion mypy-strict.ini
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

[mypy]
python_version = 3.8
plugins = mypy_plugins/check_mypy_version.py
plugins = mypy_plugins/check_mypy_version.py, numpy.typing.mypy_plugin

cache_dir = .mypy_cache/strict
strict_optional = True
Expand Down
2 changes: 1 addition & 1 deletion mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# test_run_mypy in test/test_type_hints.py uses this string)

[mypy]
plugins = mypy_plugins/check_mypy_version.py
plugins = mypy_plugins/check_mypy_version.py, numpy.typing.mypy_plugin

cache_dir = .mypy_cache/normal
warn_unused_configs = True
Expand Down

0 comments on commit 5f1ac18

Please sign in to comment.