1use crate::boundary::read_models::TierCacheApi;
12use crate::models::TradingLimits;
13use hypercall_types::WalletAddress;
14use redis::aio::ConnectionManager;
15use redis::Script;
16use std::collections::HashMap;
17use std::sync::{Arc, RwLock as StdRwLock};
18use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
19
20const CHECK_AND_INCREMENT_LUA: &str = r#"
23local current = tonumber(redis.call('GET', KEYS[1]) or '0')
24local limit = tonumber(ARGV[1])
25local ttl = tonumber(ARGV[2])
26
27if current < limit then
28 current = redis.call('INCR', KEYS[1])
29 redis.call('EXPIRE', KEYS[1], ttl)
30 return {current, 1}
31end
32return {current, 0}
33"#;
34
35const REDIS_TTL_BUFFER_SECS: u64 = 10;
37
38const GET_COUNT_LUA: &str = r#"
40local current = redis.call('GET', KEYS[1])
41if current == false then
42 return 0
43end
44return tonumber(current)
45"#;
46
47#[derive(Debug, Clone)]
49struct FixedWindowCounter {
50 count: u32,
51 window_start: Instant,
52}
53
54impl FixedWindowCounter {
55 fn new() -> Self {
56 Self {
57 count: 0,
58 window_start: Instant::now(),
59 }
60 }
61
62 fn maybe_reset(&mut self, window_duration: Duration) -> bool {
65 if self.window_start.elapsed() >= window_duration {
66 self.count = 0;
67 self.window_start = Instant::now();
68 true
69 } else {
70 false
71 }
72 }
73
74 fn try_increment(&mut self, limit: u32, window_duration: Duration) -> bool {
76 self.maybe_reset(window_duration);
77 if self.count < limit {
78 self.count += 1;
79 true
80 } else {
81 false
82 }
83 }
84
85 fn time_until_reset(&self, window_duration: Duration) -> Duration {
87 let elapsed = self.window_start.elapsed();
88 if elapsed >= window_duration {
89 Duration::ZERO
90 } else {
91 window_duration - elapsed
92 }
93 }
94
95 fn current_count(&self) -> u32 {
97 self.count
98 }
99}
100
101#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
103pub enum RateLimitAction {
104 OrderPlacement,
106 OrderCancellation,
108 ApiRequest,
110 RfqSubmit,
114}
115
116impl RateLimitAction {
117 fn as_key_str(&self) -> &'static str {
119 match self {
120 RateLimitAction::OrderPlacement => "order",
121 RateLimitAction::OrderCancellation => "cancel",
122 RateLimitAction::ApiRequest => "api",
123 RateLimitAction::RfqSubmit => "rfq_submit",
124 }
125 }
126}
127
128#[derive(Debug, Clone)]
130pub struct RateLimitInfo {
131 pub limit: u32,
133 pub remaining: u32,
135 pub reset_at: u64,
137}
138
139#[derive(Debug, Clone)]
141pub enum RateLimitError {
142 Exceeded {
144 retry_after_secs: u32,
146 limit: u32,
148 action: RateLimitAction,
150 },
151 ServiceUnavailable {
153 message: String,
155 },
156}
157
158impl std::fmt::Display for RateLimitError {
159 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
160 match self {
161 RateLimitError::Exceeded {
162 retry_after_secs,
163 limit,
164 action,
165 } => {
166 write!(
167 f,
168 "Rate limit exceeded for {:?}: {} per minute, retry after {} seconds",
169 action, limit, retry_after_secs
170 )
171 }
172 RateLimitError::ServiceUnavailable { message } => {
173 write!(f, "Rate limit service unavailable: {}", message)
174 }
175 }
176 }
177}
178
179impl std::error::Error for RateLimitError {}
180
181#[derive(Debug, Clone, Copy, PartialEq, Eq)]
183pub enum RateLimitBackend {
184 Redis,
186 InMemory,
188}
189
190pub struct RateLimitCache {
196 redis: Option<ConnectionManager>,
198 order_counts: StdRwLock<HashMap<WalletAddress, FixedWindowCounter>>,
200 cancel_counts: StdRwLock<HashMap<WalletAddress, FixedWindowCounter>>,
202 api_counts: StdRwLock<HashMap<WalletAddress, FixedWindowCounter>>,
204 rfq_submit_counts: StdRwLock<HashMap<WalletAddress, FixedWindowCounter>>,
206 tier_cache: Arc<dyn TierCacheApi>,
208 window_duration: Duration,
210 check_and_increment_script: Script,
212 get_count_script: Script,
214}
215
216impl RateLimitCache {
217 pub async fn new(tier_cache: Arc<dyn TierCacheApi>, redis_url: Option<&str>) -> Self {
222 let redis = if let Some(url) = redis_url {
223 match Self::connect_redis(url).await {
224 Ok(conn) => {
225 tracing::info!("RateLimitCache using Redis backend at {}", url);
226 Some(conn)
227 }
228 Err(e) => {
229 tracing::warn!(
230 "Failed to connect to Redis at {}, using in-memory backend: {}",
231 url,
232 e
233 );
234 None
235 }
236 }
237 } else {
238 tracing::info!("REDIS_URL not set, using in-memory rate limiting");
239 None
240 };
241
242 Self {
243 redis,
244 order_counts: StdRwLock::new(HashMap::new()),
245 cancel_counts: StdRwLock::new(HashMap::new()),
246 api_counts: StdRwLock::new(HashMap::new()),
247 rfq_submit_counts: StdRwLock::new(HashMap::new()),
248 tier_cache,
249 window_duration: Duration::from_secs(60),
250 check_and_increment_script: Script::new(CHECK_AND_INCREMENT_LUA),
251 get_count_script: Script::new(GET_COUNT_LUA),
252 }
253 }
254
255 async fn connect_redis(url: &str) -> Result<ConnectionManager, redis::RedisError> {
257 let client = redis::Client::open(url)?;
258 ConnectionManager::new(client).await
259 }
260
261 pub fn backend(&self) -> RateLimitBackend {
263 if self.redis.is_some() {
264 RateLimitBackend::Redis
265 } else {
266 RateLimitBackend::InMemory
267 }
268 }
269
270 fn redis_key(wallet: &WalletAddress, action: RateLimitAction, window_start: u64) -> String {
272 format!(
273 "ratelimit:{}:{}:{}",
274 wallet,
275 action.as_key_str(),
276 window_start
277 )
278 }
279
280 fn get_window_start(now_secs: u64, window_secs: u64) -> u64 {
282 (now_secs / window_secs) * window_secs
283 }
284
285 fn get_counters(
287 &self,
288 action: RateLimitAction,
289 ) -> &StdRwLock<HashMap<WalletAddress, FixedWindowCounter>> {
290 match action {
291 RateLimitAction::OrderPlacement => &self.order_counts,
292 RateLimitAction::OrderCancellation => &self.cancel_counts,
293 RateLimitAction::ApiRequest => &self.api_counts,
294 RateLimitAction::RfqSubmit => &self.rfq_submit_counts,
295 }
296 }
297
298 fn get_limit_for_action(limits: &TradingLimits, action: RateLimitAction) -> u32 {
300 match action {
301 RateLimitAction::OrderPlacement => limits.orders_per_minute as u32,
302 RateLimitAction::OrderCancellation => limits.cancels_per_minute as u32,
303 RateLimitAction::ApiRequest => limits.api_requests_per_minute as u32,
304 RateLimitAction::RfqSubmit => 10,
310 }
311 }
312
313 pub async fn check_and_increment(
317 &self,
318 wallet: &WalletAddress,
319 action: RateLimitAction,
320 ) -> Result<RateLimitInfo, RateLimitError> {
321 let limits = self.tier_cache.get_trading_limits_async(wallet).await;
324 let limit = Self::get_limit_for_action(&limits, action);
325
326 if let Some(ref redis) = self.redis {
327 self.check_and_increment_redis(redis.clone(), wallet, action, limit)
328 .await
329 } else {
330 self.check_and_increment_memory(wallet, action, limit)
331 }
332 }
333
334 async fn check_and_increment_redis(
336 &self,
337 mut redis: ConnectionManager,
338 wallet: &WalletAddress,
339 action: RateLimitAction,
340 limit: u32,
341 ) -> Result<RateLimitInfo, RateLimitError> {
342 let now = SystemTime::now()
343 .duration_since(UNIX_EPOCH)
344 .unwrap()
345 .as_secs();
346
347 let window_secs = self.window_duration.as_secs();
348 let window_start = Self::get_window_start(now, window_secs);
349 let window_end = window_start + window_secs;
350 let key = Self::redis_key(wallet, action, window_start);
351
352 let ttl = window_secs + REDIS_TTL_BUFFER_SECS;
354
355 let result: Result<(i64, i64), redis::RedisError> = self
357 .check_and_increment_script
358 .key(&key)
359 .arg(limit)
360 .arg(ttl)
361 .invoke_async(&mut redis)
362 .await;
363
364 match result {
365 Ok((count, allowed)) => {
366 if allowed == 1 {
367 Ok(RateLimitInfo {
368 limit,
369 remaining: limit.saturating_sub(count as u32),
370 reset_at: window_end,
371 })
372 } else {
373 let retry_after = (window_end.saturating_sub(now)).max(1) as u32;
374 Err(RateLimitError::Exceeded {
375 retry_after_secs: retry_after,
376 limit,
377 action,
378 })
379 }
380 }
381 Err(e) => {
382 tracing::error!("Redis rate limit check failed: {}", e);
383 Err(RateLimitError::ServiceUnavailable {
386 message: format!("Redis error: {}", e),
387 })
388 }
389 }
390 }
391
392 fn check_and_increment_memory(
394 &self,
395 wallet: &WalletAddress,
396 action: RateLimitAction,
397 limit: u32,
398 ) -> Result<RateLimitInfo, RateLimitError> {
399 let counters = self.get_counters(action);
400
401 let mut map = counters.write().unwrap();
403 let counter = map.entry(*wallet).or_insert_with(FixedWindowCounter::new);
404
405 if counter.try_increment(limit, self.window_duration) {
406 let remaining = limit.saturating_sub(counter.current_count());
407 let reset_at = SystemTime::now()
408 .duration_since(UNIX_EPOCH)
409 .unwrap()
410 .as_secs()
411 + counter.time_until_reset(self.window_duration).as_secs();
412
413 Ok(RateLimitInfo {
414 limit,
415 remaining,
416 reset_at,
417 })
418 } else {
419 let retry_after = counter.time_until_reset(self.window_duration).as_secs() as u32;
420 Err(RateLimitError::Exceeded {
421 retry_after_secs: retry_after.max(1),
422 limit,
423 action,
424 })
425 }
426 }
427
428 pub async fn get_info(
430 &self,
431 wallet: &WalletAddress,
432 action: RateLimitAction,
433 ) -> Result<RateLimitInfo, RateLimitError> {
434 let limits = self.tier_cache.get_trading_limits_async(wallet).await;
435 let limit = Self::get_limit_for_action(&limits, action);
436
437 if let Some(ref redis) = self.redis {
438 self.get_info_redis(redis.clone(), wallet, action, limit)
439 .await
440 } else {
441 Ok(self.get_info_memory(wallet, action, limit))
442 }
443 }
444
445 async fn get_info_redis(
447 &self,
448 mut redis: ConnectionManager,
449 wallet: &WalletAddress,
450 action: RateLimitAction,
451 limit: u32,
452 ) -> Result<RateLimitInfo, RateLimitError> {
453 let now = SystemTime::now()
454 .duration_since(UNIX_EPOCH)
455 .unwrap()
456 .as_secs();
457
458 let window_secs = self.window_duration.as_secs();
459 let window_start = Self::get_window_start(now, window_secs);
460 let window_end = window_start + window_secs;
461 let key = Self::redis_key(wallet, action, window_start);
462
463 let result: Result<i64, redis::RedisError> = self
464 .get_count_script
465 .key(&key)
466 .invoke_async(&mut redis)
467 .await;
468
469 match result {
470 Ok(count) => {
471 let remaining = limit.saturating_sub(count as u32);
472 Ok(RateLimitInfo {
473 limit,
474 remaining,
475 reset_at: window_end,
476 })
477 }
478 Err(e) => {
479 tracing::error!("Redis rate limit get_info failed: {}", e);
480 Err(RateLimitError::ServiceUnavailable {
481 message: format!("Redis error: {}", e),
482 })
483 }
484 }
485 }
486
487 fn get_info_memory(
489 &self,
490 wallet: &WalletAddress,
491 action: RateLimitAction,
492 limit: u32,
493 ) -> RateLimitInfo {
494 let counters = self.get_counters(action);
495
496 let (current_count, time_until_reset) = {
497 let mut map = counters.write().unwrap();
498 if let Some(counter) = map.get_mut(wallet) {
499 counter.maybe_reset(self.window_duration);
500 (
501 counter.current_count(),
502 counter.time_until_reset(self.window_duration),
503 )
504 } else {
505 (0, self.window_duration)
506 }
507 };
508
509 let remaining = limit.saturating_sub(current_count);
510 let reset_at = SystemTime::now()
511 .duration_since(UNIX_EPOCH)
512 .unwrap()
513 .as_secs()
514 + time_until_reset.as_secs();
515
516 RateLimitInfo {
517 limit,
518 remaining,
519 reset_at,
520 }
521 }
522
523 pub fn cleanup_expired(&self) {
526 let threshold = self.window_duration * 2;
527
528 let cleanup_map = |map: &StdRwLock<HashMap<WalletAddress, FixedWindowCounter>>| {
529 let mut guard = map.write().unwrap();
530 guard.retain(|_, counter| counter.window_start.elapsed() < threshold);
531 };
532
533 cleanup_map(&self.order_counts);
534 cleanup_map(&self.cancel_counts);
535 cleanup_map(&self.api_counts);
536 }
537
538 pub fn tracked_wallet_count(&self) -> usize {
540 let order_len = self.order_counts.read().unwrap().len();
541 let cancel_len = self.cancel_counts.read().unwrap().len();
542 let api_len = self.api_counts.read().unwrap().len();
543 order_len.max(cancel_len).max(api_len)
544 }
545}
546
547#[cfg(test)]
548mod tests {
549 use super::*;
550
551 #[test]
552 fn test_fixed_window_counter_increments() {
553 let mut counter = FixedWindowCounter::new();
554 let window = Duration::from_secs(60);
555
556 assert!(counter.try_increment(10, window));
557 assert_eq!(counter.current_count(), 1);
558
559 assert!(counter.try_increment(10, window));
560 assert_eq!(counter.current_count(), 2);
561 }
562
563 #[test]
564 fn test_fixed_window_counter_blocks_at_limit() {
565 let mut counter = FixedWindowCounter::new();
566 let window = Duration::from_secs(60);
567
568 for _ in 0..10 {
570 assert!(counter.try_increment(10, window));
571 }
572
573 assert!(!counter.try_increment(10, window));
575 assert_eq!(counter.current_count(), 10);
576 }
577
578 #[test]
579 fn test_rate_limit_error_display() {
580 let err = RateLimitError::Exceeded {
581 retry_after_secs: 30,
582 limit: 60,
583 action: RateLimitAction::OrderPlacement,
584 };
585
586 let display = format!("{}", err);
587 assert!(display.contains("60"));
588 assert!(display.contains("30"));
589
590 let err = RateLimitError::ServiceUnavailable {
591 message: "Redis error".to_string(),
592 };
593 let display = format!("{}", err);
594 assert!(display.contains("unavailable"));
595 }
596
597 #[test]
598 fn test_get_limit_for_action() {
599 let limits = TradingLimits {
600 max_open_orders: 100,
601 max_open_positions: 50,
602 orders_per_minute: 60,
603 cancels_per_minute: 120,
604 api_requests_per_minute: 600,
605 };
606
607 assert_eq!(
608 RateLimitCache::get_limit_for_action(&limits, RateLimitAction::OrderPlacement),
609 60
610 );
611 assert_eq!(
612 RateLimitCache::get_limit_for_action(&limits, RateLimitAction::OrderCancellation),
613 120
614 );
615 assert_eq!(
616 RateLimitCache::get_limit_for_action(&limits, RateLimitAction::ApiRequest),
617 600
618 );
619 }
620
621 #[test]
622 fn test_redis_key_generation() {
623 use std::str::FromStr;
624 let wallet = WalletAddress::from_str("0xf39Fd6e51aad88F6F4ce6aB8827279cffFb92266").unwrap();
625 let key = RateLimitCache::redis_key(&wallet, RateLimitAction::OrderPlacement, 1704067200);
626 assert!(key.starts_with("ratelimit:"));
627 assert!(key.contains(":order:"));
628 assert!(key.contains("1704067200"));
629 }
630
631 #[test]
632 fn test_get_window_start() {
633 assert_eq!(RateLimitCache::get_window_start(1704067234, 60), 1704067200);
635 assert_eq!(RateLimitCache::get_window_start(1704067200, 60), 1704067200);
636 assert_eq!(RateLimitCache::get_window_start(1704067259, 60), 1704067200);
637 assert_eq!(RateLimitCache::get_window_start(1704067260, 60), 1704067260);
638 }
639
640 #[test]
641 fn test_action_key_str() {
642 assert_eq!(RateLimitAction::OrderPlacement.as_key_str(), "order");
643 assert_eq!(RateLimitAction::OrderCancellation.as_key_str(), "cancel");
644 assert_eq!(RateLimitAction::ApiRequest.as_key_str(), "api");
645 }
646}