Skip to main content

foundry_common/provider/mpp/
ws.rs

1//! MPP WebSocket transport.
2//!
3//! Implements [`PubSubConnect`] with automatic MPP 402 challenge/credential
4//! handshake at WebSocket connect time. Non-MPP servers are handled via a
5//! timeout-based fallback: if no challenge frame arrives within a short window
6//! after connecting, we assume the server is a plain JSON-RPC WebSocket.
7
8use alloy_json_rpc::PubSubItem;
9use alloy_pubsub::{ConnectionHandle, PubSubConnect};
10use alloy_transport::{Authorization, TransportErrorKind, TransportResult, utils::guess_local_url};
11use alloy_transport_ws::WsBackend;
12use futures::{SinkExt, StreamExt};
13use mpp::{
14    client::{
15        PaymentProvider,
16        ws::{WsClientMessage, WsServerMessage},
17    },
18    protocol::core::{PaymentChallenge, format_authorization},
19};
20use rustls::crypto::{CryptoProvider, aws_lc_rs};
21use std::{io, time::Duration};
22use tokio::time::timeout;
23use tokio_tungstenite::{
24    MaybeTlsStream, WebSocketStream, connect_async,
25    tungstenite::{
26        Message,
27        client::IntoClientRequest,
28        http::{HeaderValue, header::AUTHORIZATION},
29    },
30};
31use tracing::debug;
32
33use super::{
34    keys::DiscoverOptions,
35    transport::{LazySessionProvider, extract_challenge_chain_and_currency},
36};
37
38type TungsteniteStream = WebSocketStream<MaybeTlsStream<tokio::net::TcpStream>>;
39
40/// Timeout for waiting on an MPP challenge frame from the server.
41///
42/// If no challenge arrives within this window after the WebSocket upgrade
43/// completes, we assume it's a plain (non-MPP) JSON-RPC WebSocket.
44const MPP_CHALLENGE_TIMEOUT: Duration = Duration::from_millis(500);
45
46/// Keepalive ping interval (matches alloy-transport-ws default).
47const KEEPALIVE_INTERVAL: Duration = Duration::from_secs(10);
48
49/// WebSocket connector with automatic MPP payment at connect time.
50///
51/// Implements [`PubSubConnect`] so it can be used as a drop-in replacement for
52/// alloy's `WsConnect`. On connect, it:
53///
54/// 1. Opens a WebSocket connection.
55/// 2. Waits briefly for an MPP challenge frame from the server.
56/// 3. If a challenge arrives, performs the payment handshake using [`LazySessionProvider`] (same
57///    payment logic as the HTTP transport).
58/// 4. Spawns a backend loop that bridges the authenticated WebSocket to the alloy
59///    [`PubSubFrontend`](alloy_pubsub::PubSubFrontend).
60///
61/// Non-MPP servers work transparently — the timeout expires and the backend
62/// proceeds with normal JSON-RPC message forwarding.
63#[derive(Clone, Debug)]
64pub struct MppWsConnect {
65    url: String,
66    auth: Option<Authorization>,
67    provider: LazySessionProvider,
68}
69
70impl MppWsConnect {
71    /// Create a new MPP WebSocket connector for the given URL.
72    pub fn new(url: String) -> Self {
73        let origin = url.clone();
74        let auth =
75            url::Url::parse(&url).ok().and_then(|parsed| Authorization::extract_from_url(&parsed));
76        Self { url, auth, provider: LazySessionProvider::new(origin) }
77    }
78
79    /// Set the authorization header (e.g. JWT bearer token).
80    pub fn with_auth(mut self, auth: Authorization) -> Self {
81        self.auth = Some(auth);
82        self
83    }
84
85    /// Attempt the MPP handshake on an already-connected WebSocket.
86    ///
87    /// Waits up to [`MPP_CHALLENGE_TIMEOUT`] for a challenge frame. If one
88    /// arrives, pays it and sends the credential back. Returns any buffered
89    /// non-challenge messages that arrived during the handshake window.
90    async fn try_mpp_handshake(
91        socket: &mut TungsteniteStream,
92        provider: &LazySessionProvider,
93    ) -> TransportResult<Vec<String>> {
94        let mut buffered_messages: Vec<String> = Vec::new();
95
96        // Wait briefly for a challenge frame.
97        let first_msg = timeout(MPP_CHALLENGE_TIMEOUT, socket.next()).await;
98
99        let challenge_frame = match first_msg {
100            Err(_) => {
101                // Timeout — not an MPP server.
102                debug!("no MPP challenge within timeout, treating as plain WS");
103                return Ok(buffered_messages);
104            }
105            Ok(None) => {
106                return Err(TransportErrorKind::custom(io::Error::other(
107                    "WebSocket closed before any message",
108                )));
109            }
110            Ok(Some(Err(e))) => {
111                return Err(TransportErrorKind::custom(e));
112            }
113            Ok(Some(Ok(msg))) => msg,
114        };
115
116        let text = match &challenge_frame {
117            Message::Text(t) => t.as_ref(),
118            Message::Ping(_) | Message::Pong(_) | Message::Frame(_) => {
119                return Ok(buffered_messages);
120            }
121            Message::Binary(_) | Message::Close(_) => {
122                return Err(TransportErrorKind::custom(io::Error::other(
123                    "unexpected binary/close frame on WS connect",
124                )));
125            }
126        };
127
128        // Try to parse as an MPP server message.
129        let server_msg: WsServerMessage = match serde_json::from_str(text) {
130            Ok(m) => m,
131            Err(_) => {
132                // Not an MPP message — buffer it for the backend.
133                buffered_messages.push(text.to_owned());
134                return Ok(buffered_messages);
135            }
136        };
137
138        let challenge: PaymentChallenge = match server_msg {
139            WsServerMessage::Challenge { ref challenge, .. } => {
140                serde_json::from_value(challenge.clone()).map_err(|e| {
141                    TransportErrorKind::custom(io::Error::other(format!(
142                        "failed to parse MPP WS challenge: {e}"
143                    )))
144                })?
145            }
146            _ => {
147                // Non-challenge MPP message — buffer it.
148                buffered_messages.push(text.to_owned());
149                return Ok(buffered_messages);
150            }
151        };
152
153        debug!(id = %challenge.id, method = %challenge.method, intent = %challenge.intent, "received MPP WS challenge, paying");
154
155        // Resolve the payment provider (lazily discovers keys on first use).
156        let (chain_id, currency) = extract_challenge_chain_and_currency(&challenge);
157        let currency = currency.and_then(|s| s.parse().ok());
158        let session =
159            provider.get_or_init(DiscoverOptions { chain_id, currency }).map_err(|e| {
160                TransportErrorKind::custom(io::Error::other(format!(
161                    "MPP key discovery failed: {e}"
162                )))
163            })?;
164
165        let credential = session.pay(&challenge).await.map_err(|e| {
166            TransportErrorKind::custom(io::Error::other(format!("MPP WS payment failed: {e}")))
167        })?;
168
169        // Everything after pay() must rollback on failure — wrap so we can't
170        // miss an error path.
171        let result = async {
172            let auth_header = format_authorization(&credential).map_err(|e| {
173                TransportErrorKind::custom(io::Error::other(format!(
174                    "failed to format MPP credential: {e}"
175                )))
176            })?;
177
178            // Send credential as a WS message.
179            let cred_msg = WsClientMessage::Credential { credential: auth_header };
180            let cred_text = serde_json::to_string(&cred_msg).map_err(|e| {
181                TransportErrorKind::custom(io::Error::other(format!(
182                    "failed to serialize credential message: {e}"
183                )))
184            })?;
185
186            socket.send(Message::Text(cred_text.into())).await.map_err(|e| {
187                TransportErrorKind::custom(io::Error::other(format!(
188                    "failed to send MPP credential: {e}"
189                )))
190            })?;
191
192            // Wait for server acknowledgement (receipt or data).
193            let ack = timeout(Duration::from_secs(30), socket.next()).await.map_err(|_| {
194                TransportErrorKind::custom(io::Error::other(
195                    "timeout waiting for MPP server acknowledgement",
196                ))
197            })?;
198
199            match ack {
200                None => {
201                    return Err(TransportErrorKind::custom(io::Error::other(
202                        "WebSocket closed after sending credential",
203                    )));
204                }
205                Some(Err(e)) => return Err(TransportErrorKind::custom(e)),
206                Some(Ok(Message::Text(t))) => {
207                    if let Ok(msg) = serde_json::from_str::<WsServerMessage>(t.as_ref()) {
208                        match msg {
209                            WsServerMessage::Receipt { .. } => {
210                                debug!("MPP WS handshake complete (receipt received)");
211                            }
212                            WsServerMessage::Error { error } => {
213                                return Err(TransportErrorKind::custom(io::Error::other(format!(
214                                    "MPP WS server error: {error}"
215                                ))));
216                            }
217                            _ => {
218                                buffered_messages.push(t.to_string());
219                            }
220                        }
221                    } else {
222                        buffered_messages.push(t.to_string());
223                    }
224                }
225                Some(Ok(Message::Close(_))) => {
226                    return Err(TransportErrorKind::custom(io::Error::other(
227                        "WebSocket closed after sending credential",
228                    )));
229                }
230                Some(Ok(_)) => {}
231            }
232
233            Ok(buffered_messages)
234        }
235        .await;
236
237        match &result {
238            Ok(_) => provider.flush_pending(),
239            Err(_) => provider.rollback_pending(),
240        }
241
242        result
243    }
244}
245
246impl PubSubConnect for MppWsConnect {
247    fn is_local(&self) -> bool {
248        guess_local_url(&self.url)
249    }
250
251    async fn connect(&self) -> TransportResult<ConnectionHandle> {
252        let mut request =
253            self.url.as_str().into_client_request().map_err(TransportErrorKind::custom)?;
254
255        if let Some(ref auth) = self.auth {
256            let mut auth_value =
257                HeaderValue::from_str(&auth.to_string()).map_err(TransportErrorKind::custom)?;
258            auth_value.set_sensitive(true);
259            request.headers_mut().insert(AUTHORIZATION, auth_value);
260        }
261
262        // Install the default rustls crypto provider (required by rustls 0.23+).
263        let _ = CryptoProvider::install_default(aws_lc_rs::default_provider());
264
265        let (mut socket, _) = connect_async(request).await.map_err(TransportErrorKind::custom)?;
266
267        // Attempt MPP handshake (timeout-based fallback for non-MPP servers).
268        let buffered = Self::try_mpp_handshake(&mut socket, &self.provider).await?;
269
270        let (handle, interface) = ConnectionHandle::new();
271
272        // Replay any messages that arrived during the handshake window.
273        for msg in &buffered {
274            if let Ok(item) = serde_json::from_str::<PubSubItem>(msg) {
275                let _ = interface.send_to_frontend(item);
276            }
277        }
278
279        // Reuse alloy's WsBackend for the post-handshake JSON-RPC loop.
280        WsBackend::from_socket(socket, interface, KEEPALIVE_INTERVAL).spawn();
281
282        Ok(handle)
283    }
284}
285
286#[cfg(test)]
287mod tests {
288    use super::*;
289    use alloy_json_rpc::{Id, Request, RequestMeta, RequestPacket, ResponsePacket};
290    use alloy_primitives::hex;
291    use mpp::{
292        PrivateKeySigner,
293        protocol::core::{Base64UrlJson, IntentName, MethodName, parse_authorization},
294    };
295    use tokio::task::JoinHandle;
296    use tokio_tungstenite::accept_async;
297    use tower::Service;
298
299    fn test_challenge() -> PaymentChallenge {
300        let request = Base64UrlJson::from_value(&serde_json::json!({
301            "amount": "1000",
302            "currency": "0x0000000000000000000000000000000000000000",
303            "recipient": "0x0000000000000000000000000000000000000001",
304            "methodDetails": {
305                "chainId": 42431
306            }
307        }))
308        .unwrap();
309
310        PaymentChallenge {
311            id: "ws-test-id".to_string(),
312            realm: "test-realm".to_string(),
313            method: MethodName::new("tempo"),
314            intent: IntentName::new("session"),
315            request,
316            expires: None,
317            description: None,
318            digest: None,
319            opaque: None,
320        }
321    }
322
323    /// Spawn a WS server on localhost, returns (ws_url, join_handle).
324    /// `handler` receives the server-side socket for full control.
325    async fn spawn_ws_server<F, Fut>(handler: F) -> (String, JoinHandle<()>)
326    where
327        F: FnOnce(WebSocketStream<tokio::net::TcpStream>) -> Fut + Send + 'static,
328        Fut: std::future::Future<Output = ()> + Send,
329    {
330        let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
331        let addr = listener.local_addr().unwrap();
332        let handle = tokio::spawn(async move {
333            let (stream, _) = listener.accept().await.unwrap();
334            let ws = accept_async(stream).await.unwrap();
335            handler(ws).await;
336        });
337        (format!("ws://{addr}"), handle)
338    }
339
340    /// Plain WS server (no MPP) — connect and send a JSON-RPC request.
341    #[tokio::test]
342    async fn test_ws_no_mpp_plain_jsonrpc() {
343        let (url, server) = spawn_ws_server(|mut ws| async move {
344            // Wait for a JSON-RPC request, echo a response.
345            while let Some(Ok(msg)) = ws.next().await {
346                if let Message::Text(text) = msg {
347                    let req: serde_json::Value = serde_json::from_str(&text).unwrap();
348                    let id = req.get("id").unwrap().clone();
349                    let resp = serde_json::json!({
350                        "jsonrpc": "2.0",
351                        "id": id,
352                        "result": "0xabc"
353                    });
354                    ws.send(Message::Text(resp.to_string().into())).await.unwrap();
355                }
356            }
357        })
358        .await;
359
360        let connector = MppWsConnect::new(url);
361        let mut frontend = connector.into_service().await.unwrap();
362
363        let req = Request {
364            meta: RequestMeta::new("eth_blockNumber".into(), Id::Number(1)),
365            params: serde_json::Value::Array(vec![]),
366        };
367        let packet = RequestPacket::Single(req.serialize().unwrap());
368        let resp = frontend.call(packet).await.unwrap();
369
370        match resp {
371            ResponsePacket::Single(r) => assert!(r.is_success()),
372            _ => panic!("expected single response"),
373        }
374
375        server.abort();
376    }
377
378    /// MPP server sends challenge → client pays → server sends receipt.
379    #[tokio::test]
380    async fn test_ws_mpp_challenge_credential_receipt() {
381        let challenge = test_challenge();
382        let challenge_json = serde_json::to_value(&challenge).unwrap();
383
384        let (url, server) = spawn_ws_server(move |mut ws| async move {
385            // Send challenge.
386            let challenge_msg = serde_json::json!({
387                "type": "challenge",
388                "challenge": challenge_json
389            });
390            ws.send(Message::Text(challenge_msg.to_string().into())).await.unwrap();
391
392            // Receive credential.
393            let msg = ws.next().await.unwrap().unwrap();
394            let text = match msg {
395                Message::Text(t) => t,
396                other => panic!("expected text, got {other:?}"),
397            };
398            let parsed: serde_json::Value = serde_json::from_str(&text).unwrap();
399            assert_eq!(parsed["type"], "credential");
400            // Verify it's a valid MPP credential.
401            let cred_str = parsed["credential"].as_str().unwrap();
402            let cred = parse_authorization(cred_str).unwrap();
403            assert_eq!(cred.challenge.id, "ws-test-id");
404
405            // Send receipt.
406            let receipt_msg = serde_json::json!({
407                "type": "receipt",
408                "receipt": { "id": "ws-test-id" }
409            });
410            ws.send(Message::Text(receipt_msg.to_string().into())).await.unwrap();
411
412            // Now serve JSON-RPC.
413            while let Some(Ok(msg)) = ws.next().await {
414                if let Message::Text(text) = msg {
415                    let req: serde_json::Value = serde_json::from_str(&text).unwrap();
416                    let id = req.get("id").unwrap().clone();
417                    let resp = serde_json::json!({
418                        "jsonrpc": "2.0",
419                        "id": id,
420                        "result": "0xpaid"
421                    });
422                    ws.send(Message::Text(resp.to_string().into())).await.unwrap();
423                }
424            }
425        })
426        .await;
427
428        // Set a random private key so LazySessionProvider can initialize.
429        let signer = PrivateKeySigner::random();
430        let key_hex = hex::encode(signer.to_bytes());
431        unsafe { std::env::set_var("TEMPO_PRIVATE_KEY", &key_hex) };
432
433        let connector = MppWsConnect::new(url);
434        let mut frontend = connector.into_service().await.unwrap();
435
436        let req = Request {
437            meta: RequestMeta::new("eth_blockNumber".into(), Id::Number(1)),
438            params: serde_json::Value::Array(vec![]),
439        };
440        let packet = RequestPacket::Single(req.serialize().unwrap());
441        let resp = frontend.call(packet).await.unwrap();
442
443        match resp {
444            ResponsePacket::Single(r) => assert!(r.is_success()),
445            _ => panic!("expected single response"),
446        }
447
448        unsafe { std::env::remove_var("TEMPO_PRIVATE_KEY") };
449        server.abort();
450    }
451
452    /// MPP server sends challenge, client pays, server closes → rollback.
453    #[tokio::test]
454    async fn test_ws_mpp_rollback_on_post_pay_close() {
455        let challenge = test_challenge();
456        let challenge_json = serde_json::to_value(&challenge).unwrap();
457
458        let (url, server) = spawn_ws_server(move |mut ws| async move {
459            // Send challenge.
460            let challenge_msg = serde_json::json!({
461                "type": "challenge",
462                "challenge": challenge_json
463            });
464            ws.send(Message::Text(challenge_msg.to_string().into())).await.unwrap();
465
466            // Receive credential (client paid).
467            let _ = ws.next().await;
468
469            // Close without sending receipt — simulates post-pay failure.
470            ws.close(None).await.ok();
471        })
472        .await;
473
474        let signer = PrivateKeySigner::random();
475        let key_hex = hex::encode(signer.to_bytes());
476        unsafe { std::env::set_var("TEMPO_PRIVATE_KEY", &key_hex) };
477
478        let connector = MppWsConnect::new(url);
479        let result = connector.connect().await;
480
481        // Connect must fail — server closed after credential was sent.
482        assert!(result.is_err(), "expected error when server closes after payment, got Ok");
483        let err = result.unwrap_err().to_string();
484        assert!(
485            err.contains("closed") || err.contains("WebSocket"),
486            "expected close-related error, got: {err}"
487        );
488
489        unsafe { std::env::remove_var("TEMPO_PRIVATE_KEY") };
490        server.abort();
491    }
492}