1use super::persist;
10use alloy_primitives::{Address, B256, Bytes, TxKind, U256};
11use foundry_wallets::Channel;
12use mpp::{
13 client::{
14 PaymentProvider,
15 channel_ops::{
16 ChannelEntry, OpenPayloadOptions, build_credential, create_voucher_payload,
17 resolve_chain_id, resolve_escrow,
18 },
19 tempo::signing::{TempoSigningMode, sign_and_encode_async},
20 },
21 error::MppError,
22 protocol::{
23 core::{PaymentChallenge, PaymentCredential},
24 intents::SessionRequest,
25 methods::tempo::session::TempoSessionExt,
26 },
27 tempo::{Call, SessionCredentialPayload, compute_channel_id, sign_voucher},
28};
29use std::{
30 collections::HashMap,
31 sync::{Arc, Mutex, OnceLock},
32};
33
34type SharedChannelState = (Arc<Mutex<HashMap<String, ChannelEntry>>>, Arc<Mutex<bool>>);
36
37static GLOBAL_CHANNELS: OnceLock<Mutex<HashMap<String, SharedChannelState>>> = OnceLock::new();
41
42static GLOBAL_PERSISTED: OnceLock<Arc<Mutex<HashMap<String, Channel>>>> = OnceLock::new();
47
48#[derive(Clone, Debug)]
53enum PendingAction {
54 Open { key: String },
56 TopUp { key: String, old_deposit: String },
58 Voucher { key: String, old_cumulative: u128 },
60}
61
62const EXPIRING_NONCE_KEY: U256 = U256::MAX;
64
65const VALID_BEFORE_SECS: u64 = 25;
67
68const SESSION_OPEN_GAS_LIMIT: u64 = 10_000_000;
70
71const MAX_FEE_PER_GAS: u128 = 20_000_000_000;
73
74const MAX_PRIORITY_FEE_PER_GAS: u128 = 20_000_000_000;
76
77#[derive(Clone)]
83pub struct SessionProvider {
84 signer: mpp::PrivateKeySigner,
85 signing_mode: TempoSigningMode,
86 authorized_signer: Option<Address>,
87 default_deposit: Option<u128>,
88 channels: Arc<Mutex<HashMap<String, ChannelEntry>>>,
89 key_provisioned: Arc<Mutex<bool>>,
90 persisted: Arc<Mutex<HashMap<String, Channel>>>,
91 pending: Arc<Mutex<Option<PendingAction>>>,
93 key_chain_id: Option<u64>,
96 key_currencies: Vec<Address>,
99 origin: String,
100}
101
102impl std::fmt::Debug for SessionProvider {
103 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
104 f.debug_struct("SessionProvider")
105 .field("signing_mode", &self.signing_mode)
106 .field("authorized_signer", &self.authorized_signer)
107 .field("default_deposit", &self.default_deposit)
108 .finish_non_exhaustive()
109 }
110}
111
112impl SessionProvider {
113 pub fn new(signer: mpp::PrivateKeySigner, origin: String) -> Self {
121 let persisted =
123 GLOBAL_PERSISTED.get_or_init(|| Arc::new(Mutex::new(persist::load_channels()))).clone();
124
125 let (channels, key_provisioned) = {
127 let global = GLOBAL_CHANNELS.get_or_init(|| Mutex::new(HashMap::new()));
128 let mut map = global.lock().unwrap();
129 map.entry(origin.clone())
130 .or_insert_with(|| {
131 let mut channels: HashMap<String, ChannelEntry> = HashMap::new();
133 for (key, ch) in persisted.lock().unwrap().iter() {
134 if ch.origin == origin
135 && let Some(entry) = persist::to_channel_entry(ch)
136 {
137 channels.insert(key.clone(), entry);
138 }
139 }
140 (Arc::new(Mutex::new(channels)), Arc::new(Mutex::new(true)))
141 })
142 .clone()
143 };
144
145 Self {
146 signer,
147 signing_mode: TempoSigningMode::Direct,
148 authorized_signer: None,
149 default_deposit: None,
150 channels,
151 key_provisioned,
152 persisted,
153 pending: Arc::new(Mutex::new(None)),
154 key_chain_id: None,
155 key_currencies: vec![],
156 origin,
157 }
158 }
159
160 pub fn with_signing_mode(mut self, mode: TempoSigningMode) -> Self {
162 self.signing_mode = mode;
163 self
164 }
165
166 pub const fn with_authorized_signer(mut self, addr: Address) -> Self {
168 self.authorized_signer = Some(addr);
169 self
170 }
171
172 pub const fn with_default_deposit(mut self, deposit: u128) -> Self {
174 self.default_deposit = Some(deposit);
175 self
176 }
177
178 pub fn funding_wallet_address(&self) -> Address {
180 self.signing_mode.from_address(self.signer.address())
181 }
182
183 pub const fn key_chain_id(&self) -> Option<u64> {
185 self.key_chain_id
186 }
187
188 pub fn with_key_filters(mut self, chain_id: Option<u64>, currencies: Vec<Address>) -> Self {
192 self.key_chain_id = chain_id;
193 self.key_currencies = currencies;
194 self
195 }
196
197 pub fn matches_challenge(&self, chain_id: Option<u64>, currency: Option<Address>) -> bool {
200 if let Some(cid) = chain_id
201 && self.key_chain_id.is_some_and(|k| k != cid)
202 {
203 return false;
204 }
205 if let Some(cur) = currency
206 && !self.key_currencies.is_empty()
207 && !self.key_currencies.contains(&cur)
208 {
209 return false;
210 }
211 true
212 }
213
214 pub fn clear_channels(&self) {
219 let origin = &self.origin;
220 let mut channels = self.channels.lock().unwrap();
222 let mut persisted = self.persisted.lock().unwrap();
223 let keys_to_remove: Vec<(String, String)> = persisted
224 .iter()
225 .filter(|(_, ch)| ch.origin == *origin)
226 .map(|(k, ch): (&String, &Channel)| (k.clone(), ch.channel_id.clone()))
227 .collect();
228 for (key, channel_id) in &keys_to_remove {
229 channels.remove(key);
230 persisted.remove(key);
231 persist::delete_channel_from_db(channel_id);
232 }
233 }
234
235 pub fn set_key_provisioned(&self, provisioned: bool) {
237 *self.key_provisioned.lock().unwrap() = provisioned;
238 }
239
240 pub fn is_key_provisioned(&self) -> bool {
242 *self.key_provisioned.lock().unwrap()
243 }
244
245 pub fn flush_pending(&self) {
249 let pending = self.pending.lock().unwrap().take();
250 if pending.is_some() {
251 persist::save_channels(&self.persisted.lock().unwrap());
252 }
253 }
254
255 pub fn commit_topup_and_track_voucher(&self) {
261 let pending = self.pending.lock().unwrap().take();
262 if let Some(PendingAction::TopUp { key, .. }) = pending {
263 let old_cumulative =
266 self.channels.lock().unwrap().get(&key).map(|e| e.cumulative_amount).unwrap_or(0);
267 *self.pending.lock().unwrap() = Some(PendingAction::Voucher { key, old_cumulative });
268 }
269 }
270
271 pub fn rollback_pending(&self) {
275 let pending = self.pending.lock().unwrap().take();
276 if let Some(action) = pending {
277 match action {
278 PendingAction::Open { key } => {
279 self.channels.lock().unwrap().remove(&key);
280 self.persisted.lock().unwrap().remove(&key);
281 }
282 PendingAction::TopUp { key, old_deposit } => {
283 if let Some(p) = self.persisted.lock().unwrap().get_mut(&key) {
284 p.deposit = old_deposit;
285 }
286 }
287 PendingAction::Voucher { key, old_cumulative } => {
288 if let Some(entry) = self.channels.lock().unwrap().get_mut(&key) {
289 entry.cumulative_amount = old_cumulative;
290 }
291 if let Some(p) = self.persisted.lock().unwrap().get_mut(&key) {
292 p.cumulative_amount = old_cumulative.to_string();
293 }
294 }
295 }
296 }
297 }
298
299 fn channel_key(
300 origin: &str,
301 payer: &Address,
302 authorized_signer: Option<Address>,
303 payee: &Address,
304 currency: &Address,
305 escrow: &Address,
306 chain_id: u64,
307 ) -> String {
308 let origin_hash = &alloy_primitives::keccak256(origin.as_bytes()).to_string()[..18];
311 let signer = authorized_signer.unwrap_or(*payer);
312 format!("{origin_hash}:{chain_id}:{payer}:{signer}:{payee}:{currency}:{escrow}")
313 .to_lowercase()
314 }
315
316 fn resolve_deposit(&self, suggested: Option<&str>) -> Result<u128, MppError> {
317 let suggested_val = suggested.and_then(|s| s.parse::<u128>().ok());
318
319 if let (Some(sv), Some(local)) = (suggested_val, self.default_deposit)
322 && sv > local
323 {
324 let _ = sh_warn!(
325 "server-suggested deposit ({sv}) exceeds local default ({local}); \
326 set MPP_DEPOSIT to override"
327 );
328 }
329
330 let amount = self.default_deposit.or(suggested_val);
331
332 amount.ok_or_else(|| {
333 MppError::InvalidConfig("no deposit amount: set default_deposit".to_string())
334 })
335 }
336
337 async fn create_open_tx(
338 &self,
339 payer: Address,
340 options: OpenPayloadOptions,
341 ) -> Result<(ChannelEntry, SessionCredentialPayload), MppError> {
342 use alloy_sol_types::SolCall as _;
343
344 let authorized_signer = options.authorized_signer.unwrap_or(payer);
345 let salt = B256::random();
346
347 let channel_id = compute_channel_id(
348 payer,
349 options.payee,
350 options.currency,
351 salt,
352 authorized_signer,
353 options.escrow_contract,
354 options.chain_id,
355 );
356
357 alloy_sol_types::sol! {
358 interface ITIP20 {
359 function approve(address spender, uint256 amount) external returns (bool);
360 }
361 interface IEscrow {
362 function open(
363 address payee,
364 address token,
365 uint128 deposit,
366 bytes32 salt,
367 address authorizedSigner
368 ) external;
369 }
370 }
371
372 let approve_data =
373 ITIP20::approveCall::new((options.escrow_contract, U256::from(options.deposit)))
374 .abi_encode();
375
376 let open_data = IEscrow::openCall::new((
377 options.payee,
378 options.currency,
379 options.deposit,
380 salt,
381 authorized_signer,
382 ))
383 .abi_encode();
384
385 let calls = vec![
386 Call {
387 to: TxKind::Call(options.currency),
388 value: U256::ZERO,
389 input: Bytes::from(approve_data),
390 },
391 Call {
392 to: TxKind::Call(options.escrow_contract),
393 value: U256::ZERO,
394 input: Bytes::from(open_data),
395 },
396 ];
397
398 let valid_before = {
399 let now = std::time::SystemTime::now()
400 .duration_since(std::time::UNIX_EPOCH)
401 .unwrap_or_default()
402 .as_secs();
403 Some(now + VALID_BEFORE_SECS)
404 };
405
406 let tx = mpp::client::tempo::charge::tx_builder::build_tempo_tx(
407 mpp::client::tempo::charge::tx_builder::TempoTxOptions {
408 calls,
409 chain_id: options.chain_id,
410 fee_token: options.currency,
411 nonce: 0,
412 nonce_key: EXPIRING_NONCE_KEY,
413 gas_limit: SESSION_OPEN_GAS_LIMIT,
414 max_fee_per_gas: MAX_FEE_PER_GAS,
415 max_priority_fee_per_gas: MAX_PRIORITY_FEE_PER_GAS,
416 fee_payer: options.fee_payer,
417 valid_before,
418 key_authorization: (!*self.key_provisioned.lock().unwrap())
419 .then(|| self.signing_mode.key_authorization().cloned())
420 .flatten(),
421 },
422 );
423
424 let signed_tx = sign_and_encode_async(tx, &self.signer, &self.signing_mode).await?;
425
426 let voucher = sign_voucher(
427 &self.signer,
428 channel_id,
429 options.initial_amount,
430 options.escrow_contract,
431 options.chain_id,
432 )
433 .await?;
434
435 let entry = ChannelEntry {
436 channel_id,
437 salt,
438 cumulative_amount: options.initial_amount,
439 escrow_contract: options.escrow_contract,
440 chain_id: options.chain_id,
441 opened: true,
442 };
443
444 let signed_tx_hex = alloy_primitives::hex::encode_prefixed(&signed_tx);
445 let voucher_sig_hex = alloy_primitives::hex::encode_prefixed(&voucher);
446
447 Ok((
448 entry,
449 SessionCredentialPayload::Open {
450 payload_type: "transaction".to_string(),
451 channel_id: channel_id.to_string(),
452 transaction: signed_tx_hex,
453 authorized_signer: Some(format!("{authorized_signer}")),
454 cumulative_amount: options.initial_amount.to_string(),
455 signature: voucher_sig_hex,
456 },
457 ))
458 }
459
460 async fn create_topup_tx(
461 &self,
462 entry: &ChannelEntry,
463 additional_deposit: u128,
464 currency: Address,
465 fee_payer: bool,
466 ) -> Result<SessionCredentialPayload, MppError> {
467 use alloy_sol_types::SolCall as _;
468
469 alloy_sol_types::sol! {
470 interface ITIP20 {
471 function approve(address spender, uint256 amount) external returns (bool);
472 }
473 interface IEscrow {
474 function topUp(bytes32 channelId, uint256 additionalDeposit) external;
475 }
476 }
477
478 let approve_data =
479 ITIP20::approveCall::new((entry.escrow_contract, U256::from(additional_deposit)))
480 .abi_encode();
481 let topup_data =
482 IEscrow::topUpCall::new((entry.channel_id, U256::from(additional_deposit)))
483 .abi_encode();
484
485 let calls = vec![
486 Call {
487 to: TxKind::Call(currency),
488 value: U256::ZERO,
489 input: Bytes::from(approve_data),
490 },
491 Call {
492 to: TxKind::Call(entry.escrow_contract),
493 value: U256::ZERO,
494 input: Bytes::from(topup_data),
495 },
496 ];
497
498 let valid_before = {
499 let now = std::time::SystemTime::now()
500 .duration_since(std::time::UNIX_EPOCH)
501 .unwrap_or_default()
502 .as_secs();
503 Some(now + VALID_BEFORE_SECS)
504 };
505
506 let tx = mpp::client::tempo::charge::tx_builder::build_tempo_tx(
507 mpp::client::tempo::charge::tx_builder::TempoTxOptions {
508 calls,
509 chain_id: entry.chain_id,
510 fee_token: currency,
511 nonce: 0,
512 nonce_key: EXPIRING_NONCE_KEY,
513 gas_limit: SESSION_OPEN_GAS_LIMIT,
514 max_fee_per_gas: MAX_FEE_PER_GAS,
515 max_priority_fee_per_gas: MAX_PRIORITY_FEE_PER_GAS,
516 fee_payer,
517 valid_before,
518 key_authorization: None,
519 },
520 );
521
522 let signed_tx = sign_and_encode_async(tx, &self.signer, &self.signing_mode).await?;
523
524 Ok(SessionCredentialPayload::TopUp {
525 payload_type: "transaction".to_string(),
526 channel_id: entry.channel_id.to_string(),
527 transaction: alloy_primitives::hex::encode_prefixed(&signed_tx),
528 additional_deposit: additional_deposit.to_string(),
529 })
530 }
531}
532
533impl SessionProvider {
534 async fn pay_charge(
536 &self,
537 challenge: &PaymentChallenge,
538 ) -> Result<PaymentCredential, MppError> {
539 use mpp::client::tempo::charge::{SignOptions, TempoCharge};
540
541 let charge = TempoCharge::from_challenge(challenge)?;
542
543 let signing_mode = if *self.key_provisioned.lock().unwrap() {
547 match &self.signing_mode {
548 TempoSigningMode::Keychain { wallet, version, .. } => TempoSigningMode::Keychain {
549 wallet: *wallet,
550 key_authorization: None,
551 version: *version,
552 },
553 other => other.clone(),
554 }
555 } else {
556 self.signing_mode.clone()
557 };
558
559 let options = SignOptions { signing_mode: Some(signing_mode), ..Default::default() };
560 let signed = charge.sign_with_options(&self.signer, options).await?;
561 Ok(signed.into_credential())
562 }
563}
564
565impl PaymentProvider for SessionProvider {
566 fn supports(&self, method: &str, intent: &str) -> bool {
567 method == "tempo" && (intent == "session" || intent == "charge")
568 }
569
570 async fn pay(&self, challenge: &PaymentChallenge) -> Result<PaymentCredential, MppError> {
571 if challenge.intent.as_str() == "charge" {
572 return self.pay_charge(challenge).await;
573 }
574 self.pay_session(challenge).await
575 }
576}
577
578impl SessionProvider {
579 async fn pay_session(
580 &self,
581 challenge: &PaymentChallenge,
582 ) -> Result<PaymentCredential, MppError> {
583 let session_req: SessionRequest = challenge.request.decode().map_err(|e| {
584 MppError::InvalidConfig(format!("failed to decode session request: {e}"))
585 })?;
586
587 let chain_id = resolve_chain_id(challenge);
588 let escrow_contract = resolve_escrow(challenge, chain_id, None)?;
589 let payee: Address = session_req
590 .recipient
591 .as_deref()
592 .ok_or_else(|| {
593 MppError::InvalidConfig("session challenge missing recipient".to_string())
594 })?
595 .parse()
596 .map_err(|_e| MppError::InvalidConfig("invalid recipient address".to_string()))?;
597 let currency: Address = session_req
598 .currency
599 .parse()
600 .map_err(|_e| MppError::InvalidConfig("invalid currency address".to_string()))?;
601 let amount: u128 = session_req.parse_amount()?;
602
603 let payer = self.signing_mode.from_address(self.signer.address());
604
605 let key = Self::channel_key(
606 &self.origin,
607 &payer,
608 self.authorized_signer,
609 &payee,
610 ¤cy,
611 &escrow_contract,
612 chain_id,
613 );
614
615 let voucher_info = {
616 let mut channels = self.channels.lock().unwrap();
617 if let Some(entry) = channels.get_mut(&key)
618 && entry.opened
619 {
620 let deposit = self
621 .persisted
622 .lock()
623 .unwrap()
624 .get(&key)
625 .and_then(|p| p.deposit.parse::<u128>().ok())
626 .unwrap_or(u128::MAX);
627
628 if entry.cumulative_amount + amount > deposit {
629 Some(Err((entry.clone(), deposit)))
630 } else {
631 Some(Ok(entry.clone()))
634 }
635 } else {
636 None
637 }
638 };
639
640 if let Some(result) = voucher_info {
641 match result {
642 Err((entry, deposit)) => {
643 let additional =
644 self.resolve_deposit(session_req.suggested_deposit.as_deref())?;
645 tracing::debug!(
646 cumulative = entry.cumulative_amount,
647 amount,
648 deposit,
649 additional,
650 "channel deposit exhausted, topping up"
651 );
652
653 let payload = self
654 .create_topup_tx(&entry, additional, currency, session_req.fee_payer())
655 .await?;
656
657 let old_deposit = {
659 let mut persisted = self.persisted.lock().unwrap();
660 if let Some(p) = persisted.get_mut(&key) {
661 let old = p.deposit.clone();
662 let old_val: u128 = old.parse().unwrap_or(0);
663 p.deposit = (old_val + additional).to_string();
664 old
665 } else {
666 "0".to_string()
667 }
668 };
669 *self.pending.lock().unwrap() =
670 Some(PendingAction::TopUp { key: key.clone(), old_deposit });
671
672 return Ok(build_credential(challenge, payload, chain_id, payer));
673 }
674 Ok(entry) => {
675 let old_cumulative = entry.cumulative_amount;
676 let new_cumulative = old_cumulative + amount;
677 let payload = create_voucher_payload(
678 &self.signer,
679 entry.channel_id,
680 new_cumulative,
681 escrow_contract,
682 chain_id,
683 )
684 .await?;
685
686 {
688 let mut channels = self.channels.lock().unwrap();
689 if let Some(e) = channels.get_mut(&key) {
690 e.cumulative_amount = new_cumulative;
691 }
692 }
693
694 let updated_entry = ChannelEntry { cumulative_amount: new_cumulative, ..entry };
698 let mut persisted = self.persisted.lock().unwrap();
699 persist::upsert_channel_in_memory(&mut persisted, &key, &updated_entry);
700 drop(persisted);
701
702 if self.pending.lock().unwrap().is_none() {
705 *self.pending.lock().unwrap() =
706 Some(PendingAction::Voucher { key, old_cumulative });
707 }
708
709 return Ok(build_credential(challenge, payload, chain_id, payer));
710 }
711 }
712 }
713
714 let deposit = self.resolve_deposit(session_req.suggested_deposit.as_deref())?;
716
717 let (entry, payload) = self
718 .create_open_tx(
719 payer,
720 OpenPayloadOptions {
721 authorized_signer: self.authorized_signer,
722 escrow_contract,
723 payee,
724 currency,
725 deposit,
726 initial_amount: amount,
727 chain_id,
728 fee_payer: session_req.fee_payer(),
729 },
730 )
731 .await?;
732
733 self.channels.lock().unwrap().insert(key.clone(), entry.clone());
735 let authorized_signer = self.authorized_signer.unwrap_or(payer);
736 self.persisted.lock().unwrap().insert(
737 key.clone(),
738 persist::from_channel_entry(
739 &entry,
740 deposit,
741 &self.origin,
742 &payer,
743 &payee,
744 ¤cy,
745 &authorized_signer,
746 ),
747 );
748 *self.pending.lock().unwrap() = Some(PendingAction::Open { key });
749 Ok(build_credential(challenge, payload, chain_id, payer))
750 }
751}
752
753#[cfg(test)]
754mod tests {
755 use super::*;
756 use mpp::client::tempo::signing::KeychainVersion;
757 use tempo_primitives::transaction::{
758 KeyAuthorization, PrimitiveSignature, SignatureType, SignedKeyAuthorization,
759 };
760
761 fn test_key_authorization() -> SignedKeyAuthorization {
763 SignedKeyAuthorization {
764 authorization: KeyAuthorization::unrestricted(
765 4217,
766 SignatureType::Secp256k1,
767 Address::ZERO,
768 ),
769 signature: PrimitiveSignature::from_bytes(&[0u8; 65]).expect("valid dummy signature"),
770 }
771 }
772
773 fn strip_key_auth_if_provisioned(
774 mode: &TempoSigningMode,
775 provisioned: bool,
776 ) -> TempoSigningMode {
777 if provisioned {
778 match mode {
779 TempoSigningMode::Keychain { wallet, version, .. } => TempoSigningMode::Keychain {
780 wallet: *wallet,
781 key_authorization: None,
782 version: *version,
783 },
784 other => other.clone(),
785 }
786 } else {
787 mode.clone()
788 }
789 }
790
791 fn unique_origin() -> String {
793 format!("https://rpc-{}.example.com", alloy_primitives::B256::random())
794 }
795
796 #[test]
797 fn test_key_provisioned_default_is_true() {
798 let signer = mpp::PrivateKeySigner::random();
799 let provider = SessionProvider::new(signer, unique_origin());
800 assert!(*provider.key_provisioned.lock().unwrap());
801 }
802
803 #[test]
804 fn test_set_key_provisioned() {
805 let signer = mpp::PrivateKeySigner::random();
806 let provider = SessionProvider::new(signer, unique_origin());
807 provider.set_key_provisioned(false);
808 assert!(!*provider.key_provisioned.lock().unwrap());
809 provider.set_key_provisioned(true);
810 assert!(*provider.key_provisioned.lock().unwrap());
811 }
812
813 #[test]
814 fn test_pay_charge_strips_key_auth_when_provisioned() {
815 let signer = mpp::PrivateKeySigner::random();
816 let wallet = Address::repeat_byte(0xAA);
817 let signing_mode = TempoSigningMode::Keychain {
818 wallet,
819 key_authorization: Some(Box::new(test_key_authorization())),
820 version: KeychainVersion::V2,
821 };
822 let provider =
823 SessionProvider::new(signer, unique_origin()).with_signing_mode(signing_mode);
824
825 let provisioned = *provider.key_provisioned.lock().unwrap();
826 let result_mode = strip_key_auth_if_provisioned(&provider.signing_mode, provisioned);
827
828 assert!(
829 result_mode.key_authorization().is_none(),
830 "key_authorization should be stripped when key is provisioned"
831 );
832 }
833
834 #[test]
835 fn test_pay_charge_keeps_key_auth_when_not_provisioned() {
836 let signer = mpp::PrivateKeySigner::random();
837 let wallet = Address::repeat_byte(0xAA);
838 let signing_mode = TempoSigningMode::Keychain {
839 wallet,
840 key_authorization: Some(Box::new(test_key_authorization())),
841 version: KeychainVersion::V2,
842 };
843 let provider =
844 SessionProvider::new(signer, unique_origin()).with_signing_mode(signing_mode);
845
846 provider.set_key_provisioned(false);
847
848 let provisioned = *provider.key_provisioned.lock().unwrap();
849 let result_mode = strip_key_auth_if_provisioned(&provider.signing_mode, provisioned);
850
851 assert!(
852 result_mode.key_authorization().is_some(),
853 "key_authorization should be preserved when key is NOT provisioned"
854 );
855 }
856
857 #[test]
858 fn test_pay_charge_direct_mode_unaffected() {
859 let signer = mpp::PrivateKeySigner::random();
860 let provider = SessionProvider::new(signer, unique_origin())
861 .with_signing_mode(TempoSigningMode::Direct);
862
863 let provisioned = *provider.key_provisioned.lock().unwrap();
864 let result_mode = strip_key_auth_if_provisioned(&provider.signing_mode, provisioned);
865
866 assert!(
867 matches!(result_mode, TempoSigningMode::Direct),
868 "Direct mode should pass through unchanged"
869 );
870 }
871
872 #[tokio::test]
876 async fn test_concurrent_voucher_increments_are_unique() {
877 let channels: Arc<Mutex<HashMap<String, ChannelEntry>>> =
878 Arc::new(Mutex::new(HashMap::new()));
879 let key = "test-channel".to_string();
880 channels.lock().unwrap().insert(
881 key.clone(),
882 ChannelEntry {
883 channel_id: Default::default(),
884 salt: Default::default(),
885 cumulative_amount: 0,
886 escrow_contract: Address::ZERO,
887 chain_id: 42431,
888 opened: true,
889 },
890 );
891
892 let pay_lock = std::sync::Arc::new(tokio::sync::Mutex::new(()));
895 let amount: u128 = 1000;
896 let num_tasks = 20;
897 let results: Arc<Mutex<Vec<u128>>> = Arc::new(Mutex::new(Vec::new()));
898
899 let mut handles = Vec::new();
900 for _ in 0..num_tasks {
901 let channels = channels.clone();
902 let key = key.clone();
903 let results = results.clone();
904 let pay_lock = pay_lock.clone();
905 handles.push(tokio::spawn(async move {
906 let _guard = pay_lock.lock().await;
907 let cumulative = {
908 let mut ch = channels.lock().unwrap();
909 let entry = ch.get_mut(&key).unwrap();
910 entry.cumulative_amount += amount;
911 entry.cumulative_amount
912 };
913 results.lock().unwrap().push(cumulative);
914 }));
915 }
916
917 for h in handles {
918 h.await.unwrap();
919 }
920
921 let mut amounts = results.lock().unwrap().clone();
922 amounts.sort();
923 amounts.dedup();
924 assert_eq!(
925 amounts.len(),
926 num_tasks,
927 "each concurrent increment should produce a unique cumulative_amount"
928 );
929 assert_eq!(
930 *amounts.last().unwrap(),
931 amount * num_tasks as u128,
932 "final cumulative_amount should equal amount × num_tasks"
933 );
934 }
935}