Skip to main content

hypercall/shared/
task_group.rs

1//! Task group for managing background tasks with graceful shutdown.
2//!
3//! This module provides a `TaskGroup` that tracks spawned tasks and allows
4//! for coordinated shutdown with timeout handling.
5//!
6//! # Example
7//! ```ignore
8//! use crate::shared::task_group::TaskGroup;
9//! use crate::shared::shutdown::Shutdown;
10//! use std::time::Duration;
11//!
12//! let shutdown = Shutdown::new();
13//! let mut group = TaskGroup::new();
14//!
15//! // Spawn a task that respects shutdown
16//! let mut shutdown_rx = shutdown.subscribe();
17//! group.spawn("MyTask", async move {
18//!     loop {
19//!         tokio::select! {
20//!             _ = shutdown_rx.recv() => break,
21//!             // ... do work
22//!         }
23//!     }
24//!     Ok(())
25//! });
26//!
27//! // Later, shutdown all tasks
28//! group.shutdown_and_join(&shutdown, Duration::from_secs(5)).await?;
29//! ```
30
31use anyhow::{anyhow, Result};
32use std::future::Future;
33use std::time::Duration;
34use tokio::task::JoinHandle;
35use tracing::{error, info, warn};
36
37use super::shutdown::Shutdown;
38
39/// A handle to a spawned background task.
40struct TaskHandle {
41    name: &'static str,
42    handle: JoinHandle<Result<()>>,
43}
44
45/// A group of background tasks that can be shut down together.
46///
47/// Use `TaskGroup` to spawn related background tasks and ensure they all
48/// shut down gracefully with a timeout.
49pub struct TaskGroup {
50    tasks: Vec<TaskHandle>,
51}
52
53impl Default for TaskGroup {
54    fn default() -> Self {
55        Self::new()
56    }
57}
58
59impl TaskGroup {
60    /// Create a new empty task group.
61    pub fn new() -> Self {
62        Self { tasks: Vec::new() }
63    }
64
65    /// Spawn a new task and add it to the group.
66    ///
67    /// The task should respect shutdown signals (e.g., via `tokio::select!`
68    /// on a shutdown receiver) to allow graceful termination.
69    ///
70    /// # Arguments
71    /// * `name` - A static name for the task (used in error messages)
72    /// * `fut` - The future to spawn
73    pub fn spawn<F>(&mut self, name: &'static str, fut: F)
74    where
75        F: Future<Output = Result<()>> + Send + 'static,
76    {
77        let handle = tokio::spawn(fut);
78        self.tasks.push(TaskHandle { name, handle });
79    }
80
81    /// Wait for all tasks to complete with a timeout.
82    ///
83    /// Returns `Ok(())` if all tasks complete successfully within the timeout.
84    /// Returns `Err` if any task times out, panics, or returns an error.
85    ///
86    /// Note: Tasks are joined sequentially. Each task gets the full timeout duration,
87    /// so the total wait time may exceed the timeout if multiple tasks are slow.
88    /// Also note that timed-out tasks continue running in the background - the timeout
89    /// means "gave up waiting" not "task was stopped".
90    pub async fn join_all(self, timeout: Duration) -> Result<()> {
91        let mut errors = Vec::new();
92        let deadline = tokio::time::Instant::now() + timeout;
93
94        for task in self.tasks {
95            let remaining = deadline.saturating_duration_since(tokio::time::Instant::now());
96            match tokio::time::timeout(remaining, task.handle).await {
97                Ok(Ok(Ok(()))) => {
98                    info!("Task '{}' completed successfully", task.name);
99                }
100                Ok(Ok(Err(e))) => {
101                    error!("Task '{}' returned error: {}", task.name, e);
102                    errors.push(format!("{}: {}", task.name, e));
103                }
104                Ok(Err(join_error)) => {
105                    if join_error.is_panic() {
106                        error!("Task '{}' panicked", task.name);
107                        errors.push(format!("{}: panicked", task.name));
108                    } else if join_error.is_cancelled() {
109                        warn!("Task '{}' was cancelled", task.name);
110                        // Cancelled tasks are not treated as errors
111                    }
112                }
113                Err(_) => {
114                    error!("Task '{}' timed out after {:?}", task.name, timeout);
115                    errors.push(format!("{}: timed out", task.name));
116                }
117            }
118        }
119
120        if errors.is_empty() {
121            Ok(())
122        } else {
123            Err(anyhow!("Task group shutdown failed: {}", errors.join(", ")))
124        }
125    }
126
127    /// Trigger shutdown and wait for all tasks to complete.
128    ///
129    /// This is a convenience method that calls `shutdown.trigger()` and then
130    /// waits for all tasks to complete with the given timeout.
131    pub async fn shutdown_and_join(self, shutdown: &Shutdown, timeout: Duration) -> Result<()> {
132        info!("Triggering shutdown for task group");
133        shutdown.trigger();
134        self.join_all(timeout).await
135    }
136
137    /// Get the number of tasks in the group.
138    pub fn len(&self) -> usize {
139        self.tasks.len()
140    }
141
142    /// Check if the group is empty.
143    pub fn is_empty(&self) -> bool {
144        self.tasks.is_empty()
145    }
146}
147
148#[cfg(test)]
149mod tests {
150    use super::*;
151    use std::sync::atomic::{AtomicBool, Ordering};
152    use std::sync::Arc;
153
154    #[tokio::test]
155    async fn test_task_group_basic() {
156        let mut group = TaskGroup::new();
157        let completed = Arc::new(AtomicBool::new(false));
158        let completed_clone = completed.clone();
159
160        group.spawn("TestTask", async move {
161            completed_clone.store(true, Ordering::SeqCst);
162            Ok(())
163        });
164
165        let result = group.join_all(Duration::from_secs(1)).await;
166        assert!(result.is_ok());
167        assert!(completed.load(Ordering::SeqCst));
168    }
169
170    #[tokio::test]
171    async fn test_task_group_with_shutdown() {
172        let shutdown = Shutdown::new();
173        let mut group = TaskGroup::new();
174
175        // Task that waits for shutdown
176        let mut shutdown_rx = shutdown.subscribe();
177        group.spawn("ShutdownAwareTask", async move {
178            let _ = shutdown_rx.recv().await;
179            Ok(())
180        });
181
182        // Spawn a task that triggers shutdown after a short delay
183        let shutdown_clone = shutdown.clone();
184        tokio::spawn(async move {
185            tokio::time::sleep(Duration::from_millis(10)).await;
186            shutdown_clone.trigger();
187        });
188
189        let result = group.join_all(Duration::from_secs(1)).await;
190        assert!(result.is_ok());
191    }
192
193    #[tokio::test]
194    async fn test_shutdown_and_join() {
195        let shutdown = Shutdown::new();
196        let mut group = TaskGroup::new();
197
198        let mut shutdown_rx = shutdown.subscribe();
199        group.spawn("WaitForShutdown", async move {
200            let _ = shutdown_rx.recv().await;
201            Ok(())
202        });
203
204        // This should trigger and wait
205        let result = group
206            .shutdown_and_join(&shutdown, Duration::from_secs(1))
207            .await;
208        assert!(result.is_ok());
209    }
210
211    #[tokio::test]
212    async fn test_task_error_propagation() {
213        let mut group = TaskGroup::new();
214
215        group.spawn(
216            "FailingTask",
217            async move { Err(anyhow!("intentional failure")) },
218        );
219
220        let result = group.join_all(Duration::from_secs(1)).await;
221        assert!(result.is_err());
222        assert!(result.unwrap_err().to_string().contains("FailingTask"));
223    }
224
225    #[tokio::test]
226    async fn test_multiple_tasks() {
227        let shutdown = Shutdown::new();
228        let mut group = TaskGroup::new();
229
230        // Spawn multiple tasks
231        for i in 0..3 {
232            let mut shutdown_rx = shutdown.subscribe();
233            let name: &'static str = match i {
234                0 => "Task0",
235                1 => "Task1",
236                2 => "Task2",
237                _ => unreachable!(),
238            };
239            group.spawn(name, async move {
240                let _ = shutdown_rx.recv().await;
241                Ok(())
242            });
243        }
244
245        assert_eq!(group.len(), 3);
246
247        let result = group
248            .shutdown_and_join(&shutdown, Duration::from_secs(1))
249            .await;
250        assert!(result.is_ok());
251    }
252}