Skip to content

Commit

Permalink
Async methods now takes Receiver
Browse files Browse the repository at this point in the history
  • Loading branch information
kngwyu committed Jun 13, 2020
1 parent f322771 commit 78577c6
Show file tree
Hide file tree
Showing 4 changed files with 98 additions and 19 deletions.
9 changes: 6 additions & 3 deletions pyo3-derive-backend/src/defs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -156,18 +156,21 @@ pub const ASYNC: Proto = Proto {
slot_table: "pyo3::ffi::PyAsyncMethods",
set_slot_table: "set_async_methods",
methods: &[
MethodProto::Unary {
MethodProto::UnaryS {
name: "__await__",
arg: "Receiver",
pyres: true,
proto: "pyo3::class::pyasync::PyAsyncAwaitProtocol",
},
MethodProto::Unary {
MethodProto::UnaryS {
name: "__aiter__",
arg: "Receiver",
pyres: true,
proto: "pyo3::class::pyasync::PyAsyncAiterProtocol",
},
MethodProto::Unary {
MethodProto::UnaryS {
name: "__anext__",
arg: "Receiver",
pyres: true,
proto: "pyo3::class::pyasync::PyAsyncAnextProtocol",
},
Expand Down
5 changes: 3 additions & 2 deletions src/class/macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,9 @@ macro_rules! py_unarys_func {
{
$crate::callback_body!(py, {
let slf = py.from_borrowed_ptr::<$crate::PyCell<T>>(slf);
let borrow = <T::Receiver>::try_from_pycell(slf)
.map_err(|e| e.into())?;
let borrow =
<T::Receiver as $crate::derive_utils::TryFromPyCell<_>>::try_from_pycell(slf)
.map_err(|e| e.into())?;

$class::$f(borrow).into()$(.map($conv))?
})
Expand Down
26 changes: 13 additions & 13 deletions src/class/pyasync.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
//! [PEP-0492](https://www.python.org/dev/peps/pep-0492/)
//!

use crate::derive_utils::TryFromPyCell;
use crate::err::PyResult;
use crate::{ffi, PyClass, PyObject};

Expand All @@ -16,21 +17,21 @@ use crate::{ffi, PyClass, PyObject};
/// Each method in this trait corresponds to Python async/await implementation.
#[allow(unused_variables)]
pub trait PyAsyncProtocol<'p>: PyClass {
fn __await__(&'p self) -> Self::Result
fn __await__(slf: Self::Receiver) -> Self::Result
where
Self: PyAsyncAwaitProtocol<'p>,
{
unimplemented!()
}

fn __aiter__(&'p self) -> Self::Result
fn __aiter__(slf: Self::Receiver) -> Self::Result
where
Self: PyAsyncAiterProtocol<'p>,
{
unimplemented!()
}

fn __anext__(&'p mut self) -> Self::Result
fn __anext__(slf: Self::Receiver) -> Self::Result
where
Self: PyAsyncAnextProtocol<'p>,
{
Expand Down Expand Up @@ -58,16 +59,19 @@ pub trait PyAsyncProtocol<'p>: PyClass {
}

pub trait PyAsyncAwaitProtocol<'p>: PyAsyncProtocol<'p> {
type Receiver: TryFromPyCell<'p, Self>;
type Success: crate::IntoPy<PyObject>;
type Result: Into<PyResult<Self::Success>>;
}

pub trait PyAsyncAiterProtocol<'p>: PyAsyncProtocol<'p> {
type Receiver: TryFromPyCell<'p, Self>;
type Success: crate::IntoPy<PyObject>;
type Result: Into<PyResult<Self::Success>>;
}

pub trait PyAsyncAnextProtocol<'p>: PyAsyncProtocol<'p> {
type Receiver: TryFromPyCell<'p, Self>;
type Success: crate::IntoPy<PyObject>;
type Result: Into<PyResult<Option<Self::Success>>>;
}
Expand All @@ -90,13 +94,13 @@ impl ffi::PyAsyncMethods {
where
T: for<'p> PyAsyncAwaitProtocol<'p>,
{
self.am_await = py_unary_func!(PyAsyncAwaitProtocol, T::__await__);
self.am_await = py_unarys_func!(PyAsyncAwaitProtocol, T::__await__);
}
pub fn set_aiter<T>(&mut self)
where
T: for<'p> PyAsyncAiterProtocol<'p>,
{
self.am_aiter = py_unary_func!(PyAsyncAiterProtocol, T::__aiter__);
self.am_aiter = py_unarys_func!(PyAsyncAiterProtocol, T::__aiter__);
}
pub fn set_anext<T>(&mut self)
where
Expand All @@ -123,7 +127,9 @@ mod anext {
fn convert(self, py: Python) -> PyResult<*mut ffi::PyObject> {
match self.0 {
Some(val) => Ok(val.into_py(py).into_ptr()),
None => Err(crate::exceptions::StopAsyncIteration::py_err(())),
None => Err(crate::exceptions::StopAsyncIteration::py_err(
"Task Completed",
)),
}
}
}
Expand All @@ -133,12 +139,6 @@ mod anext {
where
T: for<'p> PyAsyncAnextProtocol<'p>,
{
py_unary_func!(
PyAsyncAnextProtocol,
T::__anext__,
call_mut,
*mut crate::ffi::PyObject,
IterANextOutput
)
py_unarys_func!(PyAsyncAnextProtocol, T::__anext__, IterANextOutput)
}
}
77 changes: 76 additions & 1 deletion tests/test_dunder.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use pyo3::class::{
PyContextProtocol, PyIterProtocol, PyMappingProtocol, PyObjectProtocol, PySequenceProtocol,
PyAsyncProtocol, PyContextProtocol, PyIterProtocol, PyMappingProtocol, PyObjectProtocol,
PySequenceProtocol,
};
use pyo3::exceptions::{IndexError, ValueError};
use pyo3::prelude::*;
Expand Down Expand Up @@ -552,3 +553,77 @@ fn getattr_doesnt_override_member() {
py_assert!(py, inst, "inst.data == 4");
py_assert!(py, inst, "inst.a == 8");
}

/// Wraps a Python future and yield it once.
#[pyclass]
struct OnceFuture {
future: PyObject,
polled: bool,
}

#[pymethods]
impl OnceFuture {
#[new]
fn new(future: PyObject) -> Self {
OnceFuture {
future,
polled: false,
}
}
}

#[pyproto]
impl PyAsyncProtocol for OnceFuture {
fn __await__(slf: PyRef<Self>) -> PyResult<Py<Self>> {
Ok(slf.into())
}
}

#[pyproto]
impl PyIterProtocol for OnceFuture {
fn __iter__(slf: PyRef<Self>) -> PyResult<Py<Self>> {
Ok(slf.into())
}
fn __next__(mut slf: PyRefMut<Self>) -> PyResult<Option<PyObject>> {
if !slf.polled {
slf.polled = true;
Ok(Some(slf.future.clone()))
} else {
Ok(None)
}
}
}

#[test]
fn test_await() {
let gil = Python::acquire_gil();
let py = gil.python();
let once = py.get_type::<OnceFuture>();
let source = pyo3::indoc::indoc!(
r#"
import asyncio
import sys
async def main():
res = await Once(await asyncio.sleep(0.1))
return res
# It looks like that https://bugs.python.org/issue38563 solves this problem,
# but we see still errors on Github actions...
if sys.platform == "win32" and sys.version_info >= (3, 8, 0):
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
loop = asyncio.get_event_loop()
assert loop.run_until_complete(main()) is None
loop.close()
"#
);
let globals = PyModule::import(py, "__main__").unwrap().dict();
globals.set_item("Once", once).unwrap();
py.run(source, None, Some(globals))
.map_err(|e| {
e.print(py);
py.run("import sys; sys.stderr.flush()", None, None)
.unwrap();
})
.unwrap();
}

0 comments on commit 78577c6

Please sign in to comment.