hypercall/snapshot/instruments/
db.rs1use crate::read_cache::instruments_registry::{serialization, InstrumentSnapshotState};
2use crate::snapshot::error::SnapshotError;
3use crate::snapshot::traits::{SnapshotLoader, SnapshotState, SnapshotWriter, Snapshotable};
4use anyhow::Result;
5use hypercall_db::types::snapshots::{InstrumentSnapshotEntry, InstrumentsSnapshotInput};
6use hypercall_db::{InstrumentsSnapshotReader, InstrumentsSnapshotWriter};
7use std::collections::HashMap;
8use std::sync::Arc;
9
10const DEFAULT_RETENTION_COUNT: i64 = 5;
11const SNAPSHOT_TYPE: &str = "instruments";
12
13pub struct DbInstrumentsSnapshotWriter<S>
15where
16 S: Snapshotable,
17{
18 db: Arc<dyn InstrumentsSnapshotWriter>,
19 service: Arc<S>,
20 get_offsets:
21 Box<dyn Fn() -> Result<HashMap<String, HashMap<i32, i64>>, SnapshotError> + Send + Sync>,
22 retention_count: i64,
23}
24
25impl<S> DbInstrumentsSnapshotWriter<S>
26where
27 S: Snapshotable,
28{
29 pub fn new<F>(db: Arc<dyn InstrumentsSnapshotWriter>, service: Arc<S>, get_offsets: F) -> Self
30 where
31 F: Fn() -> Result<HashMap<String, HashMap<i32, i64>>, SnapshotError>
32 + Send
33 + Sync
34 + 'static,
35 {
36 Self {
37 db,
38 service,
39 get_offsets: Box::new(get_offsets),
40 retention_count: DEFAULT_RETENTION_COUNT,
41 }
42 }
43
44 pub fn with_retention(mut self, count: i64) -> Self {
45 self.retention_count = count;
46 self
47 }
48}
49
50impl<S> SnapshotWriter for DbInstrumentsSnapshotWriter<S>
51where
52 S: Snapshotable<Key = String, State = InstrumentSnapshotState> + 'static,
53{
54 fn take_snapshot(&self) -> Result<i64, SnapshotError> {
55 let offsets_before = (self.get_offsets)()?;
62
63 let states = tokio::task::block_in_place(|| {
65 tokio::runtime::Handle::current().block_on(self.service.list_all())
66 })?;
67
68 let offsets_after = (self.get_offsets)()?;
70 if offsets_before != offsets_after {
71 tracing::warn!(
72 "[SNAPSHOT_DRIFT] Offsets changed during instruments snapshot capture. \
73 Using offsets_before (safe). Before: {:?}, After: {:?}",
74 offsets_before,
75 offsets_after
76 );
77 }
78
79 let offsets = offsets_before;
81
82 if states.is_empty() {
85 tracing::warn!(
86 "Skipping instruments snapshot: cache is empty (would poison future restarts)"
87 );
88 return Err(SnapshotError::Serialization(
89 "Refusing to save empty instruments snapshot".to_string(),
90 ));
91 }
92
93 let mut entries = Vec::with_capacity(states.len());
95 for (symbol, state) in states {
96 let serialized = serialization::serialize(&state)?;
97 entries.push(InstrumentSnapshotEntry {
98 symbol,
99 data: serialized,
100 });
101 }
102
103 let input = InstrumentsSnapshotInput {
104 snapshot_type: SNAPSHOT_TYPE.to_string(),
105 entries,
106 offsets,
107 retention_count: self.retention_count,
108 };
109
110 self.db
111 .write_instruments_snapshot_sync(&input)
112 .map_err(|e| SnapshotError::DbError(format!("Failed to write snapshot: {}", e)))
113 }
114}
115
116pub struct DbInstrumentsSnapshotLoader {
118 db: Arc<dyn InstrumentsSnapshotReader>,
119}
120
121impl DbInstrumentsSnapshotLoader {
122 pub fn new(db: Arc<dyn InstrumentsSnapshotReader>) -> Self {
123 Self { db }
124 }
125}
126
127impl SnapshotLoader for DbInstrumentsSnapshotLoader {
128 type Key = String;
129 type State = InstrumentSnapshotState;
130
131 fn load_latest(
132 &self,
133 ) -> Result<Option<(i64, SnapshotState<Self::Key, Self::State>)>, SnapshotError> {
134 let snapshot_id = self
135 .db
136 .get_latest_instruments_snapshot_id_sync()
137 .map_err(|e| SnapshotError::DbError(format!("Failed to fetch snapshot id: {}", e)))?;
138
139 let Some(snapshot_id) = snapshot_id else {
140 return Ok(None);
141 };
142
143 let state = self.load(snapshot_id)?;
144 Ok(Some((snapshot_id, state)))
145 }
146
147 fn load(
148 &self,
149 snapshot_id: i64,
150 ) -> Result<SnapshotState<Self::Key, Self::State>, SnapshotError> {
151 let data = self
152 .db
153 .load_instruments_snapshot_sync(snapshot_id)
154 .map_err(|e| SnapshotError::DbError(format!("Failed to load snapshot: {}", e)))?;
155
156 let mut states = HashMap::new();
157 for entry in data.entries {
158 let state = serialization::deserialize(&entry.data)?;
159 states.insert(entry.symbol, state);
160 }
161
162 let mut offset_map: HashMap<String, HashMap<i32, i64>> = HashMap::new();
163 for offset in data.offsets {
164 offset_map
165 .entry(offset.stream)
166 .or_default()
167 .insert(offset.partition, offset.offset);
168 }
169
170 Ok(SnapshotState::with_data(states, offset_map))
171 }
172}