Skip to main content

hypercall/client/wallet_client/
websocket.rs

1use super::WalletClient;
2use futures::{SinkExt, StreamExt};
3use hypercall_types::ws_protocol::WsMessage;
4use tokio::sync::mpsc;
5use tokio::time::Duration;
6use tokio_tungstenite::{connect_async, tungstenite};
7use tracing::{debug, error, info};
8
9impl WalletClient {
10    pub async fn connect_websocket(
11        &self,
12        base_url: &str,
13        with_auth: bool,
14    ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
15        let ws_url = if with_auth {
16            // Use wallet-based authentication for EIP-712 flow
17            format!(
18                "{}/ws?wallet={}",
19                base_url.replace("http", "ws"),
20                self.wallet_address
21            )
22        } else {
23            format!("{}/ws", base_url.replace("http", "ws"))
24        };
25
26        debug!(
27            "Connecting to websocket: {} (wallet: {})",
28            if with_auth {
29                "with auth"
30            } else {
31                "without auth"
32            },
33            self.wallet_address
34        );
35
36        let (ws_stream, response) = connect_async(&ws_url).await?;
37        info!("WebSocket connected, status: {}", response.status());
38
39        let (mut write, mut read) = ws_stream.split();
40        let (tx, mut rx) = mpsc::unbounded_channel();
41
42        // Store the sender
43        *self.ws_tx.lock().await = Some(tx.clone());
44
45        let messages = self.ws_messages.clone();
46        // Spawn task to handle incoming messages
47        tokio::spawn(async move {
48            while let Some(msg) = read.next().await {
49                match msg {
50                    Ok(tungstenite::Message::Text(text)) => {
51                        debug!("WS received: {}", text);
52                        if let Ok(ws_msg) = sonic_rs::from_str::<WsMessage>(&text) {
53                            messages.lock().await.push(ws_msg);
54                        }
55                    }
56                    Ok(tungstenite::Message::Close(_)) => {
57                        info!("WebSocket closed");
58                        break;
59                    }
60                    Ok(tungstenite::Message::Ping(_)) => {
61                        debug!("WS received ping");
62                    }
63                    Ok(tungstenite::Message::Pong(_)) => {
64                        debug!("WS received pong");
65                    }
66                    Err(e) => {
67                        error!("WebSocket error: {}", e);
68                        break;
69                    }
70                    _ => {}
71                }
72            }
73        });
74
75        // Spawn task to handle outgoing messages
76        tokio::spawn(async move {
77            while let Some(msg) = rx.recv().await {
78                if write.send(msg).await.is_err() {
79                    break;
80                }
81            }
82        });
83
84        Ok(())
85    }
86
87    pub async fn subscribe_to_streams(
88        &self,
89        channels: Vec<&str>,
90    ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
91        let tx = self.ws_tx.lock().await;
92        let tx = tx.as_ref().ok_or("WebSocket not connected")?;
93
94        for channel in channels {
95            let subscribe_msg = WsMessage::Subscribe {
96                channel: channel.to_string(),
97                symbols: None,
98                expiry: None,
99                option_type: None,
100            };
101            let json = sonic_rs::to_string(&subscribe_msg)?;
102            debug!("Subscribing to: {}", channel);
103            tx.send(tungstenite::Message::Text(json))?;
104
105            // Wait a bit for subscription confirmation
106            tokio::time::sleep(Duration::from_millis(100)).await;
107        }
108
109        Ok(())
110    }
111
112    pub async fn get_messages(&self) -> Vec<WsMessage> {
113        self.ws_messages.lock().await.clone()
114    }
115
116    pub async fn clear_messages(&self) {
117        self.ws_messages.lock().await.clear();
118    }
119
120    pub async fn wait_for_message_type<F>(&self, check: F, timeout_ms: u64) -> Option<WsMessage>
121    where
122        F: Fn(&WsMessage) -> bool,
123    {
124        let start = std::time::Instant::now();
125        let timeout = Duration::from_millis(timeout_ms);
126
127        while start.elapsed() < timeout {
128            let messages = self.get_messages().await;
129            if let Some(msg) = messages.iter().find(|m| check(m)) {
130                return Some(msg.clone());
131            }
132            tokio::time::sleep(Duration::from_millis(50)).await;
133        }
134
135        None
136    }
137
138    pub async fn count_message_type<F>(&self, check: F) -> usize
139    where
140        F: Fn(&WsMessage) -> bool,
141    {
142        self.get_messages()
143            .await
144            .iter()
145            .filter(|msg| check(msg))
146            .count()
147    }
148}