hypercall_db_diesel/
diesel_db.rs1use crate::db_auth::DbAuthConfig;
8use anyhow::{Context, Result};
9use diesel_async::pooled_connection::deadpool::Pool;
10use diesel_async::pooled_connection::AsyncDieselConnectionManager;
11use diesel_async::{AsyncPgConnection, RunQueryDsl};
12use std::sync::Arc;
13use tracing::info;
14
15pub type AsyncDbPool = Pool<AsyncPgConnection>;
17
18#[derive(Debug)]
22struct NoVerifier;
23
24impl rustls::client::danger::ServerCertVerifier for NoVerifier {
25 fn verify_server_cert(
26 &self,
27 _end_entity: &rustls::pki_types::CertificateDer<'_>,
28 _intermediates: &[rustls::pki_types::CertificateDer<'_>],
29 _server_name: &rustls::pki_types::ServerName<'_>,
30 _ocsp_response: &[u8],
31 _now: rustls::pki_types::UnixTime,
32 ) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
33 Ok(rustls::client::danger::ServerCertVerified::assertion())
34 }
35
36 fn verify_tls12_signature(
37 &self,
38 _message: &[u8],
39 _cert: &rustls::pki_types::CertificateDer<'_>,
40 _dss: &rustls::DigitallySignedStruct,
41 ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
42 Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
43 }
44
45 fn verify_tls13_signature(
46 &self,
47 _message: &[u8],
48 _cert: &rustls::pki_types::CertificateDer<'_>,
49 _dss: &rustls::DigitallySignedStruct,
50 ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
51 Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
52 }
53
54 fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
55 rustls::crypto::ring::default_provider()
56 .signature_verification_algorithms
57 .supported_schemes()
58 }
59}
60
61fn redact_url(url: &str) -> String {
62 url.find('@')
63 .map(|at| format!("postgres://***@{}", &url[at + 1..]))
64 .unwrap_or_else(|| "postgres://***".to_string())
65}
66
67fn ensure_default_rustls_crypto_provider() {
69 use std::sync::Once;
70 static INSTALL: Once = Once::new();
71 INSTALL.call_once(|| {
72 if rustls::crypto::CryptoProvider::get_default().is_none() {
73 let _ = rustls::crypto::ring::default_provider().install_default();
74 assert!(
75 rustls::crypto::CryptoProvider::get_default().is_some(),
76 "failed to install default rustls crypto provider"
77 );
78 }
79 });
80}
81
82pub struct DieselDb {
86 pool: Arc<AsyncDbPool>,
87}
88
89impl DieselDb {
90 pub async fn new(database_url: &str, max_size: usize) -> Result<Self> {
92 ensure_default_rustls_crypto_provider();
93
94 let redacted = database_url
95 .find('@')
96 .map(|at| format!("postgres://***@{}", &database_url[at + 1..]))
97 .unwrap_or_else(|| "postgres://***".to_string());
98 info!("Initializing async DieselDb for PostgreSQL: {}", redacted);
99
100 let mut manager_config = diesel_async::pooled_connection::ManagerConfig::default();
101 manager_config.custom_setup = Box::new(|url| {
102 Box::pin(async move {
103 let tls_config = rustls::ClientConfig::builder_with_provider(
104 rustls::crypto::ring::default_provider().into(),
105 )
106 .with_safe_default_protocol_versions()
107 .expect("TLS protocol versions")
108 .dangerous()
109 .with_custom_certificate_verifier(Arc::new(NoVerifier))
110 .with_no_client_auth();
111 let tls = tokio_postgres_rustls::MakeRustlsConnect::new(tls_config);
112 let (client, connection) = tokio_postgres::connect(url, tls)
113 .await
114 .map_err(|e| diesel::ConnectionError::BadConnection(e.to_string()))?;
115
116 tokio::spawn(async move {
117 if let Err(e) = connection.await {
118 tracing::error!("async diesel connection error: {}", e);
119 }
120 });
121
122 let mut conn = AsyncPgConnection::try_from(client).await?;
123 diesel::sql_query("SET statement_timeout = '30000'")
124 .execute(&mut conn)
125 .await
126 .map_err(diesel::ConnectionError::CouldntSetupConfiguration)?;
127 diesel::sql_query("SET lock_timeout = '10000'")
128 .execute(&mut conn)
129 .await
130 .map_err(diesel::ConnectionError::CouldntSetupConfiguration)?;
131 Ok(conn)
132 })
133 });
134 let config = AsyncDieselConnectionManager::<AsyncPgConnection>::new_with_config(
135 database_url,
136 manager_config,
137 );
138 let pool = Pool::builder(config)
139 .max_size(max_size)
140 .wait_timeout(Some(std::time::Duration::from_secs(5)))
141 .create_timeout(Some(std::time::Duration::from_secs(5)))
142 .recycle_timeout(Some(std::time::Duration::from_secs(5)))
143 .runtime(deadpool::Runtime::Tokio1)
144 .build()
145 .map_err(|e| anyhow::anyhow!("Failed to create async connection pool: {}", e))?;
146
147 let _conn = pool
149 .get()
150 .await
151 .with_context(|| "Failed to get initial async connection")?;
152
153 Ok(Self {
154 pool: Arc::new(pool),
155 })
156 }
157
158 pub async fn new_with_auth(auth: DbAuthConfig, max_size: usize) -> Result<Self> {
164 ensure_default_rustls_crypto_provider();
165
166 let initial_url = auth.current_url();
167 let redacted = redact_url(&initial_url);
168 info!(
169 "Initializing async DieselDb (auth={:?}): {}",
170 auth, redacted
171 );
172
173 let mut manager_config = diesel_async::pooled_connection::ManagerConfig::default();
174 manager_config.custom_setup = Box::new(move |url| {
175 let fresh_url = auth.current_url();
176 Box::pin(async move {
177 let connect_url = if fresh_url.is_empty() {
178 url
179 } else {
180 &fresh_url
181 };
182
183 let tls_config = rustls::ClientConfig::builder_with_provider(
184 rustls::crypto::ring::default_provider().into(),
185 )
186 .with_safe_default_protocol_versions()
187 .expect("TLS protocol versions")
188 .dangerous()
189 .with_custom_certificate_verifier(Arc::new(NoVerifier))
190 .with_no_client_auth();
191 let tls = tokio_postgres_rustls::MakeRustlsConnect::new(tls_config);
192 let (client, connection) = tokio_postgres::connect(connect_url, tls)
193 .await
194 .map_err(|e| diesel::ConnectionError::BadConnection(e.to_string()))?;
195
196 tokio::spawn(async move {
197 if let Err(e) = connection.await {
198 tracing::error!("async diesel connection error: {}", e);
199 }
200 });
201
202 let mut conn = AsyncPgConnection::try_from(client).await?;
203 diesel::sql_query("SET statement_timeout = '30000'")
204 .execute(&mut conn)
205 .await
206 .map_err(diesel::ConnectionError::CouldntSetupConfiguration)?;
207 diesel::sql_query("SET lock_timeout = '10000'")
208 .execute(&mut conn)
209 .await
210 .map_err(diesel::ConnectionError::CouldntSetupConfiguration)?;
211 Ok(conn)
212 })
213 });
214 let config = AsyncDieselConnectionManager::<AsyncPgConnection>::new_with_config(
215 &initial_url,
216 manager_config,
217 );
218 let pool = Pool::builder(config)
219 .max_size(max_size)
220 .wait_timeout(Some(std::time::Duration::from_secs(5)))
221 .create_timeout(Some(std::time::Duration::from_secs(5)))
222 .recycle_timeout(Some(std::time::Duration::from_secs(5)))
223 .runtime(deadpool::Runtime::Tokio1)
224 .build()
225 .map_err(|e| anyhow::anyhow!("Failed to create async connection pool: {}", e))?;
226
227 let _conn = pool
228 .get()
229 .await
230 .with_context(|| "Failed to get initial async connection")?;
231
232 Ok(Self {
233 pool: Arc::new(pool),
234 })
235 }
236
237 pub async fn new_no_tls(database_url: &str, max_size: usize) -> Result<Self> {
239 let config = AsyncDieselConnectionManager::<AsyncPgConnection>::new(database_url);
240 let pool = Pool::builder(config)
241 .max_size(max_size)
242 .wait_timeout(Some(std::time::Duration::from_secs(5)))
243 .create_timeout(Some(std::time::Duration::from_secs(5)))
244 .runtime(deadpool::Runtime::Tokio1)
245 .build()
246 .map_err(|e| anyhow::anyhow!("Failed to create async pool: {}", e))?;
247 Ok(Self {
248 pool: Arc::new(pool),
249 })
250 }
251
252 pub fn with_pool(pool: Arc<AsyncDbPool>) -> Self {
254 info!("Creating DieselDb with existing async pool");
255 Self { pool }
256 }
257
258 pub async fn get_conn(
260 &self,
261 ) -> Result<diesel_async::pooled_connection::deadpool::Object<AsyncPgConnection>> {
262 self.pool
263 .get()
264 .await
265 .map_err(|e| anyhow::anyhow!("Failed to get async connection: {}", e))
266 }
267}