Skip to main content

hypercall/shared/
service.rs

1use std::sync::Arc;
2
3use anyhow::Result;
4use async_trait::async_trait;
5use tracing::{error, info};
6
7use super::shutdown::{Shutdown, ShutdownRx};
8use super::task_group::TaskGroup;
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq)]
11pub enum ServiceOwner {
12    Engine,
13    Api,
14    Shared,
15}
16
17impl std::fmt::Display for ServiceOwner {
18    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
19        match self {
20            Self::Engine => write!(f, "engine"),
21            Self::Api => write!(f, "api"),
22            Self::Shared => write!(f, "shared"),
23        }
24    }
25}
26
27#[async_trait]
28pub trait Service: Send + Sync + 'static {
29    fn name(&self) -> &'static str;
30    fn owner(&self) -> ServiceOwner;
31    async fn initialize(&self) -> Result<()> {
32        Ok(())
33    }
34    async fn run(self: Arc<Self>, shutdown: ShutdownRx) -> Result<()>;
35}
36
37pub struct ServiceRegistry {
38    services: Vec<Arc<dyn Service>>,
39}
40
41impl Default for ServiceRegistry {
42    fn default() -> Self {
43        Self::new()
44    }
45}
46
47impl ServiceRegistry {
48    pub fn new() -> Self {
49        Self {
50            services: Vec::new(),
51        }
52    }
53
54    pub fn register<S: Service>(&mut self, service: Arc<S>) {
55        self.services.push(service);
56    }
57
58    pub fn register_dyn(&mut self, service: Arc<dyn Service>) {
59        self.services.push(service);
60    }
61
62    pub async fn initialize_all(&self) -> Result<()> {
63        for service in &self.services {
64            info!(
65                service = service.name(),
66                owner = %service.owner(),
67                "Initializing service"
68            );
69            service.initialize().await?;
70        }
71        Ok(())
72    }
73
74    pub fn start_all(self, shutdown: &Shutdown, task_group: &mut TaskGroup) {
75        for service in self.services {
76            let shutdown_rx = shutdown.subscribe();
77            let name = service.name();
78            let owner = service.owner();
79            info!(service = name, owner = %owner, "Starting service");
80            task_group.spawn(name, async move {
81                if let Err(e) = service.run(shutdown_rx).await {
82                    error!(service = name, owner = %owner, error = %e, "Service exited with error");
83                    return Err(e);
84                }
85                Ok(())
86            });
87        }
88    }
89
90    pub fn len(&self) -> usize {
91        self.services.len()
92    }
93
94    pub fn is_empty(&self) -> bool {
95        self.services.is_empty()
96    }
97}
98
99#[cfg(test)]
100mod tests {
101    use super::*;
102    use crate::shared::shutdown::Shutdown;
103    use std::sync::atomic::{AtomicBool, Ordering};
104    use std::time::Duration;
105
106    struct TestService {
107        initialized: AtomicBool,
108        started: Arc<tokio::sync::Notify>,
109    }
110
111    impl TestService {
112        fn new() -> Self {
113            Self {
114                initialized: AtomicBool::new(false),
115                started: Arc::new(tokio::sync::Notify::new()),
116            }
117        }
118    }
119
120    #[async_trait]
121    impl Service for TestService {
122        fn name(&self) -> &'static str {
123            "TestService"
124        }
125
126        fn owner(&self) -> ServiceOwner {
127            ServiceOwner::Api
128        }
129
130        async fn initialize(&self) -> Result<()> {
131            self.initialized.store(true, Ordering::SeqCst);
132            Ok(())
133        }
134
135        async fn run(self: Arc<Self>, mut shutdown: ShutdownRx) -> Result<()> {
136            self.started.notify_one();
137            let _ = shutdown.recv().await;
138            Ok(())
139        }
140    }
141
142    #[tokio::test]
143    async fn test_registry_lifecycle() {
144        let shutdown = Shutdown::new();
145        let mut registry = ServiceRegistry::new();
146        let mut task_group = TaskGroup::new();
147
148        let svc = Arc::new(TestService::new());
149        let svc_ref = svc.clone();
150        registry.register(svc);
151
152        assert_eq!(registry.len(), 1);
153
154        registry.initialize_all().await.unwrap();
155        assert!(svc_ref.initialized.load(Ordering::SeqCst));
156
157        let started = svc_ref.started.clone();
158        registry.start_all(&shutdown, &mut task_group);
159
160        tokio::time::timeout(Duration::from_secs(1), started.notified())
161            .await
162            .expect("service should start within 1s");
163
164        task_group
165            .shutdown_and_join(&shutdown, Duration::from_secs(1))
166            .await
167            .unwrap();
168    }
169}