Skip to main content

hypercall/read_cache/
tier.rs

1use crate::rsm::MarginMode;
2use anyhow::{anyhow, Result};
3use hypercall_db::TierWriter;
4use hypercall_db::UserTierRecord as UserTier;
5use hypercall_db::UserTierUpdate;
6use hypercall_types::api_models::{TradingLimits, UserTierData};
7use hypercall_types::WalletAddress;
8use std::collections::HashMap;
9use std::sync::Arc;
10use tokio::sync::RwLock;
11use tracing::{debug, info};
12
13fn default_user_tier(wallet: WalletAddress) -> UserTierData {
14    UserTierData {
15        wallet_address: wallet,
16        tier: "tier2".to_string(),
17    }
18}
19
20const DEFAULT_MAX_OPEN_ORDERS: i32 = 100;
21const DEFAULT_MAX_OPEN_POSITIONS: i32 = 50;
22
23#[derive(Debug, Clone, Copy)]
24pub struct TierCacheConfig {
25    pub max_open_orders_default: i32,
26    pub max_open_positions_default: i32,
27}
28
29impl Default for TierCacheConfig {
30    fn default() -> Self {
31        Self {
32            max_open_orders_default: DEFAULT_MAX_OPEN_ORDERS,
33            max_open_positions_default: DEFAULT_MAX_OPEN_POSITIONS,
34        }
35    }
36}
37
38pub struct TierCache {
39    tiers: Arc<RwLock<HashMap<WalletAddress, UserTier>>>,
40    db: Arc<dyn TierWriter>,
41    config: TierCacheConfig,
42}
43
44impl TierCache {
45    pub fn new(db: Arc<dyn TierWriter>) -> Result<Self> {
46        Self::new_with_config(db, TierCacheConfig::default())
47    }
48
49    pub fn new_with_config(db: Arc<dyn TierWriter>, config: TierCacheConfig) -> Result<Self> {
50        Ok(Self {
51            tiers: Arc::new(RwLock::new(HashMap::new())),
52            db,
53            config,
54        })
55    }
56
57    /// Load all tier configurations from database
58    pub async fn load_from_db(&self) -> Result<()> {
59        let handler = self.db.clone();
60        let tiers = tokio::task::spawn_blocking(move || handler.get_all_user_tiers_sync())
61            .await
62            .map_err(|e| anyhow::anyhow!("spawn_blocking join error: {}", e))??;
63        let mut cache = self.tiers.write().await;
64        for tier in tiers {
65            cache.insert(tier.wallet_address, tier);
66        }
67        info!("Loaded {} user tiers from database", cache.len());
68        Ok(())
69    }
70
71    /// Get tier for wallet (returns default if not found)
72    pub async fn get_tier(&self, wallet: &WalletAddress) -> Option<UserTierData> {
73        let cache = self.tiers.read().await;
74        cache.get(wallet).map(|t| UserTierData {
75            wallet_address: t.wallet_address,
76            tier: t.tier.clone(),
77        })
78    }
79
80    pub async fn get_tier_record(&self, wallet: &WalletAddress) -> Result<Option<UserTier>> {
81        let handler = self.db.clone();
82        let wallet_owned = *wallet;
83        tokio::task::spawn_blocking(move || handler.get_user_tier_sync(&wallet_owned))
84            .await
85            .map_err(|e| anyhow::anyhow!("spawn_blocking join error: {}", e))?
86    }
87
88    pub async fn restore_tier_record(
89        &self,
90        wallet: &WalletAddress,
91        previous_tier: Option<&UserTier>,
92    ) -> Result<()> {
93        match previous_tier {
94            Some(previous) => self.set_tier(new_user_tier_from_existing(previous)).await,
95            None => self.delete_tier(wallet).await,
96        }
97    }
98
99    /// Get tier synchronously (for use in unified_engine)
100    /// This is non-blocking and returns default if cache is locked or tier not found
101    pub fn get_tier_sync(&self, wallet: &WalletAddress) -> UserTierData {
102        if let Some(tier_data) = self.try_get_from_cache(wallet) {
103            return tier_data;
104        }
105
106        default_user_tier(*wallet)
107    }
108
109    fn try_get_from_cache(&self, wallet: &WalletAddress) -> Option<UserTierData> {
110        // Use try_read to avoid blocking
111        if let Ok(cache) = self.tiers.try_read() {
112            cache.get(wallet).map(|t| UserTierData {
113                wallet_address: t.wallet_address,
114                tier: t.tier.clone(),
115            })
116        } else {
117            None
118        }
119    }
120
121    /// Set or update tier
122    pub async fn set_tier(&self, update: UserTierUpdate) -> Result<()> {
123        debug!(
124            "Setting tier for wallet {} to {}",
125            update.wallet_address, update.tier
126        );
127
128        let wallet = update.wallet_address;
129        let tier_name = update.tier.clone();
130        let handler = self.db.clone();
131        tokio::task::spawn_blocking(move || handler.save_user_tier_sync(&update))
132            .await
133            .map_err(|e| anyhow::anyhow!("spawn_blocking join error: {}", e))??;
134
135        // Update cache
136        let handler = self.db.clone();
137        let tier = tokio::task::spawn_blocking(move || handler.get_user_tier_sync(&wallet))
138            .await
139            .map_err(|e| anyhow::anyhow!("spawn_blocking join error: {}", e))??
140            .expect("Tier should exist after save");
141
142        let mut cache = self.tiers.write().await;
143        let wallet_address = tier.wallet_address;
144        cache.insert(wallet_address, tier);
145
146        info!(
147            "Updated tier for wallet {} to {}",
148            wallet_address, tier_name
149        );
150
151        Ok(())
152    }
153
154    /// Reset tier limits to defaults while preserving the explicit margin mode row.
155    pub async fn delete_tier(&self, wallet: &WalletAddress) -> Result<()> {
156        debug!("Resetting tier for wallet {} to defaults", wallet);
157
158        let handler = self.db.clone();
159        let wallet_owned = *wallet;
160        tokio::task::spawn_blocking(move || handler.delete_user_tier_sync(&wallet_owned))
161            .await
162            .map_err(|e| anyhow::anyhow!("spawn_blocking join error: {}", e))??;
163
164        let handler = self.db.clone();
165        let wallet_owned = *wallet;
166        let tier = tokio::task::spawn_blocking(move || handler.get_user_tier_sync(&wallet_owned))
167            .await
168            .map_err(|e| anyhow::anyhow!("spawn_blocking join error: {}", e))??
169            .ok_or_else(|| anyhow!("missing user_tiers row after default tier reset"))?;
170
171        let mut cache = self.tiers.write().await;
172        cache.insert(*wallet, tier);
173
174        info!("Reset tier for wallet {} to default tier2", wallet);
175
176        Ok(())
177    }
178
179    /// Check if wallet is tier1 (restricted to long-only)
180    pub fn is_tier1(&self, wallet: &WalletAddress) -> bool {
181        let tier_data = self.get_tier_sync(wallet);
182        tier_data.tier == "tier1"
183    }
184
185    /// Check if wallet is tier2 (unrestricted)
186    pub fn is_tier2(&self, wallet: &WalletAddress) -> bool {
187        let tier_data = self.get_tier_sync(wallet);
188        tier_data.tier == "tier2"
189    }
190
191    // === Margin Mode Methods ===
192
193    /// Get margin mode for wallet.
194    ///
195    /// This is the async version for use in handlers.
196    pub async fn get_margin_mode(&self, wallet: &WalletAddress) -> Result<MarginMode> {
197        let cache = self.tiers.read().await;
198        // Missing user_tiers rows intentionally mean Standard margin. That is a
199        // product default, not a fabricated financial input: Standard is the
200        // most restrictive supported regime, while Portfolio margin is an
201        // explicit opt-in stored in user_tiers and replayed into the engine.
202        // Persisted rows must still parse successfully, because a corrupt
203        // explicit margin mode would route risk checks through the wrong model.
204        Ok(cache
205            .get(wallet)
206            .map(|tier| tier.margin_mode)
207            .unwrap_or(MarginMode::Standard))
208    }
209
210    /// Get the existing margin mode for a wallet without fabricating a default.
211    ///
212    /// Returns `Ok(None)` only when no tier row exists. Invalid stored values
213    /// still fail closed because the persisted margin mode is corrupted.
214    pub async fn get_existing_margin_mode(
215        &self,
216        wallet: &WalletAddress,
217    ) -> Result<Option<MarginMode>> {
218        let handler = self.db.clone();
219        let wallet_owned = *wallet;
220        tokio::task::spawn_blocking(move || handler.get_existing_margin_mode_sync(&wallet_owned))
221            .await
222            .map_err(|e| anyhow::anyhow!("spawn_blocking join error: {}", e))?
223    }
224
225    /// Get an existing margin mode synchronously without fabricating a default.
226    pub fn get_existing_margin_mode_sync(
227        &self,
228        wallet: &WalletAddress,
229    ) -> Result<Option<MarginMode>> {
230        let cache = self.tiers.try_read().map_err(|_| {
231            anyhow!(
232                "tier cache lock busy while reading margin mode for {}",
233                wallet
234            )
235        })?;
236        let Some(tier) = cache.get(wallet) else {
237            return Ok(None);
238        };
239        Ok(Some(tier.margin_mode))
240    }
241
242    /// Get margin mode synchronously (for use in unified_engine).
243    ///
244    /// Non-blocking and fail-closed if the cache is locked. Missing rows
245    /// intentionally resolve to Standard; see `get_margin_mode`.
246    pub fn get_margin_mode_sync(&self, wallet: &WalletAddress) -> Result<MarginMode> {
247        let cache = self.tiers.try_read().map_err(|_| {
248            anyhow!(
249                "tier cache lock busy while reading margin mode for {}",
250                wallet
251            )
252        })?;
253        Ok(cache
254            .get(wallet)
255            .map(|tier| tier.margin_mode)
256            .unwrap_or(MarginMode::Standard))
257    }
258
259    /// Set margin mode for a wallet.
260    ///
261    /// Creates the tier entry if it doesn't exist.
262    /// Returns the new version for use in event messages.
263    pub async fn set_margin_mode(&self, wallet: &WalletAddress, mode: MarginMode) -> Result<i64> {
264        debug!("Setting margin mode for wallet {} to {:?}", wallet, mode);
265
266        // Atomically update margin mode and increment version in DB
267        let handler = self.db.clone();
268        let wallet_owned = *wallet;
269        let new_version =
270            tokio::task::spawn_blocking(move || handler.set_margin_mode_sync(&wallet_owned, mode))
271                .await
272                .map_err(|e| anyhow::anyhow!("spawn_blocking join error: {}", e))??;
273
274        // Reload from DB (outside lock to avoid holding lock during I/O)
275        let handler = self.db.clone();
276        let wallet_owned = *wallet;
277        let tier_opt =
278            tokio::task::spawn_blocking(move || handler.get_user_tier_sync(&wallet_owned))
279                .await
280                .map_err(|e| anyhow::anyhow!("spawn_blocking join error: {}", e))??;
281        if let Some(tier) = tier_opt {
282            // Only insert if this is still the newest version
283            // (a message from another process could have arrived with a higher version)
284            let mut cache = self.tiers.write().await;
285            let current_cached_version = cache.get(wallet).map(|t| t.version).unwrap_or(0);
286            if tier.version > current_cached_version {
287                cache.insert(*wallet, tier);
288            }
289        }
290
291        info!(
292            "Updated margin mode for wallet {} to {:?}, version={}",
293            wallet, mode, new_version
294        );
295
296        Ok(new_version)
297    }
298
299    /// Create a margin mode row if one does not already exist.
300    ///
301    /// Returns `Some(version)` when a row was inserted. Existing rows are left
302    /// unchanged so callers can safely use this for bootstrap paths.
303    pub async fn insert_margin_mode_if_missing(
304        &self,
305        wallet: &WalletAddress,
306        mode: MarginMode,
307    ) -> Result<Option<i64>> {
308        debug!(
309            "Inserting margin mode for wallet {} to {:?} if missing",
310            wallet, mode
311        );
312
313        let handler = self.db.clone();
314        let wallet_owned = *wallet;
315        let inserted_version = tokio::task::spawn_blocking(move || {
316            handler.insert_margin_mode_if_missing_sync(&wallet_owned, mode)
317        })
318        .await
319        .map_err(|e| anyhow::anyhow!("spawn_blocking join error: {}", e))??;
320
321        if inserted_version.is_some() {
322            let handler = self.db.clone();
323            let wallet_owned = *wallet;
324            let tier_opt =
325                tokio::task::spawn_blocking(move || handler.get_user_tier_sync(&wallet_owned))
326                    .await
327                    .map_err(|e| anyhow::anyhow!("spawn_blocking join error: {}", e))??;
328
329            if let Some(tier) = tier_opt {
330                let mut cache = self.tiers.write().await;
331                cache.insert(*wallet, tier);
332            }
333        }
334
335        Ok(inserted_version)
336    }
337
338    /// Extract all wallet margin modes from the cache.
339    ///
340    /// Used at startup to seed engine-owned wallet_margin_modes so the engine
341    /// never needs the fallback try_read() path.
342    pub fn get_all_margin_modes_sync(&self) -> Result<HashMap<WalletAddress, MarginMode>> {
343        let cache = self
344            .tiers
345            .try_read()
346            .map_err(|_| anyhow!("tier cache lock busy during margin-mode seed"))?;
347        Ok(cache
348            .iter()
349            .map(|(wallet, tier)| (*wallet, tier.margin_mode))
350            .collect())
351    }
352
353    /// Extract all wallet trading limits from the cache.
354    ///
355    /// Returns an error if the lock cannot be acquired (e.g. contention during
356    /// engine seed). Callers must not silently default to empty.
357    pub fn get_all_trading_limits_sync(
358        &self,
359    ) -> Result<HashMap<WalletAddress, TradingLimits>, String> {
360        let cache = self
361            .tiers
362            .try_read()
363            .map_err(|_| "tier cache lock busy during trading-limits seed".to_string())?;
364        Ok(cache
365            .iter()
366            .map(|(wallet, tier)| (*wallet, self.trading_limits_from_tier(tier)))
367            .collect())
368    }
369
370    /// Extract all wallet tier strings from the cache.
371    ///
372    /// Returns an error if the lock cannot be acquired.
373    pub fn get_all_tiers_sync(&self) -> Result<HashMap<WalletAddress, String>, String> {
374        let cache = self
375            .tiers
376            .try_read()
377            .map_err(|_| "tier cache lock busy during tier seed".to_string())?;
378        Ok(cache
379            .iter()
380            .map(|(wallet, tier)| (*wallet, tier.tier.clone()))
381            .collect())
382    }
383
384    /// Check if wallet is in Standard margin mode.
385    pub fn is_standard_margin(&self, wallet: &WalletAddress) -> Result<bool> {
386        Ok(self.get_margin_mode_sync(wallet)?.is_standard())
387    }
388
389    /// Check if wallet is in Portfolio margin mode.
390    pub fn is_portfolio_margin(&self, wallet: &WalletAddress) -> Result<bool> {
391        Ok(self.get_margin_mode_sync(wallet)?.is_portfolio())
392    }
393
394    /// Apply a margin mode update from the event bus (for cross-process cache sync).
395    ///
396    /// This is called when we receive a TierUpdate message.
397    /// Only applies the update if the version is newer than the cached version.
398    pub async fn apply_margin_mode_update(
399        &self,
400        wallet: WalletAddress,
401        mode: MarginMode,
402        version: i64,
403    ) {
404        // Early check with read lock to avoid unnecessary DB reads for stale messages
405        let cached_version = {
406            let cache = self.tiers.read().await;
407            cache.get(&wallet).map(|t| t.version).unwrap_or(0)
408        };
409
410        if version <= cached_version {
411            debug!(
412                "Ignoring stale tier update: wallet={}, msg_version={}, cached_version={}",
413                wallet, version, cached_version
414            );
415            return;
416        }
417
418        // Fetch from DB (outside lock to avoid holding lock during I/O)
419        let handler = self.db.clone();
420        let tier =
421            match tokio::task::spawn_blocking(move || handler.get_user_tier_sync(&wallet)).await {
422                Ok(Ok(Some(t))) => t,
423                Ok(Ok(None)) => {
424                    debug!("No tier found in DB for wallet {}", wallet);
425                    return;
426                }
427                Ok(Err(e)) => {
428                    tracing::error!("Failed to load tier from DB for {}: {}", wallet, e);
429                    return;
430                }
431                Err(e) => {
432                    tracing::error!("spawn_blocking failed for tier load {}: {}", wallet, e);
433                    return;
434                }
435            };
436
437        // Re-check version under write lock to prevent TOCTOU race
438        let mut cache = self.tiers.write().await;
439        let current_cached_version = cache.get(&wallet).map(|t| t.version).unwrap_or(0);
440
441        // Only insert if still newer (another thread may have updated in the meantime)
442        if tier.version > current_cached_version {
443            debug!(
444                "Applied margin mode update from event bus: wallet={}, mode={:?}, version={}",
445                wallet, mode, tier.version
446            );
447            cache.insert(wallet, tier);
448        } else {
449            debug!(
450                "Skipping insert after re-check: db_version={}, cached_version={}",
451                tier.version, current_cached_version
452            );
453        }
454    }
455
456    /// Get the current version for a wallet's tier configuration.
457    pub async fn get_version(&self, wallet: &WalletAddress) -> i64 {
458        let cache = self.tiers.read().await;
459        cache.get(wallet).map(|t| t.version).unwrap_or(0)
460    }
461
462    // === Trading Limits Methods ===
463
464    /// Get trading limits for a wallet.
465    ///
466    /// Returns the wallet's configured limits, or defaults if not found.
467    pub fn get_trading_limits(&self, wallet: &WalletAddress) -> TradingLimits {
468        if let Ok(cache) = self.tiers.try_read() {
469            if let Some(tier) = cache.get(wallet) {
470                return self.trading_limits_from_tier(tier);
471            }
472        }
473        self.default_trading_limits()
474    }
475
476    /// Get trading limits asynchronously.
477    pub async fn get_trading_limits_async(&self, wallet: &WalletAddress) -> TradingLimits {
478        let cache = self.tiers.read().await;
479        if let Some(tier) = cache.get(wallet) {
480            self.trading_limits_from_tier(tier)
481        } else {
482            self.default_trading_limits()
483        }
484    }
485
486    /// Check if the number of open orders exceeds the limit.
487    /// Returns true if the limit is exceeded.
488    /// Returns false if limit is negative (unlimited).
489    pub fn exceeds_open_orders_limit(&self, wallet: &WalletAddress, current_count: usize) -> bool {
490        let limits = self.get_trading_limits(wallet);
491        if limits.max_open_orders < 0 {
492            return false;
493        }
494        current_count >= limits.max_open_orders as usize
495    }
496
497    /// Check if the number of positions exceeds the limit.
498    /// Returns true if the limit is exceeded.
499    /// Returns false if limit is -1 (unlimited).
500    pub fn exceeds_positions_limit(&self, wallet: &WalletAddress, current_count: usize) -> bool {
501        let limits = self.get_trading_limits(wallet);
502        if limits.max_open_positions < 0 {
503            return false; // -1 means unlimited
504        }
505        current_count >= limits.max_open_positions as usize
506    }
507
508    pub fn default_trading_limits(&self) -> TradingLimits {
509        let mut limits = TradingLimits::default();
510        limits.max_open_orders = self.config.max_open_orders_default;
511        limits.max_open_positions = self.config.max_open_positions_default;
512        limits
513    }
514
515    fn trading_limits_from_tier(&self, tier: &UserTier) -> TradingLimits {
516        TradingLimits {
517            max_open_orders: tier
518                .max_open_orders
519                .unwrap_or(self.config.max_open_orders_default),
520            max_open_positions: tier.max_open_positions,
521            orders_per_minute: tier.orders_per_minute,
522            cancels_per_minute: tier.cancels_per_minute,
523            api_requests_per_minute: tier.api_requests_per_minute,
524        }
525    }
526
527    /// Set the margin mode for a wallet in the in-memory cache only (no DB write).
528    /// Use this in tests that need to simulate PM vs SM accounts.
529    #[cfg(any(test, feature = "test-utils"))]
530    pub async fn set_margin_mode_in_memory(
531        &self,
532        wallet: &WalletAddress,
533        mode: hypercall_types::MarginMode,
534    ) {
535        let mut cache = self.tiers.write().await;
536        if let Some(tier) = cache.get_mut(wallet) {
537            tier.margin_mode = mode;
538        } else {
539            cache.insert(
540                *wallet,
541                UserTier {
542                    wallet_address: *wallet,
543                    tier: "tier2".to_string(),
544                    margin_mode: mode,
545                    version: 1,
546                    max_open_orders: None,
547                    max_open_positions: 50,
548                    orders_per_minute: 60,
549                    cancels_per_minute: 120,
550                    api_requests_per_minute: 600,
551                    created_at: None,
552                    updated_at: None,
553                },
554            );
555        }
556    }
557}
558
559fn new_user_tier_from_existing(tier: &UserTier) -> UserTierUpdate {
560    UserTierUpdate {
561        wallet_address: tier.wallet_address,
562        tier: tier.tier.clone(),
563        margin_mode: Some(tier.margin_mode),
564        version: Some(tier.version),
565        max_open_orders: tier.max_open_orders,
566        max_open_positions: Some(tier.max_open_positions),
567        orders_per_minute: Some(tier.orders_per_minute),
568        cancels_per_minute: Some(tier.cancels_per_minute),
569        api_requests_per_minute: Some(tier.api_requests_per_minute),
570    }
571}