Skip to main content

hypercall_runtime_api/
sonic_json.rs

1use async_trait::async_trait;
2use axum::{
3    body::Bytes,
4    extract::{FromRequest, Request},
5    http::{header, StatusCode},
6    response::{IntoResponse, Response},
7};
8use serde::{de::DeserializeOwned, Serialize};
9use std::ops::{Deref, DerefMut};
10
11/// A JSON extractor/response type that uses sonic-rs for (de)serialization.
12///
13/// Drop-in replacement for `axum::Json<T>` backed by sonic-rs instead of serde_json.
14pub struct SonicJson<T>(pub T);
15
16#[async_trait]
17impl<T, S> FromRequest<S> for SonicJson<T>
18where
19    T: DeserializeOwned,
20    S: Send + Sync,
21{
22    type Rejection = Response;
23
24    async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
25        // Validate Content-Type header (matches axum::Json behavior)
26        let content_type = req
27            .headers()
28            .get(header::CONTENT_TYPE)
29            .and_then(|v| v.to_str().ok())
30            .unwrap_or("");
31        if !content_type.starts_with("application/json")
32            && !content_type
33                .strip_prefix("application/")
34                .is_some_and(|rest| rest.contains("+json"))
35        {
36            return Err((
37                StatusCode::UNSUPPORTED_MEDIA_TYPE,
38                [(header::CONTENT_TYPE, "application/json")],
39                b"{\"error\":\"unsupported_media_type\",\"message\":\"Expected Content-Type: application/json\"}".to_vec(),
40            )
41                .into_response());
42        }
43
44        let bytes = Bytes::from_request(req, state)
45            .await
46            .map_err(|err| err.into_response())?;
47
48        sonic_rs::from_slice(&bytes).map(SonicJson).map_err(|err| {
49            let body = sonic_rs::json!({
50                "error": "Invalid JSON",
51                "message": err.to_string(),
52            });
53            let body_bytes = sonic_rs::to_vec(&body)
54                .unwrap_or_else(|_| b"{\"error\":\"Invalid JSON\"}".to_vec());
55            (
56                StatusCode::UNPROCESSABLE_ENTITY,
57                [(header::CONTENT_TYPE, "application/json")],
58                body_bytes,
59            )
60                .into_response()
61        })
62    }
63}
64
65impl<T> IntoResponse for SonicJson<T>
66where
67    T: Serialize,
68{
69    fn into_response(self) -> Response {
70        match sonic_rs::to_vec(&self.0) {
71            Ok(bytes) => (
72                StatusCode::OK,
73                [(header::CONTENT_TYPE, "application/json")],
74                bytes,
75            )
76                .into_response(),
77            Err(err) => {
78                tracing::error!("Failed to serialize response: {}", err);
79                (
80                    StatusCode::INTERNAL_SERVER_ERROR,
81                    [(header::CONTENT_TYPE, "application/json")],
82                    b"{\"error\":\"internal_error\",\"message\":\"Internal server error\"}"
83                        .to_vec(),
84                )
85                    .into_response()
86            }
87        }
88    }
89}
90
91impl<T> Deref for SonicJson<T> {
92    type Target = T;
93
94    fn deref(&self) -> &Self::Target {
95        &self.0
96    }
97}
98
99impl<T> DerefMut for SonicJson<T> {
100    fn deref_mut(&mut self) -> &mut Self::Target {
101        &mut self.0
102    }
103}
104
105#[cfg(test)]
106mod tests {
107    use super::*;
108    use axum::{
109        body::Body,
110        http::Request as HttpRequest,
111        routing::{get, post},
112        Router,
113    };
114    use serde::{Deserialize, Serialize};
115    use tower::ServiceExt;
116
117    #[derive(Debug, Serialize, Deserialize, PartialEq)]
118    struct TestPayload {
119        name: String,
120        value: i64,
121    }
122
123    fn app() -> Router {
124        Router::new()
125            .route(
126                "/echo",
127                post(
128                    |SonicJson(payload): SonicJson<TestPayload>| async move { SonicJson(payload) },
129                ),
130            )
131            .route(
132                "/hello",
133                get(|| async {
134                    SonicJson(TestPayload {
135                        name: "hello".into(),
136                        value: 42,
137                    })
138                }),
139            )
140    }
141
142    async fn body_bytes(response: Response) -> Vec<u8> {
143        axum::body::to_bytes(response.into_body(), usize::MAX)
144            .await
145            .unwrap()
146            .to_vec()
147    }
148
149    #[tokio::test]
150    async fn test_valid_json_roundtrip() {
151        let response = app()
152            .oneshot(
153                HttpRequest::builder()
154                    .method("POST")
155                    .uri("/echo")
156                    .header("content-type", "application/json")
157                    .body(Body::from(r#"{"name":"test","value":123}"#))
158                    .unwrap(),
159            )
160            .await
161            .unwrap();
162
163        assert_eq!(response.status(), StatusCode::OK);
164        let body = body_bytes(response).await;
165        let result: TestPayload = sonic_rs::from_slice(&body).unwrap();
166        assert_eq!(
167            result,
168            TestPayload {
169                name: "test".into(),
170                value: 123
171            }
172        );
173    }
174
175    #[tokio::test]
176    async fn test_malformed_json_returns_422() {
177        let response = app()
178            .oneshot(
179                HttpRequest::builder()
180                    .method("POST")
181                    .uri("/echo")
182                    .header("content-type", "application/json")
183                    .body(Body::from("not valid json"))
184                    .unwrap(),
185            )
186            .await
187            .unwrap();
188
189        assert_eq!(response.status(), StatusCode::UNPROCESSABLE_ENTITY);
190        let body = body_bytes(response).await;
191        let error: sonic_rs::Value = sonic_rs::from_slice(&body).unwrap();
192        use sonic_rs::JsonValueTrait;
193        assert_eq!(
194            error.get("error").unwrap().as_str().unwrap(),
195            "Invalid JSON"
196        );
197    }
198
199    #[tokio::test]
200    async fn test_wrong_content_type_returns_415() {
201        let response = app()
202            .oneshot(
203                HttpRequest::builder()
204                    .method("POST")
205                    .uri("/echo")
206                    .header("content-type", "text/plain")
207                    .body(Body::from(r#"{"name":"test","value":1}"#))
208                    .unwrap(),
209            )
210            .await
211            .unwrap();
212
213        assert_eq!(response.status(), StatusCode::UNSUPPORTED_MEDIA_TYPE);
214    }
215
216    #[tokio::test]
217    async fn test_missing_content_type_returns_415() {
218        let response = app()
219            .oneshot(
220                HttpRequest::builder()
221                    .method("POST")
222                    .uri("/echo")
223                    .body(Body::from(r#"{"name":"test","value":1}"#))
224                    .unwrap(),
225            )
226            .await
227            .unwrap();
228
229        assert_eq!(response.status(), StatusCode::UNSUPPORTED_MEDIA_TYPE);
230    }
231
232    #[tokio::test]
233    async fn test_structured_json_content_type_accepted() {
234        let response = app()
235            .oneshot(
236                HttpRequest::builder()
237                    .method("POST")
238                    .uri("/echo")
239                    .header("content-type", "application/cloudevents+json")
240                    .body(Body::from(r#"{"name":"test","value":1}"#))
241                    .unwrap(),
242            )
243            .await
244            .unwrap();
245
246        assert_eq!(response.status(), StatusCode::OK);
247    }
248
249    #[tokio::test]
250    async fn test_into_response_serializes_correctly() {
251        let response = app()
252            .oneshot(
253                HttpRequest::builder()
254                    .method("GET")
255                    .uri("/hello")
256                    .body(Body::empty())
257                    .unwrap(),
258            )
259            .await
260            .unwrap();
261
262        assert_eq!(response.status(), StatusCode::OK);
263        assert_eq!(
264            response.headers().get("content-type").unwrap(),
265            "application/json"
266        );
267        let body = body_bytes(response).await;
268        let result: TestPayload = sonic_rs::from_slice(&body).unwrap();
269        assert_eq!(
270            result,
271            TestPayload {
272                name: "hello".into(),
273                value: 42
274            }
275        );
276    }
277}