hypercall/shared/
task_group.rs1use anyhow::{anyhow, Result};
32use std::future::Future;
33use std::time::Duration;
34use tokio::task::JoinHandle;
35use tracing::{error, info, warn};
36
37use super::shutdown::Shutdown;
38
39struct TaskHandle {
41 name: &'static str,
42 handle: JoinHandle<Result<()>>,
43}
44
45pub struct TaskGroup {
50 tasks: Vec<TaskHandle>,
51}
52
53impl Default for TaskGroup {
54 fn default() -> Self {
55 Self::new()
56 }
57}
58
59impl TaskGroup {
60 pub fn new() -> Self {
62 Self { tasks: Vec::new() }
63 }
64
65 pub fn spawn<F>(&mut self, name: &'static str, fut: F)
74 where
75 F: Future<Output = Result<()>> + Send + 'static,
76 {
77 let handle = tokio::spawn(fut);
78 self.tasks.push(TaskHandle { name, handle });
79 }
80
81 pub async fn join_all(self, timeout: Duration) -> Result<()> {
91 let mut errors = Vec::new();
92 let deadline = tokio::time::Instant::now() + timeout;
93
94 for task in self.tasks {
95 let remaining = deadline.saturating_duration_since(tokio::time::Instant::now());
96 match tokio::time::timeout(remaining, task.handle).await {
97 Ok(Ok(Ok(()))) => {
98 info!("Task '{}' completed successfully", task.name);
99 }
100 Ok(Ok(Err(e))) => {
101 error!("Task '{}' returned error: {}", task.name, e);
102 errors.push(format!("{}: {}", task.name, e));
103 }
104 Ok(Err(join_error)) => {
105 if join_error.is_panic() {
106 error!("Task '{}' panicked", task.name);
107 errors.push(format!("{}: panicked", task.name));
108 } else if join_error.is_cancelled() {
109 warn!("Task '{}' was cancelled", task.name);
110 }
112 }
113 Err(_) => {
114 error!("Task '{}' timed out after {:?}", task.name, timeout);
115 errors.push(format!("{}: timed out", task.name));
116 }
117 }
118 }
119
120 if errors.is_empty() {
121 Ok(())
122 } else {
123 Err(anyhow!("Task group shutdown failed: {}", errors.join(", ")))
124 }
125 }
126
127 pub async fn shutdown_and_join(self, shutdown: &Shutdown, timeout: Duration) -> Result<()> {
132 info!("Triggering shutdown for task group");
133 shutdown.trigger();
134 self.join_all(timeout).await
135 }
136
137 pub fn len(&self) -> usize {
139 self.tasks.len()
140 }
141
142 pub fn is_empty(&self) -> bool {
144 self.tasks.is_empty()
145 }
146}
147
148#[cfg(test)]
149mod tests {
150 use super::*;
151 use std::sync::atomic::{AtomicBool, Ordering};
152 use std::sync::Arc;
153
154 #[tokio::test]
155 async fn test_task_group_basic() {
156 let mut group = TaskGroup::new();
157 let completed = Arc::new(AtomicBool::new(false));
158 let completed_clone = completed.clone();
159
160 group.spawn("TestTask", async move {
161 completed_clone.store(true, Ordering::SeqCst);
162 Ok(())
163 });
164
165 let result = group.join_all(Duration::from_secs(1)).await;
166 assert!(result.is_ok());
167 assert!(completed.load(Ordering::SeqCst));
168 }
169
170 #[tokio::test]
171 async fn test_task_group_with_shutdown() {
172 let shutdown = Shutdown::new();
173 let mut group = TaskGroup::new();
174
175 let mut shutdown_rx = shutdown.subscribe();
177 group.spawn("ShutdownAwareTask", async move {
178 let _ = shutdown_rx.recv().await;
179 Ok(())
180 });
181
182 let shutdown_clone = shutdown.clone();
184 tokio::spawn(async move {
185 tokio::time::sleep(Duration::from_millis(10)).await;
186 shutdown_clone.trigger();
187 });
188
189 let result = group.join_all(Duration::from_secs(1)).await;
190 assert!(result.is_ok());
191 }
192
193 #[tokio::test]
194 async fn test_shutdown_and_join() {
195 let shutdown = Shutdown::new();
196 let mut group = TaskGroup::new();
197
198 let mut shutdown_rx = shutdown.subscribe();
199 group.spawn("WaitForShutdown", async move {
200 let _ = shutdown_rx.recv().await;
201 Ok(())
202 });
203
204 let result = group
206 .shutdown_and_join(&shutdown, Duration::from_secs(1))
207 .await;
208 assert!(result.is_ok());
209 }
210
211 #[tokio::test]
212 async fn test_task_error_propagation() {
213 let mut group = TaskGroup::new();
214
215 group.spawn(
216 "FailingTask",
217 async move { Err(anyhow!("intentional failure")) },
218 );
219
220 let result = group.join_all(Duration::from_secs(1)).await;
221 assert!(result.is_err());
222 assert!(result.unwrap_err().to_string().contains("FailingTask"));
223 }
224
225 #[tokio::test]
226 async fn test_multiple_tasks() {
227 let shutdown = Shutdown::new();
228 let mut group = TaskGroup::new();
229
230 for i in 0..3 {
232 let mut shutdown_rx = shutdown.subscribe();
233 let name: &'static str = match i {
234 0 => "Task0",
235 1 => "Task1",
236 2 => "Task2",
237 _ => unreachable!(),
238 };
239 group.spawn(name, async move {
240 let _ = shutdown_rx.recv().await;
241 Ok(())
242 });
243 }
244
245 assert_eq!(group.len(), 3);
246
247 let result = group
248 .shutdown_and_join(&shutdown, Duration::from_secs(1))
249 .await;
250 assert!(result.is_ok());
251 }
252}