Skip to content

Commit

Permalink
[small] multithreaded-pg guard attr (pytorch#93883)
Browse files Browse the repository at this point in the history
currently the test
```
pytest test/distributed/test_multi_threaded_pg.py -vs
```

has errors

```
Traceback (most recent call last):
  File "/private/home/howardhuang/.conda/envs/pytorch/lib/python3.9/threading.py", line 980, in _bootstrap_inner
    self.run()
  File "/private/home/howardhuang/.conda/envs/pytorch/lib/python3.9/threading.py", line 917, in run
    self._target(*self._args, **self._kwargs)
  File "/private/home/howardhuang/pytorch-projects/pytorch/torch/testing/_internal/common_distributed.py", line 1029, in _run
    self._tls.precision = TestCase._precision
AttributeError: 'TestCollectivesWithBaseClass' object has no attribute '_tls'
```
Pull Request resolved: pytorch#93883
Approved by: https://github.com/awgu, https://github.com/wanchaol
  • Loading branch information
H-Huang authored and pytorchmergebot committed Feb 3, 2023
1 parent 6d597c5 commit 5c7f453
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions torch/testing/_internal/common_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -1025,9 +1025,10 @@ def _run(cls, test_name, rank, world_size):
# every thread have the same value. This would be relevant when we use op db tests, where it
# needs those states to be set i.e. using instantiate_device_type_tests()
# TODO: figure out a better way to do this
self._tls = threading.local()
self._tls.precision = TestCase._precision
self._tls.rel_tol = TestCase._rel_tol
if hasattr(self, "_tls"):
self._tls = threading.local()
self._tls.precision = TestCase._precision
self._tls.rel_tol = TestCase._rel_tol

self.run_test_with_threaded_pg(test_name, rank, world_size)

Expand Down

0 comments on commit 5c7f453

Please sign in to comment.