Skip to main content

hypercall/shared/
task_leak.rs

1//! Task leak detection utilities for testing.
2//!
3//! This module provides a mechanism to verify that all spawned tasks have exited
4//! after a test completes, similar to [goleak](https://github.com/uber-go/goleak)
5//! for Go goroutines.
6//!
7//! # Example
8//! ```ignore
9//! use crate::shared::task_leak::TaskTracker;
10//! use crate::shared::shutdown::Shutdown;
11//!
12//! #[tokio::test]
13//! async fn test_no_task_leaks() {
14//!     let shutdown = Shutdown::new();
15//!     let tracker = TaskTracker::new();
16//!
17//!     // Spawn tasks using the tracker
18//!     let mut shutdown_rx = shutdown.subscribe();
19//!     tracker.spawn("MyTask", async move {
20//!         let _ = shutdown_rx.recv().await;
21//!         Ok(())
22//!     });
23//!
24//!     // Trigger shutdown
25//!     shutdown.trigger();
26//!
27//!     // Verify all tasks complete within timeout
28//!     tracker.verify_no_leaks(Duration::from_secs(5)).await.unwrap();
29//! }
30//! ```
31
32use anyhow::{anyhow, Result};
33use std::future::Future;
34use std::sync::atomic::{AtomicUsize, Ordering};
35use std::sync::Arc;
36use std::time::Duration;
37use tokio::sync::{mpsc, Mutex};
38use tokio::task::JoinHandle;
39use tracing::{error, info, warn};
40
41/// Tracks spawned tasks and detects leaks.
42///
43/// Similar to goleak, this helps detect tasks that don't properly
44/// respond to shutdown signals and continue running indefinitely.
45pub struct TaskTracker {
46    tasks: Arc<Mutex<Vec<TrackedTask>>>,
47    task_tx: mpsc::UnboundedSender<TrackedTask>,
48    spawn_count: AtomicUsize,
49    complete_count: AtomicUsize,
50}
51
52struct TrackedTask {
53    name: &'static str,
54    handle: JoinHandle<Result<()>>,
55}
56
57impl TaskTracker {
58    /// Create a new task tracker wrapped in an Arc.
59    pub fn new() -> Arc<Self> {
60        let (task_tx, mut task_rx) = mpsc::unbounded_channel();
61        let tasks = Arc::new(Mutex::new(Vec::new()));
62
63        // Background task to process queued tasks
64        let tasks_for_bg = tasks.clone();
65        tokio::spawn(async move {
66            while let Some(task) = task_rx.recv().await {
67                tasks_for_bg.lock().await.push(task);
68            }
69        });
70
71        Arc::new(Self {
72            tasks,
73            task_tx,
74            spawn_count: AtomicUsize::new(0),
75            complete_count: AtomicUsize::new(0),
76        })
77    }
78
79    /// Spawn a task and track it.
80    ///
81    /// The task should respect shutdown signals to allow proper cleanup.
82    pub fn spawn<F>(self: &Arc<Self>, name: &'static str, fut: F)
83    where
84        F: Future<Output = Result<()>> + Send + 'static,
85    {
86        self.spawn_count.fetch_add(1, Ordering::SeqCst);
87        let tracker = Arc::clone(self);
88
89        let wrapped = async move {
90            let result = fut.await;
91            tracker.complete_count.fetch_add(1, Ordering::SeqCst);
92            result
93        };
94
95        let handle = tokio::spawn(wrapped);
96        let task = TrackedTask { name, handle };
97
98        // Queue task for tracking (non-blocking)
99        let _ = self.task_tx.send(task);
100    }
101
102    /// Get the number of tasks spawned.
103    pub fn spawned_count(&self) -> usize {
104        self.spawn_count.load(Ordering::SeqCst)
105    }
106
107    /// Get the number of tasks that have completed.
108    pub fn completed_count(&self) -> usize {
109        self.complete_count.load(Ordering::SeqCst)
110    }
111
112    /// Verify that all tracked tasks complete within the timeout.
113    ///
114    /// Returns `Ok(())` if all tasks complete successfully.
115    /// Returns `Err` with details about any leaked (timed out) or failed tasks.
116    pub async fn verify_no_leaks(self: Arc<Self>, timeout: Duration) -> Result<()> {
117        // Give background task a moment to process any queued tasks
118        tokio::time::sleep(Duration::from_millis(10)).await;
119        let mut tasks = self.tasks.lock().await;
120        let total = tasks.len();
121
122        if total == 0 {
123            return Ok(());
124        }
125
126        info!("Verifying {} tracked tasks complete...", total);
127
128        let mut leaked = Vec::new();
129        let mut failed = Vec::new();
130        let mut completed = 0;
131
132        for task in tasks.drain(..) {
133            match tokio::time::timeout(timeout, task.handle).await {
134                Ok(Ok(Ok(()))) => {
135                    completed += 1;
136                }
137                Ok(Ok(Err(e))) => {
138                    error!("Task '{}' returned error: {}", task.name, e);
139                    failed.push(format!("{}: {}", task.name, e));
140                }
141                Ok(Err(join_error)) => {
142                    if join_error.is_panic() {
143                        error!("Task '{}' panicked", task.name);
144                        failed.push(format!("{}: panicked", task.name));
145                    } else if join_error.is_cancelled() {
146                        warn!("Task '{}' was cancelled", task.name);
147                        // Cancelled is not necessarily a leak
148                        completed += 1;
149                    }
150                }
151                Err(_) => {
152                    error!(
153                        "Task '{}' leaked (timed out after {:?})",
154                        task.name, timeout
155                    );
156                    leaked.push(task.name);
157                }
158            }
159        }
160
161        info!(
162            "Task verification: {} completed, {} leaked, {} failed",
163            completed,
164            leaked.len(),
165            failed.len()
166        );
167
168        if !leaked.is_empty() || !failed.is_empty() {
169            let mut msg = String::new();
170            if !leaked.is_empty() {
171                msg.push_str(&format!("Leaked tasks: {:?}. ", leaked));
172            }
173            if !failed.is_empty() {
174                msg.push_str(&format!("Failed tasks: {}.", failed.join(", ")));
175            }
176            return Err(anyhow!("Task leak detected: {}", msg));
177        }
178
179        Ok(())
180    }
181}
182
183#[cfg(test)]
184mod tests {
185    use super::*;
186    use crate::shared::shutdown::Shutdown;
187
188    #[tokio::test]
189    async fn test_tracker_detects_completed_tasks() {
190        let tracker = TaskTracker::new();
191
192        // Spawn a task that completes immediately
193        tracker.spawn("ImmediateTask", async move { Ok(()) });
194
195        // Verify it completed
196        let result = tracker.verify_no_leaks(Duration::from_secs(1)).await;
197        assert!(result.is_ok());
198    }
199
200    #[tokio::test]
201    async fn test_tracker_detects_shutdown_aware_tasks() {
202        let shutdown = Shutdown::new();
203        let tracker = TaskTracker::new();
204
205        // Spawn a task that waits for shutdown
206        let mut shutdown_rx = shutdown.subscribe();
207        tracker.spawn("ShutdownAwareTask", async move {
208            let _ = shutdown_rx.recv().await;
209            Ok(())
210        });
211
212        // Task is running
213        assert_eq!(tracker.spawned_count(), 1);
214
215        // Trigger shutdown
216        shutdown.trigger();
217
218        // Verify all tasks complete
219        let result = tracker.verify_no_leaks(Duration::from_secs(1)).await;
220        assert!(result.is_ok());
221    }
222
223    #[tokio::test]
224    async fn test_tracker_detects_leaked_tasks() {
225        let tracker = TaskTracker::new();
226
227        // Spawn a task that never exits
228        tracker.spawn("LeakyTask", async move {
229            // Simulate a task that doesn't respect shutdown
230            tokio::time::sleep(Duration::from_secs(10)).await;
231            Ok(())
232        });
233
234        // Verify with a short timeout - should detect the leak
235        let result = tracker.verify_no_leaks(Duration::from_millis(100)).await;
236        assert!(result.is_err());
237        let err_msg = result.unwrap_err().to_string();
238        assert!(err_msg.contains("LeakyTask"), "Error: {}", err_msg);
239        assert!(err_msg.contains("Leaked tasks"), "Error: {}", err_msg);
240    }
241
242    #[tokio::test]
243    async fn test_tracker_detects_failed_tasks() {
244        let tracker = TaskTracker::new();
245
246        // Spawn a task that returns an error
247        tracker.spawn("FailingTask", async move {
248            Err(anyhow!("intentional test failure"))
249        });
250
251        // Verify - should detect the failure
252        let result = tracker.verify_no_leaks(Duration::from_secs(1)).await;
253        assert!(result.is_err());
254        let err_msg = result.unwrap_err().to_string();
255        assert!(err_msg.contains("FailingTask"), "Error: {}", err_msg);
256    }
257
258    #[tokio::test]
259    async fn test_tracker_multiple_tasks() {
260        let shutdown = Shutdown::new();
261        let tracker = TaskTracker::new();
262
263        // Spawn multiple tasks
264        for _ in 0..5 {
265            let mut rx = shutdown.subscribe();
266            tracker.spawn("Task", async move {
267                let _ = rx.recv().await;
268                Ok(())
269            });
270        }
271
272        assert_eq!(tracker.spawned_count(), 5);
273
274        // Shutdown and verify
275        shutdown.trigger();
276        let result = tracker.verify_no_leaks(Duration::from_secs(1)).await;
277        assert!(result.is_ok());
278    }
279}