Skip to main content

hypercall_api/
middleware.rs

1use axum::{
2    extract::{Request, State},
3    http::{Method, StatusCode},
4    middleware::Next,
5    response::{IntoResponse, Response},
6};
7use sonic_rs::json;
8use std::sync::Arc;
9use tokio::sync::Mutex;
10
11use crate::sonic_json::SonicJson;
12
13use crate::caches::rate_limit::{RateLimitAction, RateLimitCache, RateLimitError, RateLimitInfo};
14use crate::observability_boundary::AuthFailureRecorder;
15use crate::request_auth::{authorize_signer, RequestAuthError};
16use crate::runtime_status::{ReadinessGate, StandbyReplayProgress};
17use crate::signed_actions::recover_signed_action;
18use hypercall_types::{WalletAddress, API_ROUTE_PROFILE_IMAGE};
19use std::str::FromStr;
20
21const SIGNED_JSON_BODY_LIMIT_BYTES: usize = 1_048_576;
22const PROFILE_IMAGE_BODY_LIMIT_BYTES: usize = 7_000_000;
23
24#[derive(Clone)]
25pub struct SignatureMiddlewareState {
26    pub agent_auth: Arc<dyn hypercall_runtime_api::AgentAuthProvider>,
27    pub auth_failure_recorder: Arc<dyn AuthFailureRecorder>,
28    pub signing_chain_id: u64,
29}
30
31#[cfg(feature = "otel-tracing")]
32use opentelemetry::propagation::TextMapPropagator;
33#[cfg(feature = "otel-tracing")]
34use opentelemetry_sdk::propagation::TraceContextPropagator;
35#[cfg(feature = "otel-tracing")]
36use tracing_opentelemetry::OpenTelemetrySpanExt;
37
38/// Helper function to create a JSON error response
39fn json_error_response(status: StatusCode, error: &str, message: &str) -> Response {
40    let body = json!({
41        "error": error,
42        "message": message
43    });
44    (status, SonicJson(body)).into_response()
45}
46
47/// A header extractor that implements OpenTelemetry's Extractor trait for axum headers.
48#[cfg(feature = "otel-tracing")]
49struct HeaderExtractor<'a>(&'a axum::http::HeaderMap);
50
51#[cfg(feature = "otel-tracing")]
52impl opentelemetry::propagation::Extractor for HeaderExtractor<'_> {
53    fn get(&self, key: &str) -> Option<&str> {
54        self.0.get(key).and_then(|v| v.to_str().ok())
55    }
56
57    fn keys(&self) -> Vec<&str> {
58        self.0.keys().map(|k| k.as_str()).collect()
59    }
60}
61
62/// Middleware that extracts W3C Trace Context from incoming HTTP headers.
63///
64/// When the `otel-tracing` feature is enabled, this extracts `traceparent` and `tracestate`
65/// headers and sets the OpenTelemetry context on the current tracing span, enabling
66/// distributed tracing across service boundaries.
67///
68/// This middleware should be applied early in the middleware stack (after TraceLayer)
69/// so that the extracted context is available for all subsequent handlers.
70#[cfg(feature = "otel-tracing")]
71pub async fn trace_context_middleware(req: Request, next: Next) -> Response {
72    let propagator = TraceContextPropagator::new();
73    let parent_cx = propagator.extract(&HeaderExtractor(req.headers()));
74
75    // Set the extracted context as the parent for the current span
76    tracing::Span::current().set_parent(parent_cx);
77
78    next.run(req).await
79}
80
81/// Middleware that blocks all requests (except /health, /ready, /metrics) until the service is ready.
82///
83/// This ensures clients don't receive stale data while the portfolio is being restored
84/// from a snapshot or catching up from event streams.
85///
86/// Uses the ReadinessRegistry to check all component readiness and provides detailed
87/// component breakdown in the 503 response.
88#[derive(Clone)]
89pub struct ReadinessMiddlewareState {
90    pub readiness: Arc<dyn ReadinessGate>,
91    pub standby_progress: Option<Arc<dyn StandbyReplayProgress>>,
92    pub standby_promote: Option<Arc<Mutex<Option<tokio::sync::oneshot::Sender<()>>>>>,
93}
94
95impl ReadinessMiddlewareState {
96    pub fn primary(readiness: Arc<dyn ReadinessGate>) -> Self {
97        Self {
98            readiness,
99            standby_progress: None,
100            standby_promote: None,
101        }
102    }
103
104    pub fn new(
105        readiness: Arc<dyn ReadinessGate>,
106        standby_progress: Option<Arc<dyn StandbyReplayProgress>>,
107        standby_promote: Option<Arc<Mutex<Option<tokio::sync::oneshot::Sender<()>>>>>,
108    ) -> Self {
109        Self {
110            readiness,
111            standby_progress,
112            standby_promote,
113        }
114    }
115}
116
117pub async fn readiness_middleware(
118    State(state): State<ReadinessMiddlewareState>,
119    req: Request,
120    next: Next,
121) -> Response {
122    let path = req.uri().path();
123
124    if is_readiness_bypass_path(path) {
125        return next.run(req).await;
126    }
127
128    // Block all other requests until ready
129    if !state.readiness.all_ready() {
130        if is_standby_monitoring_get(req.method(), path)
131            && standby_is_caught_up_and_unpromoted(&state).await
132        {
133            return next.run(req).await;
134        }
135
136        let components = state.readiness.reports();
137        return (
138            StatusCode::SERVICE_UNAVAILABLE,
139            SonicJson(json!({
140                "error": "service_not_ready",
141                "message": "Service is starting up. Please retry shortly.",
142                "components": components
143            })),
144        )
145            .into_response();
146    }
147
148    next.run(req).await
149}
150
151fn is_readiness_bypass_path(path: &str) -> bool {
152    matches!(
153        path,
154        "/health"
155            | "/ready"
156            | "/metrics"
157            | "/standby-ready"
158            | "/admin/promote"
159            | "/admin/drain"
160            | "/admin/undrain"
161            | "/drain-status"
162            | "/recovery-safety"
163            | "/recovery-safety/alert"
164            | "/monitoring/recovery-safety"
165    )
166}
167
168fn is_standby_monitoring_get(method: &Method, path: &str) -> bool {
169    if method != Method::GET {
170        return false;
171    }
172
173    matches!(
174        path,
175        "/monitoring/integrity"
176            | "/monitoring/orderbooks"
177            | "/monitoring/accounts"
178            | "/monitoring/positions"
179            | "/monitoring/directive-outbox"
180            | "/monitoring/deposits"
181            | "/monitoring/engine-state-digest"
182            | "/recovery-safety"
183            | "/recovery-safety/alert"
184            | "/monitoring/recovery-safety"
185            | "/monitoring/trading-halts"
186            | "/monitoring/vol-oracles"
187            | "/monitoring/price-oracles"
188            | "/monitoring/vol-surface"
189    )
190}
191
192async fn standby_is_caught_up_and_unpromoted(state: &ReadinessMiddlewareState) -> bool {
193    let Some(progress) = &state.standby_progress else {
194        return false;
195    };
196    if !progress.is_caught_up() {
197        return false;
198    }
199
200    match &state.standby_promote {
201        Some(promote) => promote.lock().await.is_some(),
202        None => false,
203    }
204}
205
206#[derive(Clone)]
207pub struct AuthenticatedUser {
208    pub wallet_address: WalletAddress,
209}
210
211#[derive(Clone)]
212pub struct AccountContext {
213    pub account_contract_address: WalletAddress,
214    pub api_wallet_address: WalletAddress,
215}
216
217#[derive(Clone, Debug)]
218pub struct SignerContext {
219    pub wallet_address: WalletAddress, // The account being acted upon
220    pub signer_address: WalletAddress, // The agent that signed (may equal wallet)
221}
222
223/// Middleware for signature-based authentication with agent authorization
224pub async fn signature_and_agent_middleware(
225    State(state): State<SignatureMiddlewareState>,
226    req: Request,
227    next: Next,
228) -> Response {
229    match signature_and_agent_middleware_inner(state, req, next).await {
230        Ok(response) => response,
231        Err(response) => response,
232    }
233}
234
235async fn signature_and_agent_middleware_inner(
236    state: SignatureMiddlewareState,
237    req: Request,
238    next: Next,
239) -> Result<Response, Response> {
240    // Extract the request body to get signature, nonce, and wallet
241    let (mut parts, body) = req.into_parts();
242    let body_limit = match (parts.uri.path(), &parts.method) {
243        (API_ROUTE_PROFILE_IMAGE, &Method::POST) => PROFILE_IMAGE_BODY_LIMIT_BYTES,
244        _ => SIGNED_JSON_BODY_LIMIT_BYTES,
245    };
246
247    let body_bytes = match axum::body::to_bytes(body, body_limit).await {
248        Ok(bytes) => bytes,
249        Err(_) => {
250            return Err(json_error_response(
251                StatusCode::PAYLOAD_TOO_LARGE,
252                "payload_too_large",
253                "Request body exceeds the allowed size",
254            ));
255        }
256    };
257
258    // Parse the JSON body
259    let request_data: sonic_rs::Value = match sonic_rs::from_slice(&body_bytes) {
260        Ok(json) => json,
261        Err(e) => {
262            tracing::error!("Failed to parse request body: {}", e);
263            return Err(json_error_response(
264                StatusCode::BAD_REQUEST,
265                "invalid_json",
266                &format!("Failed to parse JSON: {}", e),
267            ));
268        }
269    };
270
271    let recovered = recover_signed_action(
272        parts.uri.path(),
273        &parts.method,
274        &request_data,
275        state.signing_chain_id,
276    )
277    .map_err(|e| {
278        if let Some(reason) = e.auth_failure_reason {
279            state.auth_failure_recorder.record_auth_failure(reason);
280        }
281        tracing::error!(
282            error = e.error,
283            message = %e.message,
284            "Failed to recover signed action"
285        );
286        json_error_response(e.status, e.error, &e.message)
287    })?;
288
289    let signer_ctx = authorize_signer(
290        state.agent_auth.as_ref(),
291        recovered.wallet,
292        recovered.signer,
293    )
294    .map_err(|e| match e {
295        RequestAuthError::Unauthorized { wallet, signer } => {
296            tracing::warn!(
297                "Signer {} is not authorized to act for wallet {}",
298                signer,
299                wallet
300            );
301            json_error_response(StatusCode::UNAUTHORIZED, "unauthorized", &e.to_string())
302        }
303        RequestAuthError::Signature(_) => json_error_response(
304            StatusCode::BAD_REQUEST,
305            "signature_verification_failed",
306            &e.to_string(),
307        ),
308    })?;
309
310    // Inject SignerContext into request extensions
311    parts.extensions.insert(signer_ctx);
312
313    // Reconstruct the request with the original body
314    let req = Request::from_parts(parts, axum::body::Body::from(body_bytes));
315
316    Ok(next.run(req).await)
317}
318
319/// Create a 429 Too Many Requests response with rate limit headers.
320fn rate_limit_exceeded_response(err: &RateLimitError) -> Response {
321    match err {
322        RateLimitError::Exceeded {
323            retry_after_secs,
324            limit,
325            ..
326        } => {
327            let body = json!({
328                "error": "rate_limit_exceeded",
329                "message": format!("{}", err),
330                "retry_after_secs": retry_after_secs,
331                "limit": limit,
332            });
333
334            let mut response = (StatusCode::TOO_MANY_REQUESTS, SonicJson(body)).into_response();
335
336            // Add Retry-After header
337            response
338                .headers_mut()
339                .insert("Retry-After", retry_after_secs.to_string().parse().unwrap());
340            response
341                .headers_mut()
342                .insert("X-RateLimit-Limit", limit.to_string().parse().unwrap());
343            response
344                .headers_mut()
345                .insert("X-RateLimit-Remaining", "0".parse().unwrap());
346
347            response
348        }
349        RateLimitError::ServiceUnavailable { message } => {
350            let body = json!({
351                "error": "service_unavailable",
352                "message": message,
353            });
354
355            (StatusCode::SERVICE_UNAVAILABLE, SonicJson(body)).into_response()
356        }
357    }
358}
359
360/// Add rate limit headers to a response.
361fn add_rate_limit_headers(response: &mut Response, info: &RateLimitInfo) {
362    response
363        .headers_mut()
364        .insert("X-RateLimit-Limit", info.limit.to_string().parse().unwrap());
365    response.headers_mut().insert(
366        "X-RateLimit-Remaining",
367        info.remaining.to_string().parse().unwrap(),
368    );
369    response.headers_mut().insert(
370        "X-RateLimit-Reset",
371        info.reset_at.to_string().parse().unwrap(),
372    );
373}
374
375/// Extract wallet address from SignerContext extension.
376fn extract_wallet(req: &Request) -> Option<WalletAddress> {
377    req.extensions()
378        .get::<SignerContext>()
379        .map(|ctx| ctx.wallet_address)
380}
381
382/// State required for rate limiting middleware.
383#[derive(Clone)]
384pub struct RateLimitState {
385    pub rate_limiter: Arc<RateLimitCache>,
386}
387
388/// Middleware to check order placement rate limits.
389///
390/// Must be applied AFTER signature middleware (needs SignerContext).
391pub async fn order_rate_limit_middleware(
392    State(state): State<RateLimitState>,
393    req: Request,
394    next: Next,
395) -> Response {
396    // Extract wallet from SignerContext (set by signature middleware)
397    let wallet = req
398        .extensions()
399        .get::<SignerContext>()
400        .map(|ctx| ctx.wallet_address);
401
402    if let Some(wallet) = wallet {
403        // Check order placement rate limit
404        match state
405            .rate_limiter
406            .check_and_increment(&wallet, RateLimitAction::OrderPlacement)
407            .await
408        {
409            Ok(info) => {
410                let mut response = next.run(req).await;
411                // Add rate limit headers to successful response
412                response
413                    .headers_mut()
414                    .insert("X-RateLimit-Limit", info.limit.to_string().parse().unwrap());
415                response.headers_mut().insert(
416                    "X-RateLimit-Remaining",
417                    info.remaining.to_string().parse().unwrap(),
418                );
419                response.headers_mut().insert(
420                    "X-RateLimit-Reset",
421                    info.reset_at.to_string().parse().unwrap(),
422                );
423                response
424            }
425            Err(err) => rate_limit_exceeded_response(&err),
426        }
427    } else {
428        // No wallet context - allow through (signature middleware will handle auth)
429        next.run(req).await
430    }
431}
432
433/// Middleware to check order cancellation rate limits.
434///
435/// Must be applied AFTER signature middleware (needs SignerContext).
436pub async fn cancel_rate_limit_middleware(
437    State(state): State<RateLimitState>,
438    req: Request,
439    next: Next,
440) -> Response {
441    // Extract wallet from SignerContext
442    let wallet = req
443        .extensions()
444        .get::<SignerContext>()
445        .map(|ctx| ctx.wallet_address);
446
447    if let Some(wallet) = wallet {
448        match state
449            .rate_limiter
450            .check_and_increment(&wallet, RateLimitAction::OrderCancellation)
451            .await
452        {
453            Ok(info) => {
454                let mut response = next.run(req).await;
455                response
456                    .headers_mut()
457                    .insert("X-RateLimit-Limit", info.limit.to_string().parse().unwrap());
458                response.headers_mut().insert(
459                    "X-RateLimit-Remaining",
460                    info.remaining.to_string().parse().unwrap(),
461                );
462                response.headers_mut().insert(
463                    "X-RateLimit-Reset",
464                    info.reset_at.to_string().parse().unwrap(),
465                );
466                response
467            }
468            Err(err) => rate_limit_exceeded_response(&err),
469        }
470    } else {
471        next.run(req).await
472    }
473}
474
475/// Middleware to check general API rate limits.
476///
477/// This can be applied to any endpoint. For authenticated endpoints,
478/// it uses the wallet from SignerContext. For public endpoints,
479/// it attempts to extract wallet from query parameters.
480pub async fn api_rate_limit_middleware(
481    State(state): State<RateLimitState>,
482    req: Request,
483    next: Next,
484) -> Response {
485    // Try to get wallet from SignerContext first (authenticated endpoints)
486    let wallet = req
487        .extensions()
488        .get::<SignerContext>()
489        .map(|ctx| ctx.wallet_address);
490
491    // If no SignerContext, try to extract from query params (for public endpoints)
492    let wallet = wallet.or_else(|| {
493        req.uri()
494            .query()
495            .and_then(|q| {
496                // Parse query string to find "wallet=" parameter
497                q.split('&')
498                    .find(|p| p.starts_with("wallet="))
499                    .map(|p| &p[7..]) // Skip "wallet="
500            })
501            .and_then(|w| WalletAddress::from_str(w).ok())
502    });
503
504    if let Some(wallet) = wallet {
505        match state
506            .rate_limiter
507            .check_and_increment(&wallet, RateLimitAction::ApiRequest)
508            .await
509        {
510            Ok(info) => {
511                let mut response = next.run(req).await;
512                response
513                    .headers_mut()
514                    .insert("X-RateLimit-Limit", info.limit.to_string().parse().unwrap());
515                response.headers_mut().insert(
516                    "X-RateLimit-Remaining",
517                    info.remaining.to_string().parse().unwrap(),
518                );
519                response.headers_mut().insert(
520                    "X-RateLimit-Reset",
521                    info.reset_at.to_string().parse().unwrap(),
522                );
523                response
524            }
525            Err(err) => rate_limit_exceeded_response(&err),
526        }
527    } else {
528        // No wallet identified - allow through without rate limiting
529        // (could add IP-based rate limiting here in the future)
530        next.run(req).await
531    }
532}
533
534/// Combined middleware for write routes that applies appropriate rate limit
535/// based on the request path and method.
536///
537/// Must be applied AFTER signature middleware (needs SignerContext).
538///
539/// Rate limit mapping:
540/// - POST /order, PUT /order -> OrderPlacement
541/// - DELETE /order, DELETE /order_cloid -> OrderCancellation
542/// - All other write routes -> ApiRequest
543pub async fn write_route_rate_limit_middleware(
544    State(state): State<RateLimitState>,
545    req: Request,
546    next: Next,
547) -> Response {
548    use axum::http::Method;
549    use hypercall_types::{API_ROUTE_ORDER, API_ROUTE_ORDER_CLOID, API_ROUTE_RFQ_REQUEST};
550
551    // Determine which rate limit action applies based on path and method
552    let action = match (req.method(), req.uri().path()) {
553        (&Method::POST, API_ROUTE_ORDER) | (&Method::PUT, API_ROUTE_ORDER) => {
554            RateLimitAction::OrderPlacement
555        }
556        (&Method::DELETE, API_ROUTE_ORDER) | (&Method::DELETE, API_ROUTE_ORDER_CLOID) => {
557            RateLimitAction::OrderCancellation
558        }
559        // RFQ submit fans out to every connected QP — throttle per wallet
560        (&Method::POST, API_ROUTE_RFQ_REQUEST) => RateLimitAction::RfqSubmit,
561        // MMP and admin endpoints use general API rate limit
562        _ => RateLimitAction::ApiRequest,
563    };
564
565    // Extract wallet from SignerContext (set by signature middleware)
566    let wallet = extract_wallet(&req);
567
568    if let Some(wallet) = wallet {
569        match state
570            .rate_limiter
571            .check_and_increment(&wallet, action)
572            .await
573        {
574            Ok(info) => {
575                let mut response = next.run(req).await;
576                add_rate_limit_headers(&mut response, &info);
577                response
578            }
579            Err(err) => rate_limit_exceeded_response(&err),
580        }
581    } else {
582        // No wallet context yet - allow through (signature middleware will handle auth)
583        next.run(req).await
584    }
585}
586
587// Extension trait to extract authenticated user from request
588pub mod auth_ext {
589    use super::{AccountContext, AuthenticatedUser, SignerContext};
590    use crate::error::ApiError;
591    use axum::{async_trait, extract::FromRequestParts, http::request::Parts};
592
593    #[async_trait]
594    impl<S> FromRequestParts<S> for AuthenticatedUser
595    where
596        S: Send + Sync,
597    {
598        type Rejection = ApiError;
599
600        async fn from_request_parts(
601            parts: &mut Parts,
602            _state: &S,
603        ) -> Result<Self, Self::Rejection> {
604            parts
605                .extensions
606                .get::<AuthenticatedUser>()
607                .cloned()
608                .ok_or_else(|| ApiError::unauthorized("Authentication required"))
609        }
610    }
611
612    #[async_trait]
613    impl<S> FromRequestParts<S> for AccountContext
614    where
615        S: Send + Sync,
616    {
617        type Rejection = ApiError;
618
619        async fn from_request_parts(
620            parts: &mut Parts,
621            _state: &S,
622        ) -> Result<Self, Self::Rejection> {
623            parts
624                .extensions
625                .get::<AccountContext>()
626                .cloned()
627                .ok_or_else(|| ApiError::unauthorized("Account context required"))
628        }
629    }
630
631    #[async_trait]
632    impl<S> FromRequestParts<S> for SignerContext
633    where
634        S: Send + Sync,
635    {
636        type Rejection = ApiError;
637
638        async fn from_request_parts(
639            parts: &mut Parts,
640            _state: &S,
641        ) -> Result<Self, Self::Rejection> {
642            parts
643                .extensions
644                .get::<SignerContext>()
645                .cloned()
646                .ok_or_else(|| {
647                    ApiError::unauthorized("Signer context required - valid signature is required")
648                })
649        }
650    }
651}
652
653#[cfg(test)]
654mod tests {
655    use super::*;
656
657    struct TestReadinessGate {
658        ready: bool,
659    }
660
661    impl ReadinessGate for TestReadinessGate {
662        fn reports(&self) -> Vec<crate::models::ReadinessComponentReport> {
663            vec![crate::models::ReadinessComponentReport {
664                name: "engine".to_string(),
665                ready: self.ready,
666                detail: Some(if self.ready { "Ready" } else { "CatchingUp" }.to_string()),
667            }]
668        }
669
670        fn all_ready(&self) -> bool {
671            self.ready
672        }
673    }
674
675    #[derive(Clone, Default)]
676    struct TestStandbyReplayProgress {
677        caught_up: bool,
678    }
679
680    impl StandbyReplayProgress for TestStandbyReplayProgress {
681        fn commands_replayed(&self) -> u64 {
682            0
683        }
684
685        fn is_caught_up(&self) -> bool {
686            self.caught_up
687        }
688
689        fn replay_cursor_seq(&self) -> Option<i64> {
690            None
691        }
692
693        fn latest_stream_seq(&self) -> Option<u64> {
694            None
695        }
696
697        fn stream_lag(&self) -> Option<u64> {
698            None
699        }
700
701        fn last_replayed_seq(&self) -> Option<i64> {
702            None
703        }
704
705        fn last_replay_unix_ms(&self) -> Option<u64> {
706            None
707        }
708    }
709
710    fn not_ready_registry() -> Arc<dyn ReadinessGate> {
711        Arc::new(TestReadinessGate { ready: false })
712    }
713
714    #[test]
715    fn recovery_safety_bypasses_readiness_gate() {
716        assert!(is_readiness_bypass_path("/monitoring/recovery-safety"));
717        assert!(is_readiness_bypass_path("/recovery-safety"));
718        assert!(is_readiness_bypass_path("/recovery-safety/alert"));
719    }
720
721    #[test]
722    fn standby_monitoring_bypass_only_allows_read_only_state_routes() {
723        assert!(is_standby_monitoring_get(
724            &Method::GET,
725            "/monitoring/integrity"
726        ));
727        assert!(is_standby_monitoring_get(
728            &Method::GET,
729            "/monitoring/orderbooks"
730        ));
731        assert!(is_standby_monitoring_get(
732            &Method::GET,
733            "/monitoring/accounts"
734        ));
735        assert!(is_standby_monitoring_get(
736            &Method::GET,
737            "/monitoring/positions"
738        ));
739        assert!(is_standby_monitoring_get(
740            &Method::GET,
741            "/monitoring/directive-outbox"
742        ));
743        assert!(is_standby_monitoring_get(
744            &Method::GET,
745            "/monitoring/engine-state-digest"
746        ));
747        assert!(is_standby_monitoring_get(
748            &Method::GET,
749            "/monitoring/recovery-safety"
750        ));
751        assert!(is_standby_monitoring_get(&Method::GET, "/recovery-safety"));
752        assert!(is_standby_monitoring_get(
753            &Method::GET,
754            "/recovery-safety/alert"
755        ));
756        assert!(!is_standby_monitoring_get(
757            &Method::POST,
758            "/monitoring/integrity"
759        ));
760        assert!(!is_standby_monitoring_get(
761            &Method::GET,
762            "/monitoring/reload_instruments_cache"
763        ));
764        assert!(!is_standby_monitoring_get(
765            &Method::GET,
766            hypercall_types::API_ROUTE_ORDERS
767        ));
768    }
769
770    #[tokio::test]
771    async fn caught_up_unpromoted_standby_can_bypass_monitoring_readiness() {
772        let progress = TestStandbyReplayProgress { caught_up: true };
773        let (promote_tx, _promote_rx) = tokio::sync::oneshot::channel();
774        let state = ReadinessMiddlewareState::new(
775            not_ready_registry(),
776            Some(Arc::new(progress)),
777            Some(Arc::new(Mutex::new(Some(promote_tx)))),
778        );
779
780        assert!(standby_is_caught_up_and_unpromoted(&state).await);
781    }
782
783    #[tokio::test]
784    async fn standby_bypass_is_disabled_before_catchup_and_after_promote() {
785        let progress = TestStandbyReplayProgress { caught_up: false };
786        let (promote_tx, _promote_rx) = tokio::sync::oneshot::channel();
787        let state = ReadinessMiddlewareState::new(
788            not_ready_registry(),
789            Some(Arc::new(progress.clone())),
790            Some(Arc::new(Mutex::new(Some(promote_tx)))),
791        );
792        assert!(!standby_is_caught_up_and_unpromoted(&state).await);
793
794        let promoted_state = ReadinessMiddlewareState::new(
795            not_ready_registry(),
796            Some(Arc::new(TestStandbyReplayProgress { caught_up: true })),
797            None,
798        );
799        assert!(!standby_is_caught_up_and_unpromoted(&promoted_state).await);
800    }
801
802    #[test]
803    fn primary_middleware_state_never_has_standby_bypass_inputs() {
804        let state = ReadinessMiddlewareState::primary(not_ready_registry());
805
806        assert!(state.standby_progress.is_none());
807        assert!(state.standby_promote.is_none());
808    }
809}