include/boost/capy/when_all.hpp

96.9% Lines (95/98) 91.3% Functions (484/530)
Line TLA Hits Source Code
1 //
2 // Copyright (c) 2026 Steve Gerbino
3 //
4 // Distributed under the Boost Software License, Version 1.0. (See accompanying
5 // file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
6 //
7 // Official repository: https://github.com/cppalliance/capy
8 //
9
10 #ifndef BOOST_CAPY_WHEN_ALL_HPP
11 #define BOOST_CAPY_WHEN_ALL_HPP
12
13 #include <boost/capy/detail/config.hpp>
14 #include <boost/capy/concept/executor.hpp>
15 #include <boost/capy/concept/io_awaitable.hpp>
16 #include <coroutine>
17 #include <boost/capy/ex/io_env.hpp>
18 #include <boost/capy/ex/frame_allocator.hpp>
19 #include <boost/capy/task.hpp>
20
21 #include <array>
22 #include <atomic>
23 #include <exception>
24 #include <optional>
25 #include <stop_token>
26 #include <tuple>
27 #include <type_traits>
28 #include <utility>
29
30 namespace boost {
31 namespace capy {
32
33 namespace detail {
34
35 /** Type trait to filter void types from a tuple.
36
37 Void-returning tasks do not contribute a value to the result tuple.
38 This trait computes the filtered result type.
39
40 Example: filter_void_tuple_t<int, void, string> = tuple<int, string>
41 */
42 template<typename T>
43 using wrap_non_void_t = std::conditional_t<std::is_void_v<T>, std::tuple<>, std::tuple<T>>;
44
45 template<typename... Ts>
46 using filter_void_tuple_t = decltype(std::tuple_cat(std::declval<wrap_non_void_t<Ts>>()...));
47
48 /** Holds the result of a single task within when_all.
49 */
50 template<typename T>
51 struct result_holder
52 {
53 std::optional<T> value_;
54
55 62x void set(T v)
56 {
57 62x value_ = std::move(v);
58 62x }
59
60 55x T get() &&
61 {
62 55x return std::move(*value_);
63 }
64 };
65
66 /** Specialization for void tasks - no value storage needed.
67 */
68 template<>
69 struct result_holder<void>
70 {
71 };
72
73 /** Shared state for when_all operation.
74
75 @tparam Ts The result types of the tasks.
76 */
77 template<typename... Ts>
78 struct when_all_state
79 {
80 static constexpr std::size_t task_count = sizeof...(Ts);
81
82 // Completion tracking - when_all waits for all children
83 std::atomic<std::size_t> remaining_count_;
84
85 // Result storage in input order
86 std::tuple<result_holder<Ts>...> results_;
87
88 // Runner handles - destroyed in await_resume while allocator is valid
89 std::array<std::coroutine_handle<>, task_count> runner_handles_{};
90
91 // Exception storage - first error wins, others discarded
92 std::atomic<bool> has_exception_{false};
93 std::exception_ptr first_exception_;
94
95 // Stop propagation - on error, request stop for siblings
96 std::stop_source stop_source_;
97
98 // Connects parent's stop_token to our stop_source
99 struct stop_callback_fn
100 {
101 std::stop_source* source_;
102 4x void operator()() const { source_->request_stop(); }
103 };
104 using stop_callback_t = std::stop_callback<stop_callback_fn>;
105 std::optional<stop_callback_t> parent_stop_callback_;
106
107 // Parent resumption
108 std::coroutine_handle<> continuation_;
109 io_env const* caller_env_ = nullptr;
110
111 61x when_all_state()
112 61x : remaining_count_(task_count)
113 {
114 61x }
115
116 // Runners self-destruct in final_suspend. No destruction needed here.
117
118 /** Capture an exception (first one wins).
119 */
120 20x void capture_exception(std::exception_ptr ep)
121 {
122 20x bool expected = false;
123 20x if(has_exception_.compare_exchange_strong(
124 expected, true, std::memory_order_relaxed))
125 17x first_exception_ = ep;
126 20x }
127
128 };
129
130 /** Wrapper coroutine that intercepts task completion.
131
132 This runner awaits its assigned task and stores the result in
133 the shared state, or captures the exception and requests stop.
134 */
135 template<typename T, typename... Ts>
136 struct when_all_runner
137 {
138 struct promise_type // : frame_allocating_base // DISABLED FOR TESTING
139 {
140 when_all_state<Ts...>* state_ = nullptr;
141 io_env env_;
142
143 134x when_all_runner get_return_object()
144 {
145 134x return when_all_runner(std::coroutine_handle<promise_type>::from_promise(*this));
146 }
147
148 134x std::suspend_always initial_suspend() noexcept
149 {
150 134x return {};
151 }
152
153 134x auto final_suspend() noexcept
154 {
155 struct awaiter
156 {
157 promise_type* p_;
158
159 58x bool await_ready() const noexcept
160 {
161 58x return false;
162 }
163
164 58x std::coroutine_handle<> await_suspend(std::coroutine_handle<> h) noexcept
165 {
166 // Extract everything needed before self-destruction.
167 58x auto* state = p_->state_;
168 58x auto* counter = &state->remaining_count_;
169 58x auto* caller_env = state->caller_env_;
170 58x auto cont = state->continuation_;
171
172 58x h.destroy();
173
174 // If last runner, dispatch parent for symmetric transfer.
175 58x auto remaining = counter->fetch_sub(1, std::memory_order_acq_rel);
176 58x if(remaining == 1)
177 29x return caller_env->executor.dispatch(cont);
178 29x return std::noop_coroutine();
179 }
180
181 void await_resume() const noexcept
182 {
183 }
184 };
185 134x return awaiter{this};
186 }
187
188 114x void return_void()
189 {
190 114x }
191
192 20x void unhandled_exception()
193 {
194 20x state_->capture_exception(std::current_exception());
195 // Request stop for sibling tasks
196 20x state_->stop_source_.request_stop();
197 20x }
198
199 template<class Awaitable>
200 struct transform_awaiter
201 {
202 std::decay_t<Awaitable> a_;
203 promise_type* p_;
204
205 134x bool await_ready()
206 {
207 134x return a_.await_ready();
208 }
209
210 134x decltype(auto) await_resume()
211 {
212 134x return a_.await_resume();
213 }
214
215 template<class Promise>
216 133x auto await_suspend(std::coroutine_handle<Promise> h)
217 {
218 #ifdef _MSC_VER
219 using R = decltype(a_.await_suspend(h, &p_->env_));
220 if constexpr (std::is_same_v<R, std::coroutine_handle<>>)
221 a_.await_suspend(h, &p_->env_).resume();
222 else
223 return a_.await_suspend(h, &p_->env_);
224 #else
225 133x return a_.await_suspend(h, &p_->env_);
226 #endif
227 }
228 };
229
230 template<class Awaitable>
231 134x auto await_transform(Awaitable&& a)
232 {
233 using A = std::decay_t<Awaitable>;
234 if constexpr (IoAwaitable<A>)
235 {
236 return transform_awaiter<Awaitable>{
237 268x std::forward<Awaitable>(a), this};
238 }
239 else
240 {
241 static_assert(sizeof(A) == 0, "requires IoAwaitable");
242 }
243 134x }
244 };
245
246 std::coroutine_handle<promise_type> h_;
247
248 134x explicit when_all_runner(std::coroutine_handle<promise_type> h)
249 134x : h_(h)
250 {
251 134x }
252
253 // Enable move for all clang versions - some versions need it
254 when_all_runner(when_all_runner&& other) noexcept : h_(std::exchange(other.h_, nullptr)) {}
255
256 // Non-copyable
257 when_all_runner(when_all_runner const&) = delete;
258 when_all_runner& operator=(when_all_runner const&) = delete;
259 when_all_runner& operator=(when_all_runner&&) = delete;
260
261 134x auto release() noexcept
262 {
263 134x return std::exchange(h_, nullptr);
264 }
265 };
266
267 /** Create a runner coroutine for a single awaitable.
268
269 Awaitable is passed directly to ensure proper coroutine frame storage.
270 */
271 template<std::size_t Index, IoAwaitable Awaitable, typename... Ts>
272 when_all_runner<awaitable_result_t<Awaitable>, Ts...>
273 134x make_when_all_runner(Awaitable inner, when_all_state<Ts...>* state)
274 {
275 using T = awaitable_result_t<Awaitable>;
276 if constexpr (std::is_void_v<T>)
277 {
278 co_await std::move(inner);
279 }
280 else
281 {
282 std::get<Index>(state->results_).set(co_await std::move(inner));
283 }
284 268x }
285
286 /** Internal awaitable that launches all runner coroutines and waits.
287
288 This awaitable is used inside the when_all coroutine to handle
289 the concurrent execution of child awaitables.
290 */
291 template<IoAwaitable... Awaitables>
292 class when_all_launcher
293 {
294 using state_type = when_all_state<awaitable_result_t<Awaitables>...>;
295
296 std::tuple<Awaitables...>* awaitables_;
297 state_type* state_;
298
299 public:
300 61x when_all_launcher(
301 std::tuple<Awaitables...>* awaitables,
302 state_type* state)
303 61x : awaitables_(awaitables)
304 61x , state_(state)
305 {
306 61x }
307
308 61x bool await_ready() const noexcept
309 {
310 61x return sizeof...(Awaitables) == 0;
311 }
312
313 61x std::coroutine_handle<> await_suspend(std::coroutine_handle<> continuation, io_env const* caller_env)
314 {
315 61x state_->continuation_ = continuation;
316 61x state_->caller_env_ = caller_env;
317
318 // Forward parent's stop requests to children
319 61x if(caller_env->stop_token.stop_possible())
320 {
321 16x state_->parent_stop_callback_.emplace(
322 8x caller_env->stop_token,
323 8x typename state_type::stop_callback_fn{&state_->stop_source_});
324
325 8x if(caller_env->stop_token.stop_requested())
326 4x state_->stop_source_.request_stop();
327 }
328
329 // CRITICAL: If the last task finishes synchronously then the parent
330 // coroutine resumes, destroying its frame, and destroying this object
331 // prior to the completion of await_suspend. Therefore, await_suspend
332 // must ensure `this` cannot be referenced after calling `launch_one`
333 // for the last time.
334 61x auto token = state_->stop_source_.get_token();
335 [&]<std::size_t... Is>(std::index_sequence<Is...>) {
336 30x (..., launch_one<Is>(caller_env->executor, token));
337 61x }(std::index_sequence_for<Awaitables...>{});
338
339 // Let signal_completion() handle resumption
340 122x return std::noop_coroutine();
341 61x }
342
343 61x void await_resume() const noexcept
344 {
345 // Results are extracted by the when_all coroutine from state
346 61x }
347
348 private:
349 template<std::size_t I>
350 134x void launch_one(executor_ref caller_ex, std::stop_token token)
351 {
352 134x auto runner = make_when_all_runner<I>(
353 134x std::move(std::get<I>(*awaitables_)), state_);
354
355 134x auto h = runner.release();
356 134x h.promise().state_ = state_;
357 134x h.promise().env_ = io_env{caller_ex, token, state_->caller_env_->frame_allocator};
358
359 134x std::coroutine_handle<> ch{h};
360 134x state_->runner_handles_[I] = ch;
361 134x state_->caller_env_->executor.post(ch);
362 268x }
363 };
364
365 /** Compute the result type for when_all.
366
367 Returns void when all tasks are void (P2300 aligned),
368 otherwise returns a tuple with void types filtered out.
369 */
370 template<typename... Ts>
371 using when_all_result_t = std::conditional_t<
372 std::is_same_v<filter_void_tuple_t<Ts...>, std::tuple<>>,
373 void,
374 filter_void_tuple_t<Ts...>>;
375
376 /** Helper to extract a single result, returning empty tuple for void.
377 This is a separate function to work around a GCC-11 ICE that occurs
378 when using nested immediately-invoked lambdas with pack expansion.
379 */
380 template<std::size_t I, typename... Ts>
381 59x auto extract_single_result(when_all_state<Ts...>& state)
382 {
383 using T = std::tuple_element_t<I, std::tuple<Ts...>>;
384 if constexpr (std::is_void_v<T>)
385 4x return std::tuple<>();
386 else
387 55x return std::make_tuple(std::move(std::get<I>(state.results_)).get());
388 }
389
390 /** Extract results from state, filtering void types.
391 */
392 template<typename... Ts>
393 25x auto extract_results(when_all_state<Ts...>& state)
394 {
395 25x return [&]<std::size_t... Is>(std::index_sequence<Is...>) {
396 5x return std::tuple_cat(extract_single_result<Is>(state)...);
397 50x }(std::index_sequence_for<Ts...>{});
398 }
399
400 } // namespace detail
401
402 /** Execute multiple awaitables concurrently and collect their results.
403
404 Launches all awaitables simultaneously and waits for all to complete
405 before returning. Results are collected in input order. If any
406 awaitable throws, cancellation is requested for siblings and the first
407 exception is rethrown after all awaitables complete.
408
409 @li All child awaitables run concurrently on the caller's executor
410 @li Results are returned as a tuple in input order
411 @li Void-returning awaitables do not contribute to the result tuple
412 @li If all awaitables return void, `when_all` returns `task<void>`
413 @li First exception wins; subsequent exceptions are discarded
414 @li Stop is requested for siblings on first error
415 @li Completes only after all children have finished
416
417 @par Thread Safety
418 The returned task must be awaited from a single execution context.
419 Child awaitables execute concurrently but complete through the caller's
420 executor.
421
422 @param awaitables The awaitables to execute concurrently. Each must
423 satisfy @ref IoAwaitable and is consumed (moved-from) when
424 `when_all` is awaited.
425
426 @return A task yielding a tuple of non-void results. Returns
427 `task<void>` when all input awaitables return void.
428
429 @par Example
430
431 @code
432 task<> example()
433 {
434 // Concurrent fetch, results collected in order
435 auto [user, posts] = co_await when_all(
436 fetch_user( id ), // task<User>
437 fetch_posts( id ) // task<std::vector<Post>>
438 );
439
440 // Void awaitables don't contribute to result
441 co_await when_all(
442 log_event( "start" ), // task<void>
443 notify_user( id ) // task<void>
444 );
445 // Returns task<void>, no result tuple
446 }
447 @endcode
448
449 @see IoAwaitable, task
450 */
451 template<IoAwaitable... As>
452 61x [[nodiscard]] auto when_all(As... awaitables)
453 -> task<detail::when_all_result_t<detail::awaitable_result_t<As>...>>
454 {
455 using result_type = detail::when_all_result_t<detail::awaitable_result_t<As>...>;
456
457 // State is stored in the coroutine frame, using the frame allocator
458 detail::when_all_state<detail::awaitable_result_t<As>...> state;
459
460 // Store awaitables in the frame
461 std::tuple<As...> awaitable_tuple(std::move(awaitables)...);
462
463 // Launch all awaitables and wait for completion
464 co_await detail::when_all_launcher<As...>(&awaitable_tuple, &state);
465
466 // Propagate first exception if any.
467 // Safe without explicit acquire: capture_exception() is sequenced-before
468 // signal_completion()'s acq_rel fetch_sub, which synchronizes-with the
469 // last task's decrement that resumes this coroutine.
470 if(state.first_exception_)
471 std::rethrow_exception(state.first_exception_);
472
473 // Extract and return results
474 if constexpr (std::is_void_v<result_type>)
475 co_return;
476 else
477 co_return detail::extract_results(state);
478 122x }
479
480 /// Compute the result type of `when_all` for the given task types.
481 template<typename... Ts>
482 using when_all_result_type = detail::when_all_result_t<Ts...>;
483
484 } // namespace capy
485 } // namespace boost
486
487 #endif
488