Skip to main content

hypercall_db_diesel/
db_auth.rs

1use anyhow::Result;
2
3/// How the application authenticates to PostgreSQL.
4///
5/// `Password` is the legacy path: a static `DATABASE_URL` with embedded
6/// credentials. `RdsIam` generates short-lived AWS IAM tokens, refreshed
7/// automatically in the background.
8#[derive(Clone)]
9pub enum DbAuthConfig {
10    /// Static password embedded in the connection URL.
11    Password { database_url: String },
12
13    /// AWS RDS IAM authentication. The provider caches a fresh token and
14    /// exposes `current_url()` for connection managers.
15    #[cfg(feature = "rds-iam")]
16    RdsIam {
17        provider: std::sync::Arc<crate::rds_iam::RdsIamTokenProvider>,
18    },
19}
20
21impl DbAuthConfig {
22    /// Build a `Password` config from a database URL string.
23    pub fn password(database_url: impl Into<String>) -> Self {
24        Self::Password {
25            database_url: database_url.into(),
26        }
27    }
28
29    /// Build an `RdsIam` config. Loads AWS credentials, generates the first
30    /// token, and spawns a background refresh task.
31    #[cfg(feature = "rds-iam")]
32    pub async fn rds_iam(database_url: &str) -> Result<Self> {
33        let provider = crate::rds_iam::RdsIamTokenProvider::new(database_url).await?;
34        provider.spawn_refresh_task();
35        Ok(Self::RdsIam { provider })
36    }
37
38    /// Return the current database URL (with valid credentials).
39    pub fn current_url(&self) -> String {
40        match self {
41            Self::Password { database_url } => database_url.clone(),
42            #[cfg(feature = "rds-iam")]
43            Self::RdsIam { provider } => provider.current_url().as_ref().clone(),
44        }
45    }
46
47    /// Build from the `DB_AUTH_MODE` env var convention.
48    ///
49    /// - `DB_AUTH_MODE=rds_iam` (or `rds-iam`) → [`Self::rds_iam`]
50    /// - anything else (or unset) → [`Self::password`]
51    pub async fn from_env(database_url: &str) -> Result<Self> {
52        let mode = std::env::var("DB_AUTH_MODE").unwrap_or_default();
53        match mode.as_str() {
54            #[cfg(feature = "rds-iam")]
55            "rds_iam" | "rds-iam" => Self::rds_iam(database_url).await,
56            #[cfg(not(feature = "rds-iam"))]
57            "rds_iam" | "rds-iam" => anyhow::bail!(
58                "DB_AUTH_MODE=rds_iam requested but the rds-iam feature is not compiled in"
59            ),
60            _ => Ok(Self::password(database_url)),
61        }
62    }
63}
64
65#[cfg(feature = "rds-iam")]
66impl DbAuthConfig {
67    /// Build an `RdsIam` variant with a pre-seeded URL (no AWS calls).
68    /// Used by tests and any code that manages token lifecycle externally.
69    pub fn rds_iam_with_provider(
70        provider: std::sync::Arc<crate::rds_iam::RdsIamTokenProvider>,
71    ) -> Self {
72        Self::RdsIam { provider }
73    }
74}
75
76impl std::fmt::Debug for DbAuthConfig {
77    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
78        match self {
79            Self::Password { .. } => f.debug_struct("Password").finish_non_exhaustive(),
80            #[cfg(feature = "rds-iam")]
81            Self::RdsIam { .. } => f.debug_struct("RdsIam").finish_non_exhaustive(),
82        }
83    }
84}
85
86#[cfg(test)]
87mod tests {
88    use super::*;
89
90    #[test]
91    fn password_current_url_round_trip() {
92        let url = "postgres://user:pass@host:5432/db";
93        let auth = DbAuthConfig::password(url);
94        assert_eq!(auth.current_url(), url);
95    }
96
97    #[test]
98    fn password_clone_is_independent() {
99        let auth = DbAuthConfig::password("postgres://a@b/c");
100        let cloned = auth.clone();
101        assert_eq!(auth.current_url(), cloned.current_url());
102    }
103
104    #[test]
105    fn debug_does_not_leak_credentials() {
106        let auth = DbAuthConfig::password("postgres://user:secret@host/db");
107        let debug = format!("{:?}", auth);
108        assert!(!debug.contains("secret"));
109        assert!(!debug.contains("user"));
110        assert!(debug.contains("Password"));
111    }
112
113    #[tokio::test]
114    async fn from_env_defaults_to_password() {
115        std::env::remove_var("DB_AUTH_MODE");
116        let auth = DbAuthConfig::from_env("postgres://u@h/d").await.unwrap();
117        assert_eq!(auth.current_url(), "postgres://u@h/d");
118    }
119
120    #[tokio::test]
121    async fn from_env_unknown_mode_defaults_to_password() {
122        std::env::set_var("DB_AUTH_MODE", "something_else");
123        let auth = DbAuthConfig::from_env("postgres://u@h/d").await.unwrap();
124        assert_eq!(auth.current_url(), "postgres://u@h/d");
125        std::env::remove_var("DB_AUTH_MODE");
126    }
127
128    #[cfg(feature = "rds-iam")]
129    #[test]
130    fn rds_iam_current_url_reads_from_provider() {
131        let parsed = crate::rds_iam::ParsedDbUrl::parse("postgres://user@host/db").unwrap();
132        let mock_gen = crate::rds_iam::test_support::MockTokenGenerator::new(vec![]);
133        let provider =
134            crate::rds_iam::RdsIamTokenProvider::with_generator(parsed, Box::new(mock_gen));
135        provider.store_url("postgres://user:fresh_token@host/db".into());
136
137        let auth = DbAuthConfig::rds_iam_with_provider(provider.clone());
138        assert_eq!(auth.current_url(), "postgres://user:fresh_token@host/db");
139
140        provider.store_url("postgres://user:rotated@host/db".into());
141        assert_eq!(auth.current_url(), "postgres://user:rotated@host/db");
142    }
143}