1use super::persist;
10use alloy_primitives::{Address, B256, Bytes, TxKind, U256};
11use foundry_wallets::Channel;
12use mpp::{
13 client::{
14 PaymentProvider,
15 channel_ops::{
16 ChannelEntry, OpenPayloadOptions, build_credential, create_voucher_payload,
17 resolve_chain_id, resolve_escrow,
18 },
19 tempo::signing::{TempoSigningMode, sign_and_encode_async},
20 },
21 error::MppError,
22 protocol::{
23 core::{PaymentChallenge, PaymentCredential},
24 intents::SessionRequest,
25 methods::tempo::session::TempoSessionExt,
26 },
27 tempo::{Call, SessionCredentialPayload, compute_channel_id, sign_voucher},
28};
29use std::{
30 collections::HashMap,
31 sync::{Arc, Mutex, OnceLock},
32};
33
34type SharedChannelState = (Arc<Mutex<HashMap<String, ChannelEntry>>>, Arc<Mutex<bool>>);
36
37static GLOBAL_CHANNELS: OnceLock<Mutex<HashMap<String, SharedChannelState>>> = OnceLock::new();
41
42static GLOBAL_PERSISTED: OnceLock<Arc<Mutex<HashMap<String, Channel>>>> = OnceLock::new();
47
48#[derive(Clone, Debug)]
53enum PendingAction {
54 Open { key: String },
56 TopUp { key: String, old_deposit: String },
58 Voucher { key: String, old_cumulative: u128 },
60}
61
62const EXPIRING_NONCE_KEY: U256 = U256::MAX;
64
65const VALID_BEFORE_SECS: u64 = 25;
67
68const SESSION_OPEN_GAS_LIMIT: u64 = 10_000_000;
70
71const MAX_FEE_PER_GAS: u128 = 20_000_000_000;
73
74const MAX_PRIORITY_FEE_PER_GAS: u128 = 20_000_000_000;
76
77#[derive(Clone)]
83pub struct SessionProvider {
84 signer: mpp::PrivateKeySigner,
85 signing_mode: TempoSigningMode,
86 authorized_signer: Option<Address>,
87 default_deposit: Option<u128>,
88 channels: Arc<Mutex<HashMap<String, ChannelEntry>>>,
89 key_provisioned: Arc<Mutex<bool>>,
90 persisted: Arc<Mutex<HashMap<String, Channel>>>,
91 pending: Arc<Mutex<Option<PendingAction>>>,
93 key_chain_id: Option<u64>,
96 key_currencies: Vec<Address>,
99 origin: String,
100}
101
102impl std::fmt::Debug for SessionProvider {
103 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
104 f.debug_struct("SessionProvider")
105 .field("signing_mode", &self.signing_mode)
106 .field("authorized_signer", &self.authorized_signer)
107 .field("default_deposit", &self.default_deposit)
108 .finish_non_exhaustive()
109 }
110}
111
112impl SessionProvider {
113 pub fn new(signer: mpp::PrivateKeySigner, origin: String) -> Self {
121 let persisted =
123 GLOBAL_PERSISTED.get_or_init(|| Arc::new(Mutex::new(persist::load_channels()))).clone();
124
125 let (channels, key_provisioned) = {
127 let global = GLOBAL_CHANNELS.get_or_init(|| Mutex::new(HashMap::new()));
128 let mut map = global.lock().unwrap();
129 map.entry(origin.clone())
130 .or_insert_with(|| {
131 let mut channels: HashMap<String, ChannelEntry> = HashMap::new();
133 for (key, ch) in persisted.lock().unwrap().iter() {
134 if ch.origin == origin
135 && let Some(entry) = persist::to_channel_entry(ch)
136 {
137 channels.insert(key.clone(), entry);
138 }
139 }
140 (Arc::new(Mutex::new(channels)), Arc::new(Mutex::new(true)))
141 })
142 .clone()
143 };
144
145 Self {
146 signer,
147 signing_mode: TempoSigningMode::Direct,
148 authorized_signer: None,
149 default_deposit: None,
150 channels,
151 key_provisioned,
152 persisted,
153 pending: Arc::new(Mutex::new(None)),
154 key_chain_id: None,
155 key_currencies: vec![],
156 origin,
157 }
158 }
159
160 pub fn with_signing_mode(mut self, mode: TempoSigningMode) -> Self {
162 self.signing_mode = mode;
163 self
164 }
165
166 pub const fn with_authorized_signer(mut self, addr: Address) -> Self {
168 self.authorized_signer = Some(addr);
169 self
170 }
171
172 pub const fn with_default_deposit(mut self, deposit: u128) -> Self {
174 self.default_deposit = Some(deposit);
175 self
176 }
177
178 pub fn with_key_filters(mut self, chain_id: Option<u64>, currencies: Vec<Address>) -> Self {
182 self.key_chain_id = chain_id;
183 self.key_currencies = currencies;
184 self
185 }
186
187 pub fn matches_challenge(&self, chain_id: Option<u64>, currency: Option<Address>) -> bool {
190 if let Some(cid) = chain_id
191 && self.key_chain_id.is_some_and(|k| k != cid)
192 {
193 return false;
194 }
195 if let Some(cur) = currency
196 && !self.key_currencies.is_empty()
197 && !self.key_currencies.contains(&cur)
198 {
199 return false;
200 }
201 true
202 }
203
204 pub fn clear_channels(&self) {
209 let origin = &self.origin;
210 let mut channels = self.channels.lock().unwrap();
212 let mut persisted = self.persisted.lock().unwrap();
213 let keys_to_remove: Vec<(String, String)> = persisted
214 .iter()
215 .filter(|(_, ch)| ch.origin == *origin)
216 .map(|(k, ch): (&String, &Channel)| (k.clone(), ch.channel_id.clone()))
217 .collect();
218 for (key, channel_id) in &keys_to_remove {
219 channels.remove(key);
220 persisted.remove(key);
221 persist::delete_channel_from_db(channel_id);
222 }
223 }
224
225 pub fn set_key_provisioned(&self, provisioned: bool) {
227 *self.key_provisioned.lock().unwrap() = provisioned;
228 }
229
230 pub fn is_key_provisioned(&self) -> bool {
232 *self.key_provisioned.lock().unwrap()
233 }
234
235 pub fn flush_pending(&self) {
239 let pending = self.pending.lock().unwrap().take();
240 if pending.is_some() {
241 persist::save_channels(&self.persisted.lock().unwrap());
242 }
243 }
244
245 pub fn commit_topup_and_track_voucher(&self) {
251 let pending = self.pending.lock().unwrap().take();
252 if let Some(PendingAction::TopUp { key, .. }) = pending {
253 let old_cumulative =
256 self.channels.lock().unwrap().get(&key).map(|e| e.cumulative_amount).unwrap_or(0);
257 *self.pending.lock().unwrap() = Some(PendingAction::Voucher { key, old_cumulative });
258 }
259 }
260
261 pub fn rollback_pending(&self) {
265 let pending = self.pending.lock().unwrap().take();
266 if let Some(action) = pending {
267 match action {
268 PendingAction::Open { key } => {
269 self.channels.lock().unwrap().remove(&key);
270 self.persisted.lock().unwrap().remove(&key);
271 }
272 PendingAction::TopUp { key, old_deposit } => {
273 if let Some(p) = self.persisted.lock().unwrap().get_mut(&key) {
274 p.deposit = old_deposit;
275 }
276 }
277 PendingAction::Voucher { key, old_cumulative } => {
278 if let Some(entry) = self.channels.lock().unwrap().get_mut(&key) {
279 entry.cumulative_amount = old_cumulative;
280 }
281 if let Some(p) = self.persisted.lock().unwrap().get_mut(&key) {
282 p.cumulative_amount = old_cumulative.to_string();
283 }
284 }
285 }
286 }
287 }
288
289 fn channel_key(
290 origin: &str,
291 payer: &Address,
292 authorized_signer: Option<Address>,
293 payee: &Address,
294 currency: &Address,
295 escrow: &Address,
296 chain_id: u64,
297 ) -> String {
298 let origin_hash = &alloy_primitives::keccak256(origin.as_bytes()).to_string()[..18];
301 let signer = authorized_signer.unwrap_or(*payer);
302 format!("{origin_hash}:{chain_id}:{payer}:{signer}:{payee}:{currency}:{escrow}")
303 .to_lowercase()
304 }
305
306 fn resolve_deposit(&self, suggested: Option<&str>) -> Result<u128, MppError> {
307 let suggested_val = suggested.and_then(|s| s.parse::<u128>().ok());
308
309 if let (Some(sv), Some(local)) = (suggested_val, self.default_deposit)
312 && sv > local
313 {
314 let _ = sh_warn!(
315 "server-suggested deposit ({sv}) exceeds local default ({local}); \
316 set MPP_DEPOSIT to override"
317 );
318 }
319
320 let amount = self.default_deposit.or(suggested_val);
321
322 amount.ok_or_else(|| {
323 MppError::InvalidConfig("no deposit amount: set default_deposit".to_string())
324 })
325 }
326
327 async fn create_open_tx(
328 &self,
329 payer: Address,
330 options: OpenPayloadOptions,
331 ) -> Result<(ChannelEntry, SessionCredentialPayload), MppError> {
332 use alloy_sol_types::SolCall as _;
333
334 let authorized_signer = options.authorized_signer.unwrap_or(payer);
335 let salt = B256::random();
336
337 let channel_id = compute_channel_id(
338 payer,
339 options.payee,
340 options.currency,
341 salt,
342 authorized_signer,
343 options.escrow_contract,
344 options.chain_id,
345 );
346
347 alloy_sol_types::sol! {
348 interface ITIP20 {
349 function approve(address spender, uint256 amount) external returns (bool);
350 }
351 interface IEscrow {
352 function open(
353 address payee,
354 address token,
355 uint128 deposit,
356 bytes32 salt,
357 address authorizedSigner
358 ) external;
359 }
360 }
361
362 let approve_data =
363 ITIP20::approveCall::new((options.escrow_contract, U256::from(options.deposit)))
364 .abi_encode();
365
366 let open_data = IEscrow::openCall::new((
367 options.payee,
368 options.currency,
369 options.deposit,
370 salt,
371 authorized_signer,
372 ))
373 .abi_encode();
374
375 let calls = vec![
376 Call {
377 to: TxKind::Call(options.currency),
378 value: U256::ZERO,
379 input: Bytes::from(approve_data),
380 },
381 Call {
382 to: TxKind::Call(options.escrow_contract),
383 value: U256::ZERO,
384 input: Bytes::from(open_data),
385 },
386 ];
387
388 let valid_before = {
389 let now = std::time::SystemTime::now()
390 .duration_since(std::time::UNIX_EPOCH)
391 .unwrap_or_default()
392 .as_secs();
393 Some(now + VALID_BEFORE_SECS)
394 };
395
396 let tx = mpp::client::tempo::charge::tx_builder::build_tempo_tx(
397 mpp::client::tempo::charge::tx_builder::TempoTxOptions {
398 calls,
399 chain_id: options.chain_id,
400 fee_token: options.currency,
401 nonce: 0,
402 nonce_key: EXPIRING_NONCE_KEY,
403 gas_limit: SESSION_OPEN_GAS_LIMIT,
404 max_fee_per_gas: MAX_FEE_PER_GAS,
405 max_priority_fee_per_gas: MAX_PRIORITY_FEE_PER_GAS,
406 fee_payer: options.fee_payer,
407 valid_before,
408 key_authorization: (!*self.key_provisioned.lock().unwrap())
409 .then(|| self.signing_mode.key_authorization().cloned())
410 .flatten(),
411 },
412 );
413
414 let signed_tx = sign_and_encode_async(tx, &self.signer, &self.signing_mode).await?;
415
416 let voucher = sign_voucher(
417 &self.signer,
418 channel_id,
419 options.initial_amount,
420 options.escrow_contract,
421 options.chain_id,
422 )
423 .await?;
424
425 let entry = ChannelEntry {
426 channel_id,
427 salt,
428 cumulative_amount: options.initial_amount,
429 escrow_contract: options.escrow_contract,
430 chain_id: options.chain_id,
431 opened: true,
432 };
433
434 let signed_tx_hex = alloy_primitives::hex::encode_prefixed(&signed_tx);
435 let voucher_sig_hex = alloy_primitives::hex::encode_prefixed(&voucher);
436
437 Ok((
438 entry,
439 SessionCredentialPayload::Open {
440 payload_type: "transaction".to_string(),
441 channel_id: channel_id.to_string(),
442 transaction: signed_tx_hex,
443 authorized_signer: Some(format!("{authorized_signer}")),
444 cumulative_amount: options.initial_amount.to_string(),
445 signature: voucher_sig_hex,
446 },
447 ))
448 }
449
450 async fn create_topup_tx(
451 &self,
452 entry: &ChannelEntry,
453 additional_deposit: u128,
454 currency: Address,
455 fee_payer: bool,
456 ) -> Result<SessionCredentialPayload, MppError> {
457 use alloy_sol_types::SolCall as _;
458
459 alloy_sol_types::sol! {
460 interface ITIP20 {
461 function approve(address spender, uint256 amount) external returns (bool);
462 }
463 interface IEscrow {
464 function topUp(bytes32 channelId, uint256 additionalDeposit) external;
465 }
466 }
467
468 let approve_data =
469 ITIP20::approveCall::new((entry.escrow_contract, U256::from(additional_deposit)))
470 .abi_encode();
471 let topup_data =
472 IEscrow::topUpCall::new((entry.channel_id, U256::from(additional_deposit)))
473 .abi_encode();
474
475 let calls = vec![
476 Call {
477 to: TxKind::Call(currency),
478 value: U256::ZERO,
479 input: Bytes::from(approve_data),
480 },
481 Call {
482 to: TxKind::Call(entry.escrow_contract),
483 value: U256::ZERO,
484 input: Bytes::from(topup_data),
485 },
486 ];
487
488 let valid_before = {
489 let now = std::time::SystemTime::now()
490 .duration_since(std::time::UNIX_EPOCH)
491 .unwrap_or_default()
492 .as_secs();
493 Some(now + VALID_BEFORE_SECS)
494 };
495
496 let tx = mpp::client::tempo::charge::tx_builder::build_tempo_tx(
497 mpp::client::tempo::charge::tx_builder::TempoTxOptions {
498 calls,
499 chain_id: entry.chain_id,
500 fee_token: currency,
501 nonce: 0,
502 nonce_key: EXPIRING_NONCE_KEY,
503 gas_limit: SESSION_OPEN_GAS_LIMIT,
504 max_fee_per_gas: MAX_FEE_PER_GAS,
505 max_priority_fee_per_gas: MAX_PRIORITY_FEE_PER_GAS,
506 fee_payer,
507 valid_before,
508 key_authorization: None,
509 },
510 );
511
512 let signed_tx = sign_and_encode_async(tx, &self.signer, &self.signing_mode).await?;
513
514 Ok(SessionCredentialPayload::TopUp {
515 payload_type: "transaction".to_string(),
516 channel_id: entry.channel_id.to_string(),
517 transaction: alloy_primitives::hex::encode_prefixed(&signed_tx),
518 additional_deposit: additional_deposit.to_string(),
519 })
520 }
521}
522
523impl SessionProvider {
524 async fn pay_charge(
526 &self,
527 challenge: &PaymentChallenge,
528 ) -> Result<PaymentCredential, MppError> {
529 use mpp::client::tempo::charge::{SignOptions, TempoCharge};
530
531 let charge = TempoCharge::from_challenge(challenge)?;
532
533 let signing_mode = if *self.key_provisioned.lock().unwrap() {
537 match &self.signing_mode {
538 TempoSigningMode::Keychain { wallet, version, .. } => TempoSigningMode::Keychain {
539 wallet: *wallet,
540 key_authorization: None,
541 version: *version,
542 },
543 other => other.clone(),
544 }
545 } else {
546 self.signing_mode.clone()
547 };
548
549 let options = SignOptions { signing_mode: Some(signing_mode), ..Default::default() };
550 let signed = charge.sign_with_options(&self.signer, options).await?;
551 Ok(signed.into_credential())
552 }
553}
554
555impl PaymentProvider for SessionProvider {
556 fn supports(&self, method: &str, intent: &str) -> bool {
557 method == "tempo" && (intent == "session" || intent == "charge")
558 }
559
560 async fn pay(&self, challenge: &PaymentChallenge) -> Result<PaymentCredential, MppError> {
561 if challenge.intent.as_str() == "charge" {
562 return self.pay_charge(challenge).await;
563 }
564 self.pay_session(challenge).await
565 }
566}
567
568impl SessionProvider {
569 async fn pay_session(
570 &self,
571 challenge: &PaymentChallenge,
572 ) -> Result<PaymentCredential, MppError> {
573 let session_req: SessionRequest = challenge.request.decode().map_err(|e| {
574 MppError::InvalidConfig(format!("failed to decode session request: {e}"))
575 })?;
576
577 let chain_id = resolve_chain_id(challenge);
578 let escrow_contract = resolve_escrow(challenge, chain_id, None)?;
579 let payee: Address = session_req
580 .recipient
581 .as_deref()
582 .ok_or_else(|| {
583 MppError::InvalidConfig("session challenge missing recipient".to_string())
584 })?
585 .parse()
586 .map_err(|_e| MppError::InvalidConfig("invalid recipient address".to_string()))?;
587 let currency: Address = session_req
588 .currency
589 .parse()
590 .map_err(|_e| MppError::InvalidConfig("invalid currency address".to_string()))?;
591 let amount: u128 = session_req.parse_amount()?;
592
593 let payer = self.signing_mode.from_address(self.signer.address());
594
595 let key = Self::channel_key(
596 &self.origin,
597 &payer,
598 self.authorized_signer,
599 &payee,
600 ¤cy,
601 &escrow_contract,
602 chain_id,
603 );
604
605 let voucher_info = {
606 let mut channels = self.channels.lock().unwrap();
607 if let Some(entry) = channels.get_mut(&key)
608 && entry.opened
609 {
610 let deposit = self
611 .persisted
612 .lock()
613 .unwrap()
614 .get(&key)
615 .and_then(|p| p.deposit.parse::<u128>().ok())
616 .unwrap_or(u128::MAX);
617
618 if entry.cumulative_amount + amount > deposit {
619 Some(Err((entry.clone(), deposit)))
620 } else {
621 Some(Ok(entry.clone()))
624 }
625 } else {
626 None
627 }
628 };
629
630 if let Some(result) = voucher_info {
631 match result {
632 Err((entry, deposit)) => {
633 let additional =
634 self.resolve_deposit(session_req.suggested_deposit.as_deref())?;
635 tracing::debug!(
636 cumulative = entry.cumulative_amount,
637 amount,
638 deposit,
639 additional,
640 "channel deposit exhausted, topping up"
641 );
642
643 let payload = self
644 .create_topup_tx(&entry, additional, currency, session_req.fee_payer())
645 .await?;
646
647 let old_deposit = {
649 let mut persisted = self.persisted.lock().unwrap();
650 if let Some(p) = persisted.get_mut(&key) {
651 let old = p.deposit.clone();
652 let old_val: u128 = old.parse().unwrap_or(0);
653 p.deposit = (old_val + additional).to_string();
654 old
655 } else {
656 "0".to_string()
657 }
658 };
659 *self.pending.lock().unwrap() =
660 Some(PendingAction::TopUp { key: key.clone(), old_deposit });
661
662 return Ok(build_credential(challenge, payload, chain_id, payer));
663 }
664 Ok(entry) => {
665 let old_cumulative = entry.cumulative_amount;
666 let new_cumulative = old_cumulative + amount;
667 let payload = create_voucher_payload(
668 &self.signer,
669 entry.channel_id,
670 new_cumulative,
671 escrow_contract,
672 chain_id,
673 )
674 .await?;
675
676 {
678 let mut channels = self.channels.lock().unwrap();
679 if let Some(e) = channels.get_mut(&key) {
680 e.cumulative_amount = new_cumulative;
681 }
682 }
683
684 let updated_entry = ChannelEntry { cumulative_amount: new_cumulative, ..entry };
688 let mut persisted = self.persisted.lock().unwrap();
689 persist::upsert_channel_in_memory(&mut persisted, &key, &updated_entry);
690 drop(persisted);
691
692 if self.pending.lock().unwrap().is_none() {
695 *self.pending.lock().unwrap() =
696 Some(PendingAction::Voucher { key, old_cumulative });
697 }
698
699 return Ok(build_credential(challenge, payload, chain_id, payer));
700 }
701 }
702 }
703
704 let deposit = self.resolve_deposit(session_req.suggested_deposit.as_deref())?;
706
707 let (entry, payload) = self
708 .create_open_tx(
709 payer,
710 OpenPayloadOptions {
711 authorized_signer: self.authorized_signer,
712 escrow_contract,
713 payee,
714 currency,
715 deposit,
716 initial_amount: amount,
717 chain_id,
718 fee_payer: session_req.fee_payer(),
719 },
720 )
721 .await?;
722
723 self.channels.lock().unwrap().insert(key.clone(), entry.clone());
725 let authorized_signer = self.authorized_signer.unwrap_or(payer);
726 self.persisted.lock().unwrap().insert(
727 key.clone(),
728 persist::from_channel_entry(
729 &entry,
730 deposit,
731 &self.origin,
732 &payer,
733 &payee,
734 ¤cy,
735 &authorized_signer,
736 ),
737 );
738 *self.pending.lock().unwrap() = Some(PendingAction::Open { key });
739 Ok(build_credential(challenge, payload, chain_id, payer))
740 }
741}
742
743#[cfg(test)]
744mod tests {
745 use super::*;
746 use mpp::client::tempo::signing::KeychainVersion;
747 use tempo_primitives::transaction::{
748 KeyAuthorization, PrimitiveSignature, SignatureType, SignedKeyAuthorization,
749 };
750
751 fn test_key_authorization() -> SignedKeyAuthorization {
753 SignedKeyAuthorization {
754 authorization: KeyAuthorization::unrestricted(
755 4217,
756 SignatureType::Secp256k1,
757 Address::ZERO,
758 ),
759 signature: PrimitiveSignature::from_bytes(&[0u8; 65]).expect("valid dummy signature"),
760 }
761 }
762
763 fn strip_key_auth_if_provisioned(
764 mode: &TempoSigningMode,
765 provisioned: bool,
766 ) -> TempoSigningMode {
767 if provisioned {
768 match mode {
769 TempoSigningMode::Keychain { wallet, version, .. } => TempoSigningMode::Keychain {
770 wallet: *wallet,
771 key_authorization: None,
772 version: *version,
773 },
774 other => other.clone(),
775 }
776 } else {
777 mode.clone()
778 }
779 }
780
781 fn unique_origin() -> String {
783 format!("https://rpc-{}.example.com", alloy_primitives::B256::random())
784 }
785
786 #[test]
787 fn test_key_provisioned_default_is_true() {
788 let signer = mpp::PrivateKeySigner::random();
789 let provider = SessionProvider::new(signer, unique_origin());
790 assert!(*provider.key_provisioned.lock().unwrap());
791 }
792
793 #[test]
794 fn test_set_key_provisioned() {
795 let signer = mpp::PrivateKeySigner::random();
796 let provider = SessionProvider::new(signer, unique_origin());
797 provider.set_key_provisioned(false);
798 assert!(!*provider.key_provisioned.lock().unwrap());
799 provider.set_key_provisioned(true);
800 assert!(*provider.key_provisioned.lock().unwrap());
801 }
802
803 #[test]
804 fn test_pay_charge_strips_key_auth_when_provisioned() {
805 let signer = mpp::PrivateKeySigner::random();
806 let wallet = Address::repeat_byte(0xAA);
807 let signing_mode = TempoSigningMode::Keychain {
808 wallet,
809 key_authorization: Some(Box::new(test_key_authorization())),
810 version: KeychainVersion::V2,
811 };
812 let provider =
813 SessionProvider::new(signer, unique_origin()).with_signing_mode(signing_mode);
814
815 let provisioned = *provider.key_provisioned.lock().unwrap();
816 let result_mode = strip_key_auth_if_provisioned(&provider.signing_mode, provisioned);
817
818 assert!(
819 result_mode.key_authorization().is_none(),
820 "key_authorization should be stripped when key is provisioned"
821 );
822 }
823
824 #[test]
825 fn test_pay_charge_keeps_key_auth_when_not_provisioned() {
826 let signer = mpp::PrivateKeySigner::random();
827 let wallet = Address::repeat_byte(0xAA);
828 let signing_mode = TempoSigningMode::Keychain {
829 wallet,
830 key_authorization: Some(Box::new(test_key_authorization())),
831 version: KeychainVersion::V2,
832 };
833 let provider =
834 SessionProvider::new(signer, unique_origin()).with_signing_mode(signing_mode);
835
836 provider.set_key_provisioned(false);
837
838 let provisioned = *provider.key_provisioned.lock().unwrap();
839 let result_mode = strip_key_auth_if_provisioned(&provider.signing_mode, provisioned);
840
841 assert!(
842 result_mode.key_authorization().is_some(),
843 "key_authorization should be preserved when key is NOT provisioned"
844 );
845 }
846
847 #[test]
848 fn test_pay_charge_direct_mode_unaffected() {
849 let signer = mpp::PrivateKeySigner::random();
850 let provider = SessionProvider::new(signer, unique_origin())
851 .with_signing_mode(TempoSigningMode::Direct);
852
853 let provisioned = *provider.key_provisioned.lock().unwrap();
854 let result_mode = strip_key_auth_if_provisioned(&provider.signing_mode, provisioned);
855
856 assert!(
857 matches!(result_mode, TempoSigningMode::Direct),
858 "Direct mode should pass through unchanged"
859 );
860 }
861
862 #[tokio::test]
866 async fn test_concurrent_voucher_increments_are_unique() {
867 let channels: Arc<Mutex<HashMap<String, ChannelEntry>>> =
868 Arc::new(Mutex::new(HashMap::new()));
869 let key = "test-channel".to_string();
870 channels.lock().unwrap().insert(
871 key.clone(),
872 ChannelEntry {
873 channel_id: Default::default(),
874 salt: Default::default(),
875 cumulative_amount: 0,
876 escrow_contract: Address::ZERO,
877 chain_id: 42431,
878 opened: true,
879 },
880 );
881
882 let pay_lock = std::sync::Arc::new(tokio::sync::Mutex::new(()));
885 let amount: u128 = 1000;
886 let num_tasks = 20;
887 let results: Arc<Mutex<Vec<u128>>> = Arc::new(Mutex::new(Vec::new()));
888
889 let mut handles = Vec::new();
890 for _ in 0..num_tasks {
891 let channels = channels.clone();
892 let key = key.clone();
893 let results = results.clone();
894 let pay_lock = pay_lock.clone();
895 handles.push(tokio::spawn(async move {
896 let _guard = pay_lock.lock().await;
897 let cumulative = {
898 let mut ch = channels.lock().unwrap();
899 let entry = ch.get_mut(&key).unwrap();
900 entry.cumulative_amount += amount;
901 entry.cumulative_amount
902 };
903 results.lock().unwrap().push(cumulative);
904 }));
905 }
906
907 for h in handles {
908 h.await.unwrap();
909 }
910
911 let mut amounts = results.lock().unwrap().clone();
912 amounts.sort();
913 amounts.dedup();
914 assert_eq!(
915 amounts.len(),
916 num_tasks,
917 "each concurrent increment should produce a unique cumulative_amount"
918 );
919 assert_eq!(
920 *amounts.last().unwrap(),
921 amount * num_tasks as u128,
922 "final cumulative_amount should equal amount × num_tasks"
923 );
924 }
925}