1use futures::StreamExt;
2use std::sync::atomic::{AtomicBool, AtomicI64, AtomicU64, Ordering};
3use std::sync::Arc;
4use tokio::sync::broadcast::error::RecvError;
5use tracing::{error, info, warn};
6
7use super::{NatsConfig, NatsReplaySubscriber, ReplayMessage};
8
9const CATCH_UP_TICK_INTERVAL: std::time::Duration = std::time::Duration::from_millis(250);
10
11#[derive(Clone)]
13pub struct ReplayProgress {
14 commands_replayed: Arc<AtomicU64>,
15 caught_up: Arc<AtomicBool>,
16 replay_cursor_seq: Arc<AtomicI64>,
17 last_replayed_seq: Arc<AtomicI64>,
18 last_replay_unix_ms: Arc<AtomicU64>,
19 latest_stream_seq: Arc<AtomicI64>,
20}
21
22impl Default for ReplayProgress {
23 fn default() -> Self {
24 Self::new()
25 }
26}
27
28impl ReplayProgress {
29 pub fn new() -> Self {
30 Self {
31 commands_replayed: Arc::new(AtomicU64::new(0)),
32 caught_up: Arc::new(AtomicBool::new(false)),
33 replay_cursor_seq: Arc::new(AtomicI64::new(-1)),
34 last_replayed_seq: Arc::new(AtomicI64::new(-1)),
35 last_replay_unix_ms: Arc::new(AtomicU64::new(0)),
36 latest_stream_seq: Arc::new(AtomicI64::new(-1)),
37 }
38 }
39
40 pub fn commands_replayed(&self) -> u64 {
41 self.commands_replayed.load(Ordering::Relaxed)
42 }
43
44 pub fn is_caught_up(&self) -> bool {
45 self.caught_up.load(Ordering::Relaxed)
46 }
47
48 pub fn last_replayed_seq(&self) -> Option<i64> {
49 let seq = self.last_replayed_seq.load(Ordering::Relaxed);
50 (seq >= 0).then_some(seq)
51 }
52
53 pub fn replay_cursor_seq(&self) -> Option<i64> {
54 let seq = self.replay_cursor_seq.load(Ordering::Relaxed);
55 (seq >= 0).then_some(seq)
56 }
57
58 pub fn latest_stream_seq(&self) -> Option<u64> {
59 let seq = self.latest_stream_seq.load(Ordering::Relaxed);
60 (seq >= 0).then_some(seq as u64)
61 }
62
63 pub fn stream_lag(&self) -> Option<u64> {
64 let latest = self.latest_stream_seq()?;
65 let cursor = self.replay_cursor_seq()?;
66 Some(latest.saturating_sub(cursor.max(0) as u64))
67 }
68
69 pub fn last_replay_unix_ms(&self) -> Option<u64> {
70 let ts = self.last_replay_unix_ms.load(Ordering::Relaxed);
71 (ts > 0).then_some(ts)
72 }
73
74 fn set_replay_cursor(&self, seq: i64) {
75 self.replay_cursor_seq.store(seq, Ordering::Relaxed);
76 }
77
78 fn record_replayed(&self, command_seq: u64, stream_seq: u64) {
79 self.commands_replayed.fetch_add(1, Ordering::Relaxed);
80 self.replay_cursor_seq
81 .store(stream_seq as i64, Ordering::Relaxed);
82 self.last_replayed_seq
83 .store(command_seq as i64, Ordering::Relaxed);
84 let now_ms = std::time::SystemTime::now()
85 .duration_since(std::time::UNIX_EPOCH)
86 .expect("system clock before UNIX_EPOCH")
87 .as_millis() as u64;
88 self.last_replay_unix_ms.store(now_ms, Ordering::Relaxed);
89 }
90
91 fn update_catch_up_state(&self, latest_stream_seq: u64, replay_cursor_seq: i64) -> bool {
92 self.latest_stream_seq
93 .store(latest_stream_seq as i64, Ordering::Relaxed);
94 self.set_replay_cursor(replay_cursor_seq);
95 let caught_up = is_caught_up_to_stream_tail(replay_cursor_seq, latest_stream_seq);
96 self.caught_up.swap(caught_up, Ordering::Relaxed) != caught_up
97 }
98
99 #[cfg(test)]
100 pub(crate) fn test_record_replayed(&self, command_seq: u64, stream_seq: u64) {
101 self.record_replayed(command_seq, stream_seq);
102 }
103
104 #[cfg(test)]
105 pub(crate) fn test_update_catch_up_state(
106 &self,
107 latest_stream_seq: u64,
108 replay_cursor_seq: i64,
109 ) {
110 self.update_catch_up_state(latest_stream_seq, replay_cursor_seq);
111 }
112}
113
114fn is_caught_up_to_stream_tail(replay_cursor_seq: i64, latest_stream_seq: u64) -> bool {
115 replay_cursor_seq >= latest_stream_seq as i64
116}
117
118async fn refresh_catch_up_state(subscriber: &NatsReplaySubscriber, progress: &ReplayProgress) {
119 let latest_stream_seq = match subscriber.latest_stream_sequence().await {
120 Ok(seq) => seq,
121 Err(error) => {
122 warn!(
123 error = %error,
124 "Cannot determine standby catch-up status while stream state is unavailable"
125 );
126 return;
127 }
128 };
129 let replay_cursor_seq = subscriber.last_stream_sequence_processed().await;
130 let changed = progress.update_catch_up_state(latest_stream_seq, replay_cursor_seq);
131
132 if changed && progress.is_caught_up() {
133 info!(
134 commands_replayed = progress.commands_replayed(),
135 replay_cursor_seq,
136 latest_stream_seq,
137 "Standby caught up with primary after replaying through stream tail"
138 );
139 } else if changed {
140 info!(
141 commands_replayed = progress.commands_replayed(),
142 replay_cursor_seq,
143 latest_stream_seq,
144 stream_lag = progress.stream_lag().unwrap_or(0),
145 "Standby replay fell behind stream tail"
146 );
147 }
148}
149
150pub trait ReplayHandler: Send + 'static {
153 fn handle_command(
154 &mut self,
155 seq: u64,
156 command_type: super::CommandType,
157 command_version: u8,
158 command_data: Vec<u8>,
159 ) -> impl std::future::Future<Output = anyhow::Result<()>> + Send;
160}
161
162async fn apply_replay_message<H: ReplayHandler>(
163 subscriber: &NatsReplaySubscriber,
164 progress: &ReplayProgress,
165 handler: &mut H,
166 msg: ReplayMessage,
167) -> anyhow::Result<()> {
168 let ReplayMessage {
169 stream_seq,
170 seq,
171 command_type,
172 command_version,
173 command_data,
174 msg,
175 } = msg;
176 handler
177 .handle_command(seq, command_type, command_version, command_data)
178 .await
179 .map_err(|error| anyhow::anyhow!("failed to replay command seq={}: {}", seq, error))?;
180 msg.ack().await.map_err(|error| {
181 anyhow::anyhow!("NATS ack failed for stream_seq={}: {}", stream_seq, error)
182 })?;
183 subscriber.mark_processed(stream_seq).await;
184 progress.record_replayed(seq, stream_seq);
185 Ok(())
186}
187
188fn is_promote_signal(result: Result<(), RecvError>) -> bool {
191 match result {
192 Ok(()) => true,
193 Err(RecvError::Closed) => {
194 warn!("Shutdown channel closed (sender dropped) — not a promote signal, ignoring");
195 false
196 }
197 Err(RecvError::Lagged(_)) => {
198 warn!("Shutdown channel lagged — treating as promote signal");
199 true
200 }
201 }
202}
203
204pub async fn run_replay_loop<H: ReplayHandler>(
212 config: &NatsConfig,
213 start_from_seq: i64,
214 mut handler: H,
215 progress: ReplayProgress,
216 mut shutdown: tokio::sync::broadcast::Receiver<()>,
217) -> anyhow::Result<H> {
218 let subscriber = NatsReplaySubscriber::connect(config, start_from_seq).await?;
219 let mut stream = subscriber.subscribe().await?;
220 progress.set_replay_cursor(start_from_seq);
221
222 info!(start_from_seq, "Standby replay loop started");
223
224 let mut last_log = std::time::Instant::now();
225 let mut batch_count: u64 = 0;
226 let mut shutdown_active = true;
227 let mut catch_up_tick = tokio::time::interval(CATCH_UP_TICK_INTERVAL);
228 catch_up_tick.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
229
230 loop {
231 if !shutdown_active {
234 tokio::select! {
235 msg = stream.next() => {
236 match msg {
237 Some(Ok(msg)) => {
238 apply_replay_message(&subscriber, &progress, &mut handler, msg).await?;
239 batch_count += 1;
240 }
241 Some(Err(e)) => {
242 error!(error = %e, "NATS stream error");
243 tokio::time::sleep(std::time::Duration::from_secs(1)).await;
244 }
245 None => {
246 warn!("NATS stream ended, reconnecting in 1s");
247 tokio::time::sleep(std::time::Duration::from_secs(1)).await;
248 match subscriber.subscribe().await {
249 Ok(new_stream) => {
250 stream = new_stream;
251 info!("NATS stream reconnected");
252 }
253 Err(e) => {
254 error!(error = %e, "Failed to reconnect NATS stream, retrying in 5s");
255 tokio::time::sleep(std::time::Duration::from_secs(5)).await;
256 }
257 }
258 }
259 }
260 }
261 _ = catch_up_tick.tick() => {
262 refresh_catch_up_state(&subscriber, &progress).await;
263 }
264 }
265 } else {
266 tokio::select! {
267 biased;
268
269 result = shutdown.recv() => {
270 if is_promote_signal(result) {
271 info!(
272 commands_replayed = progress.commands_replayed(),
273 "Standby replay loop shutting down (promote signal received)"
274 );
275 break;
276 }
277 shutdown_active = false;
278 }
279
280 msg = stream.next() => {
281 match msg {
282 Some(Ok(msg)) => {
283 apply_replay_message(&subscriber, &progress, &mut handler, msg).await?;
284 batch_count += 1;
285 }
286 Some(Err(e)) => {
287 error!(error = %e, "NATS stream error");
288 tokio::time::sleep(std::time::Duration::from_secs(1)).await;
289 }
290 None => {
291 warn!("NATS stream ended, reconnecting in 1s");
292 tokio::time::sleep(std::time::Duration::from_secs(1)).await;
293 match subscriber.subscribe().await {
294 Ok(new_stream) => {
295 stream = new_stream;
296 info!("NATS stream reconnected");
297 }
298 Err(e) => {
299 error!(error = %e, "Failed to reconnect NATS stream, retrying in 5s");
300 tokio::time::sleep(std::time::Duration::from_secs(5)).await;
301 }
302 }
303 }
304 }
305 }
306
307 _ = catch_up_tick.tick() => {
308 refresh_catch_up_state(&subscriber, &progress).await;
309 }
310 }
311 } if last_log.elapsed() >= std::time::Duration::from_secs(10) {
314 info!(
315 commands_replayed = progress.commands_replayed(),
316 batch_since_last_log = batch_count,
317 caught_up = progress.is_caught_up(),
318 "Standby replay progress"
319 );
320 batch_count = 0;
321 last_log = std::time::Instant::now();
322 }
323 }
324
325 Ok(handler)
326}
327
328#[cfg(test)]
329mod tests {
330 use super::*;
331
332 #[test]
333 fn test_sender_drop_is_not_promote() {
334 assert!(!is_promote_signal(Err(RecvError::Closed)));
335 }
336
337 #[test]
338 fn test_explicit_send_is_promote() {
339 assert!(is_promote_signal(Ok(())));
340 }
341
342 #[test]
343 fn test_lagged_is_promote() {
344 assert!(is_promote_signal(Err(RecvError::Lagged(1))));
345 }
346
347 #[test]
348 fn test_caught_up_when_replay_cursor_reaches_stream_tail() {
349 assert!(is_caught_up_to_stream_tail(10, 10));
350 }
351
352 #[test]
353 fn test_caught_up_allows_empty_tail_at_start_cursor() {
354 assert!(is_caught_up_to_stream_tail(0, 0));
355 }
356
357 #[test]
358 fn test_caught_up_waits_for_stream_tail() {
359 assert!(!is_caught_up_to_stream_tail(9, 10));
360 }
361
362 #[test]
363 fn test_progress_can_move_from_caught_up_to_behind() {
364 let progress = ReplayProgress::new();
365 progress.update_catch_up_state(10, 10);
366 assert!(progress.is_caught_up());
367 assert_eq!(progress.stream_lag(), Some(0));
368
369 progress.update_catch_up_state(12, 10);
370 assert!(!progress.is_caught_up());
371 assert_eq!(progress.stream_lag(), Some(2));
372 }
373
374 #[tokio::test]
375 async fn test_replay_loop_ignores_sender_drop() {
376 let (tx, rx) = tokio::sync::broadcast::channel::<()>(1);
377 let _progress = ReplayProgress::new();
378
379 drop(tx);
381
382 let result = rx.resubscribe().recv().await;
384 assert!(!is_promote_signal(result));
385 }
386
387 #[tokio::test]
388 async fn test_replay_loop_exits_on_explicit_promote() {
389 let (tx, mut rx) = tokio::sync::broadcast::channel::<()>(1);
390 let _progress = ReplayProgress::new();
391
392 tx.send(()).unwrap();
394
395 let result = rx.recv().await;
396 assert!(is_promote_signal(result));
397 }
398}