Skip to content

Commit

Permalink
refactor gather/wait_for
Browse files Browse the repository at this point in the history
  • Loading branch information
netcan committed Nov 23, 2021
1 parent 986b9a8 commit e3f8bea
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 22 deletions.
34 changes: 29 additions & 5 deletions include/asyncio/gather.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ class GatherAwaiter {
Task<> collect_result(NoWaitAtInitialSuspend, Fut&& fut) {
try {
auto& results = std::get<ResultTypes>(result_);
if constexpr (std::is_void_v<AwaitResult<Fut>>) { co_await fut; }
else { std::get<Idx>(results) = std::move(co_await fut); }
if constexpr (std::is_void_v<AwaitResult<Fut>>) { co_await std::forward<Fut>(fut); }
else { std::get<Idx>(results) = std::move(co_await std::forward<Fut>(fut)); }
++count_;
} catch(...) {
result_ = std::current_exception();
Expand All @@ -67,12 +67,36 @@ class GatherAwaiter {

template<concepts::Awaitable... Futs> // C++17 deduction guide
GatherAwaiter(Futs&&...) -> GatherAwaiter<AwaitResult<Futs>...>;

template<concepts::Awaitable... Futs>
struct GatherAwaiterRepositry {
GatherAwaiterRepositry(Futs&&... futs)
: futs_(std::forward<Futs>(futs)...) { }

auto operator co_await() & {
return std::apply([]<concepts::Awaitable... F>(F&&... f) {
return GatherAwaiter { std::forward<F>(f)... };
}, futs_);
}

auto operator co_await() && {
return std::apply([]<concepts::Awaitable... F>(F&&... f) {
return GatherAwaiter { std::forward<F>(f)... };
}, std::move(futs_));
}

private:
std::tuple<Futs...> futs_;
};

template<concepts::Awaitable... Futs>
GatherAwaiterRepositry(Futs&&...) -> GatherAwaiterRepositry<Futs...>;
}

template<concepts::Awaitable... Futs>
[[nodiscard]]
auto gather(Futs&&... futs) -> detail::GatherAwaiter<AwaitResult<Futs>...> {
return { std::forward<Futs>(futs)... };
[[nodiscard("dicard gather doesn't make sense")]]
auto gather(Futs&&... futs) {
return detail::GatherAwaiterRepositry{ std::forward<Futs>(futs)... };
}

ASYNCIO_NS_END
Expand Down
11 changes: 9 additions & 2 deletions include/asyncio/sleep.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,19 @@ struct SleepAwaiter: private NonCopyable {
private:
Duration delay_;
};
template<typename Duration>
struct SleepAwaiterRepositry {
auto operator co_await () && {
return SleepAwaiter{delay_};
}
Duration delay_;
};
}

template<typename Rep, typename Period>
[[nodiscard]]
[[nodiscard("discard sleep doesn't make sense")]]
auto sleep(std::chrono::duration<Rep, Period> delay /* second */) {
return detail::SleepAwaiter {delay};
return detail::SleepAwaiterRepositry {delay};
}

using namespace std::chrono_literals;
Expand Down
26 changes: 22 additions & 4 deletions include/asyncio/wait_for.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ struct WaitForAwaiter {
template<concepts::Awaitable Fut>
Task<> wait_for_task(NoWaitAtInitialSuspend, Fut&& fut) {
try {
if constexpr (std::is_void_v<R>) { co_await fut; }
else { result_ = std::move(co_await fut); }
if constexpr (std::is_void_v<R>) { co_await std::forward<Fut>(fut); }
else { result_ = std::move(co_await std::forward<Fut>(fut)); }
} catch(...) {
result_ = std::current_exception();
}
Expand Down Expand Up @@ -77,12 +77,30 @@ struct WaitForAwaiter {

template<concepts::Awaitable Fut, typename Duration>
WaitForAwaiter(Fut, Duration) -> WaitForAwaiter<AwaitResult<Fut>, Duration>;

template<concepts::Awaitable Fut, typename Duration>
struct WaitForAwaiterRegistry {
WaitForAwaiterRegistry(Fut&& fut, Duration duration)
: fut_(std::forward<Fut>(fut)), duration_(duration)
{ }

auto operator co_await () && {
return WaitForAwaiter{std::move(fut_), duration_};
}
private:
Fut fut_;
Duration duration_;
};

template<concepts::Awaitable Fut, typename Duration>
WaitForAwaiterRegistry(Fut&& fut, Duration duration)
-> WaitForAwaiterRegistry<Fut, Duration>;
}

template<concepts::Awaitable Fut, typename Rep, typename Period>
[[nodiscard]]
[[nodiscard("discard wait_for doesn't make sense")]]
auto wait_for(Fut&& fut, std::chrono::duration<Rep, Period> timeout) {
return detail::WaitForAwaiter { std::forward<Fut>(fut), timeout };
return detail::WaitForAwaiterRegistry { std::forward<Fut>(fut), timeout };
}
ASYNCIO_NS_END
#endif // ASYNCIO_WAIT_FOR_H
28 changes: 17 additions & 11 deletions test/ut/task_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -172,18 +172,22 @@ SCENARIO("test gather") {
REQUIRE(! is_called);
loop.run_until_complete([&]() -> Task<> {
auto fac_lvalue = factorial("A", 2);
auto fac_xvalue = factorial("B", 3);
auto&& fac_rvalue = factorial("C", 4);
{
auto&& [a, b, c, _void] = co_await asyncio::gather(
fac_lvalue,
factorial("B", 3),
factorial("C", 4),
static_cast<Task<int>&&>(fac_xvalue),
std::move(fac_rvalue),
test_void_func()
);
REQUIRE(a == 2);
REQUIRE(b == 6);
REQUIRE(c == 24);
}
REQUIRE((co_await fac_lvalue) == 2);
REQUIRE(fac_xvalue.handle_ == nullptr);
REQUIRE(fac_rvalue.handle_ == nullptr);
is_called = true;
}());
REQUIRE(is_called);
Expand All @@ -209,15 +213,17 @@ SCENARIO("test gather") {

SECTION("test detach gather") {
REQUIRE(! is_called);
auto res = asyncio::gather(
factorial("A", 2),
factorial("B", 3)
);
loop.run_until_complete([&]() -> Task<> {
auto res = asyncio::gather(
factorial("A", 2),
factorial("B", 3)
);
auto&& [a, b] = co_await res; // no get result, factorial is rvalue, release when detach
auto&& [a, b] = co_await std::move(res);
REQUIRE(a == 2);
REQUIRE(b == 6);
is_called = true;
}());
REQUIRE(! is_called);
REQUIRE(is_called);
}

SECTION("test exception gather") {
Expand All @@ -244,13 +250,13 @@ SCENARIO("test timeout") {
co_return 0xbabababc;
};

auto wait_test = [&](auto duration, auto timeout) -> Task<int> {
auto wait_for_test = [&](auto duration, auto timeout) -> Task<int> {
co_return co_await wait_for(wait_duration(duration), timeout);
};

SECTION("no timeout") {
REQUIRE(! is_called);
REQUIRE(loop.run_until_complete(wait_test(12ms, 12000ms)) == 0xbabababc);
REQUIRE(loop.run_until_complete(wait_for_test(12ms, 12000ms)) == 0xbabababc);
REQUIRE(is_called);
}

Expand All @@ -263,7 +269,7 @@ SCENARIO("test timeout") {

SECTION("timeout error") {
REQUIRE(! is_called);
REQUIRE_THROWS_AS(loop.run_until_complete(wait_test(200ms, 100ms)), TimeoutError);
REQUIRE_THROWS_AS(loop.run_until_complete(wait_for_test(200ms, 100ms)), TimeoutError);
REQUIRE(! is_called);
}
}
Expand Down

0 comments on commit e3f8bea

Please sign in to comment.