Skip to main content

foundry_common/provider/mpp/
session.rs

1//! Tempo session payment provider with expiring nonces.
2//!
3//! Custom implementation that mirrors `tempoxyz/wallet`'s approach: uses
4//! expiring nonces (`nonce=0`, `nonceKey=MAX`, `validBefore=now+25s`) for
5//! channel open transactions instead of fetching sequential nonces via
6//! `eth_getTransactionCount`. This avoids the chicken-and-egg problem when
7//! the RPC endpoint is itself 402-gated.
8
9use super::persist;
10use alloy_primitives::{Address, B256, Bytes, TxKind, U256};
11use foundry_wallets::Channel;
12use mpp::{
13    client::{
14        PaymentProvider,
15        channel_ops::{
16            ChannelEntry, OpenPayloadOptions, build_credential, create_voucher_payload,
17            resolve_chain_id, resolve_escrow,
18        },
19        tempo::signing::{TempoSigningMode, sign_and_encode_async},
20    },
21    error::MppError,
22    protocol::{
23        core::{PaymentChallenge, PaymentCredential},
24        intents::SessionRequest,
25        methods::tempo::session::TempoSessionExt,
26    },
27    tempo::{Call, SessionCredentialPayload, compute_channel_id, sign_voucher},
28};
29use std::{
30    collections::HashMap,
31    sync::{Arc, Mutex, OnceLock},
32};
33
34/// Shared per-origin in-memory channel state: (channels, key_provisioned).
35type SharedChannelState = (Arc<Mutex<HashMap<String, ChannelEntry>>>, Arc<Mutex<bool>>);
36
37/// Process-wide channel state registry, keyed by origin URL.
38///
39/// Stores per-origin in-memory channel maps and key provisioning state.
40static GLOBAL_CHANNELS: OnceLock<Mutex<HashMap<String, SharedChannelState>>> = OnceLock::new();
41
42/// Process-wide persisted channel state, shared across ALL origins.
43///
44/// Using a single map ensures saves from different origins don't clobber
45/// each other's state.
46static GLOBAL_PERSISTED: OnceLock<Arc<Mutex<HashMap<String, Channel>>>> = OnceLock::new();
47
48/// Tracks uncommitted channel state from the most recent payment.
49///
50/// Used to defer persistence until the server confirms acceptance, preventing
51/// local state from getting ahead of reality on failed open/top-up.
52#[derive(Clone, Debug)]
53enum PendingAction {
54    /// A new channel was opened but not yet confirmed by the server.
55    Open { key: String },
56    /// A top-up was prepared but not yet confirmed by the server.
57    TopUp { key: String, old_deposit: String },
58    /// A voucher cumulative_amount was advanced but not yet confirmed.
59    Voucher { key: String, old_cumulative: u128 },
60}
61
62/// Expiring nonce key (U256::MAX) — matches the charge flow.
63const EXPIRING_NONCE_KEY: U256 = U256::MAX;
64
65/// Validity window (in seconds) for expiring nonce transactions.
66const VALID_BEFORE_SECS: u64 = 25;
67
68/// Default gas limit for session open transactions.
69const SESSION_OPEN_GAS_LIMIT: u64 = 10_000_000;
70
71/// Max fee per gas (20 gwei — Tempo's fixed base fee).
72const MAX_FEE_PER_GAS: u128 = 20_000_000_000;
73
74/// Max priority fee per gas.
75const MAX_PRIORITY_FEE_PER_GAS: u128 = 20_000_000_000;
76
77/// Tempo session provider using expiring nonces.
78///
79/// Unlike mpp-rs's `TempoSessionProvider` which fetches sequential nonces
80/// (requiring a non-gated RPC), this provider uses expiring nonces for
81/// channel open transactions — matching how `tempoxyz/wallet` works.
82#[derive(Clone)]
83pub struct SessionProvider {
84    signer: mpp::PrivateKeySigner,
85    signing_mode: TempoSigningMode,
86    authorized_signer: Option<Address>,
87    default_deposit: Option<u128>,
88    channels: Arc<Mutex<HashMap<String, ChannelEntry>>>,
89    key_provisioned: Arc<Mutex<bool>>,
90    persisted: Arc<Mutex<HashMap<String, Channel>>>,
91    /// Tracks uncommitted open/top-up state for deferred persistence.
92    pending: Arc<Mutex<Option<PendingAction>>>,
93    /// Chain ID from the key entry in `keys.toml` that was used to initialize
94    /// this provider. Used to reject challenges for a different chain.
95    key_chain_id: Option<u64>,
96    /// Currencies from the key's spending limits. Used to reject challenges
97    /// for currencies the key cannot pay with.
98    key_currencies: Vec<Address>,
99    origin: String,
100}
101
102impl std::fmt::Debug for SessionProvider {
103    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
104        f.debug_struct("SessionProvider")
105            .field("signing_mode", &self.signing_mode)
106            .field("authorized_signer", &self.authorized_signer)
107            .field("default_deposit", &self.default_deposit)
108            .finish_non_exhaustive()
109    }
110}
111
112impl SessionProvider {
113    /// Create a new session provider with the given signer and RPC origin URL.
114    ///
115    /// Channel state is shared process-wide: all `SessionProvider` instances
116    /// share the same in-memory channels and persisted state. This prevents
117    /// concurrent providers (e.g. multiple `forge script` providers for the
118    /// same URL) from reading stale `cumulative_amount` values from disk and
119    /// producing duplicate vouchers.
120    pub fn new(signer: mpp::PrivateKeySigner, origin: String) -> Self {
121        // Global persisted map shared across all origins.
122        let persisted =
123            GLOBAL_PERSISTED.get_or_init(|| Arc::new(Mutex::new(persist::load_channels()))).clone();
124
125        // Per-origin in-memory channel map + key provisioning state.
126        let (channels, key_provisioned) = {
127            let global = GLOBAL_CHANNELS.get_or_init(|| Mutex::new(HashMap::new()));
128            let mut map = global.lock().unwrap();
129            map.entry(origin.clone())
130                .or_insert_with(|| {
131                    // Hydrate only channels belonging to this origin.
132                    let mut channels: HashMap<String, ChannelEntry> = HashMap::new();
133                    for (key, ch) in persisted.lock().unwrap().iter() {
134                        if ch.origin == origin
135                            && let Some(entry) = persist::to_channel_entry(ch)
136                        {
137                            channels.insert(key.clone(), entry);
138                        }
139                    }
140                    (Arc::new(Mutex::new(channels)), Arc::new(Mutex::new(true)))
141                })
142                .clone()
143        };
144
145        Self {
146            signer,
147            signing_mode: TempoSigningMode::Direct,
148            authorized_signer: None,
149            default_deposit: None,
150            channels,
151            key_provisioned,
152            persisted,
153            pending: Arc::new(Mutex::new(None)),
154            key_chain_id: None,
155            key_currencies: vec![],
156            origin,
157        }
158    }
159
160    /// Set the signing mode (direct or keychain).
161    pub fn with_signing_mode(mut self, mode: TempoSigningMode) -> Self {
162        self.signing_mode = mode;
163        self
164    }
165
166    /// Set the authorized signer address for keychain mode.
167    pub const fn with_authorized_signer(mut self, addr: Address) -> Self {
168        self.authorized_signer = Some(addr);
169        self
170    }
171
172    /// Set the default deposit amount.
173    pub const fn with_default_deposit(mut self, deposit: u128) -> Self {
174        self.default_deposit = Some(deposit);
175        self
176    }
177
178    /// Address that funds payments for this provider.
179    pub fn funding_wallet_address(&self) -> Address {
180        self.signing_mode.from_address(self.signer.address())
181    }
182
183    /// Chain ID from the selected wallet key, when known.
184    pub const fn key_chain_id(&self) -> Option<u64> {
185        self.key_chain_id
186    }
187
188    /// Set the chain ID and currencies from the key entry used to initialize
189    /// this provider. Used to reject challenges for incompatible chains/currencies.
190    /// When `chain_id` is `None` (e.g. env var key), chain filtering is skipped.
191    pub fn with_key_filters(mut self, chain_id: Option<u64>, currencies: Vec<Address>) -> Self {
192        self.key_chain_id = chain_id;
193        self.key_currencies = currencies;
194        self
195    }
196
197    /// Check whether this provider's key is compatible with the given
198    /// chain ID and currency from a 402 challenge.
199    pub fn matches_challenge(&self, chain_id: Option<u64>, currency: Option<Address>) -> bool {
200        if let Some(cid) = chain_id
201            && self.key_chain_id.is_some_and(|k| k != cid)
202        {
203            return false;
204        }
205        if let Some(cur) = currency
206            && !self.key_currencies.is_empty()
207            && !self.key_currencies.contains(&cur)
208        {
209            return false;
210        }
211        true
212    }
213
214    /// Clear channels belonging to this origin (e.g. after server 410).
215    ///
216    /// Only removes channels whose `origin` matches `self.origin`, preserving
217    /// channels for other RPC endpoints.
218    pub fn clear_channels(&self) {
219        let origin = &self.origin;
220        // Lock order: channels → persisted (consistent with pay_session)
221        let mut channels = self.channels.lock().unwrap();
222        let mut persisted = self.persisted.lock().unwrap();
223        let keys_to_remove: Vec<(String, String)> = persisted
224            .iter()
225            .filter(|(_, ch)| ch.origin == *origin)
226            .map(|(k, ch): (&String, &Channel)| (k.clone(), ch.channel_id.clone()))
227            .collect();
228        for (key, channel_id) in &keys_to_remove {
229            channels.remove(key);
230            persisted.remove(key);
231            persist::delete_channel_from_db(channel_id);
232        }
233    }
234
235    /// Mark whether the access key has been provisioned on-chain.
236    pub fn set_key_provisioned(&self, provisioned: bool) {
237        *self.key_provisioned.lock().unwrap() = provisioned;
238    }
239
240    /// Check whether the access key has been provisioned on-chain.
241    pub fn is_key_provisioned(&self) -> bool {
242        *self.key_provisioned.lock().unwrap()
243    }
244
245    /// Persist any pending open/top-up/voucher state to disk.
246    ///
247    /// Called by the transport after the server confirms acceptance.
248    pub fn flush_pending(&self) {
249        let pending = self.pending.lock().unwrap().take();
250        if pending.is_some() {
251            persist::save_channels(&self.persisted.lock().unwrap());
252        }
253    }
254
255    /// Commit a pending top-up (deposit increase) without flushing to disk.
256    ///
257    /// Called by the transport when the server returns 204 (top-up accepted).
258    /// The deposit increase is now committed, but the follow-up voucher is
259    /// tracked as a new pending action.
260    pub fn commit_topup_and_track_voucher(&self) {
261        let pending = self.pending.lock().unwrap().take();
262        if let Some(PendingAction::TopUp { key, .. }) = pending {
263            // Top-up is now committed — read the current cumulative_amount
264            // so we can roll back just the voucher increment if needed.
265            let old_cumulative =
266                self.channels.lock().unwrap().get(&key).map(|e| e.cumulative_amount).unwrap_or(0);
267            *self.pending.lock().unwrap() = Some(PendingAction::Voucher { key, old_cumulative });
268        }
269    }
270
271    /// Roll back pending open/top-up/voucher state on failure.
272    ///
273    /// Called by the transport when the server rejects the payment or times out.
274    pub fn rollback_pending(&self) {
275        let pending = self.pending.lock().unwrap().take();
276        if let Some(action) = pending {
277            match action {
278                PendingAction::Open { key } => {
279                    self.channels.lock().unwrap().remove(&key);
280                    self.persisted.lock().unwrap().remove(&key);
281                }
282                PendingAction::TopUp { key, old_deposit } => {
283                    if let Some(p) = self.persisted.lock().unwrap().get_mut(&key) {
284                        p.deposit = old_deposit;
285                    }
286                }
287                PendingAction::Voucher { key, old_cumulative } => {
288                    if let Some(entry) = self.channels.lock().unwrap().get_mut(&key) {
289                        entry.cumulative_amount = old_cumulative;
290                    }
291                    if let Some(p) = self.persisted.lock().unwrap().get_mut(&key) {
292                        p.cumulative_amount = old_cumulative.to_string();
293                    }
294                }
295            }
296        }
297    }
298
299    fn channel_key(
300        origin: &str,
301        payer: &Address,
302        authorized_signer: Option<Address>,
303        payee: &Address,
304        currency: &Address,
305        escrow: &Address,
306        chain_id: u64,
307    ) -> String {
308        // Use first 8 bytes of origin hash to scope the key without persisting
309        // the full URL (which may contain secrets in query params).
310        let origin_hash = &alloy_primitives::keccak256(origin.as_bytes()).to_string()[..18];
311        let signer = authorized_signer.unwrap_or(*payer);
312        format!("{origin_hash}:{chain_id}:{payer}:{signer}:{payee}:{currency}:{escrow}")
313            .to_lowercase()
314    }
315
316    fn resolve_deposit(&self, suggested: Option<&str>) -> Result<u128, MppError> {
317        let suggested_val = suggested.and_then(|s| s.parse::<u128>().ok());
318
319        // Local config takes priority. Warn when server suggests more so users
320        // can bump MPP_DEPOSIT if the default is too low.
321        if let (Some(sv), Some(local)) = (suggested_val, self.default_deposit)
322            && sv > local
323        {
324            let _ = sh_warn!(
325                "server-suggested deposit ({sv}) exceeds local default ({local}); \
326                 set MPP_DEPOSIT to override"
327            );
328        }
329
330        let amount = self.default_deposit.or(suggested_val);
331
332        amount.ok_or_else(|| {
333            MppError::InvalidConfig("no deposit amount: set default_deposit".to_string())
334        })
335    }
336
337    async fn create_open_tx(
338        &self,
339        payer: Address,
340        options: OpenPayloadOptions,
341    ) -> Result<(ChannelEntry, SessionCredentialPayload), MppError> {
342        use alloy_sol_types::SolCall as _;
343
344        let authorized_signer = options.authorized_signer.unwrap_or(payer);
345        let salt = B256::random();
346
347        let channel_id = compute_channel_id(
348            payer,
349            options.payee,
350            options.currency,
351            salt,
352            authorized_signer,
353            options.escrow_contract,
354            options.chain_id,
355        );
356
357        alloy_sol_types::sol! {
358            interface ITIP20 {
359                function approve(address spender, uint256 amount) external returns (bool);
360            }
361            interface IEscrow {
362                function open(
363                    address payee,
364                    address token,
365                    uint128 deposit,
366                    bytes32 salt,
367                    address authorizedSigner
368                ) external;
369            }
370        }
371
372        let approve_data =
373            ITIP20::approveCall::new((options.escrow_contract, U256::from(options.deposit)))
374                .abi_encode();
375
376        let open_data = IEscrow::openCall::new((
377            options.payee,
378            options.currency,
379            options.deposit,
380            salt,
381            authorized_signer,
382        ))
383        .abi_encode();
384
385        let calls = vec![
386            Call {
387                to: TxKind::Call(options.currency),
388                value: U256::ZERO,
389                input: Bytes::from(approve_data),
390            },
391            Call {
392                to: TxKind::Call(options.escrow_contract),
393                value: U256::ZERO,
394                input: Bytes::from(open_data),
395            },
396        ];
397
398        let valid_before = {
399            let now = std::time::SystemTime::now()
400                .duration_since(std::time::UNIX_EPOCH)
401                .unwrap_or_default()
402                .as_secs();
403            Some(now + VALID_BEFORE_SECS)
404        };
405
406        let tx = mpp::client::tempo::charge::tx_builder::build_tempo_tx(
407            mpp::client::tempo::charge::tx_builder::TempoTxOptions {
408                calls,
409                chain_id: options.chain_id,
410                fee_token: options.currency,
411                nonce: 0,
412                nonce_key: EXPIRING_NONCE_KEY,
413                gas_limit: SESSION_OPEN_GAS_LIMIT,
414                max_fee_per_gas: MAX_FEE_PER_GAS,
415                max_priority_fee_per_gas: MAX_PRIORITY_FEE_PER_GAS,
416                fee_payer: options.fee_payer,
417                valid_before,
418                key_authorization: (!*self.key_provisioned.lock().unwrap())
419                    .then(|| self.signing_mode.key_authorization().cloned())
420                    .flatten(),
421            },
422        );
423
424        let signed_tx = sign_and_encode_async(tx, &self.signer, &self.signing_mode).await?;
425
426        let voucher = sign_voucher(
427            &self.signer,
428            channel_id,
429            options.initial_amount,
430            options.escrow_contract,
431            options.chain_id,
432        )
433        .await?;
434
435        let entry = ChannelEntry {
436            channel_id,
437            salt,
438            cumulative_amount: options.initial_amount,
439            escrow_contract: options.escrow_contract,
440            chain_id: options.chain_id,
441            opened: true,
442        };
443
444        let signed_tx_hex = alloy_primitives::hex::encode_prefixed(&signed_tx);
445        let voucher_sig_hex = alloy_primitives::hex::encode_prefixed(&voucher);
446
447        Ok((
448            entry,
449            SessionCredentialPayload::Open {
450                payload_type: "transaction".to_string(),
451                channel_id: channel_id.to_string(),
452                transaction: signed_tx_hex,
453                authorized_signer: Some(format!("{authorized_signer}")),
454                cumulative_amount: options.initial_amount.to_string(),
455                signature: voucher_sig_hex,
456            },
457        ))
458    }
459
460    async fn create_topup_tx(
461        &self,
462        entry: &ChannelEntry,
463        additional_deposit: u128,
464        currency: Address,
465        fee_payer: bool,
466    ) -> Result<SessionCredentialPayload, MppError> {
467        use alloy_sol_types::SolCall as _;
468
469        alloy_sol_types::sol! {
470            interface ITIP20 {
471                function approve(address spender, uint256 amount) external returns (bool);
472            }
473            interface IEscrow {
474                function topUp(bytes32 channelId, uint256 additionalDeposit) external;
475            }
476        }
477
478        let approve_data =
479            ITIP20::approveCall::new((entry.escrow_contract, U256::from(additional_deposit)))
480                .abi_encode();
481        let topup_data =
482            IEscrow::topUpCall::new((entry.channel_id, U256::from(additional_deposit)))
483                .abi_encode();
484
485        let calls = vec![
486            Call {
487                to: TxKind::Call(currency),
488                value: U256::ZERO,
489                input: Bytes::from(approve_data),
490            },
491            Call {
492                to: TxKind::Call(entry.escrow_contract),
493                value: U256::ZERO,
494                input: Bytes::from(topup_data),
495            },
496        ];
497
498        let valid_before = {
499            let now = std::time::SystemTime::now()
500                .duration_since(std::time::UNIX_EPOCH)
501                .unwrap_or_default()
502                .as_secs();
503            Some(now + VALID_BEFORE_SECS)
504        };
505
506        let tx = mpp::client::tempo::charge::tx_builder::build_tempo_tx(
507            mpp::client::tempo::charge::tx_builder::TempoTxOptions {
508                calls,
509                chain_id: entry.chain_id,
510                fee_token: currency,
511                nonce: 0,
512                nonce_key: EXPIRING_NONCE_KEY,
513                gas_limit: SESSION_OPEN_GAS_LIMIT,
514                max_fee_per_gas: MAX_FEE_PER_GAS,
515                max_priority_fee_per_gas: MAX_PRIORITY_FEE_PER_GAS,
516                fee_payer,
517                valid_before,
518                key_authorization: None,
519            },
520        );
521
522        let signed_tx = sign_and_encode_async(tx, &self.signer, &self.signing_mode).await?;
523
524        Ok(SessionCredentialPayload::TopUp {
525            payload_type: "transaction".to_string(),
526            channel_id: entry.channel_id.to_string(),
527            transaction: alloy_primitives::hex::encode_prefixed(&signed_tx),
528            additional_deposit: additional_deposit.to_string(),
529        })
530    }
531}
532
533impl SessionProvider {
534    /// Handle a charge intent by building and signing a TIP-20 transfer transaction.
535    async fn pay_charge(
536        &self,
537        challenge: &PaymentChallenge,
538    ) -> Result<PaymentCredential, MppError> {
539        use mpp::client::tempo::charge::{SignOptions, TempoCharge};
540
541        let charge = TempoCharge::from_challenge(challenge)?;
542
543        // Strip key_authorization from the signing mode when the key is already
544        // provisioned on-chain. Otherwise the payment tx includes a redundant
545        // key provisioning call that fails with "access key already exists".
546        let signing_mode = if *self.key_provisioned.lock().unwrap() {
547            match &self.signing_mode {
548                TempoSigningMode::Keychain { wallet, version, .. } => TempoSigningMode::Keychain {
549                    wallet: *wallet,
550                    key_authorization: None,
551                    version: *version,
552                },
553                other => other.clone(),
554            }
555        } else {
556            self.signing_mode.clone()
557        };
558
559        let options = SignOptions { signing_mode: Some(signing_mode), ..Default::default() };
560        let signed = charge.sign_with_options(&self.signer, options).await?;
561        Ok(signed.into_credential())
562    }
563}
564
565impl PaymentProvider for SessionProvider {
566    fn supports(&self, method: &str, intent: &str) -> bool {
567        method == "tempo" && (intent == "session" || intent == "charge")
568    }
569
570    async fn pay(&self, challenge: &PaymentChallenge) -> Result<PaymentCredential, MppError> {
571        if challenge.intent.as_str() == "charge" {
572            return self.pay_charge(challenge).await;
573        }
574        self.pay_session(challenge).await
575    }
576}
577
578impl SessionProvider {
579    async fn pay_session(
580        &self,
581        challenge: &PaymentChallenge,
582    ) -> Result<PaymentCredential, MppError> {
583        let session_req: SessionRequest = challenge.request.decode().map_err(|e| {
584            MppError::InvalidConfig(format!("failed to decode session request: {e}"))
585        })?;
586
587        let chain_id = resolve_chain_id(challenge);
588        let escrow_contract = resolve_escrow(challenge, chain_id, None)?;
589        let payee: Address = session_req
590            .recipient
591            .as_deref()
592            .ok_or_else(|| {
593                MppError::InvalidConfig("session challenge missing recipient".to_string())
594            })?
595            .parse()
596            .map_err(|_e| MppError::InvalidConfig("invalid recipient address".to_string()))?;
597        let currency: Address = session_req
598            .currency
599            .parse()
600            .map_err(|_e| MppError::InvalidConfig("invalid currency address".to_string()))?;
601        let amount: u128 = session_req.parse_amount()?;
602
603        let payer = self.signing_mode.from_address(self.signer.address());
604
605        let key = Self::channel_key(
606            &self.origin,
607            &payer,
608            self.authorized_signer,
609            &payee,
610            &currency,
611            &escrow_contract,
612            chain_id,
613        );
614
615        let voucher_info = {
616            let mut channels = self.channels.lock().unwrap();
617            if let Some(entry) = channels.get_mut(&key)
618                && entry.opened
619            {
620                let deposit = self
621                    .persisted
622                    .lock()
623                    .unwrap()
624                    .get(&key)
625                    .and_then(|p| p.deposit.parse::<u128>().ok())
626                    .unwrap_or(u128::MAX);
627
628                if entry.cumulative_amount + amount > deposit {
629                    Some(Err((entry.clone(), deposit)))
630                } else {
631                    // Clone without incrementing — only commit after
632                    // create_voucher_payload succeeds.
633                    Some(Ok(entry.clone()))
634                }
635            } else {
636                None
637            }
638        };
639
640        if let Some(result) = voucher_info {
641            match result {
642                Err((entry, deposit)) => {
643                    let additional =
644                        self.resolve_deposit(session_req.suggested_deposit.as_deref())?;
645                    tracing::debug!(
646                        cumulative = entry.cumulative_amount,
647                        amount,
648                        deposit,
649                        additional,
650                        "channel deposit exhausted, topping up"
651                    );
652
653                    let payload = self
654                        .create_topup_tx(&entry, additional, currency, session_req.fee_payer())
655                        .await?;
656
657                    // Update in-memory state but defer persistence until server confirms.
658                    let old_deposit = {
659                        let mut persisted = self.persisted.lock().unwrap();
660                        if let Some(p) = persisted.get_mut(&key) {
661                            let old = p.deposit.clone();
662                            let old_val: u128 = old.parse().unwrap_or(0);
663                            p.deposit = (old_val + additional).to_string();
664                            old
665                        } else {
666                            "0".to_string()
667                        }
668                    };
669                    *self.pending.lock().unwrap() =
670                        Some(PendingAction::TopUp { key: key.clone(), old_deposit });
671
672                    return Ok(build_credential(challenge, payload, chain_id, payer));
673                }
674                Ok(entry) => {
675                    let old_cumulative = entry.cumulative_amount;
676                    let new_cumulative = old_cumulative + amount;
677                    let payload = create_voucher_payload(
678                        &self.signer,
679                        entry.channel_id,
680                        new_cumulative,
681                        escrow_contract,
682                        chain_id,
683                    )
684                    .await?;
685
686                    // Payload succeeded — now commit the cumulative increment.
687                    {
688                        let mut channels = self.channels.lock().unwrap();
689                        if let Some(e) = channels.get_mut(&key) {
690                            e.cumulative_amount = new_cumulative;
691                        }
692                    }
693
694                    // Update in-memory persisted state but never write to disk
695                    // here — flush_pending() handles persistence after server
696                    // confirms acceptance.
697                    let updated_entry = ChannelEntry { cumulative_amount: new_cumulative, ..entry };
698                    let mut persisted = self.persisted.lock().unwrap();
699                    persist::upsert_channel_in_memory(&mut persisted, &key, &updated_entry);
700                    drop(persisted);
701
702                    // Track the voucher so we can roll back cumulative_amount
703                    // if the server rejects.
704                    if self.pending.lock().unwrap().is_none() {
705                        *self.pending.lock().unwrap() =
706                            Some(PendingAction::Voucher { key, old_cumulative });
707                    }
708
709                    return Ok(build_credential(challenge, payload, chain_id, payer));
710                }
711            }
712        }
713
714        // No existing channel — open with expiring nonces
715        let deposit = self.resolve_deposit(session_req.suggested_deposit.as_deref())?;
716
717        let (entry, payload) = self
718            .create_open_tx(
719                payer,
720                OpenPayloadOptions {
721                    authorized_signer: self.authorized_signer,
722                    escrow_contract,
723                    payee,
724                    currency,
725                    deposit,
726                    initial_amount: amount,
727                    chain_id,
728                    fee_payer: session_req.fee_payer(),
729                },
730            )
731            .await?;
732
733        // Update in-memory state but defer disk persistence until server confirms.
734        self.channels.lock().unwrap().insert(key.clone(), entry.clone());
735        let authorized_signer = self.authorized_signer.unwrap_or(payer);
736        self.persisted.lock().unwrap().insert(
737            key.clone(),
738            persist::from_channel_entry(
739                &entry,
740                deposit,
741                &self.origin,
742                &payer,
743                &payee,
744                &currency,
745                &authorized_signer,
746            ),
747        );
748        *self.pending.lock().unwrap() = Some(PendingAction::Open { key });
749        Ok(build_credential(challenge, payload, chain_id, payer))
750    }
751}
752
753#[cfg(test)]
754mod tests {
755    use super::*;
756    use mpp::client::tempo::signing::KeychainVersion;
757    use tempo_primitives::transaction::{
758        KeyAuthorization, PrimitiveSignature, SignatureType, SignedKeyAuthorization,
759    };
760
761    /// Create a dummy `SignedKeyAuthorization` for tests.
762    fn test_key_authorization() -> SignedKeyAuthorization {
763        SignedKeyAuthorization {
764            authorization: KeyAuthorization::unrestricted(
765                4217,
766                SignatureType::Secp256k1,
767                Address::ZERO,
768            ),
769            signature: PrimitiveSignature::from_bytes(&[0u8; 65]).expect("valid dummy signature"),
770        }
771    }
772
773    fn strip_key_auth_if_provisioned(
774        mode: &TempoSigningMode,
775        provisioned: bool,
776    ) -> TempoSigningMode {
777        if provisioned {
778            match mode {
779                TempoSigningMode::Keychain { wallet, version, .. } => TempoSigningMode::Keychain {
780                    wallet: *wallet,
781                    key_authorization: None,
782                    version: *version,
783                },
784                other => other.clone(),
785            }
786        } else {
787            mode.clone()
788        }
789    }
790
791    /// Generate a unique origin URL per test to avoid shared state collisions.
792    fn unique_origin() -> String {
793        format!("https://rpc-{}.example.com", alloy_primitives::B256::random())
794    }
795
796    #[test]
797    fn test_key_provisioned_default_is_true() {
798        let signer = mpp::PrivateKeySigner::random();
799        let provider = SessionProvider::new(signer, unique_origin());
800        assert!(*provider.key_provisioned.lock().unwrap());
801    }
802
803    #[test]
804    fn test_set_key_provisioned() {
805        let signer = mpp::PrivateKeySigner::random();
806        let provider = SessionProvider::new(signer, unique_origin());
807        provider.set_key_provisioned(false);
808        assert!(!*provider.key_provisioned.lock().unwrap());
809        provider.set_key_provisioned(true);
810        assert!(*provider.key_provisioned.lock().unwrap());
811    }
812
813    #[test]
814    fn test_pay_charge_strips_key_auth_when_provisioned() {
815        let signer = mpp::PrivateKeySigner::random();
816        let wallet = Address::repeat_byte(0xAA);
817        let signing_mode = TempoSigningMode::Keychain {
818            wallet,
819            key_authorization: Some(Box::new(test_key_authorization())),
820            version: KeychainVersion::V2,
821        };
822        let provider =
823            SessionProvider::new(signer, unique_origin()).with_signing_mode(signing_mode);
824
825        let provisioned = *provider.key_provisioned.lock().unwrap();
826        let result_mode = strip_key_auth_if_provisioned(&provider.signing_mode, provisioned);
827
828        assert!(
829            result_mode.key_authorization().is_none(),
830            "key_authorization should be stripped when key is provisioned"
831        );
832    }
833
834    #[test]
835    fn test_pay_charge_keeps_key_auth_when_not_provisioned() {
836        let signer = mpp::PrivateKeySigner::random();
837        let wallet = Address::repeat_byte(0xAA);
838        let signing_mode = TempoSigningMode::Keychain {
839            wallet,
840            key_authorization: Some(Box::new(test_key_authorization())),
841            version: KeychainVersion::V2,
842        };
843        let provider =
844            SessionProvider::new(signer, unique_origin()).with_signing_mode(signing_mode);
845
846        provider.set_key_provisioned(false);
847
848        let provisioned = *provider.key_provisioned.lock().unwrap();
849        let result_mode = strip_key_auth_if_provisioned(&provider.signing_mode, provisioned);
850
851        assert!(
852            result_mode.key_authorization().is_some(),
853            "key_authorization should be preserved when key is NOT provisioned"
854        );
855    }
856
857    #[test]
858    fn test_pay_charge_direct_mode_unaffected() {
859        let signer = mpp::PrivateKeySigner::random();
860        let provider = SessionProvider::new(signer, unique_origin())
861            .with_signing_mode(TempoSigningMode::Direct);
862
863        let provisioned = *provider.key_provisioned.lock().unwrap();
864        let result_mode = strip_key_auth_if_provisioned(&provider.signing_mode, provisioned);
865
866        assert!(
867            matches!(result_mode, TempoSigningMode::Direct),
868            "Direct mode should pass through unchanged"
869        );
870    }
871
872    /// Verify that a payment serialization lock (mirroring `lock_pay()` in
873    /// `LazySessionProvider`) prevents concurrent voucher increments from
874    /// producing duplicate cumulative amounts.
875    #[tokio::test]
876    async fn test_concurrent_voucher_increments_are_unique() {
877        let channels: Arc<Mutex<HashMap<String, ChannelEntry>>> =
878            Arc::new(Mutex::new(HashMap::new()));
879        let key = "test-channel".to_string();
880        channels.lock().unwrap().insert(
881            key.clone(),
882            ChannelEntry {
883                channel_id: Default::default(),
884                salt: Default::default(),
885                cumulative_amount: 0,
886                escrow_contract: Address::ZERO,
887                chain_id: 42431,
888                opened: true,
889            },
890        );
891
892        // Mirrors the `pay_lock` tokio::sync::Mutex used in LazySessionProvider
893        // to serialize the 402 → pay → retry cycle.
894        let pay_lock = std::sync::Arc::new(tokio::sync::Mutex::new(()));
895        let amount: u128 = 1000;
896        let num_tasks = 20;
897        let results: Arc<Mutex<Vec<u128>>> = Arc::new(Mutex::new(Vec::new()));
898
899        let mut handles = Vec::new();
900        for _ in 0..num_tasks {
901            let channels = channels.clone();
902            let key = key.clone();
903            let results = results.clone();
904            let pay_lock = pay_lock.clone();
905            handles.push(tokio::spawn(async move {
906                let _guard = pay_lock.lock().await;
907                let cumulative = {
908                    let mut ch = channels.lock().unwrap();
909                    let entry = ch.get_mut(&key).unwrap();
910                    entry.cumulative_amount += amount;
911                    entry.cumulative_amount
912                };
913                results.lock().unwrap().push(cumulative);
914            }));
915        }
916
917        for h in handles {
918            h.await.unwrap();
919        }
920
921        let mut amounts = results.lock().unwrap().clone();
922        amounts.sort();
923        amounts.dedup();
924        assert_eq!(
925            amounts.len(),
926            num_tasks,
927            "each concurrent increment should produce a unique cumulative_amount"
928        );
929        assert_eq!(
930            *amounts.last().unwrap(),
931            amount * num_tasks as u128,
932            "final cumulative_amount should equal amount × num_tasks"
933        );
934    }
935}