1use std::sync::Arc;
2
3use anyhow::{Context, Result};
4use arc_swap::ArcSwap;
5use tracing::{error, info};
6
7#[derive(Debug, Clone, PartialEq, Eq)]
9pub struct ParsedDbUrl {
10 pub hostname: String,
11 pub port: u16,
12 pub username: String,
13 pub dbname: String,
14 pub query_params: String,
15}
16
17impl ParsedDbUrl {
18 pub fn parse(database_url: &str) -> Result<Self> {
19 let parsed = url::Url::parse(database_url)
20 .with_context(|| "DATABASE_URL is not a valid URL for RDS IAM parsing")?;
21
22 let hostname = parsed
23 .host_str()
24 .context("DATABASE_URL missing host")?
25 .to_string();
26 let port = parsed.port().unwrap_or(5432);
27 let username = if parsed.username().is_empty() {
28 anyhow::bail!("DATABASE_URL missing username (required for RDS IAM)");
29 } else {
30 parsed.username().to_string()
31 };
32 let dbname = parsed.path().trim_start_matches('/').to_string();
33 if dbname.is_empty() {
34 anyhow::bail!("DATABASE_URL missing database name");
35 }
36 let query_params = parsed.query().map(|q| format!("?{q}")).unwrap_or_default();
37
38 Ok(Self {
39 hostname,
40 port,
41 username,
42 dbname,
43 query_params,
44 })
45 }
46
47 pub fn url_with_token(&self, token: &str) -> Result<String> {
49 let encoded = url::form_urlencoded::byte_serialize(token.as_bytes()).collect::<String>();
55 Ok(format!(
56 "postgres://{}:{}@{}:{}/{}{}",
57 self.username, encoded, self.hostname, self.port, self.dbname, self.query_params,
58 ))
59 }
60}
61
62#[async_trait::async_trait]
65pub(crate) trait TokenGenerator: Send + Sync {
66 async fn generate(&self, parsed: &ParsedDbUrl) -> Result<String>;
67}
68
69#[cfg(feature = "rds-iam")]
71struct AwsTokenGenerator {
72 aws_config: aws_config::SdkConfig,
73}
74
75#[cfg(feature = "rds-iam")]
76#[async_trait::async_trait]
77impl TokenGenerator for AwsTokenGenerator {
78 async fn generate(&self, parsed: &ParsedDbUrl) -> Result<String> {
79 let config = aws_sdk_rds::auth_token::Config::builder()
80 .hostname(&parsed.hostname)
81 .port(parsed.port as u64)
82 .username(&parsed.username)
83 .build()
84 .map_err(|e| anyhow::anyhow!("failed to build IAM auth token config: {e}"))?;
85 let token = aws_sdk_rds::auth_token::AuthTokenGenerator::new(config)
86 .auth_token(&self.aws_config)
87 .await
88 .map_err(|e| anyhow::anyhow!("failed to generate RDS IAM auth token: {e}"))?;
89 Ok(token.to_string())
90 }
91}
92
93pub(crate) const REFRESH_INTERVAL_SECS: u64 = 600;
95pub(crate) const RETRY_DELAYS_SECS: [u64; 3] = [30, 60, 120];
96
97pub struct RdsIamTokenProvider {
103 pub(crate) parsed: ParsedDbUrl,
104 pub(crate) cached_url: Arc<ArcSwap<String>>,
105 generator: Box<dyn TokenGenerator>,
106}
107
108impl RdsIamTokenProvider {
109 #[cfg(feature = "rds-iam")]
113 pub async fn new(database_url: &str) -> Result<Arc<Self>> {
114 let parsed = ParsedDbUrl::parse(database_url)?;
115 let aws_config = aws_config::load_defaults(aws_config::BehaviorVersion::latest()).await;
116
117 let provider = Arc::new(Self {
118 parsed,
119 cached_url: Arc::new(ArcSwap::from_pointee(String::new())),
120 generator: Box::new(AwsTokenGenerator { aws_config }),
121 });
122
123 provider
124 .refresh()
125 .await
126 .context("failed to generate initial RDS IAM token")?;
127
128 Ok(provider)
129 }
130
131 pub(crate) fn with_generator(
135 parsed: ParsedDbUrl,
136 generator: Box<dyn TokenGenerator>,
137 ) -> Arc<Self> {
138 Arc::new(Self {
139 parsed,
140 cached_url: Arc::new(ArcSwap::from_pointee(String::new())),
141 generator,
142 })
143 }
144
145 pub fn current_url(&self) -> Arc<String> {
147 self.cached_url.load_full()
148 }
149
150 pub fn store_url(&self, url: String) {
152 self.cached_url.store(Arc::new(url));
153 }
154
155 pub async fn refresh(&self) -> Result<()> {
157 let token = self.generator.generate(&self.parsed).await?;
158 let url = self.parsed.url_with_token(&token)?;
159 self.cached_url.store(Arc::new(url));
160 info!(host = %self.parsed.hostname, "refreshed RDS IAM auth token");
161 Ok(())
162 }
163
164 pub fn spawn_refresh_task(self: &Arc<Self>) -> tokio::task::JoinHandle<()> {
168 let provider = Arc::clone(self);
169 tokio::spawn(async move {
170 let normal_interval = std::time::Duration::from_secs(REFRESH_INTERVAL_SECS);
171 let retry_delays: Vec<std::time::Duration> = RETRY_DELAYS_SECS
172 .iter()
173 .map(|s| std::time::Duration::from_secs(*s))
174 .collect();
175
176 loop {
177 tokio::time::sleep(normal_interval).await;
178 if provider.refresh().await.is_ok() {
179 continue;
180 }
181 error!("RDS IAM token refresh failed, starting retry backoff");
182 let mut recovered = false;
183 for delay in &retry_delays {
184 tokio::time::sleep(*delay).await;
185 match provider.refresh().await {
186 Ok(()) => {
187 recovered = true;
188 break;
189 }
190 Err(e) => error!("RDS IAM retry failed: {e:#}"),
191 }
192 }
193 if !recovered {
194 error!("RDS IAM token refresh exhausted retries, resuming normal cadence");
195 }
196 }
197 })
198 }
199}
200
201#[cfg(test)]
203pub(crate) mod test_support {
204 use super::*;
205 use std::collections::VecDeque;
206 use std::sync::Mutex;
207
208 pub(crate) struct MockTokenGenerator {
210 pub(crate) results: Arc<Mutex<VecDeque<Result<String>>>>,
211 pub(crate) call_count: Arc<std::sync::atomic::AtomicUsize>,
212 }
213
214 impl MockTokenGenerator {
215 pub(crate) fn new(results: Vec<Result<String>>) -> Self {
216 Self {
217 results: Arc::new(Mutex::new(VecDeque::from(results))),
218 call_count: Arc::new(std::sync::atomic::AtomicUsize::new(0)),
219 }
220 }
221 }
222
223 #[async_trait::async_trait]
224 impl TokenGenerator for MockTokenGenerator {
225 async fn generate(&self, _parsed: &ParsedDbUrl) -> Result<String> {
226 self.call_count
227 .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
228 self.results
229 .lock()
230 .unwrap()
231 .pop_front()
232 .unwrap_or_else(|| Err(anyhow::anyhow!("no more mock tokens")))
233 }
234 }
235}
236
237#[cfg(test)]
238mod tests {
239 use super::test_support::*;
240 use super::*;
241 use std::collections::VecDeque;
242 use std::sync::Mutex;
243
244 #[test]
245 fn parse_full_url() {
246 let parsed =
247 ParsedDbUrl::parse("postgres://myuser:mypass@db.example.com:5432/mydb?sslmode=require")
248 .unwrap();
249 assert_eq!(parsed.hostname, "db.example.com");
250 assert_eq!(parsed.port, 5432);
251 assert_eq!(parsed.username, "myuser");
252 assert_eq!(parsed.dbname, "mydb");
253 assert_eq!(parsed.query_params, "?sslmode=require");
254 }
255
256 #[test]
257 fn parse_default_port() {
258 let parsed = ParsedDbUrl::parse("postgres://user@host/db").unwrap();
259 assert_eq!(parsed.port, 5432);
260 }
261
262 #[test]
263 fn parse_no_query_params() {
264 let parsed = ParsedDbUrl::parse("postgres://user@host/db").unwrap();
265 assert_eq!(parsed.query_params, "");
266 }
267
268 #[test]
269 fn parse_multiple_query_params() {
270 let parsed =
271 ParsedDbUrl::parse("postgres://user@host/db?sslmode=require&connect_timeout=10")
272 .unwrap();
273 assert_eq!(parsed.query_params, "?sslmode=require&connect_timeout=10");
274 }
275
276 #[test]
277 fn parse_missing_host_fails() {
278 assert!(ParsedDbUrl::parse("postgres:///db").is_err());
279 }
280
281 #[test]
282 fn parse_missing_username_fails() {
283 assert!(ParsedDbUrl::parse("postgres://host/db").is_err());
284 }
285
286 #[test]
287 fn parse_missing_dbname_fails() {
288 assert!(ParsedDbUrl::parse("postgres://user@host/").is_err());
289 assert!(ParsedDbUrl::parse("postgres://user@host").is_err());
290 }
291
292 #[test]
293 fn parse_invalid_url_fails() {
294 assert!(ParsedDbUrl::parse("not-a-url").is_err());
295 }
296
297 #[test]
298 fn url_with_token_preserves_structure() {
299 let parsed = ParsedDbUrl::parse("postgres://user@host:5432/db?sslmode=require").unwrap();
300 let url = parsed.url_with_token("mytoken123").unwrap();
301 assert!(url.starts_with("postgres://user:mytoken123@host:5432/db"));
302 assert!(url.contains("sslmode=require"));
303
304 let url2 = parsed.url_with_token("tok/en=val").unwrap();
305 let reparsed = url::Url::parse(&url2).unwrap();
306 assert_eq!(reparsed.host_str(), Some("host"));
307 assert_eq!(reparsed.port(), Some(5432));
308 assert_eq!(reparsed.path(), "/db");
309 assert_eq!(reparsed.query(), Some("sslmode=require"));
310 }
311
312 #[test]
313 fn url_with_token_round_trip() {
314 let parsed = ParsedDbUrl::parse(
315 "postgres://admin@mydb.cluster.us-east-1.rds.amazonaws.com:5432/hypercall?sslmode=require",
316 )
317 .unwrap();
318 let url = parsed.url_with_token("simpletoken123").unwrap();
319 let reparsed = url::Url::parse(&url).unwrap();
320 assert_eq!(reparsed.username(), "admin");
321 assert_eq!(reparsed.password(), Some("simpletoken123"));
322 assert_eq!(
323 reparsed.host_str(),
324 Some("mydb.cluster.us-east-1.rds.amazonaws.com")
325 );
326 assert_eq!(reparsed.port(), Some(5432));
327 assert_eq!(reparsed.path(), "/hypercall");
328 assert_eq!(reparsed.query(), Some("sslmode=require"));
329 }
330
331 #[test]
332 fn store_url_visible_via_current_url() {
333 let parsed = ParsedDbUrl::parse("postgres://user@host/db").unwrap();
334 let provider =
335 RdsIamTokenProvider::with_generator(parsed, Box::new(MockTokenGenerator::new(vec![])));
336
337 assert_eq!(provider.current_url().as_ref(), "");
338
339 provider.store_url("postgres://user:token1@host/db".into());
340 assert_eq!(
341 provider.current_url().as_ref(),
342 "postgres://user:token1@host/db"
343 );
344
345 provider.store_url("postgres://user:token2@host/db".into());
346 assert_eq!(
347 provider.current_url().as_ref(),
348 "postgres://user:token2@host/db"
349 );
350 }
351
352 #[tokio::test]
353 async fn refresh_uses_generator_and_updates_url() {
354 let parsed = ParsedDbUrl::parse("postgres://user@host:5432/db").unwrap();
355 let provider = RdsIamTokenProvider::with_generator(
356 parsed,
357 Box::new(MockTokenGenerator::new(vec![
358 Ok("first_token".into()),
359 Ok("second_token".into()),
360 ])),
361 );
362
363 provider.refresh().await.unwrap();
364 let url1 = provider.current_url();
365 assert!(url1.contains("first_token"), "url={url1}");
366
367 provider.refresh().await.unwrap();
368 let url2 = provider.current_url();
369 assert!(url2.contains("second_token"), "url={url2}");
370 }
371
372 #[tokio::test]
373 async fn refresh_failure_preserves_old_url() {
374 let parsed = ParsedDbUrl::parse("postgres://user@host/db").unwrap();
375 let provider = RdsIamTokenProvider::with_generator(
376 parsed,
377 Box::new(MockTokenGenerator::new(vec![
378 Ok("good_token".into()),
379 Err(anyhow::anyhow!("AWS timeout")),
380 ])),
381 );
382
383 provider.refresh().await.unwrap();
384 let good_url = provider.current_url().as_ref().clone();
385
386 let err = provider.refresh().await;
387 assert!(err.is_err());
388
389 assert_eq!(provider.current_url().as_ref().clone(), good_url);
391 }
392
393 #[tokio::test(start_paused = true)]
394 async fn refresh_task_runs_on_schedule() {
395 let call_count = Arc::new(std::sync::atomic::AtomicUsize::new(0));
396 let cc = call_count.clone();
397
398 let parsed = ParsedDbUrl::parse("postgres://user@host/db").unwrap();
399 let gen = MockTokenGenerator {
400 results: Arc::new(Mutex::new(VecDeque::from(vec![
401 Ok("t1".into()),
402 Ok("t2".into()),
403 Ok("t3".into()),
404 Ok("t4".into()),
405 Ok("t5".into()),
406 ]))),
407 call_count: cc,
408 };
409 let provider = RdsIamTokenProvider::with_generator(parsed, Box::new(gen));
410
411 let handle = provider.spawn_refresh_task();
412 tokio::task::yield_now().await;
414
415 assert_eq!(call_count.load(std::sync::atomic::Ordering::SeqCst), 0);
417
418 tokio::time::advance(std::time::Duration::from_secs(REFRESH_INTERVAL_SECS + 1)).await;
420 for _ in 0..10 {
421 tokio::task::yield_now().await;
422 }
423 assert_eq!(call_count.load(std::sync::atomic::Ordering::SeqCst), 1);
424
425 tokio::time::advance(std::time::Duration::from_secs(REFRESH_INTERVAL_SECS)).await;
427 for _ in 0..10 {
428 tokio::task::yield_now().await;
429 }
430 assert_eq!(call_count.load(std::sync::atomic::Ordering::SeqCst), 2);
431
432 handle.abort();
433 }
434
435 #[tokio::test(start_paused = true)]
436 async fn refresh_task_retries_on_failure() {
437 let call_count = Arc::new(std::sync::atomic::AtomicUsize::new(0));
438 let cc = call_count.clone();
439
440 let parsed = ParsedDbUrl::parse("postgres://user@host/db").unwrap();
441 let gen = MockTokenGenerator {
442 results: Arc::new(Mutex::new(VecDeque::from(vec![
443 Err(anyhow::anyhow!("fail")),
445 Err(anyhow::anyhow!("fail")),
447 Ok("recovered_token".into()),
449 ]))),
450 call_count: cc,
451 };
452 let provider = RdsIamTokenProvider::with_generator(parsed, Box::new(gen));
453
454 let handle = provider.spawn_refresh_task();
455 tokio::task::yield_now().await;
456
457 tokio::time::advance(std::time::Duration::from_secs(REFRESH_INTERVAL_SECS + 1)).await;
459 for _ in 0..10 {
460 tokio::task::yield_now().await;
461 }
462 assert_eq!(call_count.load(std::sync::atomic::Ordering::SeqCst), 1);
463
464 tokio::time::advance(std::time::Duration::from_secs(31)).await;
466 for _ in 0..10 {
467 tokio::task::yield_now().await;
468 }
469 assert_eq!(call_count.load(std::sync::atomic::Ordering::SeqCst), 2);
470
471 tokio::time::advance(std::time::Duration::from_secs(61)).await;
473 for _ in 0..10 {
474 tokio::task::yield_now().await;
475 }
476 assert_eq!(call_count.load(std::sync::atomic::Ordering::SeqCst), 3);
477 assert!(
478 provider.current_url().contains("recovered_token"),
479 "url={}",
480 provider.current_url()
481 );
482
483 handle.abort();
484 }
485
486 #[tokio::test(start_paused = true)]
487 async fn refresh_task_exhausts_retries_then_resumes_normal_cadence() {
488 let call_count = Arc::new(std::sync::atomic::AtomicUsize::new(0));
489 let cc = call_count.clone();
490
491 let parsed = ParsedDbUrl::parse("postgres://user@host/db").unwrap();
492 let gen = MockTokenGenerator {
493 results: Arc::new(Mutex::new(VecDeque::from(vec![
494 Err(anyhow::anyhow!("fail")),
496 Err(anyhow::anyhow!("fail")),
498 Err(anyhow::anyhow!("fail")),
500 Err(anyhow::anyhow!("fail")),
502 Ok("finally".into()),
504 ]))),
505 call_count: cc,
506 };
507 let provider = RdsIamTokenProvider::with_generator(parsed, Box::new(gen));
508 provider.store_url("postgres://user:initial@host/db".into());
509
510 let handle = provider.spawn_refresh_task();
511 tokio::task::yield_now().await;
512
513 tokio::time::advance(std::time::Duration::from_secs(REFRESH_INTERVAL_SECS + 1)).await;
515 for _ in 0..10 {
516 tokio::task::yield_now().await;
517 }
518 tokio::time::advance(std::time::Duration::from_secs(31)).await;
520 for _ in 0..10 {
521 tokio::task::yield_now().await;
522 }
523 tokio::time::advance(std::time::Duration::from_secs(61)).await;
525 for _ in 0..10 {
526 tokio::task::yield_now().await;
527 }
528 tokio::time::advance(std::time::Duration::from_secs(121)).await;
530 for _ in 0..10 {
531 tokio::task::yield_now().await;
532 }
533
534 assert_eq!(call_count.load(std::sync::atomic::Ordering::SeqCst), 4);
535 assert!(provider.current_url().contains("initial"));
537
538 tokio::time::advance(std::time::Duration::from_secs(REFRESH_INTERVAL_SECS + 1)).await;
540 for _ in 0..10 {
541 tokio::task::yield_now().await;
542 }
543
544 assert_eq!(call_count.load(std::sync::atomic::Ordering::SeqCst), 5);
545 assert!(
546 provider.current_url().contains("finally"),
547 "url={}",
548 provider.current_url()
549 );
550
551 handle.abort();
552 }
553}