cast/cmd/wallet/
process_tree.rs1use std::{
2 io,
3 process::{Command, ExitStatus},
4 time::Duration,
5};
6
7use tokio::process::{Child, Command as TokioCommand};
8
9const TERMINATION_GRACE: Duration = Duration::from_millis(250);
10
11pub(super) struct ManagedChild {
12 child: Child,
13 group: PlatformProcessGroup,
14 waited: bool,
15}
16
17impl ManagedChild {
18 pub(super) fn spawn(command: Command) -> io::Result<Self> {
19 let mut command = TokioCommand::from(command);
20 let group = PlatformProcessGroup::configure(&mut command)?;
21 let child = command.kill_on_drop(true).spawn()?;
22 let group = group.attach(&child)?;
23 Ok(Self { child, group, waited: false })
24 }
25
26 pub(super) async fn wait(&mut self) -> io::Result<ExitStatus> {
27 let status = self.child.wait().await?;
28 self.waited = true;
29 Ok(status)
30 }
31
32 pub(super) async fn terminate_tree(&mut self) -> io::Result<()> {
33 if self.group.terminate()? {
34 tokio::time::sleep(TERMINATION_GRACE).await;
35 self.group.kill()?;
36 }
37
38 if self.waited {
39 return Ok(());
40 }
41
42 match tokio::time::timeout(TERMINATION_GRACE, self.child.wait()).await {
43 Ok(result) => {
44 self.waited = true;
45 result.map(|_| ())
46 }
47 Err(_) => {
48 self.child.start_kill()?;
49 self.child.wait().await?;
50 self.waited = true;
51 Ok(())
52 }
53 }
54 }
55}
56
57#[cfg(unix)]
58struct PlatformProcessGroup {
59 pgid: Option<libc::pid_t>,
60}
61
62#[cfg(unix)]
63impl PlatformProcessGroup {
64 fn configure(command: &mut TokioCommand) -> io::Result<Self> {
65 command.process_group(0);
66 Ok(Self { pgid: None })
67 }
68
69 fn attach(mut self, child: &Child) -> io::Result<Self> {
70 self.pgid = child.id().map(|id| id as libc::pid_t);
71 Ok(self)
72 }
73
74 fn terminate(&mut self) -> io::Result<bool> {
75 let Some(pgid) = self.pgid else {
76 return Ok(false);
77 };
78 signal_process_group(pgid, libc::SIGTERM)
79 }
80
81 fn kill(&mut self) -> io::Result<()> {
82 let Some(pgid) = self.pgid.take() else {
83 return Ok(());
84 };
85 signal_process_group(pgid, libc::SIGKILL).map(|_| ())
86 }
87}
88
89#[cfg(unix)]
90fn signal_process_group(pgid: libc::pid_t, signal: libc::c_int) -> io::Result<bool> {
91 let rc = unsafe { libc::kill(-pgid, signal) };
93 if rc == 0 {
94 Ok(true)
95 } else {
96 let err = io::Error::last_os_error();
97 if err.raw_os_error() == Some(libc::ESRCH) { Ok(false) } else { Err(err) }
98 }
99}
100
101#[cfg(windows)]
102struct PlatformProcessGroup {
103 job: Option<WindowsJob>,
104}
105
106#[cfg(windows)]
107impl PlatformProcessGroup {
108 fn configure(_command: &mut TokioCommand) -> io::Result<Self> {
109 Ok(Self { job: Some(WindowsJob::new()?) })
110 }
111
112 fn attach(self, child: &Child) -> io::Result<Self> {
113 if let Some(job) = &self.job {
114 let handle = child.raw_handle().ok_or_else(|| {
115 io::Error::new(io::ErrorKind::Other, "session child exited before job assignment")
116 })?;
117 job.assign_process(handle)?;
118 }
119 Ok(self)
120 }
121
122 fn terminate(&mut self) -> io::Result<bool> {
123 if let Some(job) = self.job.take() {
124 job.terminate()?;
125 return Ok(true);
126 }
127 Ok(false)
128 }
129
130 fn kill(&mut self) -> io::Result<()> {
131 Ok(())
132 }
133}
134
135#[cfg(windows)]
136struct WindowsJob {
137 handle: windows_sys::Win32::Foundation::HANDLE,
138}
139
140#[cfg(windows)]
141impl WindowsJob {
142 fn new() -> io::Result<Self> {
143 use windows_sys::Win32::{
144 Foundation::INVALID_HANDLE_VALUE,
145 System::JobObjects::{
146 CreateJobObjectW, JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE,
147 JOBOBJECT_EXTENDED_LIMIT_INFORMATION, JobObjectExtendedLimitInformation,
148 SetInformationJobObject,
149 },
150 };
151
152 let handle = unsafe { CreateJobObjectW(std::ptr::null(), std::ptr::null()) };
153 if handle.is_null() || handle == INVALID_HANDLE_VALUE {
154 return Err(io::Error::last_os_error());
155 }
156
157 let mut limits = JOBOBJECT_EXTENDED_LIMIT_INFORMATION::default();
158 limits.BasicLimitInformation.LimitFlags = JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE;
159 let ok = unsafe {
160 SetInformationJobObject(
161 handle,
162 JobObjectExtendedLimitInformation,
163 std::ptr::addr_of!(limits).cast(),
164 std::mem::size_of_val(&limits) as u32,
165 )
166 };
167 if ok == 0 {
168 let err = io::Error::last_os_error();
169 unsafe {
170 windows_sys::Win32::Foundation::CloseHandle(handle);
171 }
172 return Err(err);
173 }
174
175 Ok(Self { handle })
176 }
177
178 fn assign_process(&self, process: std::os::windows::io::RawHandle) -> io::Result<()> {
179 use windows_sys::Win32::System::JobObjects::AssignProcessToJobObject;
180
181 let ok = unsafe { AssignProcessToJobObject(self.handle, process.cast()) };
182 if ok == 0 { Err(io::Error::last_os_error()) } else { Ok(()) }
183 }
184
185 fn terminate(self) -> io::Result<()> {
186 use windows_sys::Win32::System::JobObjects::TerminateJobObject;
187
188 let ok = unsafe { TerminateJobObject(self.handle, 1) };
189 let err = if ok == 0 { Some(io::Error::last_os_error()) } else { None };
190 drop(self);
191 match err {
192 Some(err) => Err(err),
193 None => Ok(()),
194 }
195 }
196}
197
198#[cfg(windows)]
199impl Drop for WindowsJob {
200 fn drop(&mut self) {
201 unsafe {
202 windows_sys::Win32::Foundation::CloseHandle(self.handle);
203 }
204 }
205}
206
207#[cfg(not(any(unix, windows)))]
208struct PlatformProcessGroup;
209
210#[cfg(not(any(unix, windows)))]
211impl PlatformProcessGroup {
212 fn configure(_command: &mut TokioCommand) -> io::Result<Self> {
213 Ok(Self)
214 }
215
216 fn attach(self, _child: &Child) -> io::Result<Self> {
217 Ok(self)
218 }
219
220 fn terminate(&mut self) -> io::Result<bool> {
221 Ok(false)
222 }
223
224 fn kill(&mut self) -> io::Result<()> {
225 Ok(())
226 }
227}
228
229#[cfg(all(test, unix))]
230mod tests {
231 use super::*;
232
233 #[test]
234 fn cleanup_terminates_background_grandchild() {
235 let runtime = tokio::runtime::Runtime::new().unwrap();
236 runtime.block_on(async {
237 let tmp = tempfile::tempdir().unwrap();
238 let marker = tmp.path().join("session-child-leaked");
239 let mut command = Command::new("sh");
240 command.args([
241 "-c",
242 "(sleep 1; touch \"$1\") &",
243 "session-child",
244 &marker.to_string_lossy(),
245 ]);
246
247 let mut child = ManagedChild::spawn(command).unwrap();
248 child.wait().await.unwrap();
249 child.terminate_tree().await.unwrap();
250 tokio::time::sleep(Duration::from_millis(1200)).await;
251
252 assert!(!marker.exists(), "background grandchild escaped session cleanup");
253 });
254 }
255}