1use axum::{
2 extract::{Request, State},
3 http::{Method, StatusCode},
4 middleware::Next,
5 response::{IntoResponse, Response},
6};
7use sonic_rs::json;
8use std::sync::Arc;
9use tokio::sync::Mutex;
10
11use crate::sonic_json::SonicJson;
12
13use crate::caches::rate_limit::{RateLimitAction, RateLimitCache, RateLimitError, RateLimitInfo};
14use crate::observability_boundary::AuthFailureRecorder;
15use crate::request_auth::{authorize_signer, RequestAuthError};
16use crate::runtime_status::{ReadinessGate, StandbyReplayProgress};
17use crate::signed_actions::recover_signed_action;
18use hypercall_types::{WalletAddress, API_ROUTE_PROFILE_IMAGE};
19use std::str::FromStr;
20
21const SIGNED_JSON_BODY_LIMIT_BYTES: usize = 1_048_576;
22const PROFILE_IMAGE_BODY_LIMIT_BYTES: usize = 7_000_000;
23
24#[derive(Clone)]
25pub struct SignatureMiddlewareState {
26 pub agent_auth: Arc<dyn hypercall_runtime_api::AgentAuthProvider>,
27 pub auth_failure_recorder: Arc<dyn AuthFailureRecorder>,
28 pub signing_chain_id: u64,
29}
30
31#[cfg(feature = "otel-tracing")]
32use opentelemetry::propagation::TextMapPropagator;
33#[cfg(feature = "otel-tracing")]
34use opentelemetry_sdk::propagation::TraceContextPropagator;
35#[cfg(feature = "otel-tracing")]
36use tracing_opentelemetry::OpenTelemetrySpanExt;
37
38fn json_error_response(status: StatusCode, error: &str, message: &str) -> Response {
40 let body = json!({
41 "error": error,
42 "message": message
43 });
44 (status, SonicJson(body)).into_response()
45}
46
47#[cfg(feature = "otel-tracing")]
49struct HeaderExtractor<'a>(&'a axum::http::HeaderMap);
50
51#[cfg(feature = "otel-tracing")]
52impl opentelemetry::propagation::Extractor for HeaderExtractor<'_> {
53 fn get(&self, key: &str) -> Option<&str> {
54 self.0.get(key).and_then(|v| v.to_str().ok())
55 }
56
57 fn keys(&self) -> Vec<&str> {
58 self.0.keys().map(|k| k.as_str()).collect()
59 }
60}
61
62#[cfg(feature = "otel-tracing")]
71pub async fn trace_context_middleware(req: Request, next: Next) -> Response {
72 let propagator = TraceContextPropagator::new();
73 let parent_cx = propagator.extract(&HeaderExtractor(req.headers()));
74
75 tracing::Span::current().set_parent(parent_cx);
77
78 next.run(req).await
79}
80
81#[derive(Clone)]
89pub struct ReadinessMiddlewareState {
90 pub readiness: Arc<dyn ReadinessGate>,
91 pub standby_progress: Option<Arc<dyn StandbyReplayProgress>>,
92 pub standby_promote: Option<Arc<Mutex<Option<tokio::sync::oneshot::Sender<()>>>>>,
93}
94
95impl ReadinessMiddlewareState {
96 pub fn primary(readiness: Arc<dyn ReadinessGate>) -> Self {
97 Self {
98 readiness,
99 standby_progress: None,
100 standby_promote: None,
101 }
102 }
103
104 pub fn new(
105 readiness: Arc<dyn ReadinessGate>,
106 standby_progress: Option<Arc<dyn StandbyReplayProgress>>,
107 standby_promote: Option<Arc<Mutex<Option<tokio::sync::oneshot::Sender<()>>>>>,
108 ) -> Self {
109 Self {
110 readiness,
111 standby_progress,
112 standby_promote,
113 }
114 }
115}
116
117pub async fn readiness_middleware(
118 State(state): State<ReadinessMiddlewareState>,
119 req: Request,
120 next: Next,
121) -> Response {
122 let path = req.uri().path();
123
124 if is_readiness_bypass_path(path) {
125 return next.run(req).await;
126 }
127
128 if !state.readiness.all_ready() {
130 if is_standby_monitoring_get(req.method(), path)
131 && standby_is_caught_up_and_unpromoted(&state).await
132 {
133 return next.run(req).await;
134 }
135
136 let components = state.readiness.reports();
137 return (
138 StatusCode::SERVICE_UNAVAILABLE,
139 SonicJson(json!({
140 "error": "service_not_ready",
141 "message": "Service is starting up. Please retry shortly.",
142 "components": components
143 })),
144 )
145 .into_response();
146 }
147
148 next.run(req).await
149}
150
151fn is_readiness_bypass_path(path: &str) -> bool {
152 matches!(
153 path,
154 "/health"
155 | "/ready"
156 | "/metrics"
157 | "/standby-ready"
158 | "/admin/promote"
159 | "/admin/drain"
160 | "/admin/undrain"
161 | "/drain-status"
162 | "/recovery-safety"
163 | "/recovery-safety/alert"
164 | "/monitoring/recovery-safety"
165 )
166}
167
168fn is_standby_monitoring_get(method: &Method, path: &str) -> bool {
169 if method != Method::GET {
170 return false;
171 }
172
173 matches!(
174 path,
175 "/monitoring/integrity"
176 | "/monitoring/orderbooks"
177 | "/monitoring/accounts"
178 | "/monitoring/positions"
179 | "/monitoring/directive-outbox"
180 | "/monitoring/deposits"
181 | "/monitoring/engine-state-digest"
182 | "/recovery-safety"
183 | "/recovery-safety/alert"
184 | "/monitoring/recovery-safety"
185 | "/monitoring/trading-halts"
186 | "/monitoring/vol-oracles"
187 | "/monitoring/price-oracles"
188 | "/monitoring/vol-surface"
189 )
190}
191
192async fn standby_is_caught_up_and_unpromoted(state: &ReadinessMiddlewareState) -> bool {
193 let Some(progress) = &state.standby_progress else {
194 return false;
195 };
196 if !progress.is_caught_up() {
197 return false;
198 }
199
200 match &state.standby_promote {
201 Some(promote) => promote.lock().await.is_some(),
202 None => false,
203 }
204}
205
206#[derive(Clone)]
207pub struct AuthenticatedUser {
208 pub wallet_address: WalletAddress,
209}
210
211#[derive(Clone)]
212pub struct AccountContext {
213 pub account_contract_address: WalletAddress,
214 pub api_wallet_address: WalletAddress,
215}
216
217#[derive(Clone, Debug)]
218pub struct SignerContext {
219 pub wallet_address: WalletAddress, pub signer_address: WalletAddress, }
222
223pub async fn signature_and_agent_middleware(
225 State(state): State<SignatureMiddlewareState>,
226 req: Request,
227 next: Next,
228) -> Response {
229 match signature_and_agent_middleware_inner(state, req, next).await {
230 Ok(response) => response,
231 Err(response) => response,
232 }
233}
234
235async fn signature_and_agent_middleware_inner(
236 state: SignatureMiddlewareState,
237 req: Request,
238 next: Next,
239) -> Result<Response, Response> {
240 let (mut parts, body) = req.into_parts();
242 let body_limit = match (parts.uri.path(), &parts.method) {
243 (API_ROUTE_PROFILE_IMAGE, &Method::POST) => PROFILE_IMAGE_BODY_LIMIT_BYTES,
244 _ => SIGNED_JSON_BODY_LIMIT_BYTES,
245 };
246
247 let body_bytes = match axum::body::to_bytes(body, body_limit).await {
248 Ok(bytes) => bytes,
249 Err(_) => {
250 return Err(json_error_response(
251 StatusCode::PAYLOAD_TOO_LARGE,
252 "payload_too_large",
253 "Request body exceeds the allowed size",
254 ));
255 }
256 };
257
258 let request_data: sonic_rs::Value = match sonic_rs::from_slice(&body_bytes) {
260 Ok(json) => json,
261 Err(e) => {
262 tracing::error!("Failed to parse request body: {}", e);
263 return Err(json_error_response(
264 StatusCode::BAD_REQUEST,
265 "invalid_json",
266 &format!("Failed to parse JSON: {}", e),
267 ));
268 }
269 };
270
271 let recovered = recover_signed_action(
272 parts.uri.path(),
273 &parts.method,
274 &request_data,
275 state.signing_chain_id,
276 )
277 .map_err(|e| {
278 if let Some(reason) = e.auth_failure_reason {
279 state.auth_failure_recorder.record_auth_failure(reason);
280 }
281 tracing::error!(
282 error = e.error,
283 message = %e.message,
284 "Failed to recover signed action"
285 );
286 json_error_response(e.status, e.error, &e.message)
287 })?;
288
289 let signer_ctx = authorize_signer(
290 state.agent_auth.as_ref(),
291 recovered.wallet,
292 recovered.signer,
293 )
294 .map_err(|e| match e {
295 RequestAuthError::Unauthorized { wallet, signer } => {
296 tracing::warn!(
297 "Signer {} is not authorized to act for wallet {}",
298 signer,
299 wallet
300 );
301 json_error_response(StatusCode::UNAUTHORIZED, "unauthorized", &e.to_string())
302 }
303 RequestAuthError::Signature(_) => json_error_response(
304 StatusCode::BAD_REQUEST,
305 "signature_verification_failed",
306 &e.to_string(),
307 ),
308 })?;
309
310 parts.extensions.insert(signer_ctx);
312
313 let req = Request::from_parts(parts, axum::body::Body::from(body_bytes));
315
316 Ok(next.run(req).await)
317}
318
319fn rate_limit_exceeded_response(err: &RateLimitError) -> Response {
321 match err {
322 RateLimitError::Exceeded {
323 retry_after_secs,
324 limit,
325 ..
326 } => {
327 let body = json!({
328 "error": "rate_limit_exceeded",
329 "message": format!("{}", err),
330 "retry_after_secs": retry_after_secs,
331 "limit": limit,
332 });
333
334 let mut response = (StatusCode::TOO_MANY_REQUESTS, SonicJson(body)).into_response();
335
336 response
338 .headers_mut()
339 .insert("Retry-After", retry_after_secs.to_string().parse().unwrap());
340 response
341 .headers_mut()
342 .insert("X-RateLimit-Limit", limit.to_string().parse().unwrap());
343 response
344 .headers_mut()
345 .insert("X-RateLimit-Remaining", "0".parse().unwrap());
346
347 response
348 }
349 RateLimitError::ServiceUnavailable { message } => {
350 let body = json!({
351 "error": "service_unavailable",
352 "message": message,
353 });
354
355 (StatusCode::SERVICE_UNAVAILABLE, SonicJson(body)).into_response()
356 }
357 }
358}
359
360fn add_rate_limit_headers(response: &mut Response, info: &RateLimitInfo) {
362 response
363 .headers_mut()
364 .insert("X-RateLimit-Limit", info.limit.to_string().parse().unwrap());
365 response.headers_mut().insert(
366 "X-RateLimit-Remaining",
367 info.remaining.to_string().parse().unwrap(),
368 );
369 response.headers_mut().insert(
370 "X-RateLimit-Reset",
371 info.reset_at.to_string().parse().unwrap(),
372 );
373}
374
375fn extract_wallet(req: &Request) -> Option<WalletAddress> {
377 req.extensions()
378 .get::<SignerContext>()
379 .map(|ctx| ctx.wallet_address)
380}
381
382#[derive(Clone)]
384pub struct RateLimitState {
385 pub rate_limiter: Arc<RateLimitCache>,
386}
387
388pub async fn order_rate_limit_middleware(
392 State(state): State<RateLimitState>,
393 req: Request,
394 next: Next,
395) -> Response {
396 let wallet = req
398 .extensions()
399 .get::<SignerContext>()
400 .map(|ctx| ctx.wallet_address);
401
402 if let Some(wallet) = wallet {
403 match state
405 .rate_limiter
406 .check_and_increment(&wallet, RateLimitAction::OrderPlacement)
407 .await
408 {
409 Ok(info) => {
410 let mut response = next.run(req).await;
411 response
413 .headers_mut()
414 .insert("X-RateLimit-Limit", info.limit.to_string().parse().unwrap());
415 response.headers_mut().insert(
416 "X-RateLimit-Remaining",
417 info.remaining.to_string().parse().unwrap(),
418 );
419 response.headers_mut().insert(
420 "X-RateLimit-Reset",
421 info.reset_at.to_string().parse().unwrap(),
422 );
423 response
424 }
425 Err(err) => rate_limit_exceeded_response(&err),
426 }
427 } else {
428 next.run(req).await
430 }
431}
432
433pub async fn cancel_rate_limit_middleware(
437 State(state): State<RateLimitState>,
438 req: Request,
439 next: Next,
440) -> Response {
441 let wallet = req
443 .extensions()
444 .get::<SignerContext>()
445 .map(|ctx| ctx.wallet_address);
446
447 if let Some(wallet) = wallet {
448 match state
449 .rate_limiter
450 .check_and_increment(&wallet, RateLimitAction::OrderCancellation)
451 .await
452 {
453 Ok(info) => {
454 let mut response = next.run(req).await;
455 response
456 .headers_mut()
457 .insert("X-RateLimit-Limit", info.limit.to_string().parse().unwrap());
458 response.headers_mut().insert(
459 "X-RateLimit-Remaining",
460 info.remaining.to_string().parse().unwrap(),
461 );
462 response.headers_mut().insert(
463 "X-RateLimit-Reset",
464 info.reset_at.to_string().parse().unwrap(),
465 );
466 response
467 }
468 Err(err) => rate_limit_exceeded_response(&err),
469 }
470 } else {
471 next.run(req).await
472 }
473}
474
475pub async fn api_rate_limit_middleware(
481 State(state): State<RateLimitState>,
482 req: Request,
483 next: Next,
484) -> Response {
485 let wallet = req
487 .extensions()
488 .get::<SignerContext>()
489 .map(|ctx| ctx.wallet_address);
490
491 let wallet = wallet.or_else(|| {
493 req.uri()
494 .query()
495 .and_then(|q| {
496 q.split('&')
498 .find(|p| p.starts_with("wallet="))
499 .map(|p| &p[7..]) })
501 .and_then(|w| WalletAddress::from_str(w).ok())
502 });
503
504 if let Some(wallet) = wallet {
505 match state
506 .rate_limiter
507 .check_and_increment(&wallet, RateLimitAction::ApiRequest)
508 .await
509 {
510 Ok(info) => {
511 let mut response = next.run(req).await;
512 response
513 .headers_mut()
514 .insert("X-RateLimit-Limit", info.limit.to_string().parse().unwrap());
515 response.headers_mut().insert(
516 "X-RateLimit-Remaining",
517 info.remaining.to_string().parse().unwrap(),
518 );
519 response.headers_mut().insert(
520 "X-RateLimit-Reset",
521 info.reset_at.to_string().parse().unwrap(),
522 );
523 response
524 }
525 Err(err) => rate_limit_exceeded_response(&err),
526 }
527 } else {
528 next.run(req).await
531 }
532}
533
534pub async fn write_route_rate_limit_middleware(
544 State(state): State<RateLimitState>,
545 req: Request,
546 next: Next,
547) -> Response {
548 use axum::http::Method;
549 use hypercall_types::{API_ROUTE_ORDER, API_ROUTE_ORDER_CLOID, API_ROUTE_RFQ_REQUEST};
550
551 let action = match (req.method(), req.uri().path()) {
553 (&Method::POST, API_ROUTE_ORDER) | (&Method::PUT, API_ROUTE_ORDER) => {
554 RateLimitAction::OrderPlacement
555 }
556 (&Method::DELETE, API_ROUTE_ORDER) | (&Method::DELETE, API_ROUTE_ORDER_CLOID) => {
557 RateLimitAction::OrderCancellation
558 }
559 (&Method::POST, API_ROUTE_RFQ_REQUEST) => RateLimitAction::RfqSubmit,
561 _ => RateLimitAction::ApiRequest,
563 };
564
565 let wallet = extract_wallet(&req);
567
568 if let Some(wallet) = wallet {
569 match state
570 .rate_limiter
571 .check_and_increment(&wallet, action)
572 .await
573 {
574 Ok(info) => {
575 let mut response = next.run(req).await;
576 add_rate_limit_headers(&mut response, &info);
577 response
578 }
579 Err(err) => rate_limit_exceeded_response(&err),
580 }
581 } else {
582 next.run(req).await
584 }
585}
586
587pub mod auth_ext {
589 use super::{AccountContext, AuthenticatedUser, SignerContext};
590 use crate::error::ApiError;
591 use axum::{async_trait, extract::FromRequestParts, http::request::Parts};
592
593 #[async_trait]
594 impl<S> FromRequestParts<S> for AuthenticatedUser
595 where
596 S: Send + Sync,
597 {
598 type Rejection = ApiError;
599
600 async fn from_request_parts(
601 parts: &mut Parts,
602 _state: &S,
603 ) -> Result<Self, Self::Rejection> {
604 parts
605 .extensions
606 .get::<AuthenticatedUser>()
607 .cloned()
608 .ok_or_else(|| ApiError::unauthorized("Authentication required"))
609 }
610 }
611
612 #[async_trait]
613 impl<S> FromRequestParts<S> for AccountContext
614 where
615 S: Send + Sync,
616 {
617 type Rejection = ApiError;
618
619 async fn from_request_parts(
620 parts: &mut Parts,
621 _state: &S,
622 ) -> Result<Self, Self::Rejection> {
623 parts
624 .extensions
625 .get::<AccountContext>()
626 .cloned()
627 .ok_or_else(|| ApiError::unauthorized("Account context required"))
628 }
629 }
630
631 #[async_trait]
632 impl<S> FromRequestParts<S> for SignerContext
633 where
634 S: Send + Sync,
635 {
636 type Rejection = ApiError;
637
638 async fn from_request_parts(
639 parts: &mut Parts,
640 _state: &S,
641 ) -> Result<Self, Self::Rejection> {
642 parts
643 .extensions
644 .get::<SignerContext>()
645 .cloned()
646 .ok_or_else(|| {
647 ApiError::unauthorized("Signer context required - valid signature is required")
648 })
649 }
650 }
651}
652
653#[cfg(test)]
654mod tests {
655 use super::*;
656
657 struct TestReadinessGate {
658 ready: bool,
659 }
660
661 impl ReadinessGate for TestReadinessGate {
662 fn reports(&self) -> Vec<crate::models::ReadinessComponentReport> {
663 vec![crate::models::ReadinessComponentReport {
664 name: "engine".to_string(),
665 ready: self.ready,
666 detail: Some(if self.ready { "Ready" } else { "CatchingUp" }.to_string()),
667 }]
668 }
669
670 fn all_ready(&self) -> bool {
671 self.ready
672 }
673 }
674
675 #[derive(Clone, Default)]
676 struct TestStandbyReplayProgress {
677 caught_up: bool,
678 }
679
680 impl StandbyReplayProgress for TestStandbyReplayProgress {
681 fn commands_replayed(&self) -> u64 {
682 0
683 }
684
685 fn is_caught_up(&self) -> bool {
686 self.caught_up
687 }
688
689 fn replay_cursor_seq(&self) -> Option<i64> {
690 None
691 }
692
693 fn latest_stream_seq(&self) -> Option<u64> {
694 None
695 }
696
697 fn stream_lag(&self) -> Option<u64> {
698 None
699 }
700
701 fn last_replayed_seq(&self) -> Option<i64> {
702 None
703 }
704
705 fn last_replay_unix_ms(&self) -> Option<u64> {
706 None
707 }
708 }
709
710 fn not_ready_registry() -> Arc<dyn ReadinessGate> {
711 Arc::new(TestReadinessGate { ready: false })
712 }
713
714 #[test]
715 fn recovery_safety_bypasses_readiness_gate() {
716 assert!(is_readiness_bypass_path("/monitoring/recovery-safety"));
717 assert!(is_readiness_bypass_path("/recovery-safety"));
718 assert!(is_readiness_bypass_path("/recovery-safety/alert"));
719 }
720
721 #[test]
722 fn standby_monitoring_bypass_only_allows_read_only_state_routes() {
723 assert!(is_standby_monitoring_get(
724 &Method::GET,
725 "/monitoring/integrity"
726 ));
727 assert!(is_standby_monitoring_get(
728 &Method::GET,
729 "/monitoring/orderbooks"
730 ));
731 assert!(is_standby_monitoring_get(
732 &Method::GET,
733 "/monitoring/accounts"
734 ));
735 assert!(is_standby_monitoring_get(
736 &Method::GET,
737 "/monitoring/positions"
738 ));
739 assert!(is_standby_monitoring_get(
740 &Method::GET,
741 "/monitoring/directive-outbox"
742 ));
743 assert!(is_standby_monitoring_get(
744 &Method::GET,
745 "/monitoring/engine-state-digest"
746 ));
747 assert!(is_standby_monitoring_get(
748 &Method::GET,
749 "/monitoring/recovery-safety"
750 ));
751 assert!(is_standby_monitoring_get(&Method::GET, "/recovery-safety"));
752 assert!(is_standby_monitoring_get(
753 &Method::GET,
754 "/recovery-safety/alert"
755 ));
756 assert!(!is_standby_monitoring_get(
757 &Method::POST,
758 "/monitoring/integrity"
759 ));
760 assert!(!is_standby_monitoring_get(
761 &Method::GET,
762 "/monitoring/reload_instruments_cache"
763 ));
764 assert!(!is_standby_monitoring_get(
765 &Method::GET,
766 hypercall_types::API_ROUTE_ORDERS
767 ));
768 }
769
770 #[tokio::test]
771 async fn caught_up_unpromoted_standby_can_bypass_monitoring_readiness() {
772 let progress = TestStandbyReplayProgress { caught_up: true };
773 let (promote_tx, _promote_rx) = tokio::sync::oneshot::channel();
774 let state = ReadinessMiddlewareState::new(
775 not_ready_registry(),
776 Some(Arc::new(progress)),
777 Some(Arc::new(Mutex::new(Some(promote_tx)))),
778 );
779
780 assert!(standby_is_caught_up_and_unpromoted(&state).await);
781 }
782
783 #[tokio::test]
784 async fn standby_bypass_is_disabled_before_catchup_and_after_promote() {
785 let progress = TestStandbyReplayProgress { caught_up: false };
786 let (promote_tx, _promote_rx) = tokio::sync::oneshot::channel();
787 let state = ReadinessMiddlewareState::new(
788 not_ready_registry(),
789 Some(Arc::new(progress.clone())),
790 Some(Arc::new(Mutex::new(Some(promote_tx)))),
791 );
792 assert!(!standby_is_caught_up_and_unpromoted(&state).await);
793
794 let promoted_state = ReadinessMiddlewareState::new(
795 not_ready_registry(),
796 Some(Arc::new(TestStandbyReplayProgress { caught_up: true })),
797 None,
798 );
799 assert!(!standby_is_caught_up_and_unpromoted(&promoted_state).await);
800 }
801
802 #[test]
803 fn primary_middleware_state_never_has_standby_bypass_inputs() {
804 let state = ReadinessMiddlewareState::primary(not_ready_registry());
805
806 assert!(state.standby_progress.is_none());
807 assert!(state.standby_promote.is_none());
808 }
809}