hypercall/shared/
task_leak.rs1use anyhow::{anyhow, Result};
33use std::future::Future;
34use std::sync::atomic::{AtomicUsize, Ordering};
35use std::sync::Arc;
36use std::time::Duration;
37use tokio::sync::{mpsc, Mutex};
38use tokio::task::JoinHandle;
39use tracing::{error, info, warn};
40
41pub struct TaskTracker {
46 tasks: Arc<Mutex<Vec<TrackedTask>>>,
47 task_tx: mpsc::UnboundedSender<TrackedTask>,
48 spawn_count: AtomicUsize,
49 complete_count: AtomicUsize,
50}
51
52struct TrackedTask {
53 name: &'static str,
54 handle: JoinHandle<Result<()>>,
55}
56
57impl TaskTracker {
58 pub fn new() -> Arc<Self> {
60 let (task_tx, mut task_rx) = mpsc::unbounded_channel();
61 let tasks = Arc::new(Mutex::new(Vec::new()));
62
63 let tasks_for_bg = tasks.clone();
65 tokio::spawn(async move {
66 while let Some(task) = task_rx.recv().await {
67 tasks_for_bg.lock().await.push(task);
68 }
69 });
70
71 Arc::new(Self {
72 tasks,
73 task_tx,
74 spawn_count: AtomicUsize::new(0),
75 complete_count: AtomicUsize::new(0),
76 })
77 }
78
79 pub fn spawn<F>(self: &Arc<Self>, name: &'static str, fut: F)
83 where
84 F: Future<Output = Result<()>> + Send + 'static,
85 {
86 self.spawn_count.fetch_add(1, Ordering::SeqCst);
87 let tracker = Arc::clone(self);
88
89 let wrapped = async move {
90 let result = fut.await;
91 tracker.complete_count.fetch_add(1, Ordering::SeqCst);
92 result
93 };
94
95 let handle = tokio::spawn(wrapped);
96 let task = TrackedTask { name, handle };
97
98 let _ = self.task_tx.send(task);
100 }
101
102 pub fn spawned_count(&self) -> usize {
104 self.spawn_count.load(Ordering::SeqCst)
105 }
106
107 pub fn completed_count(&self) -> usize {
109 self.complete_count.load(Ordering::SeqCst)
110 }
111
112 pub async fn verify_no_leaks(self: Arc<Self>, timeout: Duration) -> Result<()> {
117 tokio::time::sleep(Duration::from_millis(10)).await;
119 let mut tasks = self.tasks.lock().await;
120 let total = tasks.len();
121
122 if total == 0 {
123 return Ok(());
124 }
125
126 info!("Verifying {} tracked tasks complete...", total);
127
128 let mut leaked = Vec::new();
129 let mut failed = Vec::new();
130 let mut completed = 0;
131
132 for task in tasks.drain(..) {
133 match tokio::time::timeout(timeout, task.handle).await {
134 Ok(Ok(Ok(()))) => {
135 completed += 1;
136 }
137 Ok(Ok(Err(e))) => {
138 error!("Task '{}' returned error: {}", task.name, e);
139 failed.push(format!("{}: {}", task.name, e));
140 }
141 Ok(Err(join_error)) => {
142 if join_error.is_panic() {
143 error!("Task '{}' panicked", task.name);
144 failed.push(format!("{}: panicked", task.name));
145 } else if join_error.is_cancelled() {
146 warn!("Task '{}' was cancelled", task.name);
147 completed += 1;
149 }
150 }
151 Err(_) => {
152 error!(
153 "Task '{}' leaked (timed out after {:?})",
154 task.name, timeout
155 );
156 leaked.push(task.name);
157 }
158 }
159 }
160
161 info!(
162 "Task verification: {} completed, {} leaked, {} failed",
163 completed,
164 leaked.len(),
165 failed.len()
166 );
167
168 if !leaked.is_empty() || !failed.is_empty() {
169 let mut msg = String::new();
170 if !leaked.is_empty() {
171 msg.push_str(&format!("Leaked tasks: {:?}. ", leaked));
172 }
173 if !failed.is_empty() {
174 msg.push_str(&format!("Failed tasks: {}.", failed.join(", ")));
175 }
176 return Err(anyhow!("Task leak detected: {}", msg));
177 }
178
179 Ok(())
180 }
181}
182
183#[cfg(test)]
184mod tests {
185 use super::*;
186 use crate::shared::shutdown::Shutdown;
187
188 #[tokio::test]
189 async fn test_tracker_detects_completed_tasks() {
190 let tracker = TaskTracker::new();
191
192 tracker.spawn("ImmediateTask", async move { Ok(()) });
194
195 let result = tracker.verify_no_leaks(Duration::from_secs(1)).await;
197 assert!(result.is_ok());
198 }
199
200 #[tokio::test]
201 async fn test_tracker_detects_shutdown_aware_tasks() {
202 let shutdown = Shutdown::new();
203 let tracker = TaskTracker::new();
204
205 let mut shutdown_rx = shutdown.subscribe();
207 tracker.spawn("ShutdownAwareTask", async move {
208 let _ = shutdown_rx.recv().await;
209 Ok(())
210 });
211
212 assert_eq!(tracker.spawned_count(), 1);
214
215 shutdown.trigger();
217
218 let result = tracker.verify_no_leaks(Duration::from_secs(1)).await;
220 assert!(result.is_ok());
221 }
222
223 #[tokio::test]
224 async fn test_tracker_detects_leaked_tasks() {
225 let tracker = TaskTracker::new();
226
227 tracker.spawn("LeakyTask", async move {
229 tokio::time::sleep(Duration::from_secs(10)).await;
231 Ok(())
232 });
233
234 let result = tracker.verify_no_leaks(Duration::from_millis(100)).await;
236 assert!(result.is_err());
237 let err_msg = result.unwrap_err().to_string();
238 assert!(err_msg.contains("LeakyTask"), "Error: {}", err_msg);
239 assert!(err_msg.contains("Leaked tasks"), "Error: {}", err_msg);
240 }
241
242 #[tokio::test]
243 async fn test_tracker_detects_failed_tasks() {
244 let tracker = TaskTracker::new();
245
246 tracker.spawn("FailingTask", async move {
248 Err(anyhow!("intentional test failure"))
249 });
250
251 let result = tracker.verify_no_leaks(Duration::from_secs(1)).await;
253 assert!(result.is_err());
254 let err_msg = result.unwrap_err().to_string();
255 assert!(err_msg.contains("FailingTask"), "Error: {}", err_msg);
256 }
257
258 #[tokio::test]
259 async fn test_tracker_multiple_tasks() {
260 let shutdown = Shutdown::new();
261 let tracker = TaskTracker::new();
262
263 for _ in 0..5 {
265 let mut rx = shutdown.subscribe();
266 tracker.spawn("Task", async move {
267 let _ = rx.recv().await;
268 Ok(())
269 });
270 }
271
272 assert_eq!(tracker.spawned_count(), 5);
273
274 shutdown.trigger();
276 let result = tracker.verify_no_leaks(Duration::from_secs(1)).await;
277 assert!(result.is_ok());
278 }
279}