Skip to main content

hypercall_db_diesel/
tiers.rs

1//! TierReader + TierWriter implementations for DatabaseHandler.
2//!
3//! Manages per-wallet tier configuration (margin mode, rate limits,
4//! position limits) and tier defaults.
5
6use anyhow::{Context, Result};
7use diesel::prelude::*;
8use diesel::RunQueryDsl;
9
10use hypercall_types::WalletAddress;
11
12use crate::database_handler::DatabaseHandler;
13
14impl hypercall_db::TierReader for DatabaseHandler {
15    fn get_margin_mode_sync(&self, wallet: &WalletAddress) -> Result<hypercall_types::MarginMode> {
16        use crate::schema::user_tiers::dsl as ut;
17
18        let mut conn = self.pool().get()?;
19        let Some(mode) = ut::user_tiers
20            .filter(ut::wallet_address.eq(wallet))
21            .select(ut::margin_mode)
22            .first::<String>(&mut conn)
23            .optional()?
24        else {
25            return Ok(hypercall_types::MarginMode::Standard);
26        };
27        let margin_mode = mode
28            .parse::<hypercall_types::MarginMode>()
29            .with_context(|| format!("Invalid margin mode '{}' for wallet {}", mode, wallet))?;
30        Ok(margin_mode)
31    }
32
33    fn get_existing_margin_mode_sync(
34        &self,
35        wallet: &WalletAddress,
36    ) -> Result<Option<hypercall_types::MarginMode>> {
37        use crate::schema::user_tiers::dsl as ut;
38
39        let mut conn = self.pool().get()?;
40        let Some(mode) = ut::user_tiers
41            .filter(ut::wallet_address.eq(wallet))
42            .select(ut::margin_mode)
43            .first::<String>(&mut conn)
44            .optional()?
45        else {
46            return Ok(None);
47        };
48        let margin_mode = mode
49            .parse::<hypercall_types::MarginMode>()
50            .with_context(|| format!("Invalid margin mode '{}' for wallet {}", mode, wallet))?;
51        Ok(Some(margin_mode))
52    }
53
54    fn get_tier_defaults_sync(
55        &self,
56        tier_name: &str,
57    ) -> Result<Option<hypercall_db::TierDefaultsRecord>> {
58        use crate::schema::tier_defaults::dsl;
59
60        let mut conn = self.pool().get()?;
61        let result = dsl::tier_defaults
62            .filter(dsl::tier.eq(tier_name))
63            .first::<crate::models::TierDefaults>(&mut conn)
64            .optional()?;
65        Ok(result.map(Into::into))
66    }
67
68    fn get_user_tier_sync(
69        &self,
70        wallet: &WalletAddress,
71    ) -> Result<Option<hypercall_db::UserTierRecord>> {
72        use crate::schema::user_tiers::dsl;
73
74        let mut conn = self.pool().get()?;
75        let result = dsl::user_tiers
76            .filter(dsl::wallet_address.eq(wallet))
77            .first::<crate::models::UserTier>(&mut conn)
78            .optional()?;
79        Ok(result.map(Into::into))
80    }
81
82    fn get_all_user_tiers_sync(&self) -> Result<Vec<hypercall_db::UserTierRecord>> {
83        use crate::schema::user_tiers::dsl;
84
85        let mut conn = self.pool().get()?;
86        let results = dsl::user_tiers.load::<crate::models::UserTier>(&mut conn)?;
87        Ok(results.into_iter().map(Into::into).collect())
88    }
89}
90
91impl hypercall_db::TierWriter for DatabaseHandler {
92    fn save_user_tier_sync(&self, update: &hypercall_db::UserTierUpdate) -> Result<()> {
93        use crate::schema::user_tiers::dsl;
94
95        let new_tier = crate::models::NewUserTier {
96            wallet_address: update.wallet_address,
97            tier: update.tier.clone(),
98            margin_mode: update.margin_mode.map(|m| m.to_string()),
99            version: update.version,
100            max_open_orders: update.max_open_orders,
101            max_open_positions: update.max_open_positions,
102            orders_per_minute: update.orders_per_minute,
103            cancels_per_minute: update.cancels_per_minute,
104            api_requests_per_minute: update.api_requests_per_minute,
105        };
106
107        let mut conn = self.pool().get()?;
108        diesel::insert_into(crate::schema::user_tiers::table)
109            .values(&new_tier)
110            .on_conflict(dsl::wallet_address)
111            .do_update()
112            .set(&new_tier)
113            .execute(&mut conn)?;
114        Ok(())
115    }
116
117    fn set_margin_mode_sync(
118        &self,
119        wallet: &WalletAddress,
120        margin_mode: hypercall_types::MarginMode,
121    ) -> Result<i64> {
122        use diesel::sql_types::{BigInt, Bytea, Text};
123
124        #[derive(diesel::QueryableByName)]
125        struct VersionRow {
126            #[diesel(sql_type = BigInt)]
127            version: i64,
128        }
129
130        let mut conn = self.pool().get()?;
131        let row: VersionRow = diesel::sql_query(
132            "INSERT INTO user_tiers (wallet_address, tier, margin_mode, version)
133             VALUES ($1, 'tier2', $2, 1)
134             ON CONFLICT (wallet_address)
135             DO UPDATE SET margin_mode = $2, version = user_tiers.version + 1, updated_at = NOW()
136             RETURNING version",
137        )
138        .bind::<Bytea, _>(wallet.as_bytes())
139        .bind::<Text, _>(margin_mode.as_str())
140        .get_result(&mut conn)?;
141
142        Ok(row.version)
143    }
144
145    fn insert_margin_mode_if_missing_sync(
146        &self,
147        wallet: &WalletAddress,
148        margin_mode: hypercall_types::MarginMode,
149    ) -> Result<Option<i64>> {
150        use diesel::sql_types::{BigInt, Bytea, Text};
151
152        #[derive(diesel::QueryableByName)]
153        struct VersionRow {
154            #[diesel(sql_type = BigInt)]
155            version: i64,
156        }
157
158        let mut conn = self.pool().get()?;
159        let row = diesel::sql_query(
160            "INSERT INTO user_tiers (wallet_address, tier, margin_mode, version)
161             VALUES ($1, 'tier2', $2, 1)
162             ON CONFLICT (wallet_address) DO NOTHING
163             RETURNING version",
164        )
165        .bind::<Bytea, _>(wallet.as_bytes())
166        .bind::<Text, _>(margin_mode.as_str())
167        .get_result::<VersionRow>(&mut conn)
168        .optional()?;
169
170        Ok(row.map(|row| row.version))
171    }
172
173    fn delete_user_tier_sync(&self, wallet: &WalletAddress) -> Result<()> {
174        use diesel::sql_types::Bytea;
175
176        let mut conn = self.pool().get()?;
177        diesel::sql_query(
178            "INSERT INTO user_tiers (
179                 wallet_address,
180                 tier,
181                 margin_mode,
182                 version,
183                 max_open_orders,
184                 max_open_positions,
185                 orders_per_minute,
186                 cancels_per_minute,
187                 api_requests_per_minute
188             )
189             VALUES ($1, 'tier2', 'standard', 0, NULL, 50, 60, 120, 600)
190             ON CONFLICT (wallet_address)
191             DO UPDATE SET
192                 tier = 'tier2',
193                 max_open_orders = NULL,
194                 max_open_positions = 50,
195                 orders_per_minute = 60,
196                 cancels_per_minute = 120,
197                 api_requests_per_minute = 600,
198                 updated_at = NOW()",
199        )
200        .bind::<Bytea, _>(wallet.as_bytes())
201        .execute(&mut conn)?;
202
203        Ok(())
204    }
205}
206
207#[cfg(test)]
208mod tests {
209    use crate::test_helpers::TestDb;
210    use hypercall_db::*;
211    use hypercall_types::wallet_address::test_wallet;
212    use hypercall_types::MarginMode;
213
214    #[tokio::test]
215    async fn tier_write_read_roundtrip() {
216        let test_db = TestDb::new().await.unwrap();
217        let db = test_db.handler.as_ref();
218        let wallet = test_wallet(1);
219
220        let tier = UserTierRecord {
221            wallet_address: wallet,
222            tier: "tier2".to_string(),
223            margin_mode: MarginMode::Standard,
224            version: 1,
225            max_open_orders: Some(100),
226            max_open_positions: 50,
227            orders_per_minute: 60,
228            cancels_per_minute: 120,
229            api_requests_per_minute: 600,
230            created_at: None,
231            updated_at: None,
232        };
233
234        db.save_user_tier_sync(&UserTierUpdate {
235            wallet_address: tier.wallet_address,
236            tier: tier.tier.clone(),
237            margin_mode: Some(tier.margin_mode.clone()),
238            version: Some(tier.version),
239            max_open_orders: tier.max_open_orders,
240            max_open_positions: Some(tier.max_open_positions),
241            orders_per_minute: Some(tier.orders_per_minute),
242            cancels_per_minute: Some(tier.cancels_per_minute),
243            api_requests_per_minute: Some(tier.api_requests_per_minute),
244        })
245        .unwrap();
246        let loaded = db.get_user_tier_sync(&wallet).unwrap().unwrap();
247        assert_eq!(loaded.wallet_address, wallet);
248        assert_eq!(loaded.tier, "tier2");
249        assert_eq!(loaded.max_open_positions, 50);
250    }
251
252    #[tokio::test]
253    async fn tier_margin_mode_roundtrip() {
254        let test_db = TestDb::new().await.unwrap();
255        let db = test_db.handler.as_ref();
256        let wallet = test_wallet(2);
257
258        db.insert_margin_mode_if_missing_sync(&wallet, MarginMode::Standard)
259            .unwrap();
260        let mode = db.get_margin_mode_sync(&wallet).unwrap();
261        assert_eq!(mode, MarginMode::Standard);
262
263        db.set_margin_mode_sync(&wallet, MarginMode::Portfolio)
264            .unwrap();
265        let mode = db.get_margin_mode_sync(&wallet).unwrap();
266        assert_eq!(mode, MarginMode::Portfolio);
267    }
268
269    #[tokio::test]
270    async fn tier_missing_margin_mode_defaults_to_standard_without_row() {
271        let test_db = TestDb::new().await.unwrap();
272        let db = test_db.handler.as_ref();
273        let wallet = test_wallet(22);
274
275        let mode = db.get_margin_mode_sync(&wallet).unwrap();
276        assert_eq!(mode, MarginMode::Standard);
277        assert_eq!(db.get_existing_margin_mode_sync(&wallet).unwrap(), None);
278        assert!(db.get_user_tier_sync(&wallet).unwrap().is_none());
279    }
280
281    #[tokio::test]
282    async fn tier_get_all_roundtrip() {
283        let test_db = TestDb::new().await.unwrap();
284        let db = test_db.handler.as_ref();
285
286        let tier_names = ["tier1", "tier2", "market_maker"];
287        for (i, tier_name) in tier_names.iter().enumerate() {
288            let tier = UserTierRecord {
289                wallet_address: test_wallet((i + 1) as u8),
290                tier: tier_name.to_string(),
291                margin_mode: MarginMode::Standard,
292                version: 1,
293                max_open_orders: Some(100),
294                max_open_positions: 50,
295                orders_per_minute: 60,
296                cancels_per_minute: 120,
297                api_requests_per_minute: 600,
298                created_at: None,
299                updated_at: None,
300            };
301            db.save_user_tier_sync(&UserTierUpdate {
302                wallet_address: tier.wallet_address,
303                tier: tier.tier.clone(),
304                margin_mode: Some(tier.margin_mode.clone()),
305                version: Some(tier.version),
306                max_open_orders: tier.max_open_orders,
307                max_open_positions: Some(tier.max_open_positions),
308                orders_per_minute: Some(tier.orders_per_minute),
309                cancels_per_minute: Some(tier.cancels_per_minute),
310                api_requests_per_minute: Some(tier.api_requests_per_minute),
311            })
312            .unwrap();
313        }
314
315        let all = db.get_all_user_tiers_sync().unwrap();
316        assert_eq!(all.len(), 3);
317    }
318
319    #[tokio::test]
320    async fn tier_delete_resets_to_defaults() {
321        let test_db = TestDb::new().await.unwrap();
322        let db = test_db.handler.as_ref();
323        let wallet = test_wallet(4);
324
325        let tier = UserTierRecord {
326            wallet_address: wallet,
327            tier: "market_maker".to_string(),
328            margin_mode: MarginMode::Portfolio,
329            version: 5,
330            max_open_orders: Some(500),
331            max_open_positions: 200,
332            orders_per_minute: 300,
333            cancels_per_minute: 600,
334            api_requests_per_minute: 3000,
335            created_at: None,
336            updated_at: None,
337        };
338        db.save_user_tier_sync(&UserTierUpdate {
339            wallet_address: tier.wallet_address,
340            tier: tier.tier.clone(),
341            margin_mode: Some(tier.margin_mode.clone()),
342            version: Some(tier.version),
343            max_open_orders: tier.max_open_orders,
344            max_open_positions: Some(tier.max_open_positions),
345            orders_per_minute: Some(tier.orders_per_minute),
346            cancels_per_minute: Some(tier.cancels_per_minute),
347            api_requests_per_minute: Some(tier.api_requests_per_minute),
348        })
349        .unwrap();
350        db.delete_user_tier_sync(&wallet).unwrap();
351
352        let loaded = db.get_user_tier_sync(&wallet).unwrap().unwrap();
353        assert_eq!(loaded.tier, "tier2");
354    }
355}