1use alloy_json_rpc::{RequestPacket, ResponsePacket};
8use alloy_transport::{TransportError, TransportErrorKind, TransportFut, TransportResult};
9use mpp::{
10 client::PaymentProvider,
11 protocol::core::{
12 AUTHORIZATION_HEADER, WWW_AUTHENTICATE_HEADER, format_authorization,
13 parse_www_authenticate_all,
14 },
15};
16use reqwest::{StatusCode, header::HeaderMap};
17use std::{
18 collections::HashMap,
19 fmt,
20 sync::{Mutex, OnceLock},
21 task,
22 time::Duration,
23};
24use tokio::sync::OwnedMutexGuard;
25use tower::Service;
26use tracing::{Instrument, debug, debug_span, trace};
27use url::Url;
28
29use super::{
30 keys::{DiscoverOptions, discover_mpp_config},
31 session::SessionProvider,
32};
33
34const DEFAULT_DEPOSIT: u128 = 100_000;
36
37const MPP_RETRY_TIMEOUT: Duration = Duration::from_secs(120);
39
40fn default_deposit() -> u128 {
42 std::env::var("MPP_DEPOSIT").ok().and_then(|s| s.parse().ok()).unwrap_or(DEFAULT_DEPOSIT)
43}
44
45fn format_http_diagnostics(headers: &HeaderMap) -> String {
46 const DIAGNOSTIC_HEADERS: &[&str] = &["x-request-id", "cf-ray", "server", "report-to", "nel"];
47
48 let pairs: Vec<String> = DIAGNOSTIC_HEADERS
49 .iter()
50 .filter_map(|name| {
51 headers.get(*name).and_then(|value| value.to_str().ok().map(|v| (*name, v)))
52 })
53 .map(|(name, value)| format!("{name}: {value}"))
54 .collect();
55
56 if pairs.is_empty() {
57 String::new()
58 } else {
59 format!("\n\nHTTP diagnostics:\n{}", pairs.join("\n"))
60 }
61}
62
63static GLOBAL_PAY_LOCKS: OnceLock<Mutex<HashMap<String, std::sync::Arc<tokio::sync::Mutex<()>>>>> =
68 OnceLock::new();
69
70pub type LazyMppHttpTransport = MppHttpTransport<LazySessionProvider>;
73
74#[derive(Clone, Debug)]
77pub struct LazySessionProvider {
78 inner: std::sync::Arc<Mutex<Option<SessionProvider>>>,
79 pay_lock: std::sync::Arc<tokio::sync::Mutex<()>>,
81 origin: String,
82}
83
84impl LazySessionProvider {
85 pub(super) fn new(origin: String) -> Self {
86 let pay_lock = {
87 let global = GLOBAL_PAY_LOCKS.get_or_init(|| Mutex::new(HashMap::new()));
88 global
89 .lock()
90 .unwrap()
91 .entry(origin.clone())
92 .or_insert_with(|| std::sync::Arc::new(tokio::sync::Mutex::new(())))
93 .clone()
94 };
95 Self { inner: std::sync::Arc::new(Mutex::new(None)), pay_lock, origin }
96 }
97
98 fn set_key_provisioned(&self, provisioned: bool) {
99 if let Some(p) = self.inner.lock().unwrap().as_ref() {
100 p.set_key_provisioned(provisioned);
101 }
102 }
103
104 fn clear_channels(&self) {
105 if let Some(p) = self.inner.lock().unwrap().as_ref() {
106 p.clear_channels();
107 }
108 }
109
110 pub(super) fn flush_pending(&self) {
111 if let Some(p) = self.inner.lock().unwrap().as_ref() {
112 p.flush_pending();
113 }
114 }
115
116 pub(super) fn rollback_pending(&self) {
117 if let Some(p) = self.inner.lock().unwrap().as_ref() {
118 p.rollback_pending();
119 }
120 }
121
122 fn commit_topup_and_track_voucher(&self) {
123 if let Some(p) = self.inner.lock().unwrap().as_ref() {
124 p.commit_topup_and_track_voucher();
125 }
126 }
127
128 pub(super) fn get_or_init(&self, opts: DiscoverOptions) -> TransportResult<SessionProvider> {
129 let mut guard = self.inner.lock().unwrap();
130 if let Some(ref provider) = *guard {
131 return Ok(provider.clone());
132 }
133
134 let config = discover_mpp_config(opts).ok_or_else(|| {
135 TransportErrorKind::custom(std::io::Error::other(
136 "RPC endpoint returned HTTP 402 Payment Required. \
137 This endpoint requires payment via the Machine Payments Protocol (MPP).\n\n\
138 To configure MPP, install the Tempo wallet CLI and create a key:\n\
139 \n curl -sSL https://tempo.xyz/install.sh | bash\
140 \n tempo wallet login\
141 \n\nSee https://docs.tempo.xyz for more information.",
142 ))
143 })?;
144
145 let signer: mpp::PrivateKeySigner = config.key.parse().map_err(|e| {
146 TransportErrorKind::custom(std::io::Error::other(format!("invalid MPP key: {e}")))
147 })?;
148
149 let signing_mode = if let Some(wallet) = config.wallet_address {
150 let key_authorization = config
151 .key_authorization
152 .as_ref()
153 .map(|hex_str| {
154 crate::tempo::decode_key_authorization(hex_str).map(Box::new).map_err(|e| {
155 TransportErrorKind::custom(std::io::Error::other(format!(
156 "invalid MPP key_authorization: {e}"
157 )))
158 })
159 })
160 .transpose()?;
161
162 mpp::client::tempo::signing::TempoSigningMode::Keychain {
163 wallet,
164 key_authorization,
165 version: mpp::client::tempo::signing::KeychainVersion::V2,
166 }
167 } else {
168 mpp::client::tempo::signing::TempoSigningMode::Direct
169 };
170
171 let mut provider = SessionProvider::new(signer, self.origin.clone())
172 .with_signing_mode(signing_mode)
173 .with_default_deposit(default_deposit())
174 .with_key_filters(config.chain_id, config.currencies);
175
176 if let Some(addr) = config.key_address {
177 provider = provider.with_authorized_signer(addr);
178 }
179
180 *guard = Some(provider.clone());
181 Ok(provider)
182 }
183}
184
185#[derive(Clone, Debug)]
190pub struct MppHttpTransport<P> {
191 client: reqwest::Client,
192 url: Url,
193 provider: P,
194}
195
196impl MppHttpTransport<LazySessionProvider> {
197 pub fn lazy(client: reqwest::Client, url: Url) -> Self {
203 let origin = url.to_string();
204 Self { client, url, provider: LazySessionProvider::new(origin) }
205 }
206}
207
208impl<P> MppHttpTransport<P> {
209 pub const fn new(client: reqwest::Client, url: Url, provider: P) -> Self {
211 Self { client, url, provider }
212 }
213
214 pub const fn client(&self) -> &reqwest::Client {
216 &self.client
217 }
218}
219
220#[allow(private_bounds)]
221impl<P: ResolveProvider + Clone + Send + Sync + 'static> MppHttpTransport<P>
222where
223 P::Provider: Send + Sync + 'static,
224{
225 async fn do_request(self, req: RequestPacket) -> TransportResult<ResponsePacket> {
226 let body = serde_json::to_vec(&req).map_err(TransportErrorKind::custom)?;
227 let headers = req.headers();
228
229 let resp = self
230 .client
231 .post(self.url.clone())
232 .headers(headers.clone())
233 .header("content-type", "application/json")
234 .body(body.clone())
235 .send()
236 .await
237 .map_err(TransportErrorKind::custom)?;
238
239 if resp.status() != StatusCode::PAYMENT_REQUIRED {
240 return Self::handle_response(resp).await;
241 }
242
243 let _pay_guard = self.provider.lock_pay().await;
248
249 let (resolved, challenge) = Self::select_challenge(&resp, &self.provider)?;
250
251 debug!(id = %challenge.id, method = %challenge.method, intent = %challenge.intent, "received MPP 402 challenge, paying");
252
253 let credential = resolved.pay(&challenge).await.map_err(|e| {
254 TransportErrorKind::custom(std::io::Error::other(format!("MPP payment failed: {e}")))
255 })?;
256
257 let auth_header = format_authorization(&credential).map_err(|e| {
258 TransportErrorKind::custom(std::io::Error::other(format!(
259 "failed to format MPP credential: {e}"
260 )))
261 })?;
262
263 let retry_resp = self
266 .client
267 .post(self.url.clone())
268 .timeout(MPP_RETRY_TIMEOUT)
269 .headers(headers.clone())
270 .header("content-type", "application/json")
271 .header(AUTHORIZATION_HEADER, &auth_header)
272 .body(body.clone())
273 .send()
274 .await
275 .map_err(|e| {
276 self.provider.rollback_pending();
277 TransportErrorKind::custom(e)
278 })?;
279
280 if retry_resp.status() == StatusCode::NO_CONTENT {
282 debug!("MPP topUp accepted (204), retrying with voucher");
283
284 self.provider.commit_topup_and_track_voucher();
287
288 let resolved = self.provider.resolve()?;
289 let voucher_resp = self.pay_and_retry(&challenge, &resolved, &headers, &body).await?;
290
291 let result = Self::handle_response(voucher_resp).await;
292 if result.is_ok() {
293 self.provider.set_key_provisioned(true);
294 self.provider.flush_pending();
295 } else {
296 self.provider.rollback_pending();
297 }
298 return result;
299 }
300
301 if retry_resp.status() == StatusCode::GONE {
303 debug!("MPP channel not found (410), clearing stale local state");
304 self.provider.rollback_pending();
305 self.provider.clear_channels();
306
307 return Err(TransportErrorKind::custom(std::io::Error::other(
308 "MPP channel not found on server (410 Gone). \
309 The server may have restarted or the channel was closed externally.\n\
310 Local channel state has been cleared. Re-run to open a new channel.",
311 )));
312 }
313
314 if retry_resp.status() == StatusCode::PAYMENT_REQUIRED {
316 let diagnostics = format_http_diagnostics(retry_resp.headers());
317 let retry_body = retry_resp.bytes().await.map_err(TransportErrorKind::custom)?;
318 let retry_text = String::from_utf8_lossy(&retry_body);
319
320 let problem: Option<mpp::error::PaymentErrorDetails> =
323 serde_json::from_slice(&retry_body).ok();
324 let problem_type = problem.as_ref().map(|p| p.problem_type.as_str()).unwrap_or("");
325 let detail = problem.as_ref().map(|p| p.detail.as_str()).unwrap_or("");
326
327 let is_stale_voucher = problem_type.ends_with("/stale-voucher")
331 || detail.contains("cumulativeAmount must be strictly greater");
332 if is_stale_voucher {
333 debug!("MPP voucher stale, retrying with fresh voucher");
334 let resolved = self.provider.resolve()?;
335 if resolved.supports(challenge.method.as_str(), challenge.intent.as_str()) {
336 let final_resp =
337 self.pay_and_retry(&challenge, &resolved, &headers, &body).await?;
338
339 let result = Self::handle_response(final_resp).await;
340 if result.is_ok() {
341 self.provider.flush_pending();
342 } else {
343 self.provider.rollback_pending();
344 }
345 return result;
346 }
347 }
348
349 let needs_key_provisioning = problem_type.ends_with("/key-not-provisioned")
357 || detail.contains("access key does not exist")
358 || detail.contains("key is not provisioned");
359
360 let needs_verification_retry = (problem_type.ends_with("/verification-failed")
361 || detail.contains("verification-failed"))
362 && self.provider.is_key_provisioned();
363
364 if needs_key_provisioning || needs_verification_retry {
365 debug!(
366 problem_type,
367 "MPP 402 key not provisioned/verification-failed, retrying with key_authorization"
368 );
369 self.provider.set_key_provisioned(false);
370 self.provider.rollback_pending();
371
372 let (resolved, fresh_challenge) =
373 self.fetch_fresh_challenge(&headers, &body).await?;
374
375 let final_resp =
376 self.pay_and_retry(&fresh_challenge, &resolved, &headers, &body).await?;
377
378 let result = Self::handle_response(final_resp).await;
379 if result.is_ok() {
380 self.provider.set_key_provisioned(true);
381 self.provider.flush_pending();
382 } else {
383 self.provider.rollback_pending();
384 }
385 return result;
386 }
387
388 self.provider.rollback_pending();
389 return Err(TransportErrorKind::http_error(
390 StatusCode::PAYMENT_REQUIRED.as_u16(),
391 format!("{retry_text}{diagnostics}"),
392 ));
393 }
394
395 let result = Self::handle_response(retry_resp).await;
396 if result.is_ok() {
397 self.provider.set_key_provisioned(true);
398 self.provider.flush_pending();
399 } else {
400 self.provider.rollback_pending();
401 }
402 result
403 }
404
405 async fn pay_and_retry(
407 &self,
408 challenge: &mpp::protocol::core::PaymentChallenge,
409 provider: &P::Provider,
410 headers: &reqwest::header::HeaderMap,
411 body: &[u8],
412 ) -> TransportResult<reqwest::Response> {
413 let credential = provider.pay(challenge).await.map_err(|e| {
414 self.provider.rollback_pending();
415 TransportErrorKind::custom(std::io::Error::other(format!("MPP payment failed: {e}")))
416 })?;
417
418 let auth_header = format_authorization(&credential).map_err(|e| {
419 self.provider.rollback_pending();
420 TransportErrorKind::custom(std::io::Error::other(format!(
421 "failed to format MPP credential: {e}"
422 )))
423 })?;
424
425 self.client
426 .post(self.url.clone())
427 .timeout(MPP_RETRY_TIMEOUT)
428 .headers(headers.clone())
429 .header("content-type", "application/json")
430 .header(AUTHORIZATION_HEADER, auth_header)
431 .body(body.to_vec())
432 .send()
433 .await
434 .map_err(|e| {
435 self.provider.rollback_pending();
436 TransportErrorKind::custom(e)
437 })
438 }
439
440 async fn fetch_fresh_challenge(
446 &self,
447 headers: &reqwest::header::HeaderMap,
448 body: &[u8],
449 ) -> TransportResult<(P::Provider, mpp::protocol::core::PaymentChallenge)> {
450 let fresh_resp = self
451 .client
452 .post(self.url.clone())
453 .timeout(MPP_RETRY_TIMEOUT)
454 .headers(headers.clone())
455 .header("content-type", "application/json")
456 .body(body.to_vec())
457 .send()
458 .await
459 .map_err(TransportErrorKind::custom)?;
460
461 if fresh_resp.status() != StatusCode::PAYMENT_REQUIRED {
462 let result = Self::handle_response(fresh_resp).await;
464 return Err(result.err().unwrap_or_else(|| {
465 TransportErrorKind::custom(std::io::Error::other(
466 "unexpected success on unauthenticated fresh probe",
467 ))
468 }));
469 }
470
471 Self::select_challenge(&fresh_resp, &self.provider)
472 }
473
474 fn select_challenge(
477 resp: &reqwest::Response,
478 provider: &P,
479 ) -> TransportResult<(P::Provider, mpp::protocol::core::PaymentChallenge)> {
480 let www_auth_values: Vec<&str> = resp
481 .headers()
482 .get_all(WWW_AUTHENTICATE_HEADER)
483 .iter()
484 .filter_map(|v| v.to_str().ok())
485 .collect();
486
487 if www_auth_values.is_empty() {
488 return Err(TransportErrorKind::custom(std::io::Error::other(format!(
489 "402 response missing WWW-Authenticate header{}",
490 format_http_diagnostics(resp.headers())
491 ))));
492 }
493
494 let challenges: Vec<_> = parse_www_authenticate_all(www_auth_values)
495 .into_iter()
496 .filter_map(|r| r.ok())
497 .collect();
498
499 let mut last_resolve_err: Option<TransportError> = None;
500 let resolved_pair = challenges.iter().find_map(|c| {
501 let (chain_id, currency) = extract_challenge_chain_and_currency(c);
502 let currency = currency.and_then(|s| s.parse().ok());
503 match provider.resolve_for(DiscoverOptions { chain_id, currency }) {
504 Ok(p) => p.supports(c.method.as_str(), c.intent.as_str()).then_some((p, c.clone())),
505 Err(e) => {
506 last_resolve_err = Some(e);
507 None
508 }
509 }
510 });
511
512 resolved_pair.ok_or_else(|| {
513 if let Some(err) = last_resolve_err {
514 return err;
515 }
516 let offered: Vec<_> =
517 challenges.iter().map(|c| format!("{}.{}", c.method, c.intent)).collect();
518 TransportErrorKind::custom(std::io::Error::other(format!(
519 "no supported MPP challenge; server offered [{}]",
520 offered.join(", "),
521 )))
522 })
523 }
524
525 async fn handle_response(resp: reqwest::Response) -> TransportResult<ResponsePacket> {
526 let status = resp.status();
527 debug!(%status, "received response from MPP transport");
528 let diagnostics = format_http_diagnostics(resp.headers());
529
530 let body = resp.bytes().await.map_err(TransportErrorKind::custom)?;
531
532 if tracing::enabled!(tracing::Level::TRACE) {
533 trace!(body = %String::from_utf8_lossy(&body), "response body");
534 } else {
535 debug!(bytes = body.len(), "retrieved response body");
536 }
537
538 if !status.is_success() {
539 return Err(TransportErrorKind::http_error(
540 status.as_u16(),
541 format!("{}{diagnostics}", String::from_utf8_lossy(&body)),
542 ));
543 }
544
545 serde_json::from_slice(&body)
546 .map_err(|err| TransportError::deser_err(err, String::from_utf8_lossy(&body)))
547 }
548}
549
550pub(super) fn extract_challenge_chain_and_currency(
552 c: &mpp::protocol::core::PaymentChallenge,
553) -> (Option<u64>, Option<String>) {
554 if c.method.as_str() == "tempo" {
555 let val = c.request.decode_value().ok();
556 let chain_id = val.as_ref().and_then(|v| v.get("methodDetails")?.get("chainId")?.as_u64());
557 let currency = val.as_ref().and_then(|v| v.get("currency")?.as_str().map(String::from));
558 (chain_id, currency)
559 } else {
560 (None, None)
561 }
562}
563
564pub(crate) trait ResolveProvider {
566 type Provider: PaymentProvider;
567 fn resolve(&self) -> TransportResult<Self::Provider> {
568 self.resolve_for(Default::default())
569 }
570 fn resolve_for(&self, opts: DiscoverOptions) -> TransportResult<Self::Provider>;
571 fn set_key_provisioned(&self, _provisioned: bool) {}
572 fn is_key_provisioned(&self) -> bool {
573 true
574 }
575 fn clear_channels(&self) {}
576 fn flush_pending(&self) {}
577 fn rollback_pending(&self) {}
578 fn commit_topup_and_track_voucher(&self) {}
579 fn lock_pay(&self) -> impl std::future::Future<Output = Option<OwnedMutexGuard<()>>> + Send {
583 async { None }
584 }
585}
586
587impl<P: PaymentProvider + Clone> ResolveProvider for P {
588 type Provider = P;
589 fn resolve_for(&self, _opts: DiscoverOptions) -> TransportResult<P> {
590 Ok(self.clone())
591 }
592}
593
594impl ResolveProvider for LazySessionProvider {
595 type Provider = SessionProvider;
596 fn resolve_for(&self, opts: DiscoverOptions) -> TransportResult<SessionProvider> {
597 let provider = self.get_or_init(opts.clone())?;
598 if !provider.matches_challenge(opts.chain_id, opts.currency) {
602 return Err(TransportErrorKind::custom(std::io::Error::other(
603 "cached provider does not match challenge chain/currency",
604 )));
605 }
606 Ok(provider)
607 }
608 fn set_key_provisioned(&self, provisioned: bool) {
609 Self::set_key_provisioned(self, provisioned)
610 }
611 fn is_key_provisioned(&self) -> bool {
612 self.inner.lock().unwrap().as_ref().is_none_or(|p| p.is_key_provisioned())
613 }
614 fn clear_channels(&self) {
615 Self::clear_channels(self)
616 }
617 fn flush_pending(&self) {
618 Self::flush_pending(self)
619 }
620 fn rollback_pending(&self) {
621 Self::rollback_pending(self)
622 }
623 fn commit_topup_and_track_voucher(&self) {
624 Self::commit_topup_and_track_voucher(self)
625 }
626 fn lock_pay(&self) -> impl std::future::Future<Output = Option<OwnedMutexGuard<()>>> + Send {
627 let lock = self.pay_lock.clone();
628 async move { Some(lock.lock_owned().await) }
629 }
630}
631
632impl<P> fmt::Display for MppHttpTransport<P> {
633 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
634 write!(f, "MppHttpTransport({})", self.url)
635 }
636}
637
638#[allow(private_bounds)]
639impl<P: ResolveProvider + Clone + Send + Sync + fmt::Debug + 'static> Service<RequestPacket>
640 for MppHttpTransport<P>
641where
642 P::Provider: Send + Sync + 'static,
643{
644 type Response = ResponsePacket;
645 type Error = TransportError;
646 type Future = TransportFut<'static>;
647
648 #[inline]
649 fn poll_ready(&mut self, _cx: &mut task::Context<'_>) -> task::Poll<Result<(), Self::Error>> {
650 task::Poll::Ready(Ok(()))
651 }
652
653 #[inline]
654 fn call(&mut self, req: RequestPacket) -> Self::Future {
655 let this = self.clone();
656 let span = debug_span!("MppHttpTransport", url = %this.url);
657 Box::pin(this.do_request(req).instrument(span.or_current()))
658 }
659}
660
661#[cfg(test)]
662mod tests {
663 use super::*;
664 use crate::provider::runtime_transport::RuntimeTransportBuilder;
665 use alloy_json_rpc::{Id, Request, RequestMeta};
666 use axum::{
667 extract::State, http::StatusCode as AxumStatusCode, response::IntoResponse, routing::post,
668 };
669 use mpp::{
670 MppError,
671 protocol::core::{
672 Base64UrlJson, IntentName, MethodName, PaymentChallenge, PaymentCredential,
673 format_www_authenticate, parse_authorization,
674 },
675 };
676
677 #[derive(Clone, Debug)]
678 struct MockPaymentProvider;
679
680 impl PaymentProvider for MockPaymentProvider {
681 fn supports(&self, method: &str, intent: &str) -> bool {
682 method == "tempo" && (intent == "session" || intent == "charge")
683 }
684
685 fn pay(
686 &self,
687 challenge: &PaymentChallenge,
688 ) -> impl std::future::Future<Output = Result<PaymentCredential, MppError>> + Send {
689 let echo = challenge.to_echo();
690 async move {
691 Ok(PaymentCredential::with_source(
692 echo,
693 "test-source".to_string(),
694 serde_json::json!({"action": "voucher", "channelId": "0xtest", "cumulativeAmount": "1000", "signature": "0xtest"}),
695 ))
696 }
697 }
698 }
699
700 fn test_challenge() -> (PaymentChallenge, String) {
701 let request = Base64UrlJson::from_value(&serde_json::json!({
702 "amount": "1000",
703 "currency": "0x20c0",
704 "recipient": "0xpayee",
705 "methodDetails": {
706 "chainId": 42431
707 }
708 }))
709 .unwrap();
710
711 let challenge = PaymentChallenge {
712 id: "test-id-42".to_string(),
713 realm: "test-realm".to_string(),
714 method: MethodName::new("tempo"),
715 intent: IntentName::new("session"),
716 request,
717 expires: None,
718 description: None,
719 digest: None,
720 opaque: None,
721 };
722
723 let www_auth = format_www_authenticate(&challenge).unwrap();
724 (challenge, www_auth)
725 }
726
727 fn test_request() -> RequestPacket {
728 let req: Request<serde_json::Value> = Request {
729 meta: RequestMeta::new("eth_blockNumber".into(), Id::Number(1)),
730 params: serde_json::Value::Array(vec![]),
731 };
732 RequestPacket::Single(req.serialize().unwrap())
733 }
734
735 async fn spawn_server(app: axum::Router) -> (String, tokio::task::JoinHandle<()>) {
736 let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
737 let addr = listener.local_addr().unwrap();
738 let handle = tokio::spawn(async move {
739 axum::serve(listener, app).await.unwrap();
740 });
741 (format!("http://{addr}"), handle)
742 }
743
744 #[tokio::test]
745 async fn test_mpp_transport_no_402() {
746 let app = axum::Router::new().route(
747 "/",
748 post(|| async {
749 axum::Json(serde_json::json!({
750 "jsonrpc": "2.0",
751 "id": 1,
752 "result": "0x123"
753 }))
754 }),
755 );
756
757 let (base_url, handle) = spawn_server(app).await;
758 let mut transport = MppHttpTransport::new(
759 reqwest::Client::new(),
760 Url::parse(&base_url).unwrap(),
761 MockPaymentProvider,
762 );
763
764 let resp = tower::Service::call(&mut transport, test_request()).await.unwrap();
765 match resp {
766 ResponsePacket::Single(r) => assert!(r.is_success()),
767 _ => panic!("expected single response"),
768 }
769
770 handle.abort();
771 }
772
773 #[tokio::test]
774 async fn test_mpp_transport_402_then_success() {
775 let (_, www_auth) = test_challenge();
776 let state = AppState { www_auth };
777
778 #[derive(Clone)]
779 struct AppState {
780 www_auth: String,
781 }
782
783 let app =
784 axum::Router::new()
785 .route(
786 "/",
787 post(
788 |State(state): State<AppState>,
789 req: axum::http::Request<axum::body::Body>| async move {
790 if let Some(auth) = req.headers().get("authorization") {
791 let auth_str = auth.to_str().unwrap();
792 let credential = parse_authorization(auth_str).unwrap();
793 assert_eq!(credential.challenge.id, "test-id-42");
794 assert_eq!(credential.challenge.method.as_str(), "tempo");
795 assert!(credential.source.is_some());
796
797 (
798 AxumStatusCode::OK,
799 axum::Json(serde_json::json!({
800 "jsonrpc": "2.0",
801 "id": 1,
802 "result": "0xvalidated"
803 })),
804 )
805 .into_response()
806 } else {
807 (
808 AxumStatusCode::PAYMENT_REQUIRED,
809 [("www-authenticate", state.www_auth)],
810 "Payment Required",
811 )
812 .into_response()
813 }
814 },
815 ),
816 )
817 .with_state(state);
818
819 let (base_url, handle) = spawn_server(app).await;
820 let mut transport = MppHttpTransport::new(
821 reqwest::Client::new(),
822 Url::parse(&base_url).unwrap(),
823 MockPaymentProvider,
824 );
825
826 let resp = tower::Service::call(&mut transport, test_request()).await.unwrap();
827 match resp {
828 ResponsePacket::Single(r) => assert!(r.is_success()),
829 _ => panic!("expected single response"),
830 }
831
832 handle.abort();
833 }
834
835 #[tokio::test]
836 async fn test_mpp_transport_402_missing_www_authenticate() {
837 let app = axum::Router::new()
838 .route("/", post(|| async { (AxumStatusCode::PAYMENT_REQUIRED, "pay up") }));
839
840 let (base_url, handle) = spawn_server(app).await;
841 let mut transport = MppHttpTransport::new(
842 reqwest::Client::new(),
843 Url::parse(&base_url).unwrap(),
844 MockPaymentProvider,
845 );
846
847 let err = tower::Service::call(&mut transport, test_request()).await.unwrap_err();
848 assert!(
849 err.to_string().contains("WWW-Authenticate"),
850 "expected WWW-Authenticate error, got: {err}"
851 );
852
853 handle.abort();
854 }
855
856 #[tokio::test]
857 async fn test_plain_http_402_shows_mpp_setup_instructions() {
858 let (_, www_auth) = test_challenge();
859
860 let app = axum::Router::new().route(
861 "/",
862 post(move || {
863 let www_auth = www_auth.clone();
864 async move {
865 (
866 AxumStatusCode::PAYMENT_REQUIRED,
867 [("www-authenticate", www_auth)],
868 "Payment Required",
869 )
870 }
871 }),
872 );
873
874 let (base_url, handle) = spawn_server(app).await;
875
876 unsafe {
877 std::env::set_var("TEMPO_HOME", "/nonexistent/path");
878 std::env::remove_var("TEMPO_PRIVATE_KEY");
879 }
880
881 let transport = RuntimeTransportBuilder::new(Url::parse(&base_url).unwrap()).build();
882 let err = transport.request(test_request()).await.unwrap_err();
883 let msg = err.to_string();
884
885 assert!(
886 msg.contains("402 Payment Required") || msg.contains("no supported MPP challenge"),
887 "expected MPP setup instructions or 'no supported MPP challenge' in error, got: {msg}"
888 );
889
890 handle.abort();
891 unsafe { std::env::remove_var("TEMPO_HOME") };
892 }
893
894 #[test]
895 fn test_session_provider_supports_charge_and_session() {
896 let signer = mpp::PrivateKeySigner::random();
897 let provider =
898 super::super::session::SessionProvider::new(signer, "https://rpc.example.com".into());
899
900 assert!(provider.supports("tempo", "session"));
901 assert!(provider.supports("tempo", "charge"));
902 assert!(!provider.supports("stripe", "charge"));
903 assert!(!provider.supports("tempo", "subscribe"));
904 }
905
906 #[tokio::test]
907 async fn test_session_provider_pay_charge_parses_challenge() {
908 let signer = mpp::PrivateKeySigner::random();
909 let provider =
910 super::super::session::SessionProvider::new(signer, "https://rpc.example.com".into());
911
912 let (challenge, _) = test_challenge();
915 let err = provider.pay(&challenge).await.unwrap_err();
916 assert!(
918 !err.to_string().contains("not supported"),
919 "expected charge path to be wired up, got: {err}"
920 );
921 }
922
923 #[test]
924 fn challenge_chain_and_currency_extraction() {
925 let extract = |headers: Vec<&str>| -> Vec<(Option<u64>, Option<String>)> {
926 let challenges: Vec<_> =
927 parse_www_authenticate_all(headers).into_iter().filter_map(|r| r.ok()).collect();
928 challenges.iter().map(extract_challenge_chain_and_currency).collect()
929 };
930
931 let b64 = |v: serde_json::Value| -> String {
932 Base64UrlJson::from_value(&v).unwrap().raw().to_string()
933 };
934
935 let tempo_header = format!(
937 r#"Payment id="abc", realm="api", method="tempo", intent="charge", request="{}""#,
938 b64(
939 serde_json::json!({"amount":"1000","currency":"0x20c0","methodDetails":{"chainId":42431},"recipient":"0xabc"})
940 )
941 );
942 assert_eq!(extract(vec![&tempo_header]), vec![(Some(42431), Some("0x20c0".into()))]);
943
944 let stripe_header = format!(
946 r#"Payment id="xyz", realm="api", method="stripe", intent="charge", request="{}""#,
947 b64(serde_json::json!({"amount":"100"}))
948 );
949 assert_eq!(extract(vec![&stripe_header]), vec![(None, None)]);
950
951 let no_details = format!(
953 r#"Payment id="def", realm="api", method="tempo", intent="charge", request="{}""#,
954 b64(serde_json::json!({"amount":"1000","currency":"0x20c0","recipient":"0xabc"}))
955 );
956 assert_eq!(extract(vec![&no_details]), vec![(None, Some("0x20c0".into()))]);
957 }
958}