Skip to main content

hypercall_api/caches/
rate_limit.rs

1//! Rate limiting cache using fixed window algorithm.
2//!
3//! This module provides per-wallet rate limiting for:
4//! - Order placements
5//! - Order cancellations
6//! - API requests
7//!
8//! Supports both Redis (for distributed rate limiting across multiple instances)
9//! and in-memory backends (for local development without Redis).
10
11use crate::boundary::read_models::TierCacheApi;
12use crate::models::TradingLimits;
13use hypercall_types::WalletAddress;
14use redis::aio::ConnectionManager;
15use redis::Script;
16use std::collections::HashMap;
17use std::sync::{Arc, RwLock as StdRwLock};
18use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
19
20/// Lua script for atomic check-and-increment in Redis.
21/// Returns {current_count, allowed (1=yes, 0=no)}
22const CHECK_AND_INCREMENT_LUA: &str = r#"
23local current = tonumber(redis.call('GET', KEYS[1]) or '0')
24local limit = tonumber(ARGV[1])
25local ttl = tonumber(ARGV[2])
26
27if current < limit then
28    current = redis.call('INCR', KEYS[1])
29    redis.call('EXPIRE', KEYS[1], ttl)
30    return {current, 1}
31end
32return {current, 0}
33"#;
34
35/// Extra TTL buffer to prevent race conditions at window boundaries (in seconds).
36const REDIS_TTL_BUFFER_SECS: u64 = 10;
37
38/// Lua script to get current count without incrementing.
39const GET_COUNT_LUA: &str = r#"
40local current = redis.call('GET', KEYS[1])
41if current == false then
42    return 0
43end
44return tonumber(current)
45"#;
46
47/// Fixed window counter for rate limiting (in-memory backend).
48#[derive(Debug, Clone)]
49struct FixedWindowCounter {
50    count: u32,
51    window_start: Instant,
52}
53
54impl FixedWindowCounter {
55    fn new() -> Self {
56        Self {
57            count: 0,
58            window_start: Instant::now(),
59        }
60    }
61
62    /// Check if the window has expired and reset if necessary.
63    /// Returns true if the window was reset.
64    fn maybe_reset(&mut self, window_duration: Duration) -> bool {
65        if self.window_start.elapsed() >= window_duration {
66            self.count = 0;
67            self.window_start = Instant::now();
68            true
69        } else {
70            false
71        }
72    }
73
74    /// Increment counter if under limit, return true if allowed.
75    fn try_increment(&mut self, limit: u32, window_duration: Duration) -> bool {
76        self.maybe_reset(window_duration);
77        if self.count < limit {
78            self.count += 1;
79            true
80        } else {
81            false
82        }
83    }
84
85    /// Get time remaining until window reset.
86    fn time_until_reset(&self, window_duration: Duration) -> Duration {
87        let elapsed = self.window_start.elapsed();
88        if elapsed >= window_duration {
89            Duration::ZERO
90        } else {
91            window_duration - elapsed
92        }
93    }
94
95    /// Get current count.
96    fn current_count(&self) -> u32 {
97        self.count
98    }
99}
100
101/// Type of rate-limited action.
102#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
103pub enum RateLimitAction {
104    /// Order placement
105    OrderPlacement,
106    /// Order cancellation
107    OrderCancellation,
108    /// General API request
109    ApiRequest,
110    /// RFQ submission — per-wallet throttle to prevent a single taker
111    /// from fan-out-DOS-ing every connected QP via rapid-fire
112    /// /rfq/request calls. Default: 10/minute (see get_limit_for_action).
113    RfqSubmit,
114}
115
116impl RateLimitAction {
117    /// Get the string key for this action (used in Redis keys).
118    fn as_key_str(&self) -> &'static str {
119        match self {
120            RateLimitAction::OrderPlacement => "order",
121            RateLimitAction::OrderCancellation => "cancel",
122            RateLimitAction::ApiRequest => "api",
123            RateLimitAction::RfqSubmit => "rfq_submit",
124        }
125    }
126}
127
128/// Rate limit check result with metadata.
129#[derive(Debug, Clone)]
130pub struct RateLimitInfo {
131    /// The limit for this action
132    pub limit: u32,
133    /// Remaining requests in current window
134    pub remaining: u32,
135    /// Unix timestamp when the window resets
136    pub reset_at: u64,
137}
138
139/// Error returned when rate limit is exceeded or service is unavailable.
140#[derive(Debug, Clone)]
141pub enum RateLimitError {
142    /// Rate limit exceeded - client should retry after the specified time
143    Exceeded {
144        /// Seconds until the limit resets
145        retry_after_secs: u32,
146        /// The limit that was exceeded
147        limit: u32,
148        /// Type of action that was rate limited
149        action: RateLimitAction,
150    },
151    /// Redis service unavailable - fail closed
152    ServiceUnavailable {
153        /// Error message for logging
154        message: String,
155    },
156}
157
158impl std::fmt::Display for RateLimitError {
159    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
160        match self {
161            RateLimitError::Exceeded {
162                retry_after_secs,
163                limit,
164                action,
165            } => {
166                write!(
167                    f,
168                    "Rate limit exceeded for {:?}: {} per minute, retry after {} seconds",
169                    action, limit, retry_after_secs
170                )
171            }
172            RateLimitError::ServiceUnavailable { message } => {
173                write!(f, "Rate limit service unavailable: {}", message)
174            }
175        }
176    }
177}
178
179impl std::error::Error for RateLimitError {}
180
181/// Backend type for rate limiting.
182#[derive(Debug, Clone, Copy, PartialEq, Eq)]
183pub enum RateLimitBackend {
184    /// Redis backend for distributed rate limiting
185    Redis,
186    /// In-memory backend for single-instance rate limiting
187    InMemory,
188}
189
190/// Rate limit cache using fixed window algorithm.
191///
192/// Supports both Redis (distributed) and in-memory (single-instance) backends.
193/// When Redis is configured, rate limits are shared across all API server instances.
194/// When Redis is not configured, falls back to in-memory rate limiting.
195pub struct RateLimitCache {
196    /// Redis connection manager (None if using in-memory backend)
197    redis: Option<ConnectionManager>,
198    /// In-memory counters for order placements per wallet
199    order_counts: StdRwLock<HashMap<WalletAddress, FixedWindowCounter>>,
200    /// In-memory counters for cancellations per wallet
201    cancel_counts: StdRwLock<HashMap<WalletAddress, FixedWindowCounter>>,
202    /// In-memory counters for API requests per wallet
203    api_counts: StdRwLock<HashMap<WalletAddress, FixedWindowCounter>>,
204    /// In-memory counters for RFQ submissions per wallet
205    rfq_submit_counts: StdRwLock<HashMap<WalletAddress, FixedWindowCounter>>,
206    /// Reference to tier cache for getting limits (TierCache handles its own locking)
207    tier_cache: Arc<dyn TierCacheApi>,
208    /// Window duration (default: 60 seconds)
209    window_duration: Duration,
210    /// Lua script for atomic check-and-increment
211    check_and_increment_script: Script,
212    /// Lua script for getting current count
213    get_count_script: Script,
214}
215
216impl RateLimitCache {
217    /// Create a new rate limit cache.
218    ///
219    /// If `redis_url` is provided and connection succeeds, uses Redis backend.
220    /// Otherwise falls back to in-memory backend.
221    pub async fn new(tier_cache: Arc<dyn TierCacheApi>, redis_url: Option<&str>) -> Self {
222        let redis = if let Some(url) = redis_url {
223            match Self::connect_redis(url).await {
224                Ok(conn) => {
225                    tracing::info!("RateLimitCache using Redis backend at {}", url);
226                    Some(conn)
227                }
228                Err(e) => {
229                    tracing::warn!(
230                        "Failed to connect to Redis at {}, using in-memory backend: {}",
231                        url,
232                        e
233                    );
234                    None
235                }
236            }
237        } else {
238            tracing::info!("REDIS_URL not set, using in-memory rate limiting");
239            None
240        };
241
242        Self {
243            redis,
244            order_counts: StdRwLock::new(HashMap::new()),
245            cancel_counts: StdRwLock::new(HashMap::new()),
246            api_counts: StdRwLock::new(HashMap::new()),
247            rfq_submit_counts: StdRwLock::new(HashMap::new()),
248            tier_cache,
249            window_duration: Duration::from_secs(60),
250            check_and_increment_script: Script::new(CHECK_AND_INCREMENT_LUA),
251            get_count_script: Script::new(GET_COUNT_LUA),
252        }
253    }
254
255    /// Connect to Redis and return a connection manager.
256    async fn connect_redis(url: &str) -> Result<ConnectionManager, redis::RedisError> {
257        let client = redis::Client::open(url)?;
258        ConnectionManager::new(client).await
259    }
260
261    /// Get the backend type being used.
262    pub fn backend(&self) -> RateLimitBackend {
263        if self.redis.is_some() {
264            RateLimitBackend::Redis
265        } else {
266            RateLimitBackend::InMemory
267        }
268    }
269
270    /// Generate Redis key for a wallet/action/window.
271    fn redis_key(wallet: &WalletAddress, action: RateLimitAction, window_start: u64) -> String {
272        format!(
273            "ratelimit:{}:{}:{}",
274            wallet,
275            action.as_key_str(),
276            window_start
277        )
278    }
279
280    /// Get the window start timestamp (truncated to minute boundary).
281    fn get_window_start(now_secs: u64, window_secs: u64) -> u64 {
282        (now_secs / window_secs) * window_secs
283    }
284
285    /// Get the counters map for the given action type (in-memory backend).
286    fn get_counters(
287        &self,
288        action: RateLimitAction,
289    ) -> &StdRwLock<HashMap<WalletAddress, FixedWindowCounter>> {
290        match action {
291            RateLimitAction::OrderPlacement => &self.order_counts,
292            RateLimitAction::OrderCancellation => &self.cancel_counts,
293            RateLimitAction::ApiRequest => &self.api_counts,
294            RateLimitAction::RfqSubmit => &self.rfq_submit_counts,
295        }
296    }
297
298    /// Get the limit for the given action from trading limits.
299    fn get_limit_for_action(limits: &TradingLimits, action: RateLimitAction) -> u32 {
300        match action {
301            RateLimitAction::OrderPlacement => limits.orders_per_minute as u32,
302            RateLimitAction::OrderCancellation => limits.cancels_per_minute as u32,
303            RateLimitAction::ApiRequest => limits.api_requests_per_minute as u32,
304            // RFQ submit: 10/minute per wallet. Each submit fans out to
305            // every connected QP, so a higher rate would amplify into
306            // proportionally more WS traffic + quote generation load.
307            // Not yet in TradingLimits — hardcoded until we have a
308            // tier-aware RFQ limit field.
309            RateLimitAction::RfqSubmit => 10,
310        }
311    }
312
313    /// Check and increment the rate limit counter for a wallet/action.
314    ///
315    /// Returns Ok(RateLimitInfo) if allowed, Err(RateLimitError) if exceeded or unavailable.
316    pub async fn check_and_increment(
317        &self,
318        wallet: &WalletAddress,
319        action: RateLimitAction,
320    ) -> Result<RateLimitInfo, RateLimitError> {
321        // Wait for the tier-cache lock instead of using the sync try_read path,
322        // where transient lock contention falls through to missing-tier defaults.
323        let limits = self.tier_cache.get_trading_limits_async(wallet).await;
324        let limit = Self::get_limit_for_action(&limits, action);
325
326        if let Some(ref redis) = self.redis {
327            self.check_and_increment_redis(redis.clone(), wallet, action, limit)
328                .await
329        } else {
330            self.check_and_increment_memory(wallet, action, limit)
331        }
332    }
333
334    /// Check and increment using Redis backend.
335    async fn check_and_increment_redis(
336        &self,
337        mut redis: ConnectionManager,
338        wallet: &WalletAddress,
339        action: RateLimitAction,
340        limit: u32,
341    ) -> Result<RateLimitInfo, RateLimitError> {
342        let now = SystemTime::now()
343            .duration_since(UNIX_EPOCH)
344            .unwrap()
345            .as_secs();
346
347        let window_secs = self.window_duration.as_secs();
348        let window_start = Self::get_window_start(now, window_secs);
349        let window_end = window_start + window_secs;
350        let key = Self::redis_key(wallet, action, window_start);
351
352        // TTL is window duration plus buffer to prevent race conditions
353        let ttl = window_secs + REDIS_TTL_BUFFER_SECS;
354
355        // Execute Lua script atomically
356        let result: Result<(i64, i64), redis::RedisError> = self
357            .check_and_increment_script
358            .key(&key)
359            .arg(limit)
360            .arg(ttl)
361            .invoke_async(&mut redis)
362            .await;
363
364        match result {
365            Ok((count, allowed)) => {
366                if allowed == 1 {
367                    Ok(RateLimitInfo {
368                        limit,
369                        remaining: limit.saturating_sub(count as u32),
370                        reset_at: window_end,
371                    })
372                } else {
373                    let retry_after = (window_end.saturating_sub(now)).max(1) as u32;
374                    Err(RateLimitError::Exceeded {
375                        retry_after_secs: retry_after,
376                        limit,
377                        action,
378                    })
379                }
380            }
381            Err(e) => {
382                tracing::error!("Redis rate limit check failed: {}", e);
383                // Fail closed: reject requests when rate limiting is unavailable.
384                // This is a security decision to prevent abuse during Redis outages.
385                Err(RateLimitError::ServiceUnavailable {
386                    message: format!("Redis error: {}", e),
387                })
388            }
389        }
390    }
391
392    /// Check and increment using in-memory backend.
393    fn check_and_increment_memory(
394        &self,
395        wallet: &WalletAddress,
396        action: RateLimitAction,
397        limit: u32,
398    ) -> Result<RateLimitInfo, RateLimitError> {
399        let counters = self.get_counters(action);
400
401        // Use write lock to update counters
402        let mut map = counters.write().unwrap();
403        let counter = map.entry(*wallet).or_insert_with(FixedWindowCounter::new);
404
405        if counter.try_increment(limit, self.window_duration) {
406            let remaining = limit.saturating_sub(counter.current_count());
407            let reset_at = SystemTime::now()
408                .duration_since(UNIX_EPOCH)
409                .unwrap()
410                .as_secs()
411                + counter.time_until_reset(self.window_duration).as_secs();
412
413            Ok(RateLimitInfo {
414                limit,
415                remaining,
416                reset_at,
417            })
418        } else {
419            let retry_after = counter.time_until_reset(self.window_duration).as_secs() as u32;
420            Err(RateLimitError::Exceeded {
421                retry_after_secs: retry_after.max(1),
422                limit,
423                action,
424            })
425        }
426    }
427
428    /// Get current rate limit info for a wallet/action without incrementing.
429    pub async fn get_info(
430        &self,
431        wallet: &WalletAddress,
432        action: RateLimitAction,
433    ) -> Result<RateLimitInfo, RateLimitError> {
434        let limits = self.tier_cache.get_trading_limits_async(wallet).await;
435        let limit = Self::get_limit_for_action(&limits, action);
436
437        if let Some(ref redis) = self.redis {
438            self.get_info_redis(redis.clone(), wallet, action, limit)
439                .await
440        } else {
441            Ok(self.get_info_memory(wallet, action, limit))
442        }
443    }
444
445    /// Get info using Redis backend.
446    async fn get_info_redis(
447        &self,
448        mut redis: ConnectionManager,
449        wallet: &WalletAddress,
450        action: RateLimitAction,
451        limit: u32,
452    ) -> Result<RateLimitInfo, RateLimitError> {
453        let now = SystemTime::now()
454            .duration_since(UNIX_EPOCH)
455            .unwrap()
456            .as_secs();
457
458        let window_secs = self.window_duration.as_secs();
459        let window_start = Self::get_window_start(now, window_secs);
460        let window_end = window_start + window_secs;
461        let key = Self::redis_key(wallet, action, window_start);
462
463        let result: Result<i64, redis::RedisError> = self
464            .get_count_script
465            .key(&key)
466            .invoke_async(&mut redis)
467            .await;
468
469        match result {
470            Ok(count) => {
471                let remaining = limit.saturating_sub(count as u32);
472                Ok(RateLimitInfo {
473                    limit,
474                    remaining,
475                    reset_at: window_end,
476                })
477            }
478            Err(e) => {
479                tracing::error!("Redis rate limit get_info failed: {}", e);
480                Err(RateLimitError::ServiceUnavailable {
481                    message: format!("Redis error: {}", e),
482                })
483            }
484        }
485    }
486
487    /// Get info using in-memory backend.
488    fn get_info_memory(
489        &self,
490        wallet: &WalletAddress,
491        action: RateLimitAction,
492        limit: u32,
493    ) -> RateLimitInfo {
494        let counters = self.get_counters(action);
495
496        let (current_count, time_until_reset) = {
497            let mut map = counters.write().unwrap();
498            if let Some(counter) = map.get_mut(wallet) {
499                counter.maybe_reset(self.window_duration);
500                (
501                    counter.current_count(),
502                    counter.time_until_reset(self.window_duration),
503                )
504            } else {
505                (0, self.window_duration)
506            }
507        };
508
509        let remaining = limit.saturating_sub(current_count);
510        let reset_at = SystemTime::now()
511            .duration_since(UNIX_EPOCH)
512            .unwrap()
513            .as_secs()
514            + time_until_reset.as_secs();
515
516        RateLimitInfo {
517            limit,
518            remaining,
519            reset_at,
520        }
521    }
522
523    /// Clean up expired entries to prevent memory growth (in-memory backend only).
524    /// Call this periodically (e.g., every 5 minutes).
525    pub fn cleanup_expired(&self) {
526        let threshold = self.window_duration * 2;
527
528        let cleanup_map = |map: &StdRwLock<HashMap<WalletAddress, FixedWindowCounter>>| {
529            let mut guard = map.write().unwrap();
530            guard.retain(|_, counter| counter.window_start.elapsed() < threshold);
531        };
532
533        cleanup_map(&self.order_counts);
534        cleanup_map(&self.cancel_counts);
535        cleanup_map(&self.api_counts);
536    }
537
538    /// Get the number of tracked wallets (for metrics, in-memory backend only).
539    pub fn tracked_wallet_count(&self) -> usize {
540        let order_len = self.order_counts.read().unwrap().len();
541        let cancel_len = self.cancel_counts.read().unwrap().len();
542        let api_len = self.api_counts.read().unwrap().len();
543        order_len.max(cancel_len).max(api_len)
544    }
545}
546
547#[cfg(test)]
548mod tests {
549    use super::*;
550
551    #[test]
552    fn test_fixed_window_counter_increments() {
553        let mut counter = FixedWindowCounter::new();
554        let window = Duration::from_secs(60);
555
556        assert!(counter.try_increment(10, window));
557        assert_eq!(counter.current_count(), 1);
558
559        assert!(counter.try_increment(10, window));
560        assert_eq!(counter.current_count(), 2);
561    }
562
563    #[test]
564    fn test_fixed_window_counter_blocks_at_limit() {
565        let mut counter = FixedWindowCounter::new();
566        let window = Duration::from_secs(60);
567
568        // Fill up to limit
569        for _ in 0..10 {
570            assert!(counter.try_increment(10, window));
571        }
572
573        // 11th request should fail
574        assert!(!counter.try_increment(10, window));
575        assert_eq!(counter.current_count(), 10);
576    }
577
578    #[test]
579    fn test_rate_limit_error_display() {
580        let err = RateLimitError::Exceeded {
581            retry_after_secs: 30,
582            limit: 60,
583            action: RateLimitAction::OrderPlacement,
584        };
585
586        let display = format!("{}", err);
587        assert!(display.contains("60"));
588        assert!(display.contains("30"));
589
590        let err = RateLimitError::ServiceUnavailable {
591            message: "Redis error".to_string(),
592        };
593        let display = format!("{}", err);
594        assert!(display.contains("unavailable"));
595    }
596
597    #[test]
598    fn test_get_limit_for_action() {
599        let limits = TradingLimits {
600            max_open_orders: 100,
601            max_open_positions: 50,
602            orders_per_minute: 60,
603            cancels_per_minute: 120,
604            api_requests_per_minute: 600,
605        };
606
607        assert_eq!(
608            RateLimitCache::get_limit_for_action(&limits, RateLimitAction::OrderPlacement),
609            60
610        );
611        assert_eq!(
612            RateLimitCache::get_limit_for_action(&limits, RateLimitAction::OrderCancellation),
613            120
614        );
615        assert_eq!(
616            RateLimitCache::get_limit_for_action(&limits, RateLimitAction::ApiRequest),
617            600
618        );
619    }
620
621    #[test]
622    fn test_redis_key_generation() {
623        use std::str::FromStr;
624        let wallet = WalletAddress::from_str("0xf39Fd6e51aad88F6F4ce6aB8827279cffFb92266").unwrap();
625        let key = RateLimitCache::redis_key(&wallet, RateLimitAction::OrderPlacement, 1704067200);
626        assert!(key.starts_with("ratelimit:"));
627        assert!(key.contains(":order:"));
628        assert!(key.contains("1704067200"));
629    }
630
631    #[test]
632    fn test_get_window_start() {
633        // 1704067234 should round down to 1704067200 (60-second windows)
634        assert_eq!(RateLimitCache::get_window_start(1704067234, 60), 1704067200);
635        assert_eq!(RateLimitCache::get_window_start(1704067200, 60), 1704067200);
636        assert_eq!(RateLimitCache::get_window_start(1704067259, 60), 1704067200);
637        assert_eq!(RateLimitCache::get_window_start(1704067260, 60), 1704067260);
638    }
639
640    #[test]
641    fn test_action_key_str() {
642        assert_eq!(RateLimitAction::OrderPlacement.as_key_str(), "order");
643        assert_eq!(RateLimitAction::OrderCancellation.as_key_str(), "cancel");
644        assert_eq!(RateLimitAction::ApiRequest.as_key_str(), "api");
645    }
646}