1use anyhow::{Context, Result};
7use diesel::prelude::*;
8use diesel::RunQueryDsl;
9
10use hypercall_types::WalletAddress;
11
12use crate::database_handler::DatabaseHandler;
13
14impl hypercall_db::TierReader for DatabaseHandler {
15 fn get_margin_mode_sync(&self, wallet: &WalletAddress) -> Result<hypercall_types::MarginMode> {
16 use crate::schema::user_tiers::dsl as ut;
17
18 let mut conn = self.pool().get()?;
19 let Some(mode) = ut::user_tiers
20 .filter(ut::wallet_address.eq(wallet))
21 .select(ut::margin_mode)
22 .first::<String>(&mut conn)
23 .optional()?
24 else {
25 return Ok(hypercall_types::MarginMode::Standard);
26 };
27 let margin_mode = mode
28 .parse::<hypercall_types::MarginMode>()
29 .with_context(|| format!("Invalid margin mode '{}' for wallet {}", mode, wallet))?;
30 Ok(margin_mode)
31 }
32
33 fn get_existing_margin_mode_sync(
34 &self,
35 wallet: &WalletAddress,
36 ) -> Result<Option<hypercall_types::MarginMode>> {
37 use crate::schema::user_tiers::dsl as ut;
38
39 let mut conn = self.pool().get()?;
40 let Some(mode) = ut::user_tiers
41 .filter(ut::wallet_address.eq(wallet))
42 .select(ut::margin_mode)
43 .first::<String>(&mut conn)
44 .optional()?
45 else {
46 return Ok(None);
47 };
48 let margin_mode = mode
49 .parse::<hypercall_types::MarginMode>()
50 .with_context(|| format!("Invalid margin mode '{}' for wallet {}", mode, wallet))?;
51 Ok(Some(margin_mode))
52 }
53
54 fn get_tier_defaults_sync(
55 &self,
56 tier_name: &str,
57 ) -> Result<Option<hypercall_db::TierDefaultsRecord>> {
58 use crate::schema::tier_defaults::dsl;
59
60 let mut conn = self.pool().get()?;
61 let result = dsl::tier_defaults
62 .filter(dsl::tier.eq(tier_name))
63 .first::<crate::models::TierDefaults>(&mut conn)
64 .optional()?;
65 Ok(result.map(Into::into))
66 }
67
68 fn get_user_tier_sync(
69 &self,
70 wallet: &WalletAddress,
71 ) -> Result<Option<hypercall_db::UserTierRecord>> {
72 use crate::schema::user_tiers::dsl;
73
74 let mut conn = self.pool().get()?;
75 let result = dsl::user_tiers
76 .filter(dsl::wallet_address.eq(wallet))
77 .first::<crate::models::UserTier>(&mut conn)
78 .optional()?;
79 Ok(result.map(Into::into))
80 }
81
82 fn get_all_user_tiers_sync(&self) -> Result<Vec<hypercall_db::UserTierRecord>> {
83 use crate::schema::user_tiers::dsl;
84
85 let mut conn = self.pool().get()?;
86 let results = dsl::user_tiers.load::<crate::models::UserTier>(&mut conn)?;
87 Ok(results.into_iter().map(Into::into).collect())
88 }
89}
90
91impl hypercall_db::TierWriter for DatabaseHandler {
92 fn save_user_tier_sync(&self, update: &hypercall_db::UserTierUpdate) -> Result<()> {
93 use crate::schema::user_tiers::dsl;
94
95 let new_tier = crate::models::NewUserTier {
96 wallet_address: update.wallet_address,
97 tier: update.tier.clone(),
98 margin_mode: update.margin_mode.map(|m| m.to_string()),
99 version: update.version,
100 max_open_orders: update.max_open_orders,
101 max_open_positions: update.max_open_positions,
102 orders_per_minute: update.orders_per_minute,
103 cancels_per_minute: update.cancels_per_minute,
104 api_requests_per_minute: update.api_requests_per_minute,
105 };
106
107 let mut conn = self.pool().get()?;
108 diesel::insert_into(crate::schema::user_tiers::table)
109 .values(&new_tier)
110 .on_conflict(dsl::wallet_address)
111 .do_update()
112 .set(&new_tier)
113 .execute(&mut conn)?;
114 Ok(())
115 }
116
117 fn set_margin_mode_sync(
118 &self,
119 wallet: &WalletAddress,
120 margin_mode: hypercall_types::MarginMode,
121 ) -> Result<i64> {
122 use diesel::sql_types::{BigInt, Bytea, Text};
123
124 #[derive(diesel::QueryableByName)]
125 struct VersionRow {
126 #[diesel(sql_type = BigInt)]
127 version: i64,
128 }
129
130 let mut conn = self.pool().get()?;
131 let row: VersionRow = diesel::sql_query(
132 "INSERT INTO user_tiers (wallet_address, tier, margin_mode, version)
133 VALUES ($1, 'tier2', $2, 1)
134 ON CONFLICT (wallet_address)
135 DO UPDATE SET margin_mode = $2, version = user_tiers.version + 1, updated_at = NOW()
136 RETURNING version",
137 )
138 .bind::<Bytea, _>(wallet.as_bytes())
139 .bind::<Text, _>(margin_mode.as_str())
140 .get_result(&mut conn)?;
141
142 Ok(row.version)
143 }
144
145 fn insert_margin_mode_if_missing_sync(
146 &self,
147 wallet: &WalletAddress,
148 margin_mode: hypercall_types::MarginMode,
149 ) -> Result<Option<i64>> {
150 use diesel::sql_types::{BigInt, Bytea, Text};
151
152 #[derive(diesel::QueryableByName)]
153 struct VersionRow {
154 #[diesel(sql_type = BigInt)]
155 version: i64,
156 }
157
158 let mut conn = self.pool().get()?;
159 let row = diesel::sql_query(
160 "INSERT INTO user_tiers (wallet_address, tier, margin_mode, version)
161 VALUES ($1, 'tier2', $2, 1)
162 ON CONFLICT (wallet_address) DO NOTHING
163 RETURNING version",
164 )
165 .bind::<Bytea, _>(wallet.as_bytes())
166 .bind::<Text, _>(margin_mode.as_str())
167 .get_result::<VersionRow>(&mut conn)
168 .optional()?;
169
170 Ok(row.map(|row| row.version))
171 }
172
173 fn delete_user_tier_sync(&self, wallet: &WalletAddress) -> Result<()> {
174 use diesel::sql_types::Bytea;
175
176 let mut conn = self.pool().get()?;
177 diesel::sql_query(
178 "INSERT INTO user_tiers (
179 wallet_address,
180 tier,
181 margin_mode,
182 version,
183 max_open_orders,
184 max_open_positions,
185 orders_per_minute,
186 cancels_per_minute,
187 api_requests_per_minute
188 )
189 VALUES ($1, 'tier2', 'standard', 0, NULL, 50, 60, 120, 600)
190 ON CONFLICT (wallet_address)
191 DO UPDATE SET
192 tier = 'tier2',
193 max_open_orders = NULL,
194 max_open_positions = 50,
195 orders_per_minute = 60,
196 cancels_per_minute = 120,
197 api_requests_per_minute = 600,
198 updated_at = NOW()",
199 )
200 .bind::<Bytea, _>(wallet.as_bytes())
201 .execute(&mut conn)?;
202
203 Ok(())
204 }
205}
206
207#[cfg(test)]
208mod tests {
209 use crate::test_helpers::TestDb;
210 use hypercall_db::*;
211 use hypercall_types::wallet_address::test_wallet;
212 use hypercall_types::MarginMode;
213
214 #[tokio::test]
215 async fn tier_write_read_roundtrip() {
216 let test_db = TestDb::new().await.unwrap();
217 let db = test_db.handler.as_ref();
218 let wallet = test_wallet(1);
219
220 let tier = UserTierRecord {
221 wallet_address: wallet,
222 tier: "tier2".to_string(),
223 margin_mode: MarginMode::Standard,
224 version: 1,
225 max_open_orders: Some(100),
226 max_open_positions: 50,
227 orders_per_minute: 60,
228 cancels_per_minute: 120,
229 api_requests_per_minute: 600,
230 created_at: None,
231 updated_at: None,
232 };
233
234 db.save_user_tier_sync(&UserTierUpdate {
235 wallet_address: tier.wallet_address,
236 tier: tier.tier.clone(),
237 margin_mode: Some(tier.margin_mode.clone()),
238 version: Some(tier.version),
239 max_open_orders: tier.max_open_orders,
240 max_open_positions: Some(tier.max_open_positions),
241 orders_per_minute: Some(tier.orders_per_minute),
242 cancels_per_minute: Some(tier.cancels_per_minute),
243 api_requests_per_minute: Some(tier.api_requests_per_minute),
244 })
245 .unwrap();
246 let loaded = db.get_user_tier_sync(&wallet).unwrap().unwrap();
247 assert_eq!(loaded.wallet_address, wallet);
248 assert_eq!(loaded.tier, "tier2");
249 assert_eq!(loaded.max_open_positions, 50);
250 }
251
252 #[tokio::test]
253 async fn tier_margin_mode_roundtrip() {
254 let test_db = TestDb::new().await.unwrap();
255 let db = test_db.handler.as_ref();
256 let wallet = test_wallet(2);
257
258 db.insert_margin_mode_if_missing_sync(&wallet, MarginMode::Standard)
259 .unwrap();
260 let mode = db.get_margin_mode_sync(&wallet).unwrap();
261 assert_eq!(mode, MarginMode::Standard);
262
263 db.set_margin_mode_sync(&wallet, MarginMode::Portfolio)
264 .unwrap();
265 let mode = db.get_margin_mode_sync(&wallet).unwrap();
266 assert_eq!(mode, MarginMode::Portfolio);
267 }
268
269 #[tokio::test]
270 async fn tier_missing_margin_mode_defaults_to_standard_without_row() {
271 let test_db = TestDb::new().await.unwrap();
272 let db = test_db.handler.as_ref();
273 let wallet = test_wallet(22);
274
275 let mode = db.get_margin_mode_sync(&wallet).unwrap();
276 assert_eq!(mode, MarginMode::Standard);
277 assert_eq!(db.get_existing_margin_mode_sync(&wallet).unwrap(), None);
278 assert!(db.get_user_tier_sync(&wallet).unwrap().is_none());
279 }
280
281 #[tokio::test]
282 async fn tier_get_all_roundtrip() {
283 let test_db = TestDb::new().await.unwrap();
284 let db = test_db.handler.as_ref();
285
286 let tier_names = ["tier1", "tier2", "market_maker"];
287 for (i, tier_name) in tier_names.iter().enumerate() {
288 let tier = UserTierRecord {
289 wallet_address: test_wallet((i + 1) as u8),
290 tier: tier_name.to_string(),
291 margin_mode: MarginMode::Standard,
292 version: 1,
293 max_open_orders: Some(100),
294 max_open_positions: 50,
295 orders_per_minute: 60,
296 cancels_per_minute: 120,
297 api_requests_per_minute: 600,
298 created_at: None,
299 updated_at: None,
300 };
301 db.save_user_tier_sync(&UserTierUpdate {
302 wallet_address: tier.wallet_address,
303 tier: tier.tier.clone(),
304 margin_mode: Some(tier.margin_mode.clone()),
305 version: Some(tier.version),
306 max_open_orders: tier.max_open_orders,
307 max_open_positions: Some(tier.max_open_positions),
308 orders_per_minute: Some(tier.orders_per_minute),
309 cancels_per_minute: Some(tier.cancels_per_minute),
310 api_requests_per_minute: Some(tier.api_requests_per_minute),
311 })
312 .unwrap();
313 }
314
315 let all = db.get_all_user_tiers_sync().unwrap();
316 assert_eq!(all.len(), 3);
317 }
318
319 #[tokio::test]
320 async fn tier_delete_resets_to_defaults() {
321 let test_db = TestDb::new().await.unwrap();
322 let db = test_db.handler.as_ref();
323 let wallet = test_wallet(4);
324
325 let tier = UserTierRecord {
326 wallet_address: wallet,
327 tier: "market_maker".to_string(),
328 margin_mode: MarginMode::Portfolio,
329 version: 5,
330 max_open_orders: Some(500),
331 max_open_positions: 200,
332 orders_per_minute: 300,
333 cancels_per_minute: 600,
334 api_requests_per_minute: 3000,
335 created_at: None,
336 updated_at: None,
337 };
338 db.save_user_tier_sync(&UserTierUpdate {
339 wallet_address: tier.wallet_address,
340 tier: tier.tier.clone(),
341 margin_mode: Some(tier.margin_mode.clone()),
342 version: Some(tier.version),
343 max_open_orders: tier.max_open_orders,
344 max_open_positions: Some(tier.max_open_positions),
345 orders_per_minute: Some(tier.orders_per_minute),
346 cancels_per_minute: Some(tier.cancels_per_minute),
347 api_requests_per_minute: Some(tier.api_requests_per_minute),
348 })
349 .unwrap();
350 db.delete_user_tier_sync(&wallet).unwrap();
351
352 let loaded = db.get_user_tier_sync(&wallet).unwrap().unwrap();
353 assert_eq!(loaded.tier, "tier2");
354 }
355}