Skip to main content

hypercall/nats/
mod.rs

1pub mod deserialize;
2pub mod replay_loop;
3pub mod standby_handler;
4#[cfg(test)]
5mod tests;
6
7use async_nats::jetstream;
8use std::sync::Arc;
9use tokio::sync::Mutex;
10use tracing::{error, info, warn};
11
12pub const COMMAND_PAYLOAD_PREFIX_LEN: usize = 10;
13pub const LEGACY_COMMAND_PAYLOAD_PREFIX_LEN: usize = 9;
14pub const COMMAND_WIRE_VERSION_V1: u8 = 1;
15pub const LEGACY_COMMAND_WIRE_VERSION: u8 = 0;
16
17fn stream_name(env: &str) -> String {
18    format!("ENGINE_COMMANDS_{}", env.to_uppercase())
19}
20
21fn subject(env: &str) -> String {
22    format!("engine.commands.{}", env.to_lowercase())
23}
24
25fn balance_update_stream_name(env: &str) -> String {
26    format!("ENGINE_BALANCE_UPDATES_{}", env.to_uppercase())
27}
28
29fn balance_update_subject(env: &str) -> String {
30    format!("engine.balance_updates.{}", env.to_lowercase())
31}
32
33fn decode_balance_update_wire(data: &[u8]) -> anyhow::Result<hypercall_types::BalanceUpdate> {
34    if data.first() != Some(&hypercall_types::WIRE_FORMAT_VERSION) {
35        anyhow::bail!("balance update payload has unsupported wire version");
36    }
37    rmp_serde::from_slice(&data[1..])
38        .map_err(|error| anyhow::anyhow!("balance update payload deserialize failed: {error}"))
39}
40
41async fn get_or_create_command_stream(
42    js: &jetstream::Context,
43    config: &NatsConfig,
44) -> anyhow::Result<jetstream::stream::Stream> {
45    let desired_config = command_stream_config(config);
46    let mut stream = js.get_or_create_stream(desired_config.clone()).await?;
47    let current_config = stream.cached_info().config.clone();
48
49    if current_config.subjects != desired_config.subjects
50        || current_config.retention != desired_config.retention
51        || current_config.max_age != desired_config.max_age
52        || current_config.storage != desired_config.storage
53    {
54        let mut updated_config = current_config;
55        updated_config.subjects = desired_config.subjects.clone();
56        updated_config.retention = desired_config.retention;
57        updated_config.max_age = desired_config.max_age;
58        updated_config.storage = desired_config.storage;
59        js.update_stream(&updated_config).await?;
60        stream.info().await?;
61        info!(
62            stream = %desired_config.name,
63            max_age_secs = config.stream_max_age_secs,
64            "Updated NATS JetStream command stream config"
65        );
66    }
67
68    Ok(stream)
69}
70
71async fn get_or_create_balance_update_stream(
72    js: &jetstream::Context,
73    config: &NatsConfig,
74) -> anyhow::Result<jetstream::stream::Stream> {
75    let desired_config = balance_update_stream_config(config);
76    let mut stream = js.get_or_create_stream(desired_config.clone()).await?;
77    let current_config = stream.cached_info().config.clone();
78
79    if current_config.subjects != desired_config.subjects
80        || current_config.retention != desired_config.retention
81        || current_config.max_age != desired_config.max_age
82        || current_config.storage != desired_config.storage
83    {
84        let mut updated_config = current_config;
85        updated_config.subjects = desired_config.subjects.clone();
86        updated_config.retention = desired_config.retention;
87        updated_config.max_age = desired_config.max_age;
88        updated_config.storage = desired_config.storage;
89        js.update_stream(&updated_config).await?;
90        stream.info().await?;
91        info!(
92            stream = %desired_config.name,
93            max_age_secs = config.stream_max_age_secs,
94            "Updated NATS JetStream balance update stream config"
95        );
96    }
97
98    Ok(stream)
99}
100
101fn command_stream_config(config: &NatsConfig) -> jetstream::stream::Config {
102    jetstream::stream::Config {
103        name: stream_name(&config.env),
104        subjects: vec![subject(&config.env)],
105        retention: jetstream::stream::RetentionPolicy::Limits,
106        max_age: std::time::Duration::from_secs(config.stream_max_age_secs),
107        storage: jetstream::stream::StorageType::File,
108        ..Default::default()
109    }
110}
111
112fn balance_update_stream_config(config: &NatsConfig) -> jetstream::stream::Config {
113    jetstream::stream::Config {
114        name: balance_update_stream_name(&config.env),
115        subjects: vec![balance_update_subject(&config.env)],
116        retention: jetstream::stream::RetentionPolicy::Limits,
117        max_age: std::time::Duration::from_secs(config.stream_max_age_secs),
118        storage: jetstream::stream::StorageType::File,
119        ..Default::default()
120    }
121}
122
123#[derive(Debug, Clone, Copy, PartialEq, Eq)]
124#[repr(u8)]
125pub enum CommandType {
126    Order = 0,
127    PriceUpdate = 1,
128    IvUpdate = 2,
129    MarketAction = 3,
130    LiquidationState = 4,
131    TierUpdate = 5,
132    HypercorePositionUpdate = 6,
133    MmpConfigUpdate = 7,
134    RfqExecute = 8,
135    TradingModeUpdate = 9,
136    TickExpiry = 10,
137    DepositUpdate = 11,
138    LiquidationBonusUpdate = 12,
139    ApproveAgent = 13,
140    RevokeAgent = 14,
141    NonceAdvance = 15,
142    HypercoreEquityUpdate = 16,
143    OptionDepositUpdate = 17,
144    OptionWithdrawalUpdate = 18,
145    CashWithdrawalUpdate = 19,
146    SetPmSettlementPoolConfig = 20,
147    AccruePmSettlementInterest = 22,
148    ApplyPmSettlementRepayment = 23,
149    JournalPmRecoveryPlan = 24,
150    MarkPmRecoveryActionSubmitted = 25,
151    ResolvePmRecoveryAction = 26,
152    RecordPmVaultDeposit = 27,
153    RequestPmVaultWithdrawal = 28,
154}
155
156impl CommandType {
157    pub fn from_u8(v: u8) -> Option<Self> {
158        match v {
159            0 => Some(Self::Order),
160            1 => Some(Self::PriceUpdate),
161            2 => Some(Self::IvUpdate),
162            3 => Some(Self::MarketAction),
163            4 => Some(Self::LiquidationState),
164            5 => Some(Self::TierUpdate),
165            6 => Some(Self::HypercorePositionUpdate),
166            7 => Some(Self::MmpConfigUpdate),
167            8 => Some(Self::RfqExecute),
168            9 => Some(Self::TradingModeUpdate),
169            10 => Some(Self::TickExpiry),
170            11 => Some(Self::DepositUpdate),
171            12 => Some(Self::LiquidationBonusUpdate),
172            13 => Some(Self::ApproveAgent),
173            14 => Some(Self::RevokeAgent),
174            15 => Some(Self::NonceAdvance),
175            16 => Some(Self::HypercoreEquityUpdate),
176            17 => Some(Self::OptionDepositUpdate),
177            18 => Some(Self::OptionWithdrawalUpdate),
178            19 => Some(Self::CashWithdrawalUpdate),
179            20 => Some(Self::SetPmSettlementPoolConfig),
180            21 => None,
181            22 => Some(Self::AccruePmSettlementInterest),
182            23 => Some(Self::ApplyPmSettlementRepayment),
183            24 => Some(Self::JournalPmRecoveryPlan),
184            25 => Some(Self::MarkPmRecoveryActionSubmitted),
185            26 => Some(Self::ResolvePmRecoveryAction),
186            27 => Some(Self::RecordPmVaultDeposit),
187            28 => Some(Self::RequestPmVaultWithdrawal),
188            _ => None,
189        }
190    }
191
192    pub fn wire_version(self) -> u8 {
193        COMMAND_WIRE_VERSION_V1
194    }
195}
196
197#[derive(Debug, Clone, PartialEq, Eq)]
198pub(crate) struct ParsedCommandPayload {
199    pub seq: u64,
200    pub command_type: CommandType,
201    pub command_version: u8,
202    pub command_data: Vec<u8>,
203}
204
205pub(crate) fn parse_command_payload(payload: &[u8]) -> anyhow::Result<ParsedCommandPayload> {
206    if payload.len() < LEGACY_COMMAND_PAYLOAD_PREFIX_LEN {
207        anyhow::bail!("NATS message too short ({})", payload.len());
208    }
209
210    let seq = u64::from_be_bytes(payload[..8].try_into().unwrap());
211    let command_type = CommandType::from_u8(payload[8])
212        .ok_or_else(|| anyhow::anyhow!("Unknown command type byte {}", payload[8]))?;
213
214    if payload.len() < COMMAND_PAYLOAD_PREFIX_LEN {
215        return Ok(ParsedCommandPayload {
216            seq,
217            command_type,
218            command_version: LEGACY_COMMAND_WIRE_VERSION,
219            command_data: payload[LEGACY_COMMAND_PAYLOAD_PREFIX_LEN..].to_vec(),
220        });
221    }
222
223    if payload.len() > COMMAND_PAYLOAD_PREFIX_LEN
224        && payload[9] == hypercall_types::WIRE_FORMAT_VERSION
225        && payload[10] != hypercall_types::WIRE_FORMAT_VERSION
226    {
227        return Ok(ParsedCommandPayload {
228            seq,
229            command_type,
230            command_version: LEGACY_COMMAND_WIRE_VERSION,
231            command_data: payload[LEGACY_COMMAND_PAYLOAD_PREFIX_LEN..].to_vec(),
232        });
233    }
234
235    Ok(ParsedCommandPayload {
236        seq,
237        command_type,
238        command_version: payload[9],
239        command_data: payload[COMMAND_PAYLOAD_PREFIX_LEN..].to_vec(),
240    })
241}
242
243pub struct NatsConfig {
244    pub url: String,
245    pub env: String,
246    pub stream_max_age_secs: u64,
247}
248
249impl NatsConfig {
250    pub fn from_env() -> Option<Self> {
251        let url = std::env::var("NATS_URL").ok()?;
252        let env = std::env::var("NATS_ENV").unwrap_or_else(|_| "default".to_string());
253        let max_age = std::env::var("NATS_STREAM_MAX_AGE_SECS")
254            .ok()
255            .and_then(|v| v.parse().ok())
256            .unwrap_or(86400);
257        Some(Self {
258            url,
259            env,
260            stream_max_age_secs: max_age,
261        })
262    }
263
264    pub async fn latest_stream_sequence(&self) -> anyhow::Result<u64> {
265        let client = async_nats::connect(&self.url).await?;
266        let js = jetstream::new(client);
267        let mut stream = get_or_create_command_stream(&js, self).await?;
268        Ok(stream.info().await?.state.last_sequence)
269    }
270}
271
272/// Publishes engine commands to NATS JetStream.
273/// Used by the primary engine after each apply().
274#[derive(Clone)]
275pub struct NatsPublisher {
276    js: jetstream::Context,
277    subject: String,
278    seq: Arc<std::sync::atomic::AtomicU64>,
279    published: Arc<std::sync::atomic::AtomicU64>,
280}
281
282#[derive(Clone)]
283pub struct NatsBalanceUpdatePublisher {
284    js: jetstream::Context,
285    subject: String,
286    last_acked_stream_sequence: Arc<std::sync::atomic::AtomicU64>,
287    last_acked_balance_update_seq: Arc<std::sync::atomic::AtomicU64>,
288}
289
290impl NatsPublisher {
291    pub async fn connect(config: &NatsConfig) -> anyhow::Result<Self> {
292        let client = async_nats::connect(&config.url).await?;
293        let js = jetstream::new(client);
294
295        let sname = stream_name(&config.env);
296        let subj = subject(&config.env);
297
298        let mut stream = get_or_create_command_stream(&js, config).await?;
299
300        let next_command_seq = stream.info().await?.state.last_sequence.saturating_add(1);
301
302        info!(url = %config.url, env = %config.env, stream = %sname, "Connected to NATS JetStream");
303
304        Ok(Self {
305            js,
306            subject: subj,
307            seq: Arc::new(std::sync::atomic::AtomicU64::new(next_command_seq)),
308            published: Arc::new(std::sync::atomic::AtomicU64::new(0)),
309        })
310    }
311
312    /// Publish a command to the stream. Called after apply() in the engine loop.
313    /// Wire format: [seq: u64 BE][type: u8][command_version: u8][command_data bytes]
314    pub async fn publish(&self, command_type: CommandType, command_data: &[u8]) {
315        let seq = self.seq.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
316        self.published
317            .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
318        let mut payload = Vec::with_capacity(COMMAND_PAYLOAD_PREFIX_LEN + command_data.len());
319        payload.extend_from_slice(&seq.to_be_bytes());
320        payload.push(command_type as u8);
321        payload.push(command_type.wire_version());
322        payload.extend_from_slice(command_data);
323
324        match self.js.publish(self.subject.clone(), payload.into()).await {
325            Ok(ack) => {
326                if let Err(e) = ack.await {
327                    error!(%e, seq, "NATS publish ack failed");
328                }
329            }
330            Err(e) => {
331                error!(%e, seq, "NATS publish failed");
332            }
333        }
334    }
335
336    pub fn published_count(&self) -> u64 {
337        self.published.load(std::sync::atomic::Ordering::Relaxed)
338    }
339}
340
341impl NatsBalanceUpdatePublisher {
342    pub async fn connect(config: &NatsConfig) -> anyhow::Result<Self> {
343        let client = async_nats::connect(&config.url).await?;
344        let js = jetstream::new(client);
345        let sname = balance_update_stream_name(&config.env);
346        let subj = balance_update_subject(&config.env);
347        let mut stream = get_or_create_balance_update_stream(&js, config).await?;
348        let last_sequence = stream.info().await?.state.last_sequence;
349        let last_acked_balance_update_seq = if last_sequence == 0 {
350            0
351        } else {
352            let message = stream
353                .get_raw_message(last_sequence)
354                .await
355                .map_err(|error| {
356                    anyhow::anyhow!(
357                        "failed to read latest NATS balance update at stream seq {}: {}",
358                        last_sequence,
359                        error
360                    )
361                })?;
362            let update = decode_balance_update_wire(&message.payload).map_err(|error| {
363                anyhow::anyhow!(
364                    "failed to decode latest NATS balance update at stream seq {}: {}",
365                    last_sequence,
366                    error
367                )
368            })?;
369            update.balance_update_seq
370        };
371
372        info!(
373            url = %config.url,
374            env = %config.env,
375            stream = %sname,
376            "Connected to NATS JetStream balance update stream"
377        );
378
379        Ok(Self {
380            js,
381            subject: subj,
382            last_acked_stream_sequence: Arc::new(std::sync::atomic::AtomicU64::new(last_sequence)),
383            last_acked_balance_update_seq: Arc::new(std::sync::atomic::AtomicU64::new(
384                last_acked_balance_update_seq,
385            )),
386        })
387    }
388
389    pub async fn publish(&self, update: &hypercall_types::BalanceUpdate) {
390        let payload = hypercall_types::serialize_to_wire_bytes(update);
391        match self.js.publish(self.subject.clone(), payload.into()).await {
392            Ok(ack) => match ack.await {
393                Ok(publish_ack) => {
394                    self.last_acked_stream_sequence
395                        .store(publish_ack.sequence, std::sync::atomic::Ordering::Relaxed);
396                    self.last_acked_balance_update_seq.store(
397                        update.balance_update_seq,
398                        std::sync::atomic::Ordering::Relaxed,
399                    );
400                }
401                Err(e) => {
402                    error!(
403                        %e,
404                        balance_update_seq = update.balance_update_seq,
405                        "NATS balance update publish ack failed"
406                    );
407                }
408            },
409            Err(e) => {
410                error!(
411                    %e,
412                    balance_update_seq = update.balance_update_seq,
413                    "NATS balance update publish failed"
414                );
415            }
416        }
417    }
418
419    pub fn last_acked_stream_sequence(&self) -> u64 {
420        self.last_acked_stream_sequence
421            .load(std::sync::atomic::Ordering::Relaxed)
422    }
423
424    pub fn last_acked_balance_update_seq(&self) -> u64 {
425        self.last_acked_balance_update_seq
426            .load(std::sync::atomic::Ordering::Relaxed)
427    }
428}
429
430/// Subscribes to the engine command stream and replays commands.
431/// Used by the standby engine to stay caught up with the primary.
432pub struct NatsReplaySubscriber {
433    js: jetstream::Context,
434    stream_name: String,
435    last_stream_seq: Arc<Mutex<i64>>,
436}
437
438pub struct ReplayMessage {
439    pub stream_seq: u64,
440    pub seq: u64,
441    pub command_type: CommandType,
442    pub command_version: u8,
443    pub command_data: Vec<u8>,
444    msg: async_nats::jetstream::Message,
445}
446
447impl ReplayMessage {
448    pub async fn ack(&self) -> anyhow::Result<()> {
449        self.msg
450            .ack()
451            .await
452            .map_err(|error| anyhow::anyhow!("NATS ack failed: {}", error))?;
453        Ok(())
454    }
455}
456
457impl NatsReplaySubscriber {
458    pub async fn connect(
459        config: &NatsConfig,
460        start_from_stream_sequence: i64,
461    ) -> anyhow::Result<Self> {
462        let client = async_nats::connect(&config.url).await?;
463        let js = jetstream::new(client);
464        let sname = stream_name(&config.env);
465
466        info!(
467            url = %config.url,
468            env = %config.env,
469            stream = %sname,
470            start_from_stream_sequence,
471            "Standby connected to NATS JetStream"
472        );
473
474        Ok(Self {
475            js,
476            stream_name: sname,
477            last_stream_seq: Arc::new(Mutex::new(start_from_stream_sequence)),
478        })
479    }
480
481    /// Start consuming commands from the stream.
482    /// `start_from_stream_sequence` is a JetStream stream sequence cursor. Use
483    /// `-1` to consume the full stream.
484    /// Returns replay messages starting after the last processed stream sequence.
485    pub async fn subscribe(
486        &self,
487    ) -> anyhow::Result<
488        std::pin::Pin<Box<dyn futures::Stream<Item = anyhow::Result<ReplayMessage>> + Send>>,
489    > {
490        use async_nats::jetstream::consumer::pull::Config as ConsumerConfig;
491        use futures::StreamExt;
492
493        let stream = self.js.get_stream(self.stream_name.clone()).await?;
494
495        let last_processed = self.last_stream_sequence_processed().await;
496        let deliver_policy = if last_processed >= 0 {
497            jetstream::consumer::DeliverPolicy::ByStartSequence {
498                start_sequence: last_processed as u64 + 1,
499            }
500        } else {
501            jetstream::consumer::DeliverPolicy::All
502        };
503
504        // Start from the first unprocessed stream sequence. The filter below is
505        // kept as a defensive guard, but the server should not deliver old
506        // messages during standby catch-up.
507        let consumer = stream
508            .create_consumer(ConsumerConfig {
509                deliver_policy,
510                ..Default::default()
511            })
512            .await?;
513
514        let last_seq = self.last_stream_seq.clone();
515
516        let messages = consumer.messages().await?;
517
518        let filtered = messages.filter_map(move |msg_result| {
519            let last_seq = last_seq.clone();
520            async move {
521                match msg_result {
522                    Ok(msg) => {
523                        let stream_seq = match msg.info() {
524                            Ok(info) => info.stream_sequence,
525                            Err(error) => {
526                                warn!(error = %error, "Cannot read NATS message stream sequence, skipping");
527                                msg.ack().await.ok();
528                                return None;
529                            }
530                        };
531
532                        let parsed = match parse_command_payload(&msg.payload) {
533                            Ok(parsed) => parsed,
534                            Err(error) => {
535                                warn!(stream_seq, error = %error, "Skipping invalid NATS command payload");
536                                msg.ack().await.ok();
537                                return None;
538                            }
539                        };
540
541                        let last = last_seq.lock().await;
542                        if (stream_seq as i64) <= *last {
543                            msg.ack().await.ok();
544                            return None;
545                        }
546                        drop(last);
547
548                        Some(Ok(ReplayMessage {
549                            stream_seq,
550                            seq: parsed.seq,
551                            command_type: parsed.command_type,
552                            command_version: parsed.command_version,
553                            command_data: parsed.command_data,
554                            msg,
555                        }))
556                    }
557                    Err(e) => Some(Err(anyhow::anyhow!("NATS message error: {}", e))),
558                }
559            }
560        });
561
562        Ok(Box::pin(filtered))
563    }
564
565    pub async fn last_stream_sequence_processed(&self) -> i64 {
566        *self.last_stream_seq.lock().await
567    }
568
569    pub async fn latest_stream_sequence(&self) -> anyhow::Result<u64> {
570        let mut stream = self.js.get_stream(self.stream_name.clone()).await?;
571        Ok(stream.info().await?.state.last_sequence)
572    }
573
574    pub async fn mark_processed(&self, seq: u64) {
575        let mut last = self.last_stream_seq.lock().await;
576        if (seq as i64) > *last {
577            *last = seq as i64;
578        }
579    }
580}
581
582// Unit tests and integration tests are in tests.rs