hypercall_db_diesel/
transaction.rs1use anyhow::Result;
7use diesel::connection::SimpleConnection;
8use diesel::r2d2;
9
10use crate::database_handler::{DatabaseHandler, DynConnectionManager};
11
12pub struct DieselTransaction {
15 conn: Option<r2d2::PooledConnection<DynConnectionManager>>,
16 committed: bool,
17}
18
19impl DieselTransaction {
20 pub(crate) fn new(mut conn: r2d2::PooledConnection<DynConnectionManager>) -> Result<Self> {
21 conn.batch_execute("BEGIN")?;
22 Ok(Self {
23 conn: Some(conn),
24 committed: false,
25 })
26 }
27}
28
29impl Drop for DieselTransaction {
30 fn drop(&mut self) {
31 if !self.committed {
32 if let Some(ref mut conn) = self.conn {
33 let _ = conn.batch_execute("ROLLBACK");
34 }
35 }
36 }
37}
38
39impl hypercall_db::Transaction for DieselTransaction {
40 fn commit(mut self) -> Result<()> {
41 if let Some(ref mut conn) = self.conn {
42 conn.batch_execute("COMMIT")?;
43 }
44 self.committed = true;
45 Ok(())
46 }
47
48 fn rollback(mut self) -> Result<()> {
49 if let Some(ref mut conn) = self.conn {
50 conn.batch_execute("ROLLBACK")?;
51 }
52 self.committed = true; Ok(())
54 }
55}
56
57impl hypercall_db::Transactional for DatabaseHandler {
58 type Tx = DieselTransaction;
59
60 fn begin_transaction(&self) -> Result<DieselTransaction> {
61 let conn = self.pool().get()?;
62 DieselTransaction::new(conn)
63 }
64}
65
66#[cfg(test)]
67mod tests {
68 use crate::test_helpers::TestDb;
69 use hypercall_db::*;
70 use hypercall_types::wallet_address::test_wallet;
71 use hypercall_types::MarginMode;
72
73 #[tokio::test]
74 async fn transaction_commit_persists() {
75 let test_db = TestDb::new().await.unwrap();
76 let db = test_db.handler.as_ref();
77 let wallet = test_wallet(8);
78
79 {
80 let tx = db.begin_transaction().unwrap();
81 tx.commit().unwrap();
83 }
84
85 db.insert_margin_mode_if_missing_sync(&wallet, MarginMode::Standard)
87 .unwrap();
88 let mode = db.get_margin_mode_sync(&wallet).unwrap();
89 assert_eq!(mode, MarginMode::Standard);
90 }
91
92 #[tokio::test]
93 async fn transaction_rollback_on_drop() {
94 let test_db = TestDb::new().await.unwrap();
95 let db = test_db.handler.as_ref();
96
97 {
98 let _tx = db.begin_transaction().unwrap();
99 }
101
102 let all = db.get_all_user_tiers_sync().unwrap();
104 assert!(all.is_empty());
105 }
106}