hypercall_runtime_api/
sonic_json.rs1use async_trait::async_trait;
2use axum::{
3 body::Bytes,
4 extract::{FromRequest, Request},
5 http::{header, StatusCode},
6 response::{IntoResponse, Response},
7};
8use serde::{de::DeserializeOwned, Serialize};
9use std::ops::{Deref, DerefMut};
10
11pub struct SonicJson<T>(pub T);
15
16#[async_trait]
17impl<T, S> FromRequest<S> for SonicJson<T>
18where
19 T: DeserializeOwned,
20 S: Send + Sync,
21{
22 type Rejection = Response;
23
24 async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
25 let content_type = req
27 .headers()
28 .get(header::CONTENT_TYPE)
29 .and_then(|v| v.to_str().ok())
30 .unwrap_or("");
31 if !content_type.starts_with("application/json")
32 && !content_type
33 .strip_prefix("application/")
34 .is_some_and(|rest| rest.contains("+json"))
35 {
36 return Err((
37 StatusCode::UNSUPPORTED_MEDIA_TYPE,
38 [(header::CONTENT_TYPE, "application/json")],
39 b"{\"error\":\"unsupported_media_type\",\"message\":\"Expected Content-Type: application/json\"}".to_vec(),
40 )
41 .into_response());
42 }
43
44 let bytes = Bytes::from_request(req, state)
45 .await
46 .map_err(|err| err.into_response())?;
47
48 sonic_rs::from_slice(&bytes).map(SonicJson).map_err(|err| {
49 let body = sonic_rs::json!({
50 "error": "Invalid JSON",
51 "message": err.to_string(),
52 });
53 let body_bytes = sonic_rs::to_vec(&body)
54 .unwrap_or_else(|_| b"{\"error\":\"Invalid JSON\"}".to_vec());
55 (
56 StatusCode::UNPROCESSABLE_ENTITY,
57 [(header::CONTENT_TYPE, "application/json")],
58 body_bytes,
59 )
60 .into_response()
61 })
62 }
63}
64
65impl<T> IntoResponse for SonicJson<T>
66where
67 T: Serialize,
68{
69 fn into_response(self) -> Response {
70 match sonic_rs::to_vec(&self.0) {
71 Ok(bytes) => (
72 StatusCode::OK,
73 [(header::CONTENT_TYPE, "application/json")],
74 bytes,
75 )
76 .into_response(),
77 Err(err) => {
78 tracing::error!("Failed to serialize response: {}", err);
79 (
80 StatusCode::INTERNAL_SERVER_ERROR,
81 [(header::CONTENT_TYPE, "application/json")],
82 b"{\"error\":\"internal_error\",\"message\":\"Internal server error\"}"
83 .to_vec(),
84 )
85 .into_response()
86 }
87 }
88 }
89}
90
91impl<T> Deref for SonicJson<T> {
92 type Target = T;
93
94 fn deref(&self) -> &Self::Target {
95 &self.0
96 }
97}
98
99impl<T> DerefMut for SonicJson<T> {
100 fn deref_mut(&mut self) -> &mut Self::Target {
101 &mut self.0
102 }
103}
104
105#[cfg(test)]
106mod tests {
107 use super::*;
108 use axum::{
109 body::Body,
110 http::Request as HttpRequest,
111 routing::{get, post},
112 Router,
113 };
114 use serde::{Deserialize, Serialize};
115 use tower::ServiceExt;
116
117 #[derive(Debug, Serialize, Deserialize, PartialEq)]
118 struct TestPayload {
119 name: String,
120 value: i64,
121 }
122
123 fn app() -> Router {
124 Router::new()
125 .route(
126 "/echo",
127 post(
128 |SonicJson(payload): SonicJson<TestPayload>| async move { SonicJson(payload) },
129 ),
130 )
131 .route(
132 "/hello",
133 get(|| async {
134 SonicJson(TestPayload {
135 name: "hello".into(),
136 value: 42,
137 })
138 }),
139 )
140 }
141
142 async fn body_bytes(response: Response) -> Vec<u8> {
143 axum::body::to_bytes(response.into_body(), usize::MAX)
144 .await
145 .unwrap()
146 .to_vec()
147 }
148
149 #[tokio::test]
150 async fn test_valid_json_roundtrip() {
151 let response = app()
152 .oneshot(
153 HttpRequest::builder()
154 .method("POST")
155 .uri("/echo")
156 .header("content-type", "application/json")
157 .body(Body::from(r#"{"name":"test","value":123}"#))
158 .unwrap(),
159 )
160 .await
161 .unwrap();
162
163 assert_eq!(response.status(), StatusCode::OK);
164 let body = body_bytes(response).await;
165 let result: TestPayload = sonic_rs::from_slice(&body).unwrap();
166 assert_eq!(
167 result,
168 TestPayload {
169 name: "test".into(),
170 value: 123
171 }
172 );
173 }
174
175 #[tokio::test]
176 async fn test_malformed_json_returns_422() {
177 let response = app()
178 .oneshot(
179 HttpRequest::builder()
180 .method("POST")
181 .uri("/echo")
182 .header("content-type", "application/json")
183 .body(Body::from("not valid json"))
184 .unwrap(),
185 )
186 .await
187 .unwrap();
188
189 assert_eq!(response.status(), StatusCode::UNPROCESSABLE_ENTITY);
190 let body = body_bytes(response).await;
191 let error: sonic_rs::Value = sonic_rs::from_slice(&body).unwrap();
192 use sonic_rs::JsonValueTrait;
193 assert_eq!(
194 error.get("error").unwrap().as_str().unwrap(),
195 "Invalid JSON"
196 );
197 }
198
199 #[tokio::test]
200 async fn test_wrong_content_type_returns_415() {
201 let response = app()
202 .oneshot(
203 HttpRequest::builder()
204 .method("POST")
205 .uri("/echo")
206 .header("content-type", "text/plain")
207 .body(Body::from(r#"{"name":"test","value":1}"#))
208 .unwrap(),
209 )
210 .await
211 .unwrap();
212
213 assert_eq!(response.status(), StatusCode::UNSUPPORTED_MEDIA_TYPE);
214 }
215
216 #[tokio::test]
217 async fn test_missing_content_type_returns_415() {
218 let response = app()
219 .oneshot(
220 HttpRequest::builder()
221 .method("POST")
222 .uri("/echo")
223 .body(Body::from(r#"{"name":"test","value":1}"#))
224 .unwrap(),
225 )
226 .await
227 .unwrap();
228
229 assert_eq!(response.status(), StatusCode::UNSUPPORTED_MEDIA_TYPE);
230 }
231
232 #[tokio::test]
233 async fn test_structured_json_content_type_accepted() {
234 let response = app()
235 .oneshot(
236 HttpRequest::builder()
237 .method("POST")
238 .uri("/echo")
239 .header("content-type", "application/cloudevents+json")
240 .body(Body::from(r#"{"name":"test","value":1}"#))
241 .unwrap(),
242 )
243 .await
244 .unwrap();
245
246 assert_eq!(response.status(), StatusCode::OK);
247 }
248
249 #[tokio::test]
250 async fn test_into_response_serializes_correctly() {
251 let response = app()
252 .oneshot(
253 HttpRequest::builder()
254 .method("GET")
255 .uri("/hello")
256 .body(Body::empty())
257 .unwrap(),
258 )
259 .await
260 .unwrap();
261
262 assert_eq!(response.status(), StatusCode::OK);
263 assert_eq!(
264 response.headers().get("content-type").unwrap(),
265 "application/json"
266 );
267 let body = body_bytes(response).await;
268 let result: TestPayload = sonic_rs::from_slice(&body).unwrap();
269 assert_eq!(
270 result,
271 TestPayload {
272 name: "hello".into(),
273 value: 42
274 }
275 );
276 }
277}