forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Expose torch.futures.Future (pytorch#39008)
Summary: Pull Request resolved: pytorch#39008 This commit adds a `torch.futures.Future` type and exposes its ctor, `wait`, `then`, and `set_result` APIs. This type is currently a wrapper of `c10::ivalue::Future` and mainly used by RPC for now. Later, we could revamp c10d APIs to return this `Future` type as well. More utils will be added into `torch.futures` package in followup PRs. Test Plan: Imported from OSS Differential Revision: D21723022 Pulled By: mrshenli fbshipit-source-id: 92e56160544e9bf00d11db3e8347a1b9707882c9
- Loading branch information
1 parent
b3fac8a
commit bb0377b
Showing
10 changed files
with
328 additions
and
16 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
import threading | ||
import time | ||
import torch | ||
from torch.futures import Future | ||
from torch.testing._internal.common_utils import TestCase, TemporaryFileName | ||
|
||
|
||
def add_one(fut): | ||
return fut.wait() + 1 | ||
|
||
|
||
class TestFuture(TestCase): | ||
def test_wait(self): | ||
f = Future() | ||
f.set_result(torch.ones(2, 2)) | ||
|
||
self.assertEqual(f.wait(), torch.ones(2, 2)) | ||
|
||
def test_wait_multi_thread(self): | ||
|
||
def slow_set_future(fut, value): | ||
time.sleep(0.5) | ||
fut.set_result(value) | ||
|
||
f = Future() | ||
|
||
t = threading.Thread(target=slow_set_future, args=(f, torch.ones(2, 2))) | ||
t.start() | ||
|
||
self.assertEqual(f.wait(), torch.ones(2, 2)) | ||
t.join() | ||
|
||
def test_mark_future_twice(self): | ||
fut = Future() | ||
fut.set_result(1) | ||
with self.assertRaisesRegex( | ||
RuntimeError, | ||
"Future can only be marked completed once" | ||
): | ||
fut.set_result(1) | ||
|
||
def test_pickle_future(self): | ||
fut = Future() | ||
errMsg = "Can not pickle torch.futures.Future" | ||
with TemporaryFileName() as fname: | ||
with self.assertRaisesRegex(RuntimeError, errMsg): | ||
torch.save(fut, fname) | ||
|
||
def test_then(self): | ||
fut = Future() | ||
then_fut = fut.then(lambda x: x.wait() + 1) | ||
|
||
fut.set_result(torch.ones(2, 2)) | ||
self.assertEqual(fut.wait(), torch.ones(2, 2)) | ||
self.assertEqual(then_fut.wait(), torch.ones(2, 2) + 1) | ||
|
||
def test_chained_then(self): | ||
fut = Future() | ||
futs = [] | ||
last_fut = fut | ||
for _ in range(20): | ||
last_fut = last_fut.then(add_one) | ||
futs.append(last_fut) | ||
|
||
fut.set_result(torch.ones(2, 2)) | ||
|
||
for i in range(len(futs)): | ||
self.assertEqual(futs[i].wait(), torch.ones(2, 2) + i + 1) | ||
|
||
def _test_error(self, cb, errMsg): | ||
fut = Future() | ||
then_fut = fut.then(cb) | ||
|
||
fut.set_result(5) | ||
self.assertEqual(5, fut.wait()) | ||
with self.assertRaisesRegex(RuntimeError, errMsg): | ||
then_fut.wait() | ||
|
||
def test_then_wrong_arg(self): | ||
|
||
def wrong_arg(tensor): | ||
return tensor + 1 | ||
|
||
self._test_error(wrong_arg, "unsupported operand type.*Future.*int") | ||
|
||
def test_then_no_arg(self): | ||
|
||
def no_arg(): | ||
return True | ||
|
||
self._test_error(no_arg, "takes 0 positional arguments but 1 was given") | ||
|
||
def test_then_raise(self): | ||
|
||
def raise_value_error(fut): | ||
raise ValueError("Expected error") | ||
|
||
self._test_error(raise_value_error, "Expected error") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
""" | ||
The ``torch.futures`` package contains a ``Future`` type and corresponding | ||
utility functions. | ||
""" | ||
import torch | ||
|
||
|
||
class Future(torch._C.Future): | ||
r""" | ||
Wrapper around a ``torch._C.Future`` which encapsulates an asynchronous | ||
execution of a callable, e.g. :meth:`~torch.distributed.rpc.rpc_async`. It | ||
also exposes a set of APIs to add callback functions and set results. | ||
.. warning:: | ||
The ``torch.futures.Future`` is experimental and subject to change. | ||
""" | ||
def __new__(cls): | ||
return super(Future, cls).__new__(cls) | ||
|
||
def wait(self): | ||
r""" | ||
Block until the value of this ``Future`` is ready. | ||
Returns: | ||
The value held by this ``Future``. If the function (callback or RPC) | ||
creating the value has thrown an error, this ``wait`` method will | ||
also throw an error. | ||
""" | ||
return super(Future, self).wait() | ||
|
||
def then(self, callback): | ||
r""" | ||
Append the given callback function to this ``Future``, which will be run | ||
when the ``Future`` is completed. Multiple callbacks can be added to | ||
the same ``Future``, and will be invoked in the same order as they were | ||
added. The callback must take one argument, which is the reference to | ||
this ``Future``. The callback function can use the ``Future.wait()`` API | ||
to get the value. | ||
Arguments: | ||
callback(``Callable``): a ``Callable`` that takes this ``Future`` as | ||
the only argument. | ||
Returns: | ||
A new ``Future`` object that holds the return value of the | ||
``callback`` and will be marked as completed when the given | ||
``callback`` finishes. | ||
Example:: | ||
>>> import torch | ||
>>> | ||
>>> def callback(fut): | ||
>>> print(f"RPC return value is {fut.wait()}.") | ||
>>> | ||
>>> fut = torch.futures.Future() | ||
>>> # The inserted callback will print the return value when | ||
>>> # receiving the response from "worker1" | ||
>>> cb_fut = fut.then(callback) | ||
>>> chain_cb_fut = cb_fut.then(lambda x : print(f"Chained cb done. {x.wait()}")) | ||
>>> fut.set_result(5) | ||
>>> | ||
>>> # Outputs are: | ||
>>> # RPC return value is 5. | ||
>>> # Chained cb done. None | ||
""" | ||
return super(Future, self).then(callback) | ||
|
||
def set_result(self, result): | ||
r""" | ||
Set the result for this ``Future``, which will mark this ``Future`` as | ||
completed and trigger all attached callbacks. Note that a ``Future`` | ||
cannot be marked completed twice. | ||
Arguments: | ||
result (object): the result object of this ``Future``. | ||
Example:: | ||
>>> import threading | ||
>>> import time | ||
>>> import torch | ||
>>> | ||
>>> def slow_set_future(fut, value): | ||
>>> time.sleep(0.5) | ||
>>> fut.set_result(value) | ||
>>> | ||
>>> fut = torch.futures.Future() | ||
>>> t = threading.Thread( | ||
>>> target=slow_set_future, | ||
>>> args=(fut, torch.ones(2) * 3) | ||
>>> ) | ||
>>> t.start() | ||
>>> | ||
>>> print(fut.wait()) # tensor([3., 3.]) | ||
>>> t.join() | ||
""" | ||
super(Future, self).set_result(result) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.