1use anyhow::{anyhow, Context, Result};
2use async_trait::async_trait;
3use redis::aio::ConnectionManager;
4use redis::{Client, ConnectionAddr, ConnectionInfo, ProtocolVersion, RedisConnectionInfo, Script};
5use std::sync::Arc;
6use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
7use tokio::sync::Mutex;
8use tracing::{info, warn};
9
10const UPSTASH_CONNECT_TIMEOUT: Duration = Duration::from_secs(5);
11const UPSTASH_COMMAND_TIMEOUT: Duration = Duration::from_secs(10);
12const PER_SOURCE_TIMEOUT: Duration = Duration::from_secs(3);
13const ENDPOINT_ENV: &str = "MARKETS_SNAPSHOT_UPSTASH_ENDPOINT";
14const PORT_ENV: &str = "MARKETS_SNAPSHOT_UPSTASH_PORT";
15const PASSWORD_ENV: &str = "MARKETS_SNAPSHOT_UPSTASH_PASSWORD";
16
17#[async_trait]
19pub trait UpstashSnapshotSource: Send + Sync {
20 fn name(&self) -> &'static str;
22 fn key(&self) -> &str;
24 fn ttl_seconds(&self) -> u64;
26 async fn next_payload(&self) -> Result<Option<(Vec<u8>, u64)>>;
30 fn on_success(&self, built_at_ms: u64, elapsed: Duration);
32 fn on_skip(&self);
34 fn on_error(&self);
36 fn min_interval(&self) -> Option<Duration> {
41 None
42 }
43}
44
45#[derive(Debug, Clone, PartialEq, Eq)]
49struct UpstashRedisConfig {
50 endpoint: String,
51 port: u16,
52 password: String,
53}
54
55impl UpstashRedisConfig {
56 fn from_env() -> Result<Option<Self>> {
57 Self::from_env_parts(
58 std::env::var(ENDPOINT_ENV).ok(),
59 std::env::var(PORT_ENV).ok(),
60 std::env::var(PASSWORD_ENV).ok(),
61 )
62 }
63
64 fn from_env_parts(
65 endpoint: Option<String>,
66 port: Option<String>,
67 password: Option<String>,
68 ) -> Result<Option<Self>> {
69 if endpoint.is_none() && port.is_none() && password.is_none() {
70 return Ok(None);
71 }
72
73 let endpoint = require_non_empty(endpoint, ENDPOINT_ENV)?;
74 let port_raw = require_non_empty(port, PORT_ENV)?;
75 let password = require_non_empty(password, PASSWORD_ENV)?;
76
77 let port = port_raw
78 .parse::<u16>()
79 .map_err(|e| anyhow!("Invalid {} '{}': {}", PORT_ENV, port_raw, e))?;
80
81 Ok(Some(Self {
82 endpoint,
83 port,
84 password,
85 }))
86 }
87
88 fn build_client(&self) -> Result<Client> {
89 let connection_info = ConnectionInfo {
90 addr: ConnectionAddr::TcpTls {
91 host: self.endpoint.clone(),
92 port: self.port,
93 insecure: false,
94 tls_params: None,
95 },
96 redis: RedisConnectionInfo {
97 db: 0,
98 username: Some("default".to_string()),
99 password: Some(self.password.clone()),
100 protocol: ProtocolVersion::RESP2,
101 },
102 };
103
104 Client::open(connection_info).context("Failed to build Upstash Redis client")
105 }
106}
107
108fn require_non_empty(value: Option<String>, env_name: &str) -> Result<String> {
109 let value = value.ok_or_else(|| {
110 anyhow!(
111 "{} is required when Upstash publishing is configured",
112 env_name
113 )
114 })?;
115 let value = value.trim().to_string();
116 if value.is_empty() {
117 return Err(anyhow!("{} must not be empty", env_name));
118 }
119 Ok(value)
120}
121
122pub fn system_time_to_millis(value: SystemTime) -> Result<u64> {
123 Ok(value
124 .duration_since(UNIX_EPOCH)
125 .map_err(|e| anyhow!("Invalid snapshot build time: {}", e))?
126 .as_millis() as u64)
127}
128
129pub struct UpstashBatchPublisher {
134 client: Client,
135 connection: Mutex<Option<ConnectionManager>>,
136 sources: Vec<Arc<dyn UpstashSnapshotSource>>,
137 script: Script,
138}
139
140impl UpstashBatchPublisher {
141 pub fn from_env() -> Result<Option<Self>> {
144 let Some(config) = UpstashRedisConfig::from_env()? else {
145 return Ok(None);
146 };
147
148 let client = config.build_client()?;
149 Ok(Some(Self {
150 client,
151 connection: Mutex::new(None),
152 sources: Vec::new(),
153 script: Script::new(
154 "for i = 1, #KEYS do redis.call('SET', KEYS[i], ARGV[2*i-1], 'EX', tonumber(ARGV[2*i])) end return #KEYS",
155 ),
156 }))
157 }
158
159 pub fn client(&self) -> Client {
162 self.client.clone()
163 }
164
165 pub fn with_sources(mut self, sources: Vec<Arc<dyn UpstashSnapshotSource>>) -> Self {
167 self.sources = sources;
168 self
169 }
170
171 async fn tick(&self) {
172 let mut entries: Vec<(&Arc<dyn UpstashSnapshotSource>, Vec<u8>, u64)> = Vec::new();
176
177 for source in &self.sources {
178 match tokio::time::timeout(PER_SOURCE_TIMEOUT, source.next_payload()).await {
179 Ok(Ok(Some((payload, built_at_ms)))) => {
180 entries.push((source, payload, built_at_ms));
181 }
182 Ok(Ok(None)) => {
183 source.on_skip();
184 }
185 Ok(Err(error)) => {
186 source.on_error();
187 warn!(
188 source = source.name(),
189 error = %error,
190 "Upstash source failed to produce payload; skipping"
191 );
192 }
193 Err(_) => {
194 source.on_error();
195 warn!(
196 source = source.name(),
197 "Upstash source timed out after {:?}; skipping", PER_SOURCE_TIMEOUT
198 );
199 }
200 }
201 }
202
203 if entries.is_empty() {
204 return;
205 }
206
207 let start = Instant::now();
209 let result = self.eval_batch(&entries).await;
210 let elapsed = start.elapsed();
211
212 match result {
213 Ok(()) => {
214 for (source, _, built_at_ms) in &entries {
215 source.on_success(*built_at_ms, elapsed);
216 }
217 metrics::counter!("ht_upstash_batch_publish_total", "status" => "success")
218 .increment(1);
219 metrics::histogram!("ht_upstash_batch_publish_seconds")
220 .record(elapsed.as_secs_f64());
221 }
222 Err(error) => {
223 for (source, _, _) in &entries {
224 source.on_error();
225 }
226 metrics::counter!("ht_upstash_batch_publish_total", "status" => "error")
227 .increment(1);
228 warn!(
229 error = %error,
230 keys = entries.len(),
231 "Upstash batch EVAL failed; will retry next tick"
232 );
233 }
234 }
235 }
236
237 async fn eval_batch(
238 &self,
239 entries: &[(&Arc<dyn UpstashSnapshotSource>, Vec<u8>, u64)],
240 ) -> Result<()> {
241 let connection = self.ensure_connection().await?;
242
243 let keys: Vec<&str> = entries.iter().map(|(src, _, _)| src.key()).collect();
244 let mut invocation = self.script.prepare_invoke();
245 for key in &keys {
246 invocation.key(*key);
247 }
248 for (source, payload, _) in entries {
250 invocation.arg(payload.as_slice());
251 invocation.arg(source.ttl_seconds());
252 }
253
254 let timeout_result = tokio::time::timeout(
255 UPSTASH_COMMAND_TIMEOUT,
256 invocation.invoke_async::<i64>(&mut connection.clone()),
257 )
258 .await;
259
260 match timeout_result {
261 Ok(Ok(_count)) => Ok(()),
262 Ok(Err(redis_err)) => {
263 let mut conn = self.connection.lock().await;
264 *conn = None;
265 Err(anyhow!(redis_err).context("Upstash batch EVAL failed"))
266 }
267 Err(_) => {
268 let mut conn = self.connection.lock().await;
269 *conn = None;
270 Err(anyhow!(
271 "Upstash batch EVAL timed out after {:?}",
272 UPSTASH_COMMAND_TIMEOUT
273 ))
274 }
275 }
276 }
277
278 async fn ensure_connection(&self) -> Result<ConnectionManager> {
279 let mut guard = self.connection.lock().await;
280 if let Some(ref cm) = *guard {
281 return Ok(cm.clone());
282 }
283
284 let manager = tokio::time::timeout(
285 UPSTASH_CONNECT_TIMEOUT,
286 ConnectionManager::new(self.client.clone()),
287 )
288 .await
289 .map_err(|_| {
290 anyhow!(
291 "Timed out connecting to Upstash Redis after {:?}",
292 UPSTASH_CONNECT_TIMEOUT
293 )
294 })?
295 .context("Failed to connect to Upstash Redis")?;
296
297 info!("UpstashBatchPublisher connected to Upstash");
298 *guard = Some(manager.clone());
299 Ok(manager)
300 }
301}
302
303impl UpstashBatchPublisher {
304 pub async fn run_with_shutdown(
305 self: std::sync::Arc<Self>,
306 mut shutdown: tokio::sync::broadcast::Receiver<()>,
307 ) -> anyhow::Result<()> {
308 use futures::FutureExt;
309 use std::panic::AssertUnwindSafe;
310
311 let mut interval = tokio::time::interval(std::time::Duration::from_secs(1));
312 metrics::counter!("ht_upstash_batch_loop_total", "status" => "started").increment(1);
313
314 loop {
315 tokio::select! {
316 _ = shutdown.recv() => {
317 metrics::counter!("ht_upstash_batch_loop_total", "status" => "shutdown").increment(1);
318 tracing::debug!("UpstashBatchPublisher received shutdown signal");
319 break;
320 }
321 _ = interval.tick() => {
322 metrics::counter!("ht_upstash_batch_loop_total", "status" => "tick").increment(1);
323 match AssertUnwindSafe(self.tick()).catch_unwind().await {
324 Ok(()) => {}
325 Err(panic_info) => {
326 {
327 let mut conn = self.connection.lock().await;
328 *conn = None;
329 }
330 metrics::counter!("ht_upstash_batch_loop_total", "status" => "panic").increment(1);
331 let panic_msg = if let Some(message) = panic_info.downcast_ref::<&str>() {
332 *message
333 } else if let Some(message) = panic_info.downcast_ref::<String>() {
334 message.as_str()
335 } else {
336 "unknown panic payload"
337 };
338 tracing::error!(panic = panic_msg, "UpstashBatchPublisher panicked; resetting connection and continuing");
339 }
340 }
341 }
342 }
343 }
344 Ok(())
345 }
346}