Skip to main content

hypercall/snapshot/portfolio/
db.rs

1//! Database implementations for Portfolio Snapshot Writer and Loader.
2
3use crate::portfolio::PortfolioBalance;
4use crate::snapshot::error::SnapshotError;
5use crate::snapshot::traits::{SnapshotLoader, SnapshotState, SnapshotWriter, Snapshotable};
6use anyhow::Result;
7use hypercall_db::types::snapshots::{PortfolioSnapshotAccountEntry, PortfolioSnapshotInput};
8use hypercall_db::{PortfolioSnapshotReader, PortfolioSnapshotWriter};
9use hypercall_types::WalletAddress;
10use std::collections::HashMap;
11use std::str::FromStr;
12use std::sync::Arc;
13use tracing::{debug, info};
14
15/// Serialization helpers for PortfolioBalance.
16pub mod serialization {
17    use super::*;
18
19    pub fn serialize(balance: &PortfolioBalance) -> Result<Vec<u8>, SnapshotError> {
20        bincode::serialize(balance).map_err(|e| SnapshotError::Serialization(e.to_string()))
21    }
22
23    pub fn deserialize(data: &[u8]) -> Result<PortfolioBalance, SnapshotError> {
24        bincode::deserialize(data).map_err(|e| SnapshotError::Serialization(e.to_string()))
25    }
26}
27
28/// Default number of snapshots to retain.
29const DEFAULT_RETENTION_COUNT: i64 = 10;
30
31/// Database-backed portfolio snapshot writer.
32///
33/// Generic over any service that implements `Snapshotable<Key = WalletAddress, State = PortfolioBalance>`.
34/// Automatically cleans up old snapshots beyond the retention count.
35///
36/// # Barrier for Crash Consistency
37///
38/// If `capture_snapshot` is provided, it will be used to atomically capture (state, offsets)
39/// under a barrier lock. This ensures snapshots are crash-consistent: the captured state
40/// always includes all fills up to and including the captured seq.
41pub struct DbPortfolioSnapshotWriter<S>
42where
43    S: Snapshotable<Key = WalletAddress, State = PortfolioBalance>,
44{
45    db: Arc<dyn PortfolioSnapshotWriter>,
46    service: Arc<S>,
47    get_offsets: Box<dyn Fn() -> Result<HashMap<String, HashMap<i32, i64>>> + Send + Sync>,
48    capture_snapshot: Option<
49        Box<
50            dyn Fn() -> Result<(
51                    HashMap<WalletAddress, PortfolioBalance>,
52                    HashMap<String, HashMap<i32, i64>>,
53                )> + Send
54                + Sync,
55        >,
56    >,
57    /// Number of snapshots to retain (older ones are deleted)
58    retention_count: i64,
59}
60
61impl<S> DbPortfolioSnapshotWriter<S>
62where
63    S: Snapshotable<Key = WalletAddress, State = PortfolioBalance>,
64{
65    pub fn new<F>(db: Arc<dyn PortfolioSnapshotWriter>, service: Arc<S>, get_offsets: F) -> Self
66    where
67        F: Fn() -> Result<HashMap<String, HashMap<i32, i64>>> + Send + Sync + 'static,
68    {
69        Self {
70            db,
71            service,
72            get_offsets: Box::new(get_offsets),
73            capture_snapshot: None,
74            retention_count: DEFAULT_RETENTION_COUNT,
75        }
76    }
77
78    /// Set the number of snapshots to retain (default: 10).
79    pub fn with_retention(mut self, count: i64) -> Self {
80        self.retention_count = count;
81        self
82    }
83
84    /// Set an atomic capture function for both state and offsets.
85    ///
86    /// When set, `take_snapshot` uses this instead of calling `get_offsets` and
87    /// `service.list_all()` separately.
88    pub fn with_capture_snapshot<F>(mut self, capture: F) -> Self
89    where
90        F: Fn() -> Result<(
91                HashMap<WalletAddress, PortfolioBalance>,
92                HashMap<String, HashMap<i32, i64>>,
93            )> + Send
94            + Sync
95            + 'static,
96    {
97        self.capture_snapshot = Some(Box::new(capture));
98        self
99    }
100}
101
102impl<S> SnapshotWriter for DbPortfolioSnapshotWriter<S>
103where
104    S: Snapshotable<Key = WalletAddress, State = PortfolioBalance> + 'static,
105{
106    fn take_snapshot(&self) -> Result<i64, SnapshotError> {
107        let (states, offsets) = if let Some(ref capture_snapshot) = self.capture_snapshot {
108            (capture_snapshot)().map_err(|e| {
109                SnapshotError::DbError(format!("Failed to capture snapshot state+offsets: {}", e))
110            })?
111        } else {
112            // CRITICAL: Capture offsets BEFORE state to ensure consistency.
113            // If we capture state first and offsets second, updates processed between
114            // the two captures result in offsets ahead of state. On restore, we'd skip
115            // those updates, leaving permanent "holes" in the cache.
116            // By capturing offsets first, any concurrent updates result in offsets
117            // behind state, which is safe (replay includes duplicates; duplicates are idempotent).
118            let offsets_before = (self.get_offsets)()
119                .map_err(|e| SnapshotError::DbError(format!("Failed to get offsets: {}", e)))?;
120
121            let states = tokio::task::block_in_place(|| {
122                tokio::runtime::Handle::current().block_on(self.service.list_all())
123            })?;
124
125            let offsets_after = (self.get_offsets)().map_err(|e| {
126                SnapshotError::DbError(format!("Failed to get offsets for check: {}", e))
127            })?;
128            if offsets_before != offsets_after {
129                tracing::warn!(
130                    "[SNAPSHOT_DRIFT] Offsets changed during portfolio snapshot capture. \
131                     Using offsets_before (safe). Before: {:?}, After: {:?}",
132                    offsets_before,
133                    offsets_after
134                );
135            }
136
137            (states, offsets_before)
138        };
139
140        // Serialize accounts
141        let mut accounts = Vec::with_capacity(states.len());
142        for (wallet, balance) in &states {
143            let data = serialization::serialize(balance)?;
144            let wallet_hex = wallet.as_hex();
145            accounts.push(PortfolioSnapshotAccountEntry {
146                wallet: wallet_hex,
147                data,
148            });
149        }
150
151        let input = PortfolioSnapshotInput {
152            accounts,
153            offsets,
154            retention_count: self.retention_count,
155        };
156
157        self.db
158            .write_portfolio_snapshot_sync(&input)
159            .map_err(|e| SnapshotError::DbError(format!("Failed to write snapshot: {}", e)))
160    }
161}
162
163/// Database-backed portfolio snapshot loader.
164pub struct DbPortfolioSnapshotLoader {
165    db: Arc<dyn PortfolioSnapshotReader>,
166}
167
168impl DbPortfolioSnapshotLoader {
169    pub fn new(db: Arc<dyn PortfolioSnapshotReader>) -> Self {
170        Self { db }
171    }
172}
173
174impl SnapshotLoader for DbPortfolioSnapshotLoader {
175    type Key = WalletAddress;
176    type State = PortfolioBalance;
177
178    fn load_latest(
179        &self,
180    ) -> Result<Option<(i64, SnapshotState<Self::Key, Self::State>)>, SnapshotError> {
181        let latest_id = self
182            .db
183            .get_latest_portfolio_snapshot_id_sync()
184            .map_err(|e| {
185                SnapshotError::DbError(format!("Failed to query latest snapshot: {}", e))
186            })?;
187
188        match latest_id {
189            Some(snapshot_id) => {
190                let state = self.load(snapshot_id)?;
191                info!(
192                    "Loaded latest portfolio snapshot id={} with {} accounts",
193                    snapshot_id,
194                    state.state_count()
195                );
196                Ok(Some((snapshot_id, state)))
197            }
198            None => {
199                info!("No portfolio snapshots found");
200                Ok(None)
201            }
202        }
203    }
204
205    fn load(
206        &self,
207        snapshot_id: i64,
208    ) -> Result<SnapshotState<Self::Key, Self::State>, SnapshotError> {
209        // Verify snapshot exists
210        let exists = self
211            .db
212            .portfolio_snapshot_exists_sync(snapshot_id)
213            .map_err(|e| SnapshotError::DbError(format!("Failed to verify snapshot: {}", e)))?;
214
215        if !exists {
216            return Err(SnapshotError::NotFound(snapshot_id));
217        }
218
219        let data = self
220            .db
221            .load_portfolio_snapshot_sync(snapshot_id)
222            .map_err(|e| SnapshotError::DbError(format!("Failed to load snapshot: {}", e)))?;
223
224        let mut states = HashMap::new();
225        for account in data.accounts {
226            let balance = serialization::deserialize(&account.data)?;
227            if let Ok(wallet) = WalletAddress::from_str(&account.wallet) {
228                states.insert(wallet, balance);
229            } else {
230                debug!(
231                    "Skipping invalid wallet address in snapshot: {}",
232                    account.wallet
233                );
234            }
235        }
236
237        let offset_row_count = data.offsets.len();
238        let mut offsets: HashMap<String, HashMap<i32, i64>> = HashMap::new();
239        for offset in data.offsets {
240            offsets
241                .entry(offset.stream)
242                .or_default()
243                .insert(offset.partition, offset.offset);
244        }
245
246        debug!(
247            "Loaded snapshot id={}: {} accounts, {} offset entries",
248            snapshot_id,
249            states.len(),
250            offset_row_count
251        );
252
253        Ok(SnapshotState::with_data(states, offsets))
254    }
255}