1use anyhow::{Context, Result};
14use dashmap::DashMap;
15use serde::{Deserialize, Serialize};
16use std::sync::Arc;
17use tokio::sync::mpsc;
18use tracing::{debug, info, warn};
19use web_push::{ContentEncoding, SubscriptionInfo, VapidSignatureBuilder, WebPushMessageBuilder};
20
21use hypercall_db::{
22 PushSubscriptionReader, PushSubscriptionRecord, PushSubscriptionWriter,
23 UpsertPushSubscriptionInput,
24};
25
26#[derive(Debug, Clone, Copy, PartialEq, Eq)]
33pub enum NotificationType {
34 Fill,
35 Liquidation,
36 Settlement,
37}
38
39impl NotificationType {
40 fn preference_key(self) -> &'static str {
41 match self {
42 Self::Fill => "fills",
43 Self::Liquidation => "liquidations",
44 Self::Settlement => "settlements",
45 }
46 }
47}
48
49pub fn default_preferences() -> serde_json::Value {
51 serde_json::json!({
52 "fills": true,
53 "liquidations": true,
54 "settlements": true,
55 })
56}
57
58fn is_enabled(preferences: &serde_json::Value, notif_type: NotificationType) -> bool {
62 match preferences.get(notif_type.preference_key()) {
63 None => true, Some(v) => v.as_bool().unwrap_or(false), }
66}
67
68#[derive(Debug, Clone, Serialize, Deserialize)]
75pub struct PushPayload {
76 pub title: String,
77 pub body: String,
78 pub tag: String,
80 #[serde(skip_serializing_if = "Option::is_none")]
82 pub url: Option<String>,
83}
84
85struct PushJob {
88 wallet: String,
89 notif_type: NotificationType,
90 payload: PushPayload,
91}
92
93const CACHE_MAX_ENTRIES: usize = 50_000;
101
102const MAX_SUBSCRIPTIONS_PER_WALLET: i64 = 10;
105
106const ALLOWED_PUSH_ENDPOINT_PREFIXES: &[&str] = &[
109 "https://fcm.googleapis.com/", "https://updates.push.services.mozilla.com/", "https://wns.windows.com/", "https://web.push.apple.com/", ];
114
115const PUSH_QUEUE_CAPACITY: usize = 1_000;
118
119const MAX_CONCURRENT_SENDS: usize = 16;
121
122pub struct PushNotificationService {
123 db: Arc<dyn PushSubscriptionWriter>,
124 cache: Arc<DashMap<String, Vec<PushSubscriptionRecord>>>,
128 job_tx: mpsc::Sender<PushJob>,
130}
131
132impl PushNotificationService {
133 pub fn new(db: Arc<dyn PushSubscriptionWriter>, vapid_private_pem: &[u8]) -> Result<Self> {
135 VapidSignatureBuilder::from_pem_no_sub(std::io::Cursor::new(vapid_private_pem))
137 .context("VAPID private key is not valid PEM")?;
138
139 let client = Arc::new(reqwest::Client::new());
140
141 let cache: Arc<DashMap<String, Vec<PushSubscriptionRecord>>> = Arc::new(DashMap::new());
142 let vapid_key = Arc::new(vapid_private_pem.to_vec());
143
144 let (job_tx, job_rx) = mpsc::channel::<PushJob>(PUSH_QUEUE_CAPACITY);
145
146 let sender_db = db.clone();
148 let sender_client = client.clone();
149 let sender_cache = cache.clone();
150 let sender_key = vapid_key.clone();
151 tokio::spawn(push_sender_loop(
152 job_rx,
153 sender_db,
154 sender_client,
155 sender_cache,
156 sender_key,
157 ));
158
159 Ok(Self { db, cache, job_tx })
160 }
161
162 pub async fn subscribe(
169 &self,
170 wallet: &str,
171 endpoint: &str,
172 auth_key: &str,
173 p256dh_key: &str,
174 preferences: Option<serde_json::Value>,
175 ) -> Result<PushSubscriptionRecord> {
176 if !ALLOWED_PUSH_ENDPOINT_PREFIXES
178 .iter()
179 .any(|prefix| endpoint.starts_with(prefix))
180 {
181 anyhow::bail!("Push endpoint URL is not from a recognized push service");
182 }
183
184 let wallet_lower = wallet.to_lowercase();
185
186 let existing_count = self
188 .db
189 .count_push_subscriptions(&wallet_lower)
190 .await
191 .context("Failed to count existing push subscriptions")?;
192
193 if existing_count >= MAX_SUBSCRIPTIONS_PER_WALLET {
194 let is_upsert = self
196 .db
197 .push_subscription_exists(&wallet_lower, endpoint)
198 .await
199 .unwrap_or(false);
200 if !is_upsert {
201 anyhow::bail!(
202 "Maximum push subscriptions ({MAX_SUBSCRIPTIONS_PER_WALLET}) reached for this wallet"
203 );
204 }
205 }
206
207 let prefs = preferences.unwrap_or_else(default_preferences);
208
209 let row = self
210 .db
211 .upsert_push_subscription(UpsertPushSubscriptionInput {
212 wallet_address: wallet_lower.clone(),
213 endpoint: endpoint.to_string(),
214 auth_key: auth_key.to_string(),
215 p256dh_key: p256dh_key.to_string(),
216 preferences: prefs,
217 })
218 .await
219 .context("Failed to upsert push subscription")?;
220
221 self.cache.remove(&wallet_lower);
222
223 info!(wallet = wallet, "Push subscription registered");
224 Ok(row)
225 }
226
227 pub async fn update_preferences(
229 &self,
230 wallet: &str,
231 endpoint: &str,
232 preferences: serde_json::Value,
233 ) -> Result<bool> {
234 let wallet_lower = wallet.to_lowercase();
235
236 let updated = self
237 .db
238 .update_push_preferences(&wallet_lower, endpoint, preferences)
239 .await
240 .context("Failed to update push preferences")?;
241
242 self.cache.remove(&wallet_lower);
243
244 Ok(updated)
245 }
246
247 pub async fn unsubscribe(&self, wallet: &str, endpoint: &str) -> Result<bool> {
249 let wallet_lower = wallet.to_lowercase();
250
251 let deleted = self
252 .db
253 .delete_push_subscription(&wallet_lower, endpoint)
254 .await
255 .context("Failed to delete push subscription")?;
256
257 self.cache.remove(&wallet_lower);
258
259 Ok(deleted)
260 }
261
262 pub fn send_fill_notification(
268 &self,
269 wallet: String,
270 action: &str,
271 size: impl std::fmt::Display,
272 symbol: String,
273 price: impl std::fmt::Display,
274 trade_id: u64,
275 realized_pnl: Option<rust_decimal::Decimal>,
276 ) {
277 let pnl_str = match realized_pnl {
278 Some(pnl) if !pnl.is_zero() => format!(" (PnL: ${pnl})"),
279 _ => String::new(),
280 };
281 self.enqueue(PushJob {
282 wallet,
283 notif_type: NotificationType::Fill,
284 payload: PushPayload {
285 title: "Order Filled".to_string(),
286 body: format!("{action} {size} {symbol} @ ${price}{pnl_str}"),
287 tag: format!("fill-{trade_id}"),
288 url: None,
289 },
290 });
291 }
292
293 pub fn send_liquidation_notification(
295 &self,
296 wallet: String,
297 previous_state: impl std::fmt::Display,
298 new_state: impl std::fmt::Display,
299 title: &str,
300 ) {
301 self.enqueue(PushJob {
302 wallet,
303 notif_type: NotificationType::Liquidation,
304 payload: PushPayload {
305 title: title.to_string(),
306 body: format!("Account status: {previous_state} -> {new_state}"),
307 tag: "liquidation".to_string(),
308 url: None,
309 },
310 });
311 }
312
313 fn enqueue(&self, job: PushJob) {
314 metrics::counter!("ht_push_enqueued_total", "type" => job.notif_type.preference_key())
315 .increment(1);
316 if self.job_tx.try_send(job).is_err() {
317 metrics::counter!("ht_push_dropped_total").increment(1);
318 warn!("Push notification queue full, dropping message");
319 }
320 }
321}
322
323async fn get_subscriptions(
329 wallet: &str,
330 db: &dyn PushSubscriptionReader,
331 cache: &DashMap<String, Vec<PushSubscriptionRecord>>,
332) -> Vec<PushSubscriptionRecord> {
333 let wallet_lower = wallet.to_lowercase();
334
335 if let Some(cached) = cache.get(&wallet_lower) {
337 return cached.clone();
338 }
339
340 let subs = match db.get_push_subscriptions(&wallet_lower).await {
342 Ok(s) => s,
343 Err(e) => {
344 warn!(wallet, "Failed to load push subscriptions: {e}");
345 return vec![];
346 }
347 };
348
349 if cache.len() >= CACHE_MAX_ENTRIES {
351 warn!(
352 entries = cache.len(),
353 "Push subscription cache exceeded max size, clearing"
354 );
355 cache.clear();
356 }
357
358 cache.insert(wallet_lower, subs.clone());
359 subs
360}
361
362async fn remove_stale(
363 id: i64,
364 wallet_lower: &str,
365 db: &dyn PushSubscriptionWriter,
366 cache: &DashMap<String, Vec<PushSubscriptionRecord>>,
367) {
368 let _ = db.delete_push_subscription_by_id(id).await;
369 cache.remove(wallet_lower);
370 debug!(id, "Removed stale push subscription");
371}
372
373fn to_base64url(s: &str) -> String {
375 s.replace('+', "-")
376 .replace('/', "_")
377 .trim_end_matches('=')
378 .to_string()
379}
380
381async fn push_sender_loop(
384 mut rx: mpsc::Receiver<PushJob>,
385 db: Arc<dyn PushSubscriptionWriter>,
386 client: Arc<reqwest::Client>,
387 cache: Arc<DashMap<String, Vec<PushSubscriptionRecord>>>,
388 vapid_key: Arc<Vec<u8>>,
389) {
390 let semaphore = Arc::new(tokio::sync::Semaphore::new(MAX_CONCURRENT_SENDS));
391
392 while let Some(job) = rx.recv().await {
393 let subs = get_subscriptions(&job.wallet, db.as_ref(), &cache).await;
394 if subs.is_empty() {
395 continue;
396 }
397
398 let json = match serde_json::to_vec(&job.payload) {
399 Ok(j) => j,
400 Err(e) => {
401 warn!("Failed to serialize push payload: {e}");
402 continue;
403 }
404 };
405
406 let wallet_lower = job.wallet.to_lowercase();
407
408 for sub in subs {
409 if !is_enabled(&sub.preferences, job.notif_type) {
410 continue;
411 }
412
413 if !ALLOWED_PUSH_ENDPOINT_PREFIXES
415 .iter()
416 .any(|prefix| sub.endpoint.starts_with(prefix))
417 {
418 warn!(endpoint = %sub.endpoint, "Skipping push to unrecognized endpoint");
419 continue;
420 }
421
422 let permit = match semaphore.clone().acquire_owned().await {
424 Ok(p) => p,
425 Err(_) => break, };
427
428 let client = client.clone();
429 let vapid_key = vapid_key.clone();
430 let json = json.clone();
431 let db = db.clone();
432 let cache = cache.clone();
433 let wallet_lower = wallet_lower.clone();
434 let wallet_for_log = job.wallet.clone();
435
436 tokio::spawn(async move {
437 let _permit = permit; let p256dh = to_base64url(&sub.p256dh_key);
441 let auth = to_base64url(&sub.auth_key);
442 let info = SubscriptionInfo::new(&sub.endpoint, &p256dh, &auth);
443
444 let partial_builder = match VapidSignatureBuilder::from_pem_no_sub(
445 std::io::Cursor::new(vapid_key.as_ref()),
446 ) {
447 Ok(b) => b,
448 Err(e) => {
449 warn!("Failed to create VAPID builder: {e}");
450 return;
451 }
452 };
453
454 let sig = match partial_builder.add_sub_info(&info).build() {
455 Ok(sig) => sig,
456 Err(e) => {
457 warn!("Failed to build VAPID signature: {e}");
458 return;
459 }
460 };
461
462 let mut builder = WebPushMessageBuilder::new(&info);
463 builder.set_payload(ContentEncoding::Aes128Gcm, &json);
464 builder.set_vapid_signature(sig);
465
466 let message = match builder.build() {
467 Ok(m) => m,
468 Err(e) => {
469 warn!("Failed to build push message: {e}");
470 return;
471 }
472 };
473
474 let http_request = web_push::request_builder::build_request::<Vec<u8>>(message);
476 let (parts, body) = http_request.into_parts();
477 let url = parts.uri.to_string();
478
479 let mut req = client.post(&url);
480 for (name, value) in &parts.headers {
481 if let Ok(v) = value.to_str() {
482 req = req.header(name.as_str(), v);
483 }
484 }
485 req = req.body(body);
486
487 let send_start = std::time::Instant::now();
488 match req.send().await {
489 Ok(resp) => {
490 let status = resp.status().as_u16();
491 if (200..300).contains(&(status as usize)) {
492 metrics::counter!("ht_push_sent_total").increment(1);
493 metrics::histogram!("ht_push_send_duration_seconds")
494 .record(send_start.elapsed().as_secs_f64());
495 debug!(wallet = wallet_for_log, endpoint = %sub.endpoint, status, "Push notification sent");
496 } else {
497 metrics::counter!("ht_push_errors_total").increment(1);
498 let body = resp.text().await.unwrap_or_default();
499 warn!(
500 wallet = wallet_for_log,
501 endpoint = %sub.endpoint,
502 status,
503 body = %body,
504 "Push send failed"
505 );
506 if status == 404 || status == 410 {
507 metrics::counter!("ht_push_stale_removed_total").increment(1);
508 remove_stale(sub.id, &wallet_lower, db.as_ref(), &cache).await;
509 }
510 }
511 }
512 Err(e) => {
513 metrics::counter!("ht_push_errors_total").increment(1);
514 warn!(
515 wallet = wallet_for_log,
516 endpoint = %sub.endpoint,
517 error = %e,
518 "Push HTTP request failed"
519 );
520 }
521 }
522 });
523 }
524 }
525
526 info!("Push sender loop exited");
527}
528
529#[cfg(test)]
530mod tests {
531 use super::*;
532 use std::time::Instant;
533
534 #[test]
538 fn cache_lookup_no_subscribers_baseline() {
539 let cache: DashMap<String, Vec<PushSubscriptionRecord>> = DashMap::new();
540 cache.insert("0xdeadbeef".to_string(), vec![]);
541
542 let iterations = 100_000;
543 let start = Instant::now();
544
545 for _ in 0..iterations {
546 let result = cache.get("0xdeadbeef");
547 assert!(result.is_some());
548 assert!(result.unwrap().is_empty());
549 }
550
551 let elapsed = start.elapsed();
552 let per_op_ns = elapsed.as_nanos() / iterations as u128;
553
554 eprintln!(
556 "push_service cache lookup: {per_op_ns}ns/op ({iterations} iterations in {:?})",
557 elapsed
558 );
559 }
560
561 #[test]
563 fn cache_miss_populates_entry() {
564 let cache: DashMap<String, Vec<PushSubscriptionRecord>> = DashMap::new();
565
566 assert!(cache.get("0xunknown").is_none());
568
569 cache.insert("0xunknown".to_string(), vec![]);
571 assert!(cache.get("0xunknown").is_some());
572 }
573
574 #[test]
575 fn is_enabled_defaults_to_true() {
576 let prefs = serde_json::json!({});
577 assert!(is_enabled(&prefs, NotificationType::Fill));
578 assert!(is_enabled(&prefs, NotificationType::Liquidation));
579 assert!(is_enabled(&prefs, NotificationType::Settlement));
580 }
581
582 #[test]
583 fn is_enabled_respects_false() {
584 let prefs = serde_json::json!({
585 "fills": true,
586 "liquidations": false,
587 "settlements": true,
588 });
589 assert!(is_enabled(&prefs, NotificationType::Fill));
590 assert!(!is_enabled(&prefs, NotificationType::Liquidation));
591 assert!(is_enabled(&prefs, NotificationType::Settlement));
592 }
593
594 #[test]
595 fn default_preferences_enables_all() {
596 let prefs = default_preferences();
597 assert!(is_enabled(&prefs, NotificationType::Fill));
598 assert!(is_enabled(&prefs, NotificationType::Liquidation));
599 assert!(is_enabled(&prefs, NotificationType::Settlement));
600 }
601
602 #[test]
603 fn notification_type_preference_keys() {
604 assert_eq!(NotificationType::Fill.preference_key(), "fills");
605 assert_eq!(
606 NotificationType::Liquidation.preference_key(),
607 "liquidations"
608 );
609 assert_eq!(NotificationType::Settlement.preference_key(), "settlements");
610 }
611
612 #[test]
613 fn cache_eviction_on_max_entries() {
614 let cache: DashMap<String, Vec<PushSubscriptionRecord>> = DashMap::new();
615
616 for i in 0..CACHE_MAX_ENTRIES {
618 cache.insert(format!("0x{i:040x}"), vec![]);
619 }
620 assert_eq!(cache.len(), CACHE_MAX_ENTRIES);
621
622 if cache.len() >= CACHE_MAX_ENTRIES {
624 cache.clear();
625 }
626 assert_eq!(cache.len(), 0);
627 }
628
629 #[test]
630 fn is_enabled_rejects_non_bool_values() {
631 let prefs = serde_json::json!({"fills": "yes", "liquidations": 1});
633 assert!(!is_enabled(&prefs, NotificationType::Fill));
634 assert!(!is_enabled(&prefs, NotificationType::Liquidation));
635 assert!(is_enabled(&prefs, NotificationType::Settlement));
637 }
638
639 #[test]
640 fn is_enabled_handles_null_preferences() {
641 let prefs = serde_json::Value::Null;
642 assert!(is_enabled(&prefs, NotificationType::Fill));
644 assert!(is_enabled(&prefs, NotificationType::Liquidation));
645 assert!(is_enabled(&prefs, NotificationType::Settlement));
646 }
647
648 #[test]
649 fn try_send_drops_when_full() {
650 let (tx, _rx) = mpsc::channel::<PushJob>(1);
652
653 let result1 = tx.try_send(PushJob {
655 wallet: "0x1".to_string(),
656 notif_type: NotificationType::Fill,
657 payload: PushPayload {
658 title: "t".to_string(),
659 body: "b".to_string(),
660 tag: "tag".to_string(),
661 url: None,
662 },
663 });
664 assert!(result1.is_ok());
665
666 let result2 = tx.try_send(PushJob {
668 wallet: "0x2".to_string(),
669 notif_type: NotificationType::Fill,
670 payload: PushPayload {
671 title: "t".to_string(),
672 body: "b".to_string(),
673 tag: "tag".to_string(),
674 url: None,
675 },
676 });
677 assert!(result2.is_err());
678 }
679
680 #[test]
681 fn allowed_endpoint_prefixes_accept_known_services() {
682 let valid = [
683 "https://fcm.googleapis.com/fcm/send/abc123",
684 "https://updates.push.services.mozilla.com/wpush/v2/abc",
685 "https://wns.windows.com/w/?token=abc",
686 "https://web.push.apple.com/abc",
687 ];
688 for url in &valid {
689 assert!(
690 ALLOWED_PUSH_ENDPOINT_PREFIXES
691 .iter()
692 .any(|p| url.starts_with(p)),
693 "Expected {url} to be allowed"
694 );
695 }
696 }
697
698 #[test]
699 fn allowed_endpoint_prefixes_reject_arbitrary_urls() {
700 let invalid = [
701 "https://evil.com/steal",
702 "http://fcm.googleapis.com/http-not-https",
703 "https://example.com",
704 "file:///etc/passwd",
705 "",
706 ];
707 for url in &invalid {
708 assert!(
709 !ALLOWED_PUSH_ENDPOINT_PREFIXES
710 .iter()
711 .any(|p| url.starts_with(p)),
712 "Expected {url} to be rejected"
713 );
714 }
715 }
716}