Skip to main content

hypercall/nats/
replay_loop.rs

1use futures::StreamExt;
2use std::sync::atomic::{AtomicBool, AtomicI64, AtomicU64, Ordering};
3use std::sync::Arc;
4use tokio::sync::broadcast::error::RecvError;
5use tracing::{error, info, warn};
6
7use super::{NatsConfig, NatsReplaySubscriber, ReplayMessage};
8
9const CATCH_UP_TICK_INTERVAL: std::time::Duration = std::time::Duration::from_millis(250);
10
11/// Tracks how far behind the standby is from the primary.
12#[derive(Clone)]
13pub struct ReplayProgress {
14    commands_replayed: Arc<AtomicU64>,
15    caught_up: Arc<AtomicBool>,
16    replay_cursor_seq: Arc<AtomicI64>,
17    last_replayed_seq: Arc<AtomicI64>,
18    last_replay_unix_ms: Arc<AtomicU64>,
19    latest_stream_seq: Arc<AtomicI64>,
20}
21
22impl Default for ReplayProgress {
23    fn default() -> Self {
24        Self::new()
25    }
26}
27
28impl ReplayProgress {
29    pub fn new() -> Self {
30        Self {
31            commands_replayed: Arc::new(AtomicU64::new(0)),
32            caught_up: Arc::new(AtomicBool::new(false)),
33            replay_cursor_seq: Arc::new(AtomicI64::new(-1)),
34            last_replayed_seq: Arc::new(AtomicI64::new(-1)),
35            last_replay_unix_ms: Arc::new(AtomicU64::new(0)),
36            latest_stream_seq: Arc::new(AtomicI64::new(-1)),
37        }
38    }
39
40    pub fn commands_replayed(&self) -> u64 {
41        self.commands_replayed.load(Ordering::Relaxed)
42    }
43
44    pub fn is_caught_up(&self) -> bool {
45        self.caught_up.load(Ordering::Relaxed)
46    }
47
48    pub fn last_replayed_seq(&self) -> Option<i64> {
49        let seq = self.last_replayed_seq.load(Ordering::Relaxed);
50        (seq >= 0).then_some(seq)
51    }
52
53    pub fn replay_cursor_seq(&self) -> Option<i64> {
54        let seq = self.replay_cursor_seq.load(Ordering::Relaxed);
55        (seq >= 0).then_some(seq)
56    }
57
58    pub fn latest_stream_seq(&self) -> Option<u64> {
59        let seq = self.latest_stream_seq.load(Ordering::Relaxed);
60        (seq >= 0).then_some(seq as u64)
61    }
62
63    pub fn stream_lag(&self) -> Option<u64> {
64        let latest = self.latest_stream_seq()?;
65        let cursor = self.replay_cursor_seq()?;
66        Some(latest.saturating_sub(cursor.max(0) as u64))
67    }
68
69    pub fn last_replay_unix_ms(&self) -> Option<u64> {
70        let ts = self.last_replay_unix_ms.load(Ordering::Relaxed);
71        (ts > 0).then_some(ts)
72    }
73
74    fn set_replay_cursor(&self, seq: i64) {
75        self.replay_cursor_seq.store(seq, Ordering::Relaxed);
76    }
77
78    fn record_replayed(&self, command_seq: u64, stream_seq: u64) {
79        self.commands_replayed.fetch_add(1, Ordering::Relaxed);
80        self.replay_cursor_seq
81            .store(stream_seq as i64, Ordering::Relaxed);
82        self.last_replayed_seq
83            .store(command_seq as i64, Ordering::Relaxed);
84        let now_ms = std::time::SystemTime::now()
85            .duration_since(std::time::UNIX_EPOCH)
86            .expect("system clock before UNIX_EPOCH")
87            .as_millis() as u64;
88        self.last_replay_unix_ms.store(now_ms, Ordering::Relaxed);
89    }
90
91    fn update_catch_up_state(&self, latest_stream_seq: u64, replay_cursor_seq: i64) -> bool {
92        self.latest_stream_seq
93            .store(latest_stream_seq as i64, Ordering::Relaxed);
94        self.set_replay_cursor(replay_cursor_seq);
95        let caught_up = is_caught_up_to_stream_tail(replay_cursor_seq, latest_stream_seq);
96        self.caught_up.swap(caught_up, Ordering::Relaxed) != caught_up
97    }
98
99    #[cfg(test)]
100    pub(crate) fn test_record_replayed(&self, command_seq: u64, stream_seq: u64) {
101        self.record_replayed(command_seq, stream_seq);
102    }
103
104    #[cfg(test)]
105    pub(crate) fn test_update_catch_up_state(
106        &self,
107        latest_stream_seq: u64,
108        replay_cursor_seq: i64,
109    ) {
110        self.update_catch_up_state(latest_stream_seq, replay_cursor_seq);
111    }
112}
113
114fn is_caught_up_to_stream_tail(replay_cursor_seq: i64, latest_stream_seq: u64) -> bool {
115    replay_cursor_seq >= latest_stream_seq as i64
116}
117
118async fn refresh_catch_up_state(subscriber: &NatsReplaySubscriber, progress: &ReplayProgress) {
119    let latest_stream_seq = match subscriber.latest_stream_sequence().await {
120        Ok(seq) => seq,
121        Err(error) => {
122            warn!(
123                error = %error,
124                "Cannot determine standby catch-up status while stream state is unavailable"
125            );
126            return;
127        }
128    };
129    let replay_cursor_seq = subscriber.last_stream_sequence_processed().await;
130    let changed = progress.update_catch_up_state(latest_stream_seq, replay_cursor_seq);
131
132    if changed && progress.is_caught_up() {
133        info!(
134            commands_replayed = progress.commands_replayed(),
135            replay_cursor_seq,
136            latest_stream_seq,
137            "Standby caught up with primary after replaying through stream tail"
138        );
139    } else if changed {
140        info!(
141            commands_replayed = progress.commands_replayed(),
142            replay_cursor_seq,
143            latest_stream_seq,
144            stream_lag = progress.stream_lag().unwrap_or(0),
145            "Standby replay fell behind stream tail"
146        );
147    }
148}
149
150/// Callback that the replay loop calls for each command.
151/// The implementation should deserialize and call apply() on the standby engine.
152pub trait ReplayHandler: Send + 'static {
153    fn handle_command(
154        &mut self,
155        seq: u64,
156        command_type: super::CommandType,
157        command_version: u8,
158        command_data: Vec<u8>,
159    ) -> impl std::future::Future<Output = anyhow::Result<()>> + Send;
160}
161
162async fn apply_replay_message<H: ReplayHandler>(
163    subscriber: &NatsReplaySubscriber,
164    progress: &ReplayProgress,
165    handler: &mut H,
166    msg: ReplayMessage,
167) -> anyhow::Result<()> {
168    let ReplayMessage {
169        stream_seq,
170        seq,
171        command_type,
172        command_version,
173        command_data,
174        msg,
175    } = msg;
176    handler
177        .handle_command(seq, command_type, command_version, command_data)
178        .await
179        .map_err(|error| anyhow::anyhow!("failed to replay command seq={}: {}", seq, error))?;
180    msg.ack().await.map_err(|error| {
181        anyhow::anyhow!("NATS ack failed for stream_seq={}: {}", stream_seq, error)
182    })?;
183    subscriber.mark_processed(stream_seq).await;
184    progress.record_replayed(seq, stream_seq);
185    Ok(())
186}
187
188/// Check if the shutdown signal is a real promote (Ok) vs sender dropped (Err).
189/// Only returns true for an explicit promote signal.
190fn is_promote_signal(result: Result<(), RecvError>) -> bool {
191    match result {
192        Ok(()) => true,
193        Err(RecvError::Closed) => {
194            warn!("Shutdown channel closed (sender dropped) — not a promote signal, ignoring");
195            false
196        }
197        Err(RecvError::Lagged(_)) => {
198            warn!("Shutdown channel lagged — treating as promote signal");
199            true
200        }
201    }
202}
203
204/// Runs the standby replay loop: subscribes to NATS JetStream and
205/// continuously replays engine commands to keep the standby engine
206/// caught up with the primary.
207///
208/// Returns the handler ONLY when an explicit promote signal fires.
209/// Ignores sender drops and stream disconnects — only POST /admin/promote
210/// can cause this function to return.
211pub async fn run_replay_loop<H: ReplayHandler>(
212    config: &NatsConfig,
213    start_from_seq: i64,
214    mut handler: H,
215    progress: ReplayProgress,
216    mut shutdown: tokio::sync::broadcast::Receiver<()>,
217) -> anyhow::Result<H> {
218    let subscriber = NatsReplaySubscriber::connect(config, start_from_seq).await?;
219    let mut stream = subscriber.subscribe().await?;
220    progress.set_replay_cursor(start_from_seq);
221
222    info!(start_from_seq, "Standby replay loop started");
223
224    let mut last_log = std::time::Instant::now();
225    let mut batch_count: u64 = 0;
226    let mut shutdown_active = true;
227    let mut catch_up_tick = tokio::time::interval(CATCH_UP_TICK_INTERVAL);
228    catch_up_tick.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
229
230    loop {
231        // If shutdown channel closed (sender dropped), only poll the stream.
232        // This prevents biased select from spinning on Err(Closed).
233        if !shutdown_active {
234            tokio::select! {
235                msg = stream.next() => {
236                    match msg {
237                        Some(Ok(msg)) => {
238                            apply_replay_message(&subscriber, &progress, &mut handler, msg).await?;
239                            batch_count += 1;
240                        }
241                        Some(Err(e)) => {
242                            error!(error = %e, "NATS stream error");
243                            tokio::time::sleep(std::time::Duration::from_secs(1)).await;
244                        }
245                        None => {
246                            warn!("NATS stream ended, reconnecting in 1s");
247                            tokio::time::sleep(std::time::Duration::from_secs(1)).await;
248                            match subscriber.subscribe().await {
249                                Ok(new_stream) => {
250                                    stream = new_stream;
251                                    info!("NATS stream reconnected");
252                                }
253                                Err(e) => {
254                                    error!(error = %e, "Failed to reconnect NATS stream, retrying in 5s");
255                                    tokio::time::sleep(std::time::Duration::from_secs(5)).await;
256                                }
257                            }
258                        }
259                    }
260                }
261                _ = catch_up_tick.tick() => {
262                    refresh_catch_up_state(&subscriber, &progress).await;
263                }
264            }
265        } else {
266            tokio::select! {
267                biased;
268
269                result = shutdown.recv() => {
270                    if is_promote_signal(result) {
271                        info!(
272                            commands_replayed = progress.commands_replayed(),
273                            "Standby replay loop shutting down (promote signal received)"
274                        );
275                        break;
276                    }
277                    shutdown_active = false;
278                }
279
280                msg = stream.next() => {
281                    match msg {
282                        Some(Ok(msg)) => {
283                            apply_replay_message(&subscriber, &progress, &mut handler, msg).await?;
284                            batch_count += 1;
285                        }
286                        Some(Err(e)) => {
287                            error!(error = %e, "NATS stream error");
288                            tokio::time::sleep(std::time::Duration::from_secs(1)).await;
289                        }
290                        None => {
291                            warn!("NATS stream ended, reconnecting in 1s");
292                            tokio::time::sleep(std::time::Duration::from_secs(1)).await;
293                            match subscriber.subscribe().await {
294                                Ok(new_stream) => {
295                                    stream = new_stream;
296                                    info!("NATS stream reconnected");
297                                }
298                                Err(e) => {
299                                    error!(error = %e, "Failed to reconnect NATS stream, retrying in 5s");
300                                    tokio::time::sleep(std::time::Duration::from_secs(5)).await;
301                                }
302                            }
303                        }
304                    }
305                }
306
307                _ = catch_up_tick.tick() => {
308                    refresh_catch_up_state(&subscriber, &progress).await;
309                }
310            }
311        } // else shutdown_active
312
313        if last_log.elapsed() >= std::time::Duration::from_secs(10) {
314            info!(
315                commands_replayed = progress.commands_replayed(),
316                batch_since_last_log = batch_count,
317                caught_up = progress.is_caught_up(),
318                "Standby replay progress"
319            );
320            batch_count = 0;
321            last_log = std::time::Instant::now();
322        }
323    }
324
325    Ok(handler)
326}
327
328#[cfg(test)]
329mod tests {
330    use super::*;
331
332    #[test]
333    fn test_sender_drop_is_not_promote() {
334        assert!(!is_promote_signal(Err(RecvError::Closed)));
335    }
336
337    #[test]
338    fn test_explicit_send_is_promote() {
339        assert!(is_promote_signal(Ok(())));
340    }
341
342    #[test]
343    fn test_lagged_is_promote() {
344        assert!(is_promote_signal(Err(RecvError::Lagged(1))));
345    }
346
347    #[test]
348    fn test_caught_up_when_replay_cursor_reaches_stream_tail() {
349        assert!(is_caught_up_to_stream_tail(10, 10));
350    }
351
352    #[test]
353    fn test_caught_up_allows_empty_tail_at_start_cursor() {
354        assert!(is_caught_up_to_stream_tail(0, 0));
355    }
356
357    #[test]
358    fn test_caught_up_waits_for_stream_tail() {
359        assert!(!is_caught_up_to_stream_tail(9, 10));
360    }
361
362    #[test]
363    fn test_progress_can_move_from_caught_up_to_behind() {
364        let progress = ReplayProgress::new();
365        progress.update_catch_up_state(10, 10);
366        assert!(progress.is_caught_up());
367        assert_eq!(progress.stream_lag(), Some(0));
368
369        progress.update_catch_up_state(12, 10);
370        assert!(!progress.is_caught_up());
371        assert_eq!(progress.stream_lag(), Some(2));
372    }
373
374    #[tokio::test]
375    async fn test_replay_loop_ignores_sender_drop() {
376        let (tx, rx) = tokio::sync::broadcast::channel::<()>(1);
377        let _progress = ReplayProgress::new();
378
379        // Drop the sender — this should NOT cause the replay loop to exit
380        drop(tx);
381
382        // The recv should return Err(Closed), which is_promote_signal rejects
383        let result = rx.resubscribe().recv().await;
384        assert!(!is_promote_signal(result));
385    }
386
387    #[tokio::test]
388    async fn test_replay_loop_exits_on_explicit_promote() {
389        let (tx, mut rx) = tokio::sync::broadcast::channel::<()>(1);
390        let _progress = ReplayProgress::new();
391
392        // Send explicit promote signal
393        tx.send(()).unwrap();
394
395        let result = rx.recv().await;
396        assert!(is_promote_signal(result));
397    }
398}