Skip to main content

hypercall/snapshot/
traits.rs

1//! Generic snapshot traits.
2//!
3//! These traits define the contract for snapshotting any stateful service.
4//! Implementations provide the concrete types and persistence logic.
5
6use super::error::SnapshotError;
7use async_trait::async_trait;
8use std::collections::HashMap;
9use std::hash::Hash;
10
11/// Trait for services that can be snapshotted.
12///
13/// Implement this for any service whose state needs to be persisted
14/// and restored atomically with stream offsets.
15#[async_trait]
16pub trait Snapshotable: Send + Sync {
17    /// The key type for identifying entities (e.g., WalletAddress)
18    type Key: Clone + Eq + Hash + Send + Sync;
19    /// The state type for each keyed entity (e.g., wallet -> balance)
20    type State: Clone + Send + Sync;
21
22    /// List all states keyed by identifier.
23    async fn list_all(&self) -> Result<HashMap<Self::Key, Self::State>, SnapshotError>;
24
25    /// Restore a single state entry.
26    async fn restore(&self, key: &Self::Key, state: Self::State) -> Result<(), SnapshotError>;
27
28    /// Clear all state before restore.
29    async fn clear_all(&self) -> Result<(), SnapshotError>;
30}
31
32/// Trait for offset storage.
33///
34/// Offsets use `next_offset_to_apply` semantics:
35/// - All messages with offset < stored value have been applied
36/// - On restore, consumers start from stored offset
37pub trait OffsetStore: Send + Sync {
38    /// Get all offsets: stream -> partition -> next_offset_to_apply
39    fn get_all_offsets(&self) -> Result<HashMap<String, HashMap<i32, i64>>, SnapshotError>;
40
41    /// Set offsets atomically.
42    fn set_offsets(&self, offsets: HashMap<String, HashMap<i32, i64>>)
43        -> Result<(), SnapshotError>;
44}
45
46/// Generic snapshot state container.
47#[derive(Debug, Clone)]
48pub struct SnapshotState<K, S> {
49    /// States keyed by identifier
50    pub states: HashMap<K, S>,
51    /// Stream offsets: stream -> partition -> next_offset_to_apply
52    pub offsets: HashMap<String, HashMap<i32, i64>>,
53}
54
55impl<K, S> SnapshotState<K, S>
56where
57    K: Eq + Hash,
58{
59    pub fn new() -> Self {
60        Self {
61            states: HashMap::new(),
62            offsets: HashMap::new(),
63        }
64    }
65
66    pub fn with_data(states: HashMap<K, S>, offsets: HashMap<String, HashMap<i32, i64>>) -> Self {
67        Self { states, offsets }
68    }
69
70    pub fn state_count(&self) -> usize {
71        self.states.len()
72    }
73
74    pub fn offset_count(&self) -> usize {
75        self.offsets.values().map(|p| p.len()).sum()
76    }
77}
78
79impl<K, S> Default for SnapshotState<K, S>
80where
81    K: Eq + Hash,
82{
83    fn default() -> Self {
84        Self::new()
85    }
86}
87
88/// Trait for writing snapshots.
89pub trait SnapshotWriter: Send + Sync {
90    /// Take a snapshot, returning the snapshot ID.
91    fn take_snapshot(&self) -> Result<i64, SnapshotError>;
92}
93
94/// Trait for loading snapshots.
95pub trait SnapshotLoader: Send + Sync {
96    /// The key type for identifying entities
97    type Key: Clone + Eq + Hash + Send + Sync;
98    /// The state type stored in snapshots
99    type State: Clone + Send + Sync;
100
101    /// Load the latest snapshot, if any.
102    fn load_latest(
103        &self,
104    ) -> Result<Option<(i64, SnapshotState<Self::Key, Self::State>)>, SnapshotError>;
105
106    /// Load a specific snapshot by ID.
107    fn load(
108        &self,
109        snapshot_id: i64,
110    ) -> Result<SnapshotState<Self::Key, Self::State>, SnapshotError>;
111}
112
113/// Bootstrap a service from its latest snapshot.
114///
115/// Returns the snapshot ID and offsets if a snapshot was restored,
116/// or None if no snapshot exists.
117pub async fn bootstrap_from_snapshot<S, L>(
118    loader: &L,
119    service: &S,
120) -> Result<Option<(i64, HashMap<String, HashMap<i32, i64>>)>, SnapshotError>
121where
122    S: Snapshotable,
123    L: SnapshotLoader<Key = S::Key, State = S::State>,
124{
125    let snapshot_result = loader.load_latest()?;
126
127    match snapshot_result {
128        Some((snapshot_id, state)) => {
129            service.clear_all().await?;
130
131            let state_count = state.states.len();
132            for (key, s) in state.states {
133                service.restore(&key, s).await?;
134            }
135
136            tracing::info!(
137                "Restored snapshot id={} with {} entries",
138                snapshot_id,
139                state_count
140            );
141
142            Ok(Some((snapshot_id, state.offsets)))
143        }
144        None => {
145            tracing::info!("No snapshot found, starting with empty state");
146            Ok(None)
147        }
148    }
149}