foundry_common/provider/mpp/
ws.rs1use alloy_json_rpc::PubSubItem;
9use alloy_pubsub::{ConnectionHandle, PubSubConnect};
10use alloy_transport::{Authorization, TransportErrorKind, TransportResult, utils::guess_local_url};
11use alloy_transport_ws::WsBackend;
12use futures::{SinkExt, StreamExt};
13use mpp::{
14 client::{
15 PaymentProvider,
16 ws::{WsClientMessage, WsServerMessage},
17 },
18 protocol::core::{PaymentChallenge, format_authorization},
19};
20use rustls::crypto::{CryptoProvider, aws_lc_rs};
21use std::{io, time::Duration};
22use tokio::time::timeout;
23use tokio_tungstenite::{
24 MaybeTlsStream, WebSocketStream, connect_async,
25 tungstenite::{
26 Message,
27 client::IntoClientRequest,
28 http::{HeaderValue, header::AUTHORIZATION},
29 },
30};
31use tracing::debug;
32
33use super::{
34 keys::DiscoverOptions,
35 transport::{LazySessionProvider, extract_challenge_chain_and_currency},
36};
37
38type TungsteniteStream = WebSocketStream<MaybeTlsStream<tokio::net::TcpStream>>;
39
40const MPP_CHALLENGE_TIMEOUT: Duration = Duration::from_millis(500);
45
46const KEEPALIVE_INTERVAL: Duration = Duration::from_secs(10);
48
49#[derive(Clone, Debug)]
64pub struct MppWsConnect {
65 url: String,
66 auth: Option<Authorization>,
67 provider: LazySessionProvider,
68}
69
70impl MppWsConnect {
71 pub fn new(url: String) -> Self {
73 let origin = url.clone();
74 let auth =
75 url::Url::parse(&url).ok().and_then(|parsed| Authorization::extract_from_url(&parsed));
76 Self { url, auth, provider: LazySessionProvider::new(origin) }
77 }
78
79 pub fn with_auth(mut self, auth: Authorization) -> Self {
81 self.auth = Some(auth);
82 self
83 }
84
85 async fn try_mpp_handshake(
91 socket: &mut TungsteniteStream,
92 provider: &LazySessionProvider,
93 ) -> TransportResult<Vec<String>> {
94 let mut buffered_messages: Vec<String> = Vec::new();
95
96 let first_msg = timeout(MPP_CHALLENGE_TIMEOUT, socket.next()).await;
98
99 let challenge_frame = match first_msg {
100 Err(_) => {
101 debug!("no MPP challenge within timeout, treating as plain WS");
103 return Ok(buffered_messages);
104 }
105 Ok(None) => {
106 return Err(TransportErrorKind::custom(io::Error::other(
107 "WebSocket closed before any message",
108 )));
109 }
110 Ok(Some(Err(e))) => {
111 return Err(TransportErrorKind::custom(e));
112 }
113 Ok(Some(Ok(msg))) => msg,
114 };
115
116 let text = match &challenge_frame {
117 Message::Text(t) => t.as_ref(),
118 Message::Ping(_) | Message::Pong(_) | Message::Frame(_) => {
119 return Ok(buffered_messages);
120 }
121 Message::Binary(_) | Message::Close(_) => {
122 return Err(TransportErrorKind::custom(io::Error::other(
123 "unexpected binary/close frame on WS connect",
124 )));
125 }
126 };
127
128 let server_msg: WsServerMessage = match serde_json::from_str(text) {
130 Ok(m) => m,
131 Err(_) => {
132 buffered_messages.push(text.to_owned());
134 return Ok(buffered_messages);
135 }
136 };
137
138 let challenge: PaymentChallenge = match server_msg {
139 WsServerMessage::Challenge { ref challenge, .. } => {
140 serde_json::from_value(challenge.clone()).map_err(|e| {
141 TransportErrorKind::custom(io::Error::other(format!(
142 "failed to parse MPP WS challenge: {e}"
143 )))
144 })?
145 }
146 _ => {
147 buffered_messages.push(text.to_owned());
149 return Ok(buffered_messages);
150 }
151 };
152
153 debug!(id = %challenge.id, method = %challenge.method, intent = %challenge.intent, "received MPP WS challenge, paying");
154
155 let (chain_id, currency) = extract_challenge_chain_and_currency(&challenge);
157 let currency = currency.and_then(|s| s.parse().ok());
158 let session =
159 provider.get_or_init(DiscoverOptions { chain_id, currency }).map_err(|e| {
160 TransportErrorKind::custom(io::Error::other(format!(
161 "MPP key discovery failed: {e}"
162 )))
163 })?;
164
165 let credential = session.pay(&challenge).await.map_err(|e| {
166 TransportErrorKind::custom(io::Error::other(format!("MPP WS payment failed: {e}")))
167 })?;
168
169 let result = async {
172 let auth_header = format_authorization(&credential).map_err(|e| {
173 TransportErrorKind::custom(io::Error::other(format!(
174 "failed to format MPP credential: {e}"
175 )))
176 })?;
177
178 let cred_msg = WsClientMessage::Credential { credential: auth_header };
180 let cred_text = serde_json::to_string(&cred_msg).map_err(|e| {
181 TransportErrorKind::custom(io::Error::other(format!(
182 "failed to serialize credential message: {e}"
183 )))
184 })?;
185
186 socket.send(Message::Text(cred_text.into())).await.map_err(|e| {
187 TransportErrorKind::custom(io::Error::other(format!(
188 "failed to send MPP credential: {e}"
189 )))
190 })?;
191
192 let ack = timeout(Duration::from_secs(30), socket.next()).await.map_err(|_| {
194 TransportErrorKind::custom(io::Error::other(
195 "timeout waiting for MPP server acknowledgement",
196 ))
197 })?;
198
199 match ack {
200 None => {
201 return Err(TransportErrorKind::custom(io::Error::other(
202 "WebSocket closed after sending credential",
203 )));
204 }
205 Some(Err(e)) => return Err(TransportErrorKind::custom(e)),
206 Some(Ok(Message::Text(t))) => {
207 if let Ok(msg) = serde_json::from_str::<WsServerMessage>(t.as_ref()) {
208 match msg {
209 WsServerMessage::Receipt { .. } => {
210 debug!("MPP WS handshake complete (receipt received)");
211 }
212 WsServerMessage::Error { error } => {
213 return Err(TransportErrorKind::custom(io::Error::other(format!(
214 "MPP WS server error: {error}"
215 ))));
216 }
217 _ => {
218 buffered_messages.push(t.to_string());
219 }
220 }
221 } else {
222 buffered_messages.push(t.to_string());
223 }
224 }
225 Some(Ok(Message::Close(_))) => {
226 return Err(TransportErrorKind::custom(io::Error::other(
227 "WebSocket closed after sending credential",
228 )));
229 }
230 Some(Ok(_)) => {}
231 }
232
233 Ok(buffered_messages)
234 }
235 .await;
236
237 match &result {
238 Ok(_) => provider.flush_pending(),
239 Err(_) => provider.rollback_pending(),
240 }
241
242 result
243 }
244}
245
246impl PubSubConnect for MppWsConnect {
247 fn is_local(&self) -> bool {
248 guess_local_url(&self.url)
249 }
250
251 async fn connect(&self) -> TransportResult<ConnectionHandle> {
252 let mut request =
253 self.url.as_str().into_client_request().map_err(TransportErrorKind::custom)?;
254
255 if let Some(ref auth) = self.auth {
256 let mut auth_value =
257 HeaderValue::from_str(&auth.to_string()).map_err(TransportErrorKind::custom)?;
258 auth_value.set_sensitive(true);
259 request.headers_mut().insert(AUTHORIZATION, auth_value);
260 }
261
262 let _ = CryptoProvider::install_default(aws_lc_rs::default_provider());
264
265 let (mut socket, _) = connect_async(request).await.map_err(TransportErrorKind::custom)?;
266
267 let buffered = Self::try_mpp_handshake(&mut socket, &self.provider).await?;
269
270 let (handle, interface) = ConnectionHandle::new();
271
272 for msg in &buffered {
274 if let Ok(item) = serde_json::from_str::<PubSubItem>(msg) {
275 let _ = interface.send_to_frontend(item);
276 }
277 }
278
279 WsBackend::from_socket(socket, interface, KEEPALIVE_INTERVAL).spawn();
281
282 Ok(handle)
283 }
284}
285
286#[cfg(test)]
287mod tests {
288 use super::*;
289 use alloy_json_rpc::{Id, Request, RequestMeta, RequestPacket, ResponsePacket};
290 use alloy_primitives::hex;
291 use mpp::{
292 PrivateKeySigner,
293 protocol::core::{Base64UrlJson, IntentName, MethodName, parse_authorization},
294 };
295 use tokio::task::JoinHandle;
296 use tokio_tungstenite::accept_async;
297 use tower::Service;
298
299 fn test_challenge() -> PaymentChallenge {
300 let request = Base64UrlJson::from_value(&serde_json::json!({
301 "amount": "1000",
302 "currency": "0x0000000000000000000000000000000000000000",
303 "recipient": "0x0000000000000000000000000000000000000001",
304 "methodDetails": {
305 "chainId": 42431
306 }
307 }))
308 .unwrap();
309
310 PaymentChallenge {
311 id: "ws-test-id".to_string(),
312 realm: "test-realm".to_string(),
313 method: MethodName::new("tempo"),
314 intent: IntentName::new("session"),
315 request,
316 expires: None,
317 description: None,
318 digest: None,
319 opaque: None,
320 }
321 }
322
323 async fn spawn_ws_server<F, Fut>(handler: F) -> (String, JoinHandle<()>)
326 where
327 F: FnOnce(WebSocketStream<tokio::net::TcpStream>) -> Fut + Send + 'static,
328 Fut: std::future::Future<Output = ()> + Send,
329 {
330 let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
331 let addr = listener.local_addr().unwrap();
332 let handle = tokio::spawn(async move {
333 let (stream, _) = listener.accept().await.unwrap();
334 let ws = accept_async(stream).await.unwrap();
335 handler(ws).await;
336 });
337 (format!("ws://{addr}"), handle)
338 }
339
340 #[tokio::test]
342 async fn test_ws_no_mpp_plain_jsonrpc() {
343 let (url, server) = spawn_ws_server(|mut ws| async move {
344 while let Some(Ok(msg)) = ws.next().await {
346 if let Message::Text(text) = msg {
347 let req: serde_json::Value = serde_json::from_str(&text).unwrap();
348 let id = req.get("id").unwrap().clone();
349 let resp = serde_json::json!({
350 "jsonrpc": "2.0",
351 "id": id,
352 "result": "0xabc"
353 });
354 ws.send(Message::Text(resp.to_string().into())).await.unwrap();
355 }
356 }
357 })
358 .await;
359
360 let connector = MppWsConnect::new(url);
361 let mut frontend = connector.into_service().await.unwrap();
362
363 let req = Request {
364 meta: RequestMeta::new("eth_blockNumber".into(), Id::Number(1)),
365 params: serde_json::Value::Array(vec![]),
366 };
367 let packet = RequestPacket::Single(req.serialize().unwrap());
368 let resp = frontend.call(packet).await.unwrap();
369
370 match resp {
371 ResponsePacket::Single(r) => assert!(r.is_success()),
372 _ => panic!("expected single response"),
373 }
374
375 server.abort();
376 }
377
378 #[tokio::test]
380 async fn test_ws_mpp_challenge_credential_receipt() {
381 let challenge = test_challenge();
382 let challenge_json = serde_json::to_value(&challenge).unwrap();
383
384 let (url, server) = spawn_ws_server(move |mut ws| async move {
385 let challenge_msg = serde_json::json!({
387 "type": "challenge",
388 "challenge": challenge_json
389 });
390 ws.send(Message::Text(challenge_msg.to_string().into())).await.unwrap();
391
392 let msg = ws.next().await.unwrap().unwrap();
394 let text = match msg {
395 Message::Text(t) => t,
396 other => panic!("expected text, got {other:?}"),
397 };
398 let parsed: serde_json::Value = serde_json::from_str(&text).unwrap();
399 assert_eq!(parsed["type"], "credential");
400 let cred_str = parsed["credential"].as_str().unwrap();
402 let cred = parse_authorization(cred_str).unwrap();
403 assert_eq!(cred.challenge.id, "ws-test-id");
404
405 let receipt_msg = serde_json::json!({
407 "type": "receipt",
408 "receipt": { "id": "ws-test-id" }
409 });
410 ws.send(Message::Text(receipt_msg.to_string().into())).await.unwrap();
411
412 while let Some(Ok(msg)) = ws.next().await {
414 if let Message::Text(text) = msg {
415 let req: serde_json::Value = serde_json::from_str(&text).unwrap();
416 let id = req.get("id").unwrap().clone();
417 let resp = serde_json::json!({
418 "jsonrpc": "2.0",
419 "id": id,
420 "result": "0xpaid"
421 });
422 ws.send(Message::Text(resp.to_string().into())).await.unwrap();
423 }
424 }
425 })
426 .await;
427
428 let signer = PrivateKeySigner::random();
430 let key_hex = hex::encode(signer.to_bytes());
431 unsafe { std::env::set_var("TEMPO_PRIVATE_KEY", &key_hex) };
432
433 let connector = MppWsConnect::new(url);
434 let mut frontend = connector.into_service().await.unwrap();
435
436 let req = Request {
437 meta: RequestMeta::new("eth_blockNumber".into(), Id::Number(1)),
438 params: serde_json::Value::Array(vec![]),
439 };
440 let packet = RequestPacket::Single(req.serialize().unwrap());
441 let resp = frontend.call(packet).await.unwrap();
442
443 match resp {
444 ResponsePacket::Single(r) => assert!(r.is_success()),
445 _ => panic!("expected single response"),
446 }
447
448 unsafe { std::env::remove_var("TEMPO_PRIVATE_KEY") };
449 server.abort();
450 }
451
452 #[tokio::test]
454 async fn test_ws_mpp_rollback_on_post_pay_close() {
455 let challenge = test_challenge();
456 let challenge_json = serde_json::to_value(&challenge).unwrap();
457
458 let (url, server) = spawn_ws_server(move |mut ws| async move {
459 let challenge_msg = serde_json::json!({
461 "type": "challenge",
462 "challenge": challenge_json
463 });
464 ws.send(Message::Text(challenge_msg.to_string().into())).await.unwrap();
465
466 let _ = ws.next().await;
468
469 ws.close(None).await.ok();
471 })
472 .await;
473
474 let signer = PrivateKeySigner::random();
475 let key_hex = hex::encode(signer.to_bytes());
476 unsafe { std::env::set_var("TEMPO_PRIVATE_KEY", &key_hex) };
477
478 let connector = MppWsConnect::new(url);
479 let result = connector.connect().await;
480
481 assert!(result.is_err(), "expected error when server closes after payment, got Ok");
483 let err = result.unwrap_err().to_string();
484 assert!(
485 err.contains("closed") || err.contains("WebSocket"),
486 "expected close-related error, got: {err}"
487 );
488
489 unsafe { std::env::remove_var("TEMPO_PRIVATE_KEY") };
490 server.abort();
491 }
492}