hypercall/shared/
service.rs1use std::sync::Arc;
2
3use anyhow::Result;
4use async_trait::async_trait;
5use tracing::{error, info};
6
7use super::shutdown::{Shutdown, ShutdownRx};
8use super::task_group::TaskGroup;
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq)]
11pub enum ServiceOwner {
12 Engine,
13 Api,
14 Shared,
15}
16
17impl std::fmt::Display for ServiceOwner {
18 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
19 match self {
20 Self::Engine => write!(f, "engine"),
21 Self::Api => write!(f, "api"),
22 Self::Shared => write!(f, "shared"),
23 }
24 }
25}
26
27#[async_trait]
28pub trait Service: Send + Sync + 'static {
29 fn name(&self) -> &'static str;
30 fn owner(&self) -> ServiceOwner;
31 async fn initialize(&self) -> Result<()> {
32 Ok(())
33 }
34 async fn run(self: Arc<Self>, shutdown: ShutdownRx) -> Result<()>;
35}
36
37pub struct ServiceRegistry {
38 services: Vec<Arc<dyn Service>>,
39}
40
41impl Default for ServiceRegistry {
42 fn default() -> Self {
43 Self::new()
44 }
45}
46
47impl ServiceRegistry {
48 pub fn new() -> Self {
49 Self {
50 services: Vec::new(),
51 }
52 }
53
54 pub fn register<S: Service>(&mut self, service: Arc<S>) {
55 self.services.push(service);
56 }
57
58 pub fn register_dyn(&mut self, service: Arc<dyn Service>) {
59 self.services.push(service);
60 }
61
62 pub async fn initialize_all(&self) -> Result<()> {
63 for service in &self.services {
64 info!(
65 service = service.name(),
66 owner = %service.owner(),
67 "Initializing service"
68 );
69 service.initialize().await?;
70 }
71 Ok(())
72 }
73
74 pub fn start_all(self, shutdown: &Shutdown, task_group: &mut TaskGroup) {
75 for service in self.services {
76 let shutdown_rx = shutdown.subscribe();
77 let name = service.name();
78 let owner = service.owner();
79 info!(service = name, owner = %owner, "Starting service");
80 task_group.spawn(name, async move {
81 if let Err(e) = service.run(shutdown_rx).await {
82 error!(service = name, owner = %owner, error = %e, "Service exited with error");
83 return Err(e);
84 }
85 Ok(())
86 });
87 }
88 }
89
90 pub fn len(&self) -> usize {
91 self.services.len()
92 }
93
94 pub fn is_empty(&self) -> bool {
95 self.services.is_empty()
96 }
97}
98
99#[cfg(test)]
100mod tests {
101 use super::*;
102 use crate::shared::shutdown::Shutdown;
103 use std::sync::atomic::{AtomicBool, Ordering};
104 use std::time::Duration;
105
106 struct TestService {
107 initialized: AtomicBool,
108 started: Arc<tokio::sync::Notify>,
109 }
110
111 impl TestService {
112 fn new() -> Self {
113 Self {
114 initialized: AtomicBool::new(false),
115 started: Arc::new(tokio::sync::Notify::new()),
116 }
117 }
118 }
119
120 #[async_trait]
121 impl Service for TestService {
122 fn name(&self) -> &'static str {
123 "TestService"
124 }
125
126 fn owner(&self) -> ServiceOwner {
127 ServiceOwner::Api
128 }
129
130 async fn initialize(&self) -> Result<()> {
131 self.initialized.store(true, Ordering::SeqCst);
132 Ok(())
133 }
134
135 async fn run(self: Arc<Self>, mut shutdown: ShutdownRx) -> Result<()> {
136 self.started.notify_one();
137 let _ = shutdown.recv().await;
138 Ok(())
139 }
140 }
141
142 #[tokio::test]
143 async fn test_registry_lifecycle() {
144 let shutdown = Shutdown::new();
145 let mut registry = ServiceRegistry::new();
146 let mut task_group = TaskGroup::new();
147
148 let svc = Arc::new(TestService::new());
149 let svc_ref = svc.clone();
150 registry.register(svc);
151
152 assert_eq!(registry.len(), 1);
153
154 registry.initialize_all().await.unwrap();
155 assert!(svc_ref.initialized.load(Ordering::SeqCst));
156
157 let started = svc_ref.started.clone();
158 registry.start_all(&shutdown, &mut task_group);
159
160 tokio::time::timeout(Duration::from_secs(1), started.notified())
161 .await
162 .expect("service should start within 1s");
163
164 task_group
165 .shutdown_and_join(&shutdown, Duration::from_secs(1))
166 .await
167 .unwrap();
168 }
169}