Skip to content

Commit

Permalink
Expose torch.futures.Future (pytorch#39008)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#39008

This commit adds a `torch.futures.Future` type and exposes its ctor,
`wait`, `then`, and `set_result` APIs. This type is currently a
wrapper of `c10::ivalue::Future` and mainly used by RPC for now. Later,
we could revamp c10d APIs to return this `Future` type as well. More
utils will be added into `torch.futures` package in followup PRs.

Test Plan: Imported from OSS

Differential Revision: D21723022

Pulled By: mrshenli

fbshipit-source-id: 92e56160544e9bf00d11db3e8347a1b9707882c9
  • Loading branch information
mrshenli authored and facebook-github-bot committed Jun 2, 2020
1 parent b3fac8a commit bb0377b
Show file tree
Hide file tree
Showing 10 changed files with 328 additions and 16 deletions.
5 changes: 4 additions & 1 deletion aten/src/ATen/core/ivalue_inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,10 @@ struct C10_EXPORT ivalue::Future final : c10::intrusive_ptr_target {
*/
void markCompleted(IValue value) {
std::unique_lock<std::mutex> lock(mutex_);
AT_ASSERT(!completed());
TORCH_CHECK(
!completed(),
"Attempting to mark a completed Future as complete again. Note that "
"a Future can only be marked completed once.");
completed_ = true;
value_ = std::move(value);

Expand Down
3 changes: 3 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ ignore_errors = True
[mypy-torch.functional.*]
ignore_errors = True

[mypy-torch.futures.*]
ignore_errors = True

[mypy-torch.testing._internal.*]
ignore_errors = True

Expand Down
2 changes: 2 additions & 0 deletions test/run_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@
'test_determination',
'distributed/rpc/jit/test_rpc_spawn',
'distributed/rpc/faulty_agent/test_rpc_spawn',
'test_futures',
]

WINDOWS_BLACKLIST = [
Expand Down Expand Up @@ -154,6 +155,7 @@
'distributed/test_c10d_spawn',
'test_quantization',
'test_determination',
'test_futures',
]
_DEP_MODULES_CACHE = {}

Expand Down
98 changes: 98 additions & 0 deletions test/test_futures.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import threading
import time
import torch
from torch.futures import Future
from torch.testing._internal.common_utils import TestCase, TemporaryFileName


def add_one(fut):
return fut.wait() + 1


class TestFuture(TestCase):
def test_wait(self):
f = Future()
f.set_result(torch.ones(2, 2))

self.assertEqual(f.wait(), torch.ones(2, 2))

def test_wait_multi_thread(self):

def slow_set_future(fut, value):
time.sleep(0.5)
fut.set_result(value)

f = Future()

t = threading.Thread(target=slow_set_future, args=(f, torch.ones(2, 2)))
t.start()

self.assertEqual(f.wait(), torch.ones(2, 2))
t.join()

def test_mark_future_twice(self):
fut = Future()
fut.set_result(1)
with self.assertRaisesRegex(
RuntimeError,
"Future can only be marked completed once"
):
fut.set_result(1)

def test_pickle_future(self):
fut = Future()
errMsg = "Can not pickle torch.futures.Future"
with TemporaryFileName() as fname:
with self.assertRaisesRegex(RuntimeError, errMsg):
torch.save(fut, fname)

def test_then(self):
fut = Future()
then_fut = fut.then(lambda x: x.wait() + 1)

fut.set_result(torch.ones(2, 2))
self.assertEqual(fut.wait(), torch.ones(2, 2))
self.assertEqual(then_fut.wait(), torch.ones(2, 2) + 1)

def test_chained_then(self):
fut = Future()
futs = []
last_fut = fut
for _ in range(20):
last_fut = last_fut.then(add_one)
futs.append(last_fut)

fut.set_result(torch.ones(2, 2))

for i in range(len(futs)):
self.assertEqual(futs[i].wait(), torch.ones(2, 2) + i + 1)

def _test_error(self, cb, errMsg):
fut = Future()
then_fut = fut.then(cb)

fut.set_result(5)
self.assertEqual(5, fut.wait())
with self.assertRaisesRegex(RuntimeError, errMsg):
then_fut.wait()

def test_then_wrong_arg(self):

def wrong_arg(tensor):
return tensor + 1

self._test_error(wrong_arg, "unsupported operand type.*Future.*int")

def test_then_no_arg(self):

def no_arg():
return True

self._test_error(no_arg, "takes 0 positional arguments but 1 was given")

def test_then_raise(self):

def raise_value_error(fut):
raise ValueError("Expected error")

self._test_error(raise_value_error, "Expected error")
1 change: 1 addition & 0 deletions torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,7 @@ def manager_path():
import torch.cuda
import torch.autograd
from torch.autograd import no_grad, enable_grad, set_grad_enabled
import torch.futures
import torch.nn
import torch.nn.intrinsic
import torch.nn.quantized
Expand Down
29 changes: 28 additions & 1 deletion torch/csrc/jit/python/init.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -833,13 +833,40 @@ void initJITBindings(PyObject* module) {

py::class_<PythonFutureWrapper, std::shared_ptr<PythonFutureWrapper>>(
m, "Future")
.def(py::init([]() {
return std::make_shared<PythonFutureWrapper>(
c10::make_intrusive<c10::ivalue::Future>(PyObjectType::get()));
}))
.def(
"wait",
&PythonFutureWrapper::wait,
py::call_guard<py::gil_scoped_release>())
.def(
"_then",
"then",
&PythonFutureWrapper::then,
py::call_guard<py::gil_scoped_release>())
.def(
"set_result",
// Intentionally not releasing GIL
&PythonFutureWrapper::markCompleted)
.def(
py::pickle(
/* __getstate__ */
[](const PythonFutureWrapper& /* unused */) {
TORCH_CHECK(false, "Can not pickle torch.futures.Future");
// Note that this return has no meaning since we always
// throw, it's only here to satisfy Pybind API's
// requirement.
return py::make_tuple();
},
/* __setstate__ */
[](const py::tuple& /* unused */) { // NOLINT
TORCH_CHECK(false, "Can not unpickle torch.futures.Future");
// Note that this return has no meaning since we always
// throw, it's only here to satisfy PyBind's API
// requirement.
return nullptr;
}),
py::call_guard<py::gil_scoped_release>());

m.def("fork", [](const py::args& args, const py::kwargs& kwargs) {
Expand Down
10 changes: 9 additions & 1 deletion torch/csrc/jit/python/pybind_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ struct VISIBILITY_HIDDEN PythonFutureWrapper
c10::optional<UnwrapFunc> unwrap_func = c10::nullopt)
: fut(std::move(fut)), unwrap_func(std::move(unwrap_func)) {}

PythonFutureWrapper(const PythonFutureWrapper&) = delete;
explicit PythonFutureWrapper(const PythonFutureWrapper&) = delete;
PythonFutureWrapper& operator=(const PythonFutureWrapper&) = delete;

py::object wait() {
Expand Down Expand Up @@ -136,6 +136,14 @@ struct VISIBILITY_HIDDEN PythonFutureWrapper
PyObjectType::get()));
}

void markCompleted(const py::object& pyValue) {
DCHECK(PyGILState_Check());
IValue value = toIValue(pyValue, PyObjectType::get());

py::gil_scoped_release release;
fut->markCompleted(std::move(value));
}

c10::intrusive_ptr<c10::ivalue::Future> fut;
// unwrap_func works like a callback for the value returned by
// PythonFutureWrapper::wait().
Expand Down
96 changes: 96 additions & 0 deletions torch/futures/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
"""
The ``torch.futures`` package contains a ``Future`` type and corresponding
utility functions.
"""
import torch


class Future(torch._C.Future):
r"""
Wrapper around a ``torch._C.Future`` which encapsulates an asynchronous
execution of a callable, e.g. :meth:`~torch.distributed.rpc.rpc_async`. It
also exposes a set of APIs to add callback functions and set results.
.. warning::
The ``torch.futures.Future`` is experimental and subject to change.
"""
def __new__(cls):
return super(Future, cls).__new__(cls)

def wait(self):
r"""
Block until the value of this ``Future`` is ready.
Returns:
The value held by this ``Future``. If the function (callback or RPC)
creating the value has thrown an error, this ``wait`` method will
also throw an error.
"""
return super(Future, self).wait()

def then(self, callback):
r"""
Append the given callback function to this ``Future``, which will be run
when the ``Future`` is completed. Multiple callbacks can be added to
the same ``Future``, and will be invoked in the same order as they were
added. The callback must take one argument, which is the reference to
this ``Future``. The callback function can use the ``Future.wait()`` API
to get the value.
Arguments:
callback(``Callable``): a ``Callable`` that takes this ``Future`` as
the only argument.
Returns:
A new ``Future`` object that holds the return value of the
``callback`` and will be marked as completed when the given
``callback`` finishes.
Example::
>>> import torch
>>>
>>> def callback(fut):
>>> print(f"RPC return value is {fut.wait()}.")
>>>
>>> fut = torch.futures.Future()
>>> # The inserted callback will print the return value when
>>> # receiving the response from "worker1"
>>> cb_fut = fut.then(callback)
>>> chain_cb_fut = cb_fut.then(lambda x : print(f"Chained cb done. {x.wait()}"))
>>> fut.set_result(5)
>>>
>>> # Outputs are:
>>> # RPC return value is 5.
>>> # Chained cb done. None
"""
return super(Future, self).then(callback)

def set_result(self, result):
r"""
Set the result for this ``Future``, which will mark this ``Future`` as
completed and trigger all attached callbacks. Note that a ``Future``
cannot be marked completed twice.
Arguments:
result (object): the result object of this ``Future``.
Example::
>>> import threading
>>> import time
>>> import torch
>>>
>>> def slow_set_future(fut, value):
>>> time.sleep(0.5)
>>> fut.set_result(value)
>>>
>>> fut = torch.futures.Future()
>>> t = threading.Thread(
>>> target=slow_set_future,
>>> args=(fut, torch.ones(2) * 3)
>>> )
>>> t.start()
>>>
>>> print(fut.wait()) # tensor([3., 3.])
>>> t.join()
"""
super(Future, self).set_result(result)
6 changes: 3 additions & 3 deletions torch/testing/_internal/distributed/rpc/jit/rpc_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -944,7 +944,7 @@ def callback(fut):
worker_name((self.rank + 1) % self.world_size),
script_fork_wait_udf,
args=(torch.ones(2),)
)._then(callback)
).then(callback)
self.assertEqual(future.wait(), torch.ones(2) * 2 + 1)

@dist_init
Expand All @@ -963,7 +963,7 @@ def callback(fut):

num_cbs = 20
for _ in range(num_cbs):
fut = fut._then(callback)
fut = fut.then(callback)

self.assertEqual(fut.wait(), torch.ones(n, n) + 1 + num_cbs)

Expand All @@ -988,7 +988,7 @@ def callback(fut):
worker_name((self.rank + 1) % self.world_size),
script_fork_wait_throw,
args=(torch.ones(2),)
)._then(callback)
).then(callback)

with self.assertRaisesRegex(RuntimeError, "Another expected error"):
future.wait()
Expand Down
Loading

0 comments on commit bb0377b

Please sign in to comment.