Skip to main content

hypercall_db_diesel/
rds_iam.rs

1use std::sync::Arc;
2
3use anyhow::{Context, Result};
4use arc_swap::ArcSwap;
5use tracing::{error, info};
6
7/// Components parsed from a `DATABASE_URL` for IAM token injection.
8#[derive(Debug, Clone, PartialEq, Eq)]
9pub struct ParsedDbUrl {
10    pub hostname: String,
11    pub port: u16,
12    pub username: String,
13    pub dbname: String,
14    pub query_params: String,
15}
16
17impl ParsedDbUrl {
18    pub fn parse(database_url: &str) -> Result<Self> {
19        let parsed = url::Url::parse(database_url)
20            .with_context(|| "DATABASE_URL is not a valid URL for RDS IAM parsing")?;
21
22        let hostname = parsed
23            .host_str()
24            .context("DATABASE_URL missing host")?
25            .to_string();
26        let port = parsed.port().unwrap_or(5432);
27        let username = if parsed.username().is_empty() {
28            anyhow::bail!("DATABASE_URL missing username (required for RDS IAM)");
29        } else {
30            parsed.username().to_string()
31        };
32        let dbname = parsed.path().trim_start_matches('/').to_string();
33        if dbname.is_empty() {
34            anyhow::bail!("DATABASE_URL missing database name");
35        }
36        let query_params = parsed.query().map(|q| format!("?{q}")).unwrap_or_default();
37
38        Ok(Self {
39            hostname,
40            port,
41            username,
42            dbname,
43            query_params,
44        })
45    }
46
47    /// Build a full database URL with the given token as password.
48    pub fn url_with_token(&self, token: &str) -> Result<String> {
49        // The IAM token contains URL-encoded chars (%2F, etc.) that must survive
50        // libpq's single percent-decode pass intact. We percent-encode the token
51        // ourselves so that after libpq decodes, the proxy receives the original
52        // token byte-for-byte. url::Url::set_password would double-encode the
53        // existing %XX sequences, corrupting them.
54        let encoded = url::form_urlencoded::byte_serialize(token.as_bytes()).collect::<String>();
55        Ok(format!(
56            "postgres://{}:{}@{}:{}/{}{}",
57            self.username, encoded, self.hostname, self.port, self.dbname, self.query_params,
58        ))
59    }
60}
61
62/// Trait for generating database auth tokens. Production uses AWS IAM,
63/// tests can inject mock implementations.
64#[async_trait::async_trait]
65pub(crate) trait TokenGenerator: Send + Sync {
66    async fn generate(&self, parsed: &ParsedDbUrl) -> Result<String>;
67}
68
69/// Production token generator using AWS RDS IAM auth tokens.
70#[cfg(feature = "rds-iam")]
71struct AwsTokenGenerator {
72    aws_config: aws_config::SdkConfig,
73}
74
75#[cfg(feature = "rds-iam")]
76#[async_trait::async_trait]
77impl TokenGenerator for AwsTokenGenerator {
78    async fn generate(&self, parsed: &ParsedDbUrl) -> Result<String> {
79        let config = aws_sdk_rds::auth_token::Config::builder()
80            .hostname(&parsed.hostname)
81            .port(parsed.port as u64)
82            .username(&parsed.username)
83            .build()
84            .map_err(|e| anyhow::anyhow!("failed to build IAM auth token config: {e}"))?;
85        let token = aws_sdk_rds::auth_token::AuthTokenGenerator::new(config)
86            .auth_token(&self.aws_config)
87            .await
88            .map_err(|e| anyhow::anyhow!("failed to generate RDS IAM auth token: {e}"))?;
89        Ok(token.to_string())
90    }
91}
92
93/// Refresh interval and retry delays (extracted as constants for test visibility).
94pub(crate) const REFRESH_INTERVAL_SECS: u64 = 600;
95pub(crate) const RETRY_DELAYS_SECS: [u64; 3] = [30, 60, 120];
96
97/// Generates RDS IAM auth tokens and caches a ready-to-use database URL.
98///
99/// The cached URL is atomically swapped via `ArcSwap` so both sync (r2d2) and
100/// async (deadpool) pools can read it lock-free. A background tokio task
101/// refreshes the token well before the 15-minute AWS expiry.
102pub struct RdsIamTokenProvider {
103    pub(crate) parsed: ParsedDbUrl,
104    pub(crate) cached_url: Arc<ArcSwap<String>>,
105    generator: Box<dyn TokenGenerator>,
106}
107
108impl RdsIamTokenProvider {
109    /// Initialize the provider: load AWS credentials, generate the first token,
110    /// and return a ready provider. Call [`Self::spawn_refresh_task`] after to
111    /// keep the token fresh.
112    #[cfg(feature = "rds-iam")]
113    pub async fn new(database_url: &str) -> Result<Arc<Self>> {
114        let parsed = ParsedDbUrl::parse(database_url)?;
115        let aws_config = aws_config::load_defaults(aws_config::BehaviorVersion::latest()).await;
116
117        let provider = Arc::new(Self {
118            parsed,
119            cached_url: Arc::new(ArcSwap::from_pointee(String::new())),
120            generator: Box::new(AwsTokenGenerator { aws_config }),
121        });
122
123        provider
124            .refresh()
125            .await
126            .context("failed to generate initial RDS IAM token")?;
127
128        Ok(provider)
129    }
130
131    /// Create a provider with a custom token generator. Used for testing
132    /// without AWS credentials. Does NOT call `refresh()` — caller must
133    /// seed the URL via [`Self::store_url`] or call [`Self::refresh`].
134    pub(crate) fn with_generator(
135        parsed: ParsedDbUrl,
136        generator: Box<dyn TokenGenerator>,
137    ) -> Arc<Self> {
138        Arc::new(Self {
139            parsed,
140            cached_url: Arc::new(ArcSwap::from_pointee(String::new())),
141            generator,
142        })
143    }
144
145    /// Return the current database URL with a valid IAM token as password.
146    pub fn current_url(&self) -> Arc<String> {
147        self.cached_url.load_full()
148    }
149
150    /// Atomically store a new URL (used by `refresh` and tests).
151    pub fn store_url(&self, url: String) {
152        self.cached_url.store(Arc::new(url));
153    }
154
155    /// Generate a fresh auth token and update the cached URL.
156    pub async fn refresh(&self) -> Result<()> {
157        let token = self.generator.generate(&self.parsed).await?;
158        let url = self.parsed.url_with_token(&token)?;
159        self.cached_url.store(Arc::new(url));
160        info!(host = %self.parsed.hostname, "refreshed RDS IAM auth token");
161        Ok(())
162    }
163
164    /// Spawn a background task that refreshes the token every 10 minutes.
165    /// On failure, retries with exponential backoff (30s, 60s, 120s) before
166    /// falling back to the normal 10-minute cadence.
167    pub fn spawn_refresh_task(self: &Arc<Self>) -> tokio::task::JoinHandle<()> {
168        let provider = Arc::clone(self);
169        tokio::spawn(async move {
170            let normal_interval = std::time::Duration::from_secs(REFRESH_INTERVAL_SECS);
171            let retry_delays: Vec<std::time::Duration> = RETRY_DELAYS_SECS
172                .iter()
173                .map(|s| std::time::Duration::from_secs(*s))
174                .collect();
175
176            loop {
177                tokio::time::sleep(normal_interval).await;
178                if provider.refresh().await.is_ok() {
179                    continue;
180                }
181                error!("RDS IAM token refresh failed, starting retry backoff");
182                let mut recovered = false;
183                for delay in &retry_delays {
184                    tokio::time::sleep(*delay).await;
185                    match provider.refresh().await {
186                        Ok(()) => {
187                            recovered = true;
188                            break;
189                        }
190                        Err(e) => error!("RDS IAM retry failed: {e:#}"),
191                    }
192                }
193                if !recovered {
194                    error!("RDS IAM token refresh exhausted retries, resuming normal cadence");
195                }
196            }
197        })
198    }
199}
200
201/// Test support types for mocking token generation.
202#[cfg(test)]
203pub(crate) mod test_support {
204    use super::*;
205    use std::collections::VecDeque;
206    use std::sync::Mutex;
207
208    /// Mock token generator that returns pre-programmed results.
209    pub(crate) struct MockTokenGenerator {
210        pub(crate) results: Arc<Mutex<VecDeque<Result<String>>>>,
211        pub(crate) call_count: Arc<std::sync::atomic::AtomicUsize>,
212    }
213
214    impl MockTokenGenerator {
215        pub(crate) fn new(results: Vec<Result<String>>) -> Self {
216            Self {
217                results: Arc::new(Mutex::new(VecDeque::from(results))),
218                call_count: Arc::new(std::sync::atomic::AtomicUsize::new(0)),
219            }
220        }
221    }
222
223    #[async_trait::async_trait]
224    impl TokenGenerator for MockTokenGenerator {
225        async fn generate(&self, _parsed: &ParsedDbUrl) -> Result<String> {
226            self.call_count
227                .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
228            self.results
229                .lock()
230                .unwrap()
231                .pop_front()
232                .unwrap_or_else(|| Err(anyhow::anyhow!("no more mock tokens")))
233        }
234    }
235}
236
237#[cfg(test)]
238mod tests {
239    use super::test_support::*;
240    use super::*;
241    use std::collections::VecDeque;
242    use std::sync::Mutex;
243
244    #[test]
245    fn parse_full_url() {
246        let parsed =
247            ParsedDbUrl::parse("postgres://myuser:mypass@db.example.com:5432/mydb?sslmode=require")
248                .unwrap();
249        assert_eq!(parsed.hostname, "db.example.com");
250        assert_eq!(parsed.port, 5432);
251        assert_eq!(parsed.username, "myuser");
252        assert_eq!(parsed.dbname, "mydb");
253        assert_eq!(parsed.query_params, "?sslmode=require");
254    }
255
256    #[test]
257    fn parse_default_port() {
258        let parsed = ParsedDbUrl::parse("postgres://user@host/db").unwrap();
259        assert_eq!(parsed.port, 5432);
260    }
261
262    #[test]
263    fn parse_no_query_params() {
264        let parsed = ParsedDbUrl::parse("postgres://user@host/db").unwrap();
265        assert_eq!(parsed.query_params, "");
266    }
267
268    #[test]
269    fn parse_multiple_query_params() {
270        let parsed =
271            ParsedDbUrl::parse("postgres://user@host/db?sslmode=require&connect_timeout=10")
272                .unwrap();
273        assert_eq!(parsed.query_params, "?sslmode=require&connect_timeout=10");
274    }
275
276    #[test]
277    fn parse_missing_host_fails() {
278        assert!(ParsedDbUrl::parse("postgres:///db").is_err());
279    }
280
281    #[test]
282    fn parse_missing_username_fails() {
283        assert!(ParsedDbUrl::parse("postgres://host/db").is_err());
284    }
285
286    #[test]
287    fn parse_missing_dbname_fails() {
288        assert!(ParsedDbUrl::parse("postgres://user@host/").is_err());
289        assert!(ParsedDbUrl::parse("postgres://user@host").is_err());
290    }
291
292    #[test]
293    fn parse_invalid_url_fails() {
294        assert!(ParsedDbUrl::parse("not-a-url").is_err());
295    }
296
297    #[test]
298    fn url_with_token_preserves_structure() {
299        let parsed = ParsedDbUrl::parse("postgres://user@host:5432/db?sslmode=require").unwrap();
300        let url = parsed.url_with_token("mytoken123").unwrap();
301        assert!(url.starts_with("postgres://user:mytoken123@host:5432/db"));
302        assert!(url.contains("sslmode=require"));
303
304        let url2 = parsed.url_with_token("tok/en=val").unwrap();
305        let reparsed = url::Url::parse(&url2).unwrap();
306        assert_eq!(reparsed.host_str(), Some("host"));
307        assert_eq!(reparsed.port(), Some(5432));
308        assert_eq!(reparsed.path(), "/db");
309        assert_eq!(reparsed.query(), Some("sslmode=require"));
310    }
311
312    #[test]
313    fn url_with_token_round_trip() {
314        let parsed = ParsedDbUrl::parse(
315            "postgres://admin@mydb.cluster.us-east-1.rds.amazonaws.com:5432/hypercall?sslmode=require",
316        )
317        .unwrap();
318        let url = parsed.url_with_token("simpletoken123").unwrap();
319        let reparsed = url::Url::parse(&url).unwrap();
320        assert_eq!(reparsed.username(), "admin");
321        assert_eq!(reparsed.password(), Some("simpletoken123"));
322        assert_eq!(
323            reparsed.host_str(),
324            Some("mydb.cluster.us-east-1.rds.amazonaws.com")
325        );
326        assert_eq!(reparsed.port(), Some(5432));
327        assert_eq!(reparsed.path(), "/hypercall");
328        assert_eq!(reparsed.query(), Some("sslmode=require"));
329    }
330
331    #[test]
332    fn store_url_visible_via_current_url() {
333        let parsed = ParsedDbUrl::parse("postgres://user@host/db").unwrap();
334        let provider =
335            RdsIamTokenProvider::with_generator(parsed, Box::new(MockTokenGenerator::new(vec![])));
336
337        assert_eq!(provider.current_url().as_ref(), "");
338
339        provider.store_url("postgres://user:token1@host/db".into());
340        assert_eq!(
341            provider.current_url().as_ref(),
342            "postgres://user:token1@host/db"
343        );
344
345        provider.store_url("postgres://user:token2@host/db".into());
346        assert_eq!(
347            provider.current_url().as_ref(),
348            "postgres://user:token2@host/db"
349        );
350    }
351
352    #[tokio::test]
353    async fn refresh_uses_generator_and_updates_url() {
354        let parsed = ParsedDbUrl::parse("postgres://user@host:5432/db").unwrap();
355        let provider = RdsIamTokenProvider::with_generator(
356            parsed,
357            Box::new(MockTokenGenerator::new(vec![
358                Ok("first_token".into()),
359                Ok("second_token".into()),
360            ])),
361        );
362
363        provider.refresh().await.unwrap();
364        let url1 = provider.current_url();
365        assert!(url1.contains("first_token"), "url={url1}");
366
367        provider.refresh().await.unwrap();
368        let url2 = provider.current_url();
369        assert!(url2.contains("second_token"), "url={url2}");
370    }
371
372    #[tokio::test]
373    async fn refresh_failure_preserves_old_url() {
374        let parsed = ParsedDbUrl::parse("postgres://user@host/db").unwrap();
375        let provider = RdsIamTokenProvider::with_generator(
376            parsed,
377            Box::new(MockTokenGenerator::new(vec![
378                Ok("good_token".into()),
379                Err(anyhow::anyhow!("AWS timeout")),
380            ])),
381        );
382
383        provider.refresh().await.unwrap();
384        let good_url = provider.current_url().as_ref().clone();
385
386        let err = provider.refresh().await;
387        assert!(err.is_err());
388
389        // Old URL preserved after failed refresh
390        assert_eq!(provider.current_url().as_ref().clone(), good_url);
391    }
392
393    #[tokio::test(start_paused = true)]
394    async fn refresh_task_runs_on_schedule() {
395        let call_count = Arc::new(std::sync::atomic::AtomicUsize::new(0));
396        let cc = call_count.clone();
397
398        let parsed = ParsedDbUrl::parse("postgres://user@host/db").unwrap();
399        let gen = MockTokenGenerator {
400            results: Arc::new(Mutex::new(VecDeque::from(vec![
401                Ok("t1".into()),
402                Ok("t2".into()),
403                Ok("t3".into()),
404                Ok("t4".into()),
405                Ok("t5".into()),
406            ]))),
407            call_count: cc,
408        };
409        let provider = RdsIamTokenProvider::with_generator(parsed, Box::new(gen));
410
411        let handle = provider.spawn_refresh_task();
412        // Let the spawned task start and register its first sleep
413        tokio::task::yield_now().await;
414
415        // No calls yet (first refresh happens after REFRESH_INTERVAL_SECS)
416        assert_eq!(call_count.load(std::sync::atomic::Ordering::SeqCst), 0);
417
418        // Advance past first interval
419        tokio::time::advance(std::time::Duration::from_secs(REFRESH_INTERVAL_SECS + 1)).await;
420        for _ in 0..10 {
421            tokio::task::yield_now().await;
422        }
423        assert_eq!(call_count.load(std::sync::atomic::Ordering::SeqCst), 1);
424
425        // Advance past second interval
426        tokio::time::advance(std::time::Duration::from_secs(REFRESH_INTERVAL_SECS)).await;
427        for _ in 0..10 {
428            tokio::task::yield_now().await;
429        }
430        assert_eq!(call_count.load(std::sync::atomic::Ordering::SeqCst), 2);
431
432        handle.abort();
433    }
434
435    #[tokio::test(start_paused = true)]
436    async fn refresh_task_retries_on_failure() {
437        let call_count = Arc::new(std::sync::atomic::AtomicUsize::new(0));
438        let cc = call_count.clone();
439
440        let parsed = ParsedDbUrl::parse("postgres://user@host/db").unwrap();
441        let gen = MockTokenGenerator {
442            results: Arc::new(Mutex::new(VecDeque::from(vec![
443                // First scheduled refresh: fails
444                Err(anyhow::anyhow!("fail")),
445                // Retry 1 (30s): fails
446                Err(anyhow::anyhow!("fail")),
447                // Retry 2 (60s): succeeds
448                Ok("recovered_token".into()),
449            ]))),
450            call_count: cc,
451        };
452        let provider = RdsIamTokenProvider::with_generator(parsed, Box::new(gen));
453
454        let handle = provider.spawn_refresh_task();
455        tokio::task::yield_now().await;
456
457        // Advance past normal interval — triggers first (failing) refresh
458        tokio::time::advance(std::time::Duration::from_secs(REFRESH_INTERVAL_SECS + 1)).await;
459        for _ in 0..10 {
460            tokio::task::yield_now().await;
461        }
462        assert_eq!(call_count.load(std::sync::atomic::Ordering::SeqCst), 1);
463
464        // Advance 30s — first retry (also fails)
465        tokio::time::advance(std::time::Duration::from_secs(31)).await;
466        for _ in 0..10 {
467            tokio::task::yield_now().await;
468        }
469        assert_eq!(call_count.load(std::sync::atomic::Ordering::SeqCst), 2);
470
471        // Advance 60s — second retry (succeeds)
472        tokio::time::advance(std::time::Duration::from_secs(61)).await;
473        for _ in 0..10 {
474            tokio::task::yield_now().await;
475        }
476        assert_eq!(call_count.load(std::sync::atomic::Ordering::SeqCst), 3);
477        assert!(
478            provider.current_url().contains("recovered_token"),
479            "url={}",
480            provider.current_url()
481        );
482
483        handle.abort();
484    }
485
486    #[tokio::test(start_paused = true)]
487    async fn refresh_task_exhausts_retries_then_resumes_normal_cadence() {
488        let call_count = Arc::new(std::sync::atomic::AtomicUsize::new(0));
489        let cc = call_count.clone();
490
491        let parsed = ParsedDbUrl::parse("postgres://user@host/db").unwrap();
492        let gen = MockTokenGenerator {
493            results: Arc::new(Mutex::new(VecDeque::from(vec![
494                // Normal refresh: fail
495                Err(anyhow::anyhow!("fail")),
496                // Retry 1 (30s): fail
497                Err(anyhow::anyhow!("fail")),
498                // Retry 2 (60s): fail
499                Err(anyhow::anyhow!("fail")),
500                // Retry 3 (120s): fail — exhausted
501                Err(anyhow::anyhow!("fail")),
502                // Back to normal cadence (600s): succeeds
503                Ok("finally".into()),
504            ]))),
505            call_count: cc,
506        };
507        let provider = RdsIamTokenProvider::with_generator(parsed, Box::new(gen));
508        provider.store_url("postgres://user:initial@host/db".into());
509
510        let handle = provider.spawn_refresh_task();
511        tokio::task::yield_now().await;
512
513        // Normal interval → fail
514        tokio::time::advance(std::time::Duration::from_secs(REFRESH_INTERVAL_SECS + 1)).await;
515        for _ in 0..10 {
516            tokio::task::yield_now().await;
517        }
518        // Retry 1 (30s) → fail
519        tokio::time::advance(std::time::Duration::from_secs(31)).await;
520        for _ in 0..10 {
521            tokio::task::yield_now().await;
522        }
523        // Retry 2 (60s) → fail
524        tokio::time::advance(std::time::Duration::from_secs(61)).await;
525        for _ in 0..10 {
526            tokio::task::yield_now().await;
527        }
528        // Retry 3 (120s) → fail, exhausted
529        tokio::time::advance(std::time::Duration::from_secs(121)).await;
530        for _ in 0..10 {
531            tokio::task::yield_now().await;
532        }
533
534        assert_eq!(call_count.load(std::sync::atomic::Ordering::SeqCst), 4);
535        // Old URL still cached
536        assert!(provider.current_url().contains("initial"));
537
538        // Back to normal cadence → succeeds
539        tokio::time::advance(std::time::Duration::from_secs(REFRESH_INTERVAL_SECS + 1)).await;
540        for _ in 0..10 {
541            tokio::task::yield_now().await;
542        }
543
544        assert_eq!(call_count.load(std::sync::atomic::Ordering::SeqCst), 5);
545        assert!(
546            provider.current_url().contains("finally"),
547            "url={}",
548            provider.current_url()
549        );
550
551        handle.abort();
552    }
553}