Skip to main content

hypercall_db_diesel/
diesel_db.rs

1//! Async DieselDb struct with deadpool-based connection pool management.
2//!
3//! This module provides the `DieselDb` struct which owns an async connection
4//! pool to PostgreSQL via `diesel-async` + `deadpool`. Domain-specific SQL
5//! lives in sibling modules (`integrity`, `liquidation`, `nonces`, `analytics`).
6
7use crate::db_auth::DbAuthConfig;
8use anyhow::{Context, Result};
9use diesel_async::pooled_connection::deadpool::Pool;
10use diesel_async::pooled_connection::AsyncDieselConnectionManager;
11use diesel_async::{AsyncPgConnection, RunQueryDsl};
12use std::sync::Arc;
13use tracing::info;
14
15/// Async connection pool type alias (deadpool + diesel-async).
16pub type AsyncDbPool = Pool<AsyncPgConnection>;
17
18/// Accepts any server certificate. Used for `sslmode=require` where the DB
19/// uses a self-signed cert (Akamai managed Postgres). This matches libpq's
20/// behavior: `require` encrypts the connection but does not verify identity.
21#[derive(Debug)]
22struct NoVerifier;
23
24impl rustls::client::danger::ServerCertVerifier for NoVerifier {
25    fn verify_server_cert(
26        &self,
27        _end_entity: &rustls::pki_types::CertificateDer<'_>,
28        _intermediates: &[rustls::pki_types::CertificateDer<'_>],
29        _server_name: &rustls::pki_types::ServerName<'_>,
30        _ocsp_response: &[u8],
31        _now: rustls::pki_types::UnixTime,
32    ) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
33        Ok(rustls::client::danger::ServerCertVerified::assertion())
34    }
35
36    fn verify_tls12_signature(
37        &self,
38        _message: &[u8],
39        _cert: &rustls::pki_types::CertificateDer<'_>,
40        _dss: &rustls::DigitallySignedStruct,
41    ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
42        Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
43    }
44
45    fn verify_tls13_signature(
46        &self,
47        _message: &[u8],
48        _cert: &rustls::pki_types::CertificateDer<'_>,
49        _dss: &rustls::DigitallySignedStruct,
50    ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
51        Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
52    }
53
54    fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
55        rustls::crypto::ring::default_provider()
56            .signature_verification_algorithms
57            .supported_schemes()
58    }
59}
60
61fn redact_url(url: &str) -> String {
62    url.find('@')
63        .map(|at| format!("postgres://***@{}", &url[at + 1..]))
64        .unwrap_or_else(|| "postgres://***".to_string())
65}
66
67/// Ensure rustls has a process-wide crypto provider before TLS clients are built.
68fn ensure_default_rustls_crypto_provider() {
69    use std::sync::Once;
70    static INSTALL: Once = Once::new();
71    INSTALL.call_once(|| {
72        if rustls::crypto::CryptoProvider::get_default().is_none() {
73            let _ = rustls::crypto::ring::default_provider().install_default();
74            assert!(
75                rustls::crypto::CryptoProvider::get_default().is_some(),
76                "failed to install default rustls crypto provider"
77            );
78        }
79    });
80}
81
82/// Async persistence handler. Owns a deadpool-managed `diesel-async` connection
83/// pool and implements all API-path traits from `hypercall_db` (analytics,
84/// integrity, liquidation, notifications, nonces, push, usernames).
85pub struct DieselDb {
86    pool: Arc<AsyncDbPool>,
87}
88
89impl DieselDb {
90    /// Build a pool with TLS (NoVerifier for `sslmode=require`) and verify connectivity.
91    pub async fn new(database_url: &str, max_size: usize) -> Result<Self> {
92        ensure_default_rustls_crypto_provider();
93
94        let redacted = database_url
95            .find('@')
96            .map(|at| format!("postgres://***@{}", &database_url[at + 1..]))
97            .unwrap_or_else(|| "postgres://***".to_string());
98        info!("Initializing async DieselDb for PostgreSQL: {}", redacted);
99
100        let mut manager_config = diesel_async::pooled_connection::ManagerConfig::default();
101        manager_config.custom_setup = Box::new(|url| {
102            Box::pin(async move {
103                let tls_config = rustls::ClientConfig::builder_with_provider(
104                    rustls::crypto::ring::default_provider().into(),
105                )
106                .with_safe_default_protocol_versions()
107                .expect("TLS protocol versions")
108                .dangerous()
109                .with_custom_certificate_verifier(Arc::new(NoVerifier))
110                .with_no_client_auth();
111                let tls = tokio_postgres_rustls::MakeRustlsConnect::new(tls_config);
112                let (client, connection) = tokio_postgres::connect(url, tls)
113                    .await
114                    .map_err(|e| diesel::ConnectionError::BadConnection(e.to_string()))?;
115
116                tokio::spawn(async move {
117                    if let Err(e) = connection.await {
118                        tracing::error!("async diesel connection error: {}", e);
119                    }
120                });
121
122                let mut conn = AsyncPgConnection::try_from(client).await?;
123                diesel::sql_query("SET statement_timeout = '30000'")
124                    .execute(&mut conn)
125                    .await
126                    .map_err(diesel::ConnectionError::CouldntSetupConfiguration)?;
127                diesel::sql_query("SET lock_timeout = '10000'")
128                    .execute(&mut conn)
129                    .await
130                    .map_err(diesel::ConnectionError::CouldntSetupConfiguration)?;
131                Ok(conn)
132            })
133        });
134        let config = AsyncDieselConnectionManager::<AsyncPgConnection>::new_with_config(
135            database_url,
136            manager_config,
137        );
138        let pool = Pool::builder(config)
139            .max_size(max_size)
140            .wait_timeout(Some(std::time::Duration::from_secs(5)))
141            .create_timeout(Some(std::time::Duration::from_secs(5)))
142            .recycle_timeout(Some(std::time::Duration::from_secs(5)))
143            .runtime(deadpool::Runtime::Tokio1)
144            .build()
145            .map_err(|e| anyhow::anyhow!("Failed to create async connection pool: {}", e))?;
146
147        // Verify connectivity
148        let _conn = pool
149            .get()
150            .await
151            .with_context(|| "Failed to get initial async connection")?;
152
153        Ok(Self {
154            pool: Arc::new(pool),
155        })
156    }
157
158    /// Build a pool using [`DbAuthConfig`] for credential management.
159    ///
160    /// In `Password` mode this behaves identically to [`Self::new`]. In `RdsIam`
161    /// mode, each new connection fetches the latest cached IAM token so
162    /// connections created after a token refresh use valid credentials.
163    pub async fn new_with_auth(auth: DbAuthConfig, max_size: usize) -> Result<Self> {
164        ensure_default_rustls_crypto_provider();
165
166        let initial_url = auth.current_url();
167        let redacted = redact_url(&initial_url);
168        info!(
169            "Initializing async DieselDb (auth={:?}): {}",
170            auth, redacted
171        );
172
173        let mut manager_config = diesel_async::pooled_connection::ManagerConfig::default();
174        manager_config.custom_setup = Box::new(move |url| {
175            let fresh_url = auth.current_url();
176            Box::pin(async move {
177                let connect_url = if fresh_url.is_empty() {
178                    url
179                } else {
180                    &fresh_url
181                };
182
183                let tls_config = rustls::ClientConfig::builder_with_provider(
184                    rustls::crypto::ring::default_provider().into(),
185                )
186                .with_safe_default_protocol_versions()
187                .expect("TLS protocol versions")
188                .dangerous()
189                .with_custom_certificate_verifier(Arc::new(NoVerifier))
190                .with_no_client_auth();
191                let tls = tokio_postgres_rustls::MakeRustlsConnect::new(tls_config);
192                let (client, connection) = tokio_postgres::connect(connect_url, tls)
193                    .await
194                    .map_err(|e| diesel::ConnectionError::BadConnection(e.to_string()))?;
195
196                tokio::spawn(async move {
197                    if let Err(e) = connection.await {
198                        tracing::error!("async diesel connection error: {}", e);
199                    }
200                });
201
202                let mut conn = AsyncPgConnection::try_from(client).await?;
203                diesel::sql_query("SET statement_timeout = '30000'")
204                    .execute(&mut conn)
205                    .await
206                    .map_err(diesel::ConnectionError::CouldntSetupConfiguration)?;
207                diesel::sql_query("SET lock_timeout = '10000'")
208                    .execute(&mut conn)
209                    .await
210                    .map_err(diesel::ConnectionError::CouldntSetupConfiguration)?;
211                Ok(conn)
212            })
213        });
214        let config = AsyncDieselConnectionManager::<AsyncPgConnection>::new_with_config(
215            &initial_url,
216            manager_config,
217        );
218        let pool = Pool::builder(config)
219            .max_size(max_size)
220            .wait_timeout(Some(std::time::Duration::from_secs(5)))
221            .create_timeout(Some(std::time::Duration::from_secs(5)))
222            .recycle_timeout(Some(std::time::Duration::from_secs(5)))
223            .runtime(deadpool::Runtime::Tokio1)
224            .build()
225            .map_err(|e| anyhow::anyhow!("Failed to create async connection pool: {}", e))?;
226
227        let _conn = pool
228            .get()
229            .await
230            .with_context(|| "Failed to get initial async connection")?;
231
232        Ok(Self {
233            pool: Arc::new(pool),
234        })
235    }
236
237    /// Build a pool without TLS. Used for local development and tests.
238    pub async fn new_no_tls(database_url: &str, max_size: usize) -> Result<Self> {
239        let config = AsyncDieselConnectionManager::<AsyncPgConnection>::new(database_url);
240        let pool = Pool::builder(config)
241            .max_size(max_size)
242            .wait_timeout(Some(std::time::Duration::from_secs(5)))
243            .create_timeout(Some(std::time::Duration::from_secs(5)))
244            .runtime(deadpool::Runtime::Tokio1)
245            .build()
246            .map_err(|e| anyhow::anyhow!("Failed to create async pool: {}", e))?;
247        Ok(Self {
248            pool: Arc::new(pool),
249        })
250    }
251
252    /// Wrap an existing async pool (no connectivity check).
253    pub fn with_pool(pool: Arc<AsyncDbPool>) -> Self {
254        info!("Creating DieselDb with existing async pool");
255        Self { pool }
256    }
257
258    /// Checkout a connection from the pool.
259    pub async fn get_conn(
260        &self,
261    ) -> Result<diesel_async::pooled_connection::deadpool::Object<AsyncPgConnection>> {
262        self.pool
263            .get()
264            .await
265            .map_err(|e| anyhow::anyhow!("Failed to get async connection: {}", e))
266    }
267}