Skip to main content

foundry_common/provider/mpp/
session.rs

1//! Tempo session payment provider with expiring nonces.
2//!
3//! Custom implementation that mirrors `tempoxyz/wallet`'s approach: uses
4//! expiring nonces (`nonce=0`, `nonceKey=MAX`, `validBefore=now+25s`) for
5//! channel open transactions instead of fetching sequential nonces via
6//! `eth_getTransactionCount`. This avoids the chicken-and-egg problem when
7//! the RPC endpoint is itself 402-gated.
8
9use super::persist;
10use alloy_primitives::{Address, B256, Bytes, TxKind, U256};
11use foundry_wallets::Channel;
12use mpp::{
13    client::{
14        PaymentProvider,
15        channel_ops::{
16            ChannelEntry, OpenPayloadOptions, build_credential, create_voucher_payload,
17            resolve_chain_id, resolve_escrow,
18        },
19        tempo::signing::{TempoSigningMode, sign_and_encode_async},
20    },
21    error::MppError,
22    protocol::{
23        core::{PaymentChallenge, PaymentCredential},
24        intents::SessionRequest,
25        methods::tempo::session::TempoSessionExt,
26    },
27    tempo::{Call, SessionCredentialPayload, compute_channel_id, sign_voucher},
28};
29use std::{
30    collections::HashMap,
31    sync::{Arc, Mutex, OnceLock},
32};
33
34/// Shared per-origin in-memory channel state: (channels, key_provisioned).
35type SharedChannelState = (Arc<Mutex<HashMap<String, ChannelEntry>>>, Arc<Mutex<bool>>);
36
37/// Process-wide channel state registry, keyed by origin URL.
38///
39/// Stores per-origin in-memory channel maps and key provisioning state.
40static GLOBAL_CHANNELS: OnceLock<Mutex<HashMap<String, SharedChannelState>>> = OnceLock::new();
41
42/// Process-wide persisted channel state, shared across ALL origins.
43///
44/// Using a single map ensures saves from different origins don't clobber
45/// each other's state.
46static GLOBAL_PERSISTED: OnceLock<Arc<Mutex<HashMap<String, Channel>>>> = OnceLock::new();
47
48/// Tracks uncommitted channel state from the most recent payment.
49///
50/// Used to defer persistence until the server confirms acceptance, preventing
51/// local state from getting ahead of reality on failed open/top-up.
52#[derive(Clone, Debug)]
53enum PendingAction {
54    /// A new channel was opened but not yet confirmed by the server.
55    Open { key: String },
56    /// A top-up was prepared but not yet confirmed by the server.
57    TopUp { key: String, old_deposit: String },
58    /// A voucher cumulative_amount was advanced but not yet confirmed.
59    Voucher { key: String, old_cumulative: u128 },
60}
61
62/// Expiring nonce key (U256::MAX) — matches the charge flow.
63const EXPIRING_NONCE_KEY: U256 = U256::MAX;
64
65/// Validity window (in seconds) for expiring nonce transactions.
66const VALID_BEFORE_SECS: u64 = 25;
67
68/// Default gas limit for session open transactions.
69const SESSION_OPEN_GAS_LIMIT: u64 = 10_000_000;
70
71/// Max fee per gas (20 gwei — Tempo's fixed base fee).
72const MAX_FEE_PER_GAS: u128 = 20_000_000_000;
73
74/// Max priority fee per gas.
75const MAX_PRIORITY_FEE_PER_GAS: u128 = 20_000_000_000;
76
77/// Tempo session provider using expiring nonces.
78///
79/// Unlike mpp-rs's `TempoSessionProvider` which fetches sequential nonces
80/// (requiring a non-gated RPC), this provider uses expiring nonces for
81/// channel open transactions — matching how `tempoxyz/wallet` works.
82#[derive(Clone)]
83pub struct SessionProvider {
84    signer: mpp::PrivateKeySigner,
85    signing_mode: TempoSigningMode,
86    authorized_signer: Option<Address>,
87    default_deposit: Option<u128>,
88    channels: Arc<Mutex<HashMap<String, ChannelEntry>>>,
89    key_provisioned: Arc<Mutex<bool>>,
90    persisted: Arc<Mutex<HashMap<String, Channel>>>,
91    /// Tracks uncommitted open/top-up state for deferred persistence.
92    pending: Arc<Mutex<Option<PendingAction>>>,
93    /// Chain ID from the key entry in `keys.toml` that was used to initialize
94    /// this provider. Used to reject challenges for a different chain.
95    key_chain_id: Option<u64>,
96    /// Currencies from the key's spending limits. Used to reject challenges
97    /// for currencies the key cannot pay with.
98    key_currencies: Vec<Address>,
99    origin: String,
100}
101
102impl std::fmt::Debug for SessionProvider {
103    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
104        f.debug_struct("SessionProvider")
105            .field("signing_mode", &self.signing_mode)
106            .field("authorized_signer", &self.authorized_signer)
107            .field("default_deposit", &self.default_deposit)
108            .finish_non_exhaustive()
109    }
110}
111
112impl SessionProvider {
113    /// Create a new session provider with the given signer and RPC origin URL.
114    ///
115    /// Channel state is shared process-wide: all `SessionProvider` instances
116    /// share the same in-memory channels and persisted state. This prevents
117    /// concurrent providers (e.g. multiple `forge script` providers for the
118    /// same URL) from reading stale `cumulative_amount` values from disk and
119    /// producing duplicate vouchers.
120    pub fn new(signer: mpp::PrivateKeySigner, origin: String) -> Self {
121        // Global persisted map shared across all origins.
122        let persisted =
123            GLOBAL_PERSISTED.get_or_init(|| Arc::new(Mutex::new(persist::load_channels()))).clone();
124
125        // Per-origin in-memory channel map + key provisioning state.
126        let (channels, key_provisioned) = {
127            let global = GLOBAL_CHANNELS.get_or_init(|| Mutex::new(HashMap::new()));
128            let mut map = global.lock().unwrap();
129            map.entry(origin.clone())
130                .or_insert_with(|| {
131                    // Hydrate only channels belonging to this origin.
132                    let mut channels: HashMap<String, ChannelEntry> = HashMap::new();
133                    for (key, ch) in persisted.lock().unwrap().iter() {
134                        if ch.origin == origin
135                            && let Some(entry) = persist::to_channel_entry(ch)
136                        {
137                            channels.insert(key.clone(), entry);
138                        }
139                    }
140                    (Arc::new(Mutex::new(channels)), Arc::new(Mutex::new(true)))
141                })
142                .clone()
143        };
144
145        Self {
146            signer,
147            signing_mode: TempoSigningMode::Direct,
148            authorized_signer: None,
149            default_deposit: None,
150            channels,
151            key_provisioned,
152            persisted,
153            pending: Arc::new(Mutex::new(None)),
154            key_chain_id: None,
155            key_currencies: vec![],
156            origin,
157        }
158    }
159
160    /// Set the signing mode (direct or keychain).
161    pub fn with_signing_mode(mut self, mode: TempoSigningMode) -> Self {
162        self.signing_mode = mode;
163        self
164    }
165
166    /// Set the authorized signer address for keychain mode.
167    pub const fn with_authorized_signer(mut self, addr: Address) -> Self {
168        self.authorized_signer = Some(addr);
169        self
170    }
171
172    /// Set the default deposit amount.
173    pub const fn with_default_deposit(mut self, deposit: u128) -> Self {
174        self.default_deposit = Some(deposit);
175        self
176    }
177
178    /// Set the chain ID and currencies from the key entry used to initialize
179    /// this provider. Used to reject challenges for incompatible chains/currencies.
180    /// When `chain_id` is `None` (e.g. env var key), chain filtering is skipped.
181    pub fn with_key_filters(mut self, chain_id: Option<u64>, currencies: Vec<Address>) -> Self {
182        self.key_chain_id = chain_id;
183        self.key_currencies = currencies;
184        self
185    }
186
187    /// Check whether this provider's key is compatible with the given
188    /// chain ID and currency from a 402 challenge.
189    pub fn matches_challenge(&self, chain_id: Option<u64>, currency: Option<Address>) -> bool {
190        if let Some(cid) = chain_id
191            && self.key_chain_id.is_some_and(|k| k != cid)
192        {
193            return false;
194        }
195        if let Some(cur) = currency
196            && !self.key_currencies.is_empty()
197            && !self.key_currencies.contains(&cur)
198        {
199            return false;
200        }
201        true
202    }
203
204    /// Clear channels belonging to this origin (e.g. after server 410).
205    ///
206    /// Only removes channels whose `origin` matches `self.origin`, preserving
207    /// channels for other RPC endpoints.
208    pub fn clear_channels(&self) {
209        let origin = &self.origin;
210        // Lock order: channels → persisted (consistent with pay_session)
211        let mut channels = self.channels.lock().unwrap();
212        let mut persisted = self.persisted.lock().unwrap();
213        let keys_to_remove: Vec<(String, String)> = persisted
214            .iter()
215            .filter(|(_, ch)| ch.origin == *origin)
216            .map(|(k, ch): (&String, &Channel)| (k.clone(), ch.channel_id.clone()))
217            .collect();
218        for (key, channel_id) in &keys_to_remove {
219            channels.remove(key);
220            persisted.remove(key);
221            persist::delete_channel_from_db(channel_id);
222        }
223    }
224
225    /// Mark whether the access key has been provisioned on-chain.
226    pub fn set_key_provisioned(&self, provisioned: bool) {
227        *self.key_provisioned.lock().unwrap() = provisioned;
228    }
229
230    /// Check whether the access key has been provisioned on-chain.
231    pub fn is_key_provisioned(&self) -> bool {
232        *self.key_provisioned.lock().unwrap()
233    }
234
235    /// Persist any pending open/top-up/voucher state to disk.
236    ///
237    /// Called by the transport after the server confirms acceptance.
238    pub fn flush_pending(&self) {
239        let pending = self.pending.lock().unwrap().take();
240        if pending.is_some() {
241            persist::save_channels(&self.persisted.lock().unwrap());
242        }
243    }
244
245    /// Commit a pending top-up (deposit increase) without flushing to disk.
246    ///
247    /// Called by the transport when the server returns 204 (top-up accepted).
248    /// The deposit increase is now committed, but the follow-up voucher is
249    /// tracked as a new pending action.
250    pub fn commit_topup_and_track_voucher(&self) {
251        let pending = self.pending.lock().unwrap().take();
252        if let Some(PendingAction::TopUp { key, .. }) = pending {
253            // Top-up is now committed — read the current cumulative_amount
254            // so we can roll back just the voucher increment if needed.
255            let old_cumulative =
256                self.channels.lock().unwrap().get(&key).map(|e| e.cumulative_amount).unwrap_or(0);
257            *self.pending.lock().unwrap() = Some(PendingAction::Voucher { key, old_cumulative });
258        }
259    }
260
261    /// Roll back pending open/top-up/voucher state on failure.
262    ///
263    /// Called by the transport when the server rejects the payment or times out.
264    pub fn rollback_pending(&self) {
265        let pending = self.pending.lock().unwrap().take();
266        if let Some(action) = pending {
267            match action {
268                PendingAction::Open { key } => {
269                    self.channels.lock().unwrap().remove(&key);
270                    self.persisted.lock().unwrap().remove(&key);
271                }
272                PendingAction::TopUp { key, old_deposit } => {
273                    if let Some(p) = self.persisted.lock().unwrap().get_mut(&key) {
274                        p.deposit = old_deposit;
275                    }
276                }
277                PendingAction::Voucher { key, old_cumulative } => {
278                    if let Some(entry) = self.channels.lock().unwrap().get_mut(&key) {
279                        entry.cumulative_amount = old_cumulative;
280                    }
281                    if let Some(p) = self.persisted.lock().unwrap().get_mut(&key) {
282                        p.cumulative_amount = old_cumulative.to_string();
283                    }
284                }
285            }
286        }
287    }
288
289    fn channel_key(
290        origin: &str,
291        payer: &Address,
292        authorized_signer: Option<Address>,
293        payee: &Address,
294        currency: &Address,
295        escrow: &Address,
296        chain_id: u64,
297    ) -> String {
298        // Use first 8 bytes of origin hash to scope the key without persisting
299        // the full URL (which may contain secrets in query params).
300        let origin_hash = &alloy_primitives::keccak256(origin.as_bytes()).to_string()[..18];
301        let signer = authorized_signer.unwrap_or(*payer);
302        format!("{origin_hash}:{chain_id}:{payer}:{signer}:{payee}:{currency}:{escrow}")
303            .to_lowercase()
304    }
305
306    fn resolve_deposit(&self, suggested: Option<&str>) -> Result<u128, MppError> {
307        let suggested_val = suggested.and_then(|s| s.parse::<u128>().ok());
308
309        // Local config takes priority. Warn when server suggests more so users
310        // can bump MPP_DEPOSIT if the default is too low.
311        if let (Some(sv), Some(local)) = (suggested_val, self.default_deposit)
312            && sv > local
313        {
314            let _ = sh_warn!(
315                "server-suggested deposit ({sv}) exceeds local default ({local}); \
316                 set MPP_DEPOSIT to override"
317            );
318        }
319
320        let amount = self.default_deposit.or(suggested_val);
321
322        amount.ok_or_else(|| {
323            MppError::InvalidConfig("no deposit amount: set default_deposit".to_string())
324        })
325    }
326
327    async fn create_open_tx(
328        &self,
329        payer: Address,
330        options: OpenPayloadOptions,
331    ) -> Result<(ChannelEntry, SessionCredentialPayload), MppError> {
332        use alloy_sol_types::SolCall as _;
333
334        let authorized_signer = options.authorized_signer.unwrap_or(payer);
335        let salt = B256::random();
336
337        let channel_id = compute_channel_id(
338            payer,
339            options.payee,
340            options.currency,
341            salt,
342            authorized_signer,
343            options.escrow_contract,
344            options.chain_id,
345        );
346
347        alloy_sol_types::sol! {
348            interface ITIP20 {
349                function approve(address spender, uint256 amount) external returns (bool);
350            }
351            interface IEscrow {
352                function open(
353                    address payee,
354                    address token,
355                    uint128 deposit,
356                    bytes32 salt,
357                    address authorizedSigner
358                ) external;
359            }
360        }
361
362        let approve_data =
363            ITIP20::approveCall::new((options.escrow_contract, U256::from(options.deposit)))
364                .abi_encode();
365
366        let open_data = IEscrow::openCall::new((
367            options.payee,
368            options.currency,
369            options.deposit,
370            salt,
371            authorized_signer,
372        ))
373        .abi_encode();
374
375        let calls = vec![
376            Call {
377                to: TxKind::Call(options.currency),
378                value: U256::ZERO,
379                input: Bytes::from(approve_data),
380            },
381            Call {
382                to: TxKind::Call(options.escrow_contract),
383                value: U256::ZERO,
384                input: Bytes::from(open_data),
385            },
386        ];
387
388        let valid_before = {
389            let now = std::time::SystemTime::now()
390                .duration_since(std::time::UNIX_EPOCH)
391                .unwrap_or_default()
392                .as_secs();
393            Some(now + VALID_BEFORE_SECS)
394        };
395
396        let tx = mpp::client::tempo::charge::tx_builder::build_tempo_tx(
397            mpp::client::tempo::charge::tx_builder::TempoTxOptions {
398                calls,
399                chain_id: options.chain_id,
400                fee_token: options.currency,
401                nonce: 0,
402                nonce_key: EXPIRING_NONCE_KEY,
403                gas_limit: SESSION_OPEN_GAS_LIMIT,
404                max_fee_per_gas: MAX_FEE_PER_GAS,
405                max_priority_fee_per_gas: MAX_PRIORITY_FEE_PER_GAS,
406                fee_payer: options.fee_payer,
407                valid_before,
408                key_authorization: (!*self.key_provisioned.lock().unwrap())
409                    .then(|| self.signing_mode.key_authorization().cloned())
410                    .flatten(),
411            },
412        );
413
414        let signed_tx = sign_and_encode_async(tx, &self.signer, &self.signing_mode).await?;
415
416        let voucher = sign_voucher(
417            &self.signer,
418            channel_id,
419            options.initial_amount,
420            options.escrow_contract,
421            options.chain_id,
422        )
423        .await?;
424
425        let entry = ChannelEntry {
426            channel_id,
427            salt,
428            cumulative_amount: options.initial_amount,
429            escrow_contract: options.escrow_contract,
430            chain_id: options.chain_id,
431            opened: true,
432        };
433
434        let signed_tx_hex = alloy_primitives::hex::encode_prefixed(&signed_tx);
435        let voucher_sig_hex = alloy_primitives::hex::encode_prefixed(&voucher);
436
437        Ok((
438            entry,
439            SessionCredentialPayload::Open {
440                payload_type: "transaction".to_string(),
441                channel_id: channel_id.to_string(),
442                transaction: signed_tx_hex,
443                authorized_signer: Some(format!("{authorized_signer}")),
444                cumulative_amount: options.initial_amount.to_string(),
445                signature: voucher_sig_hex,
446            },
447        ))
448    }
449
450    async fn create_topup_tx(
451        &self,
452        entry: &ChannelEntry,
453        additional_deposit: u128,
454        currency: Address,
455        fee_payer: bool,
456    ) -> Result<SessionCredentialPayload, MppError> {
457        use alloy_sol_types::SolCall as _;
458
459        alloy_sol_types::sol! {
460            interface ITIP20 {
461                function approve(address spender, uint256 amount) external returns (bool);
462            }
463            interface IEscrow {
464                function topUp(bytes32 channelId, uint256 additionalDeposit) external;
465            }
466        }
467
468        let approve_data =
469            ITIP20::approveCall::new((entry.escrow_contract, U256::from(additional_deposit)))
470                .abi_encode();
471        let topup_data =
472            IEscrow::topUpCall::new((entry.channel_id, U256::from(additional_deposit)))
473                .abi_encode();
474
475        let calls = vec![
476            Call {
477                to: TxKind::Call(currency),
478                value: U256::ZERO,
479                input: Bytes::from(approve_data),
480            },
481            Call {
482                to: TxKind::Call(entry.escrow_contract),
483                value: U256::ZERO,
484                input: Bytes::from(topup_data),
485            },
486        ];
487
488        let valid_before = {
489            let now = std::time::SystemTime::now()
490                .duration_since(std::time::UNIX_EPOCH)
491                .unwrap_or_default()
492                .as_secs();
493            Some(now + VALID_BEFORE_SECS)
494        };
495
496        let tx = mpp::client::tempo::charge::tx_builder::build_tempo_tx(
497            mpp::client::tempo::charge::tx_builder::TempoTxOptions {
498                calls,
499                chain_id: entry.chain_id,
500                fee_token: currency,
501                nonce: 0,
502                nonce_key: EXPIRING_NONCE_KEY,
503                gas_limit: SESSION_OPEN_GAS_LIMIT,
504                max_fee_per_gas: MAX_FEE_PER_GAS,
505                max_priority_fee_per_gas: MAX_PRIORITY_FEE_PER_GAS,
506                fee_payer,
507                valid_before,
508                key_authorization: None,
509            },
510        );
511
512        let signed_tx = sign_and_encode_async(tx, &self.signer, &self.signing_mode).await?;
513
514        Ok(SessionCredentialPayload::TopUp {
515            payload_type: "transaction".to_string(),
516            channel_id: entry.channel_id.to_string(),
517            transaction: alloy_primitives::hex::encode_prefixed(&signed_tx),
518            additional_deposit: additional_deposit.to_string(),
519        })
520    }
521}
522
523impl SessionProvider {
524    /// Handle a charge intent by building and signing a TIP-20 transfer transaction.
525    async fn pay_charge(
526        &self,
527        challenge: &PaymentChallenge,
528    ) -> Result<PaymentCredential, MppError> {
529        use mpp::client::tempo::charge::{SignOptions, TempoCharge};
530
531        let charge = TempoCharge::from_challenge(challenge)?;
532
533        // Strip key_authorization from the signing mode when the key is already
534        // provisioned on-chain. Otherwise the payment tx includes a redundant
535        // key provisioning call that fails with "access key already exists".
536        let signing_mode = if *self.key_provisioned.lock().unwrap() {
537            match &self.signing_mode {
538                TempoSigningMode::Keychain { wallet, version, .. } => TempoSigningMode::Keychain {
539                    wallet: *wallet,
540                    key_authorization: None,
541                    version: *version,
542                },
543                other => other.clone(),
544            }
545        } else {
546            self.signing_mode.clone()
547        };
548
549        let options = SignOptions { signing_mode: Some(signing_mode), ..Default::default() };
550        let signed = charge.sign_with_options(&self.signer, options).await?;
551        Ok(signed.into_credential())
552    }
553}
554
555impl PaymentProvider for SessionProvider {
556    fn supports(&self, method: &str, intent: &str) -> bool {
557        method == "tempo" && (intent == "session" || intent == "charge")
558    }
559
560    async fn pay(&self, challenge: &PaymentChallenge) -> Result<PaymentCredential, MppError> {
561        if challenge.intent.as_str() == "charge" {
562            return self.pay_charge(challenge).await;
563        }
564        self.pay_session(challenge).await
565    }
566}
567
568impl SessionProvider {
569    async fn pay_session(
570        &self,
571        challenge: &PaymentChallenge,
572    ) -> Result<PaymentCredential, MppError> {
573        let session_req: SessionRequest = challenge.request.decode().map_err(|e| {
574            MppError::InvalidConfig(format!("failed to decode session request: {e}"))
575        })?;
576
577        let chain_id = resolve_chain_id(challenge);
578        let escrow_contract = resolve_escrow(challenge, chain_id, None)?;
579        let payee: Address = session_req
580            .recipient
581            .as_deref()
582            .ok_or_else(|| {
583                MppError::InvalidConfig("session challenge missing recipient".to_string())
584            })?
585            .parse()
586            .map_err(|_e| MppError::InvalidConfig("invalid recipient address".to_string()))?;
587        let currency: Address = session_req
588            .currency
589            .parse()
590            .map_err(|_e| MppError::InvalidConfig("invalid currency address".to_string()))?;
591        let amount: u128 = session_req.parse_amount()?;
592
593        let payer = self.signing_mode.from_address(self.signer.address());
594
595        let key = Self::channel_key(
596            &self.origin,
597            &payer,
598            self.authorized_signer,
599            &payee,
600            &currency,
601            &escrow_contract,
602            chain_id,
603        );
604
605        let voucher_info = {
606            let mut channels = self.channels.lock().unwrap();
607            if let Some(entry) = channels.get_mut(&key)
608                && entry.opened
609            {
610                let deposit = self
611                    .persisted
612                    .lock()
613                    .unwrap()
614                    .get(&key)
615                    .and_then(|p| p.deposit.parse::<u128>().ok())
616                    .unwrap_or(u128::MAX);
617
618                if entry.cumulative_amount + amount > deposit {
619                    Some(Err((entry.clone(), deposit)))
620                } else {
621                    // Clone without incrementing — only commit after
622                    // create_voucher_payload succeeds.
623                    Some(Ok(entry.clone()))
624                }
625            } else {
626                None
627            }
628        };
629
630        if let Some(result) = voucher_info {
631            match result {
632                Err((entry, deposit)) => {
633                    let additional =
634                        self.resolve_deposit(session_req.suggested_deposit.as_deref())?;
635                    tracing::debug!(
636                        cumulative = entry.cumulative_amount,
637                        amount,
638                        deposit,
639                        additional,
640                        "channel deposit exhausted, topping up"
641                    );
642
643                    let payload = self
644                        .create_topup_tx(&entry, additional, currency, session_req.fee_payer())
645                        .await?;
646
647                    // Update in-memory state but defer persistence until server confirms.
648                    let old_deposit = {
649                        let mut persisted = self.persisted.lock().unwrap();
650                        if let Some(p) = persisted.get_mut(&key) {
651                            let old = p.deposit.clone();
652                            let old_val: u128 = old.parse().unwrap_or(0);
653                            p.deposit = (old_val + additional).to_string();
654                            old
655                        } else {
656                            "0".to_string()
657                        }
658                    };
659                    *self.pending.lock().unwrap() =
660                        Some(PendingAction::TopUp { key: key.clone(), old_deposit });
661
662                    return Ok(build_credential(challenge, payload, chain_id, payer));
663                }
664                Ok(entry) => {
665                    let old_cumulative = entry.cumulative_amount;
666                    let new_cumulative = old_cumulative + amount;
667                    let payload = create_voucher_payload(
668                        &self.signer,
669                        entry.channel_id,
670                        new_cumulative,
671                        escrow_contract,
672                        chain_id,
673                    )
674                    .await?;
675
676                    // Payload succeeded — now commit the cumulative increment.
677                    {
678                        let mut channels = self.channels.lock().unwrap();
679                        if let Some(e) = channels.get_mut(&key) {
680                            e.cumulative_amount = new_cumulative;
681                        }
682                    }
683
684                    // Update in-memory persisted state but never write to disk
685                    // here — flush_pending() handles persistence after server
686                    // confirms acceptance.
687                    let updated_entry = ChannelEntry { cumulative_amount: new_cumulative, ..entry };
688                    let mut persisted = self.persisted.lock().unwrap();
689                    persist::upsert_channel_in_memory(&mut persisted, &key, &updated_entry);
690                    drop(persisted);
691
692                    // Track the voucher so we can roll back cumulative_amount
693                    // if the server rejects.
694                    if self.pending.lock().unwrap().is_none() {
695                        *self.pending.lock().unwrap() =
696                            Some(PendingAction::Voucher { key, old_cumulative });
697                    }
698
699                    return Ok(build_credential(challenge, payload, chain_id, payer));
700                }
701            }
702        }
703
704        // No existing channel — open with expiring nonces
705        let deposit = self.resolve_deposit(session_req.suggested_deposit.as_deref())?;
706
707        let (entry, payload) = self
708            .create_open_tx(
709                payer,
710                OpenPayloadOptions {
711                    authorized_signer: self.authorized_signer,
712                    escrow_contract,
713                    payee,
714                    currency,
715                    deposit,
716                    initial_amount: amount,
717                    chain_id,
718                    fee_payer: session_req.fee_payer(),
719                },
720            )
721            .await?;
722
723        // Update in-memory state but defer disk persistence until server confirms.
724        self.channels.lock().unwrap().insert(key.clone(), entry.clone());
725        let authorized_signer = self.authorized_signer.unwrap_or(payer);
726        self.persisted.lock().unwrap().insert(
727            key.clone(),
728            persist::from_channel_entry(
729                &entry,
730                deposit,
731                &self.origin,
732                &payer,
733                &payee,
734                &currency,
735                &authorized_signer,
736            ),
737        );
738        *self.pending.lock().unwrap() = Some(PendingAction::Open { key });
739        Ok(build_credential(challenge, payload, chain_id, payer))
740    }
741}
742
743#[cfg(test)]
744mod tests {
745    use super::*;
746    use mpp::client::tempo::signing::KeychainVersion;
747    use tempo_primitives::transaction::{
748        KeyAuthorization, PrimitiveSignature, SignatureType, SignedKeyAuthorization,
749    };
750
751    /// Create a dummy `SignedKeyAuthorization` for tests.
752    fn test_key_authorization() -> SignedKeyAuthorization {
753        SignedKeyAuthorization {
754            authorization: KeyAuthorization::unrestricted(
755                4217,
756                SignatureType::Secp256k1,
757                Address::ZERO,
758            ),
759            signature: PrimitiveSignature::from_bytes(&[0u8; 65]).expect("valid dummy signature"),
760        }
761    }
762
763    fn strip_key_auth_if_provisioned(
764        mode: &TempoSigningMode,
765        provisioned: bool,
766    ) -> TempoSigningMode {
767        if provisioned {
768            match mode {
769                TempoSigningMode::Keychain { wallet, version, .. } => TempoSigningMode::Keychain {
770                    wallet: *wallet,
771                    key_authorization: None,
772                    version: *version,
773                },
774                other => other.clone(),
775            }
776        } else {
777            mode.clone()
778        }
779    }
780
781    /// Generate a unique origin URL per test to avoid shared state collisions.
782    fn unique_origin() -> String {
783        format!("https://rpc-{}.example.com", alloy_primitives::B256::random())
784    }
785
786    #[test]
787    fn test_key_provisioned_default_is_true() {
788        let signer = mpp::PrivateKeySigner::random();
789        let provider = SessionProvider::new(signer, unique_origin());
790        assert!(*provider.key_provisioned.lock().unwrap());
791    }
792
793    #[test]
794    fn test_set_key_provisioned() {
795        let signer = mpp::PrivateKeySigner::random();
796        let provider = SessionProvider::new(signer, unique_origin());
797        provider.set_key_provisioned(false);
798        assert!(!*provider.key_provisioned.lock().unwrap());
799        provider.set_key_provisioned(true);
800        assert!(*provider.key_provisioned.lock().unwrap());
801    }
802
803    #[test]
804    fn test_pay_charge_strips_key_auth_when_provisioned() {
805        let signer = mpp::PrivateKeySigner::random();
806        let wallet = Address::repeat_byte(0xAA);
807        let signing_mode = TempoSigningMode::Keychain {
808            wallet,
809            key_authorization: Some(Box::new(test_key_authorization())),
810            version: KeychainVersion::V2,
811        };
812        let provider =
813            SessionProvider::new(signer, unique_origin()).with_signing_mode(signing_mode);
814
815        let provisioned = *provider.key_provisioned.lock().unwrap();
816        let result_mode = strip_key_auth_if_provisioned(&provider.signing_mode, provisioned);
817
818        assert!(
819            result_mode.key_authorization().is_none(),
820            "key_authorization should be stripped when key is provisioned"
821        );
822    }
823
824    #[test]
825    fn test_pay_charge_keeps_key_auth_when_not_provisioned() {
826        let signer = mpp::PrivateKeySigner::random();
827        let wallet = Address::repeat_byte(0xAA);
828        let signing_mode = TempoSigningMode::Keychain {
829            wallet,
830            key_authorization: Some(Box::new(test_key_authorization())),
831            version: KeychainVersion::V2,
832        };
833        let provider =
834            SessionProvider::new(signer, unique_origin()).with_signing_mode(signing_mode);
835
836        provider.set_key_provisioned(false);
837
838        let provisioned = *provider.key_provisioned.lock().unwrap();
839        let result_mode = strip_key_auth_if_provisioned(&provider.signing_mode, provisioned);
840
841        assert!(
842            result_mode.key_authorization().is_some(),
843            "key_authorization should be preserved when key is NOT provisioned"
844        );
845    }
846
847    #[test]
848    fn test_pay_charge_direct_mode_unaffected() {
849        let signer = mpp::PrivateKeySigner::random();
850        let provider = SessionProvider::new(signer, unique_origin())
851            .with_signing_mode(TempoSigningMode::Direct);
852
853        let provisioned = *provider.key_provisioned.lock().unwrap();
854        let result_mode = strip_key_auth_if_provisioned(&provider.signing_mode, provisioned);
855
856        assert!(
857            matches!(result_mode, TempoSigningMode::Direct),
858            "Direct mode should pass through unchanged"
859        );
860    }
861
862    /// Verify that a payment serialization lock (mirroring `lock_pay()` in
863    /// `LazySessionProvider`) prevents concurrent voucher increments from
864    /// producing duplicate cumulative amounts.
865    #[tokio::test]
866    async fn test_concurrent_voucher_increments_are_unique() {
867        let channels: Arc<Mutex<HashMap<String, ChannelEntry>>> =
868            Arc::new(Mutex::new(HashMap::new()));
869        let key = "test-channel".to_string();
870        channels.lock().unwrap().insert(
871            key.clone(),
872            ChannelEntry {
873                channel_id: Default::default(),
874                salt: Default::default(),
875                cumulative_amount: 0,
876                escrow_contract: Address::ZERO,
877                chain_id: 42431,
878                opened: true,
879            },
880        );
881
882        // Mirrors the `pay_lock` tokio::sync::Mutex used in LazySessionProvider
883        // to serialize the 402 → pay → retry cycle.
884        let pay_lock = std::sync::Arc::new(tokio::sync::Mutex::new(()));
885        let amount: u128 = 1000;
886        let num_tasks = 20;
887        let results: Arc<Mutex<Vec<u128>>> = Arc::new(Mutex::new(Vec::new()));
888
889        let mut handles = Vec::new();
890        for _ in 0..num_tasks {
891            let channels = channels.clone();
892            let key = key.clone();
893            let results = results.clone();
894            let pay_lock = pay_lock.clone();
895            handles.push(tokio::spawn(async move {
896                let _guard = pay_lock.lock().await;
897                let cumulative = {
898                    let mut ch = channels.lock().unwrap();
899                    let entry = ch.get_mut(&key).unwrap();
900                    entry.cumulative_amount += amount;
901                    entry.cumulative_amount
902                };
903                results.lock().unwrap().push(cumulative);
904            }));
905        }
906
907        for h in handles {
908            h.await.unwrap();
909        }
910
911        let mut amounts = results.lock().unwrap().clone();
912        amounts.sort();
913        amounts.dedup();
914        assert_eq!(
915            amounts.len(),
916            num_tasks,
917            "each concurrent increment should produce a unique cumulative_amount"
918        );
919        assert_eq!(
920            *amounts.last().unwrap(),
921            amount * num_tasks as u128,
922            "final cumulative_amount should equal amount × num_tasks"
923        );
924    }
925}