hypercall_db_diesel/
db_auth.rs1use anyhow::Result;
2
3#[derive(Clone)]
9pub enum DbAuthConfig {
10 Password { database_url: String },
12
13 #[cfg(feature = "rds-iam")]
16 RdsIam {
17 provider: std::sync::Arc<crate::rds_iam::RdsIamTokenProvider>,
18 },
19}
20
21impl DbAuthConfig {
22 pub fn password(database_url: impl Into<String>) -> Self {
24 Self::Password {
25 database_url: database_url.into(),
26 }
27 }
28
29 #[cfg(feature = "rds-iam")]
32 pub async fn rds_iam(database_url: &str) -> Result<Self> {
33 let provider = crate::rds_iam::RdsIamTokenProvider::new(database_url).await?;
34 provider.spawn_refresh_task();
35 Ok(Self::RdsIam { provider })
36 }
37
38 pub fn current_url(&self) -> String {
40 match self {
41 Self::Password { database_url } => database_url.clone(),
42 #[cfg(feature = "rds-iam")]
43 Self::RdsIam { provider } => provider.current_url().as_ref().clone(),
44 }
45 }
46
47 pub async fn from_env(database_url: &str) -> Result<Self> {
52 let mode = std::env::var("DB_AUTH_MODE").unwrap_or_default();
53 match mode.as_str() {
54 #[cfg(feature = "rds-iam")]
55 "rds_iam" | "rds-iam" => Self::rds_iam(database_url).await,
56 #[cfg(not(feature = "rds-iam"))]
57 "rds_iam" | "rds-iam" => anyhow::bail!(
58 "DB_AUTH_MODE=rds_iam requested but the rds-iam feature is not compiled in"
59 ),
60 _ => Ok(Self::password(database_url)),
61 }
62 }
63}
64
65#[cfg(feature = "rds-iam")]
66impl DbAuthConfig {
67 pub fn rds_iam_with_provider(
70 provider: std::sync::Arc<crate::rds_iam::RdsIamTokenProvider>,
71 ) -> Self {
72 Self::RdsIam { provider }
73 }
74}
75
76impl std::fmt::Debug for DbAuthConfig {
77 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
78 match self {
79 Self::Password { .. } => f.debug_struct("Password").finish_non_exhaustive(),
80 #[cfg(feature = "rds-iam")]
81 Self::RdsIam { .. } => f.debug_struct("RdsIam").finish_non_exhaustive(),
82 }
83 }
84}
85
86#[cfg(test)]
87mod tests {
88 use super::*;
89
90 #[test]
91 fn password_current_url_round_trip() {
92 let url = "postgres://user:pass@host:5432/db";
93 let auth = DbAuthConfig::password(url);
94 assert_eq!(auth.current_url(), url);
95 }
96
97 #[test]
98 fn password_clone_is_independent() {
99 let auth = DbAuthConfig::password("postgres://a@b/c");
100 let cloned = auth.clone();
101 assert_eq!(auth.current_url(), cloned.current_url());
102 }
103
104 #[test]
105 fn debug_does_not_leak_credentials() {
106 let auth = DbAuthConfig::password("postgres://user:secret@host/db");
107 let debug = format!("{:?}", auth);
108 assert!(!debug.contains("secret"));
109 assert!(!debug.contains("user"));
110 assert!(debug.contains("Password"));
111 }
112
113 #[tokio::test]
114 async fn from_env_defaults_to_password() {
115 std::env::remove_var("DB_AUTH_MODE");
116 let auth = DbAuthConfig::from_env("postgres://u@h/d").await.unwrap();
117 assert_eq!(auth.current_url(), "postgres://u@h/d");
118 }
119
120 #[tokio::test]
121 async fn from_env_unknown_mode_defaults_to_password() {
122 std::env::set_var("DB_AUTH_MODE", "something_else");
123 let auth = DbAuthConfig::from_env("postgres://u@h/d").await.unwrap();
124 assert_eq!(auth.current_url(), "postgres://u@h/d");
125 std::env::remove_var("DB_AUTH_MODE");
126 }
127
128 #[cfg(feature = "rds-iam")]
129 #[test]
130 fn rds_iam_current_url_reads_from_provider() {
131 let parsed = crate::rds_iam::ParsedDbUrl::parse("postgres://user@host/db").unwrap();
132 let mock_gen = crate::rds_iam::test_support::MockTokenGenerator::new(vec![]);
133 let provider =
134 crate::rds_iam::RdsIamTokenProvider::with_generator(parsed, Box::new(mock_gen));
135 provider.store_url("postgres://user:fresh_token@host/db".into());
136
137 let auth = DbAuthConfig::rds_iam_with_provider(provider.clone());
138 assert_eq!(auth.current_url(), "postgres://user:fresh_token@host/db");
139
140 provider.store_url("postgres://user:rotated@host/db".into());
141 assert_eq!(auth.current_url(), "postgres://user:rotated@host/db");
142 }
143}