1use crate::rsm::MarginMode;
2use anyhow::{anyhow, Result};
3use hypercall_db::TierWriter;
4use hypercall_db::UserTierRecord as UserTier;
5use hypercall_db::UserTierUpdate;
6use hypercall_types::api_models::{TradingLimits, UserTierData};
7use hypercall_types::WalletAddress;
8use std::collections::HashMap;
9use std::sync::Arc;
10use tokio::sync::RwLock;
11use tracing::{debug, info};
12
13fn default_user_tier(wallet: WalletAddress) -> UserTierData {
14 UserTierData {
15 wallet_address: wallet,
16 tier: "tier2".to_string(),
17 }
18}
19
20const DEFAULT_MAX_OPEN_ORDERS: i32 = 100;
21const DEFAULT_MAX_OPEN_POSITIONS: i32 = 50;
22
23#[derive(Debug, Clone, Copy)]
24pub struct TierCacheConfig {
25 pub max_open_orders_default: i32,
26 pub max_open_positions_default: i32,
27}
28
29impl Default for TierCacheConfig {
30 fn default() -> Self {
31 Self {
32 max_open_orders_default: DEFAULT_MAX_OPEN_ORDERS,
33 max_open_positions_default: DEFAULT_MAX_OPEN_POSITIONS,
34 }
35 }
36}
37
38pub struct TierCache {
39 tiers: Arc<RwLock<HashMap<WalletAddress, UserTier>>>,
40 db: Arc<dyn TierWriter>,
41 config: TierCacheConfig,
42}
43
44impl TierCache {
45 pub fn new(db: Arc<dyn TierWriter>) -> Result<Self> {
46 Self::new_with_config(db, TierCacheConfig::default())
47 }
48
49 pub fn new_with_config(db: Arc<dyn TierWriter>, config: TierCacheConfig) -> Result<Self> {
50 Ok(Self {
51 tiers: Arc::new(RwLock::new(HashMap::new())),
52 db,
53 config,
54 })
55 }
56
57 pub async fn load_from_db(&self) -> Result<()> {
59 let handler = self.db.clone();
60 let tiers = tokio::task::spawn_blocking(move || handler.get_all_user_tiers_sync())
61 .await
62 .map_err(|e| anyhow::anyhow!("spawn_blocking join error: {}", e))??;
63 let mut cache = self.tiers.write().await;
64 for tier in tiers {
65 cache.insert(tier.wallet_address, tier);
66 }
67 info!("Loaded {} user tiers from database", cache.len());
68 Ok(())
69 }
70
71 pub async fn get_tier(&self, wallet: &WalletAddress) -> Option<UserTierData> {
73 let cache = self.tiers.read().await;
74 cache.get(wallet).map(|t| UserTierData {
75 wallet_address: t.wallet_address,
76 tier: t.tier.clone(),
77 })
78 }
79
80 pub async fn get_tier_record(&self, wallet: &WalletAddress) -> Result<Option<UserTier>> {
81 let handler = self.db.clone();
82 let wallet_owned = *wallet;
83 tokio::task::spawn_blocking(move || handler.get_user_tier_sync(&wallet_owned))
84 .await
85 .map_err(|e| anyhow::anyhow!("spawn_blocking join error: {}", e))?
86 }
87
88 pub async fn restore_tier_record(
89 &self,
90 wallet: &WalletAddress,
91 previous_tier: Option<&UserTier>,
92 ) -> Result<()> {
93 match previous_tier {
94 Some(previous) => self.set_tier(new_user_tier_from_existing(previous)).await,
95 None => self.delete_tier(wallet).await,
96 }
97 }
98
99 pub fn get_tier_sync(&self, wallet: &WalletAddress) -> UserTierData {
102 if let Some(tier_data) = self.try_get_from_cache(wallet) {
103 return tier_data;
104 }
105
106 default_user_tier(*wallet)
107 }
108
109 fn try_get_from_cache(&self, wallet: &WalletAddress) -> Option<UserTierData> {
110 if let Ok(cache) = self.tiers.try_read() {
112 cache.get(wallet).map(|t| UserTierData {
113 wallet_address: t.wallet_address,
114 tier: t.tier.clone(),
115 })
116 } else {
117 None
118 }
119 }
120
121 pub async fn set_tier(&self, update: UserTierUpdate) -> Result<()> {
123 debug!(
124 "Setting tier for wallet {} to {}",
125 update.wallet_address, update.tier
126 );
127
128 let wallet = update.wallet_address;
129 let tier_name = update.tier.clone();
130 let handler = self.db.clone();
131 tokio::task::spawn_blocking(move || handler.save_user_tier_sync(&update))
132 .await
133 .map_err(|e| anyhow::anyhow!("spawn_blocking join error: {}", e))??;
134
135 let handler = self.db.clone();
137 let tier = tokio::task::spawn_blocking(move || handler.get_user_tier_sync(&wallet))
138 .await
139 .map_err(|e| anyhow::anyhow!("spawn_blocking join error: {}", e))??
140 .expect("Tier should exist after save");
141
142 let mut cache = self.tiers.write().await;
143 let wallet_address = tier.wallet_address;
144 cache.insert(wallet_address, tier);
145
146 info!(
147 "Updated tier for wallet {} to {}",
148 wallet_address, tier_name
149 );
150
151 Ok(())
152 }
153
154 pub async fn delete_tier(&self, wallet: &WalletAddress) -> Result<()> {
156 debug!("Resetting tier for wallet {} to defaults", wallet);
157
158 let handler = self.db.clone();
159 let wallet_owned = *wallet;
160 tokio::task::spawn_blocking(move || handler.delete_user_tier_sync(&wallet_owned))
161 .await
162 .map_err(|e| anyhow::anyhow!("spawn_blocking join error: {}", e))??;
163
164 let handler = self.db.clone();
165 let wallet_owned = *wallet;
166 let tier = tokio::task::spawn_blocking(move || handler.get_user_tier_sync(&wallet_owned))
167 .await
168 .map_err(|e| anyhow::anyhow!("spawn_blocking join error: {}", e))??
169 .ok_or_else(|| anyhow!("missing user_tiers row after default tier reset"))?;
170
171 let mut cache = self.tiers.write().await;
172 cache.insert(*wallet, tier);
173
174 info!("Reset tier for wallet {} to default tier2", wallet);
175
176 Ok(())
177 }
178
179 pub fn is_tier1(&self, wallet: &WalletAddress) -> bool {
181 let tier_data = self.get_tier_sync(wallet);
182 tier_data.tier == "tier1"
183 }
184
185 pub fn is_tier2(&self, wallet: &WalletAddress) -> bool {
187 let tier_data = self.get_tier_sync(wallet);
188 tier_data.tier == "tier2"
189 }
190
191 pub async fn get_margin_mode(&self, wallet: &WalletAddress) -> Result<MarginMode> {
197 let cache = self.tiers.read().await;
198 Ok(cache
205 .get(wallet)
206 .map(|tier| tier.margin_mode)
207 .unwrap_or(MarginMode::Standard))
208 }
209
210 pub async fn get_existing_margin_mode(
215 &self,
216 wallet: &WalletAddress,
217 ) -> Result<Option<MarginMode>> {
218 let handler = self.db.clone();
219 let wallet_owned = *wallet;
220 tokio::task::spawn_blocking(move || handler.get_existing_margin_mode_sync(&wallet_owned))
221 .await
222 .map_err(|e| anyhow::anyhow!("spawn_blocking join error: {}", e))?
223 }
224
225 pub fn get_existing_margin_mode_sync(
227 &self,
228 wallet: &WalletAddress,
229 ) -> Result<Option<MarginMode>> {
230 let cache = self.tiers.try_read().map_err(|_| {
231 anyhow!(
232 "tier cache lock busy while reading margin mode for {}",
233 wallet
234 )
235 })?;
236 let Some(tier) = cache.get(wallet) else {
237 return Ok(None);
238 };
239 Ok(Some(tier.margin_mode))
240 }
241
242 pub fn get_margin_mode_sync(&self, wallet: &WalletAddress) -> Result<MarginMode> {
247 let cache = self.tiers.try_read().map_err(|_| {
248 anyhow!(
249 "tier cache lock busy while reading margin mode for {}",
250 wallet
251 )
252 })?;
253 Ok(cache
254 .get(wallet)
255 .map(|tier| tier.margin_mode)
256 .unwrap_or(MarginMode::Standard))
257 }
258
259 pub async fn set_margin_mode(&self, wallet: &WalletAddress, mode: MarginMode) -> Result<i64> {
264 debug!("Setting margin mode for wallet {} to {:?}", wallet, mode);
265
266 let handler = self.db.clone();
268 let wallet_owned = *wallet;
269 let new_version =
270 tokio::task::spawn_blocking(move || handler.set_margin_mode_sync(&wallet_owned, mode))
271 .await
272 .map_err(|e| anyhow::anyhow!("spawn_blocking join error: {}", e))??;
273
274 let handler = self.db.clone();
276 let wallet_owned = *wallet;
277 let tier_opt =
278 tokio::task::spawn_blocking(move || handler.get_user_tier_sync(&wallet_owned))
279 .await
280 .map_err(|e| anyhow::anyhow!("spawn_blocking join error: {}", e))??;
281 if let Some(tier) = tier_opt {
282 let mut cache = self.tiers.write().await;
285 let current_cached_version = cache.get(wallet).map(|t| t.version).unwrap_or(0);
286 if tier.version > current_cached_version {
287 cache.insert(*wallet, tier);
288 }
289 }
290
291 info!(
292 "Updated margin mode for wallet {} to {:?}, version={}",
293 wallet, mode, new_version
294 );
295
296 Ok(new_version)
297 }
298
299 pub async fn insert_margin_mode_if_missing(
304 &self,
305 wallet: &WalletAddress,
306 mode: MarginMode,
307 ) -> Result<Option<i64>> {
308 debug!(
309 "Inserting margin mode for wallet {} to {:?} if missing",
310 wallet, mode
311 );
312
313 let handler = self.db.clone();
314 let wallet_owned = *wallet;
315 let inserted_version = tokio::task::spawn_blocking(move || {
316 handler.insert_margin_mode_if_missing_sync(&wallet_owned, mode)
317 })
318 .await
319 .map_err(|e| anyhow::anyhow!("spawn_blocking join error: {}", e))??;
320
321 if inserted_version.is_some() {
322 let handler = self.db.clone();
323 let wallet_owned = *wallet;
324 let tier_opt =
325 tokio::task::spawn_blocking(move || handler.get_user_tier_sync(&wallet_owned))
326 .await
327 .map_err(|e| anyhow::anyhow!("spawn_blocking join error: {}", e))??;
328
329 if let Some(tier) = tier_opt {
330 let mut cache = self.tiers.write().await;
331 cache.insert(*wallet, tier);
332 }
333 }
334
335 Ok(inserted_version)
336 }
337
338 pub fn get_all_margin_modes_sync(&self) -> Result<HashMap<WalletAddress, MarginMode>> {
343 let cache = self
344 .tiers
345 .try_read()
346 .map_err(|_| anyhow!("tier cache lock busy during margin-mode seed"))?;
347 Ok(cache
348 .iter()
349 .map(|(wallet, tier)| (*wallet, tier.margin_mode))
350 .collect())
351 }
352
353 pub fn get_all_trading_limits_sync(
358 &self,
359 ) -> Result<HashMap<WalletAddress, TradingLimits>, String> {
360 let cache = self
361 .tiers
362 .try_read()
363 .map_err(|_| "tier cache lock busy during trading-limits seed".to_string())?;
364 Ok(cache
365 .iter()
366 .map(|(wallet, tier)| (*wallet, self.trading_limits_from_tier(tier)))
367 .collect())
368 }
369
370 pub fn get_all_tiers_sync(&self) -> Result<HashMap<WalletAddress, String>, String> {
374 let cache = self
375 .tiers
376 .try_read()
377 .map_err(|_| "tier cache lock busy during tier seed".to_string())?;
378 Ok(cache
379 .iter()
380 .map(|(wallet, tier)| (*wallet, tier.tier.clone()))
381 .collect())
382 }
383
384 pub fn is_standard_margin(&self, wallet: &WalletAddress) -> Result<bool> {
386 Ok(self.get_margin_mode_sync(wallet)?.is_standard())
387 }
388
389 pub fn is_portfolio_margin(&self, wallet: &WalletAddress) -> Result<bool> {
391 Ok(self.get_margin_mode_sync(wallet)?.is_portfolio())
392 }
393
394 pub async fn apply_margin_mode_update(
399 &self,
400 wallet: WalletAddress,
401 mode: MarginMode,
402 version: i64,
403 ) {
404 let cached_version = {
406 let cache = self.tiers.read().await;
407 cache.get(&wallet).map(|t| t.version).unwrap_or(0)
408 };
409
410 if version <= cached_version {
411 debug!(
412 "Ignoring stale tier update: wallet={}, msg_version={}, cached_version={}",
413 wallet, version, cached_version
414 );
415 return;
416 }
417
418 let handler = self.db.clone();
420 let tier =
421 match tokio::task::spawn_blocking(move || handler.get_user_tier_sync(&wallet)).await {
422 Ok(Ok(Some(t))) => t,
423 Ok(Ok(None)) => {
424 debug!("No tier found in DB for wallet {}", wallet);
425 return;
426 }
427 Ok(Err(e)) => {
428 tracing::error!("Failed to load tier from DB for {}: {}", wallet, e);
429 return;
430 }
431 Err(e) => {
432 tracing::error!("spawn_blocking failed for tier load {}: {}", wallet, e);
433 return;
434 }
435 };
436
437 let mut cache = self.tiers.write().await;
439 let current_cached_version = cache.get(&wallet).map(|t| t.version).unwrap_or(0);
440
441 if tier.version > current_cached_version {
443 debug!(
444 "Applied margin mode update from event bus: wallet={}, mode={:?}, version={}",
445 wallet, mode, tier.version
446 );
447 cache.insert(wallet, tier);
448 } else {
449 debug!(
450 "Skipping insert after re-check: db_version={}, cached_version={}",
451 tier.version, current_cached_version
452 );
453 }
454 }
455
456 pub async fn get_version(&self, wallet: &WalletAddress) -> i64 {
458 let cache = self.tiers.read().await;
459 cache.get(wallet).map(|t| t.version).unwrap_or(0)
460 }
461
462 pub fn get_trading_limits(&self, wallet: &WalletAddress) -> TradingLimits {
468 if let Ok(cache) = self.tiers.try_read() {
469 if let Some(tier) = cache.get(wallet) {
470 return self.trading_limits_from_tier(tier);
471 }
472 }
473 self.default_trading_limits()
474 }
475
476 pub async fn get_trading_limits_async(&self, wallet: &WalletAddress) -> TradingLimits {
478 let cache = self.tiers.read().await;
479 if let Some(tier) = cache.get(wallet) {
480 self.trading_limits_from_tier(tier)
481 } else {
482 self.default_trading_limits()
483 }
484 }
485
486 pub fn exceeds_open_orders_limit(&self, wallet: &WalletAddress, current_count: usize) -> bool {
490 let limits = self.get_trading_limits(wallet);
491 if limits.max_open_orders < 0 {
492 return false;
493 }
494 current_count >= limits.max_open_orders as usize
495 }
496
497 pub fn exceeds_positions_limit(&self, wallet: &WalletAddress, current_count: usize) -> bool {
501 let limits = self.get_trading_limits(wallet);
502 if limits.max_open_positions < 0 {
503 return false; }
505 current_count >= limits.max_open_positions as usize
506 }
507
508 pub fn default_trading_limits(&self) -> TradingLimits {
509 let mut limits = TradingLimits::default();
510 limits.max_open_orders = self.config.max_open_orders_default;
511 limits.max_open_positions = self.config.max_open_positions_default;
512 limits
513 }
514
515 fn trading_limits_from_tier(&self, tier: &UserTier) -> TradingLimits {
516 TradingLimits {
517 max_open_orders: tier
518 .max_open_orders
519 .unwrap_or(self.config.max_open_orders_default),
520 max_open_positions: tier.max_open_positions,
521 orders_per_minute: tier.orders_per_minute,
522 cancels_per_minute: tier.cancels_per_minute,
523 api_requests_per_minute: tier.api_requests_per_minute,
524 }
525 }
526
527 #[cfg(any(test, feature = "test-utils"))]
530 pub async fn set_margin_mode_in_memory(
531 &self,
532 wallet: &WalletAddress,
533 mode: hypercall_types::MarginMode,
534 ) {
535 let mut cache = self.tiers.write().await;
536 if let Some(tier) = cache.get_mut(wallet) {
537 tier.margin_mode = mode;
538 } else {
539 cache.insert(
540 *wallet,
541 UserTier {
542 wallet_address: *wallet,
543 tier: "tier2".to_string(),
544 margin_mode: mode,
545 version: 1,
546 max_open_orders: None,
547 max_open_positions: 50,
548 orders_per_minute: 60,
549 cancels_per_minute: 120,
550 api_requests_per_minute: 600,
551 created_at: None,
552 updated_at: None,
553 },
554 );
555 }
556 }
557}
558
559fn new_user_tier_from_existing(tier: &UserTier) -> UserTierUpdate {
560 UserTierUpdate {
561 wallet_address: tier.wallet_address,
562 tier: tier.tier.clone(),
563 margin_mode: Some(tier.margin_mode),
564 version: Some(tier.version),
565 max_open_orders: tier.max_open_orders,
566 max_open_positions: Some(tier.max_open_positions),
567 orders_per_minute: Some(tier.orders_per_minute),
568 cancels_per_minute: Some(tier.cancels_per_minute),
569 api_requests_per_minute: Some(tier.api_requests_per_minute),
570 }
571}