Skip to main content

cast/cmd/wallet/
process_tree.rs

1use std::{
2    io,
3    process::{Command, ExitStatus},
4    time::Duration,
5};
6
7use tokio::process::{Child, Command as TokioCommand};
8
9const TERMINATION_GRACE: Duration = Duration::from_millis(250);
10
11pub(super) struct ManagedChild {
12    child: Child,
13    group: PlatformProcessGroup,
14    waited: bool,
15}
16
17impl ManagedChild {
18    pub(super) fn spawn(command: Command) -> io::Result<Self> {
19        let mut command = TokioCommand::from(command);
20        let group = PlatformProcessGroup::configure(&mut command)?;
21        let child = command.kill_on_drop(true).spawn()?;
22        let group = group.attach(&child)?;
23        Ok(Self { child, group, waited: false })
24    }
25
26    pub(super) async fn wait(&mut self) -> io::Result<ExitStatus> {
27        let status = self.child.wait().await?;
28        self.waited = true;
29        Ok(status)
30    }
31
32    pub(super) async fn terminate_tree(&mut self) -> io::Result<()> {
33        if self.group.terminate()? {
34            tokio::time::sleep(TERMINATION_GRACE).await;
35            self.group.kill()?;
36        }
37
38        if self.waited {
39            return Ok(());
40        }
41
42        match tokio::time::timeout(TERMINATION_GRACE, self.child.wait()).await {
43            Ok(result) => {
44                self.waited = true;
45                result.map(|_| ())
46            }
47            Err(_) => {
48                self.child.start_kill()?;
49                self.child.wait().await?;
50                self.waited = true;
51                Ok(())
52            }
53        }
54    }
55}
56
57#[cfg(unix)]
58struct PlatformProcessGroup {
59    pgid: Option<libc::pid_t>,
60}
61
62#[cfg(unix)]
63impl PlatformProcessGroup {
64    fn configure(command: &mut TokioCommand) -> io::Result<Self> {
65        command.process_group(0);
66        Ok(Self { pgid: None })
67    }
68
69    fn attach(mut self, child: &Child) -> io::Result<Self> {
70        self.pgid = child.id().map(|id| id as libc::pid_t);
71        Ok(self)
72    }
73
74    fn terminate(&mut self) -> io::Result<bool> {
75        let Some(pgid) = self.pgid else {
76            return Ok(false);
77        };
78        signal_process_group(pgid, libc::SIGTERM)
79    }
80
81    fn kill(&mut self) -> io::Result<()> {
82        let Some(pgid) = self.pgid.take() else {
83            return Ok(());
84        };
85        signal_process_group(pgid, libc::SIGKILL).map(|_| ())
86    }
87}
88
89#[cfg(unix)]
90fn signal_process_group(pgid: libc::pid_t, signal: libc::c_int) -> io::Result<bool> {
91    // SAFETY: negative pid targets the process group created for the child process.
92    let rc = unsafe { libc::kill(-pgid, signal) };
93    if rc == 0 {
94        Ok(true)
95    } else {
96        let err = io::Error::last_os_error();
97        if err.raw_os_error() == Some(libc::ESRCH) { Ok(false) } else { Err(err) }
98    }
99}
100
101#[cfg(windows)]
102struct PlatformProcessGroup {
103    job: Option<WindowsJob>,
104}
105
106#[cfg(windows)]
107impl PlatformProcessGroup {
108    fn configure(_command: &mut TokioCommand) -> io::Result<Self> {
109        Ok(Self { job: Some(WindowsJob::new()?) })
110    }
111
112    fn attach(self, child: &Child) -> io::Result<Self> {
113        if let Some(job) = &self.job {
114            let handle = child.raw_handle().ok_or_else(|| {
115                io::Error::new(io::ErrorKind::Other, "session child exited before job assignment")
116            })?;
117            job.assign_process(handle)?;
118        }
119        Ok(self)
120    }
121
122    fn terminate(&mut self) -> io::Result<bool> {
123        if let Some(job) = self.job.take() {
124            job.terminate()?;
125            return Ok(true);
126        }
127        Ok(false)
128    }
129
130    fn kill(&mut self) -> io::Result<()> {
131        Ok(())
132    }
133}
134
135#[cfg(windows)]
136struct WindowsJob {
137    handle: windows_sys::Win32::Foundation::HANDLE,
138}
139
140#[cfg(windows)]
141impl WindowsJob {
142    fn new() -> io::Result<Self> {
143        use windows_sys::Win32::{
144            Foundation::INVALID_HANDLE_VALUE,
145            System::JobObjects::{
146                CreateJobObjectW, JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE,
147                JOBOBJECT_EXTENDED_LIMIT_INFORMATION, JobObjectExtendedLimitInformation,
148                SetInformationJobObject,
149            },
150        };
151
152        let handle = unsafe { CreateJobObjectW(std::ptr::null(), std::ptr::null()) };
153        if handle.is_null() || handle == INVALID_HANDLE_VALUE {
154            return Err(io::Error::last_os_error());
155        }
156
157        let mut limits = JOBOBJECT_EXTENDED_LIMIT_INFORMATION::default();
158        limits.BasicLimitInformation.LimitFlags = JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE;
159        let ok = unsafe {
160            SetInformationJobObject(
161                handle,
162                JobObjectExtendedLimitInformation,
163                std::ptr::addr_of!(limits).cast(),
164                std::mem::size_of_val(&limits) as u32,
165            )
166        };
167        if ok == 0 {
168            let err = io::Error::last_os_error();
169            unsafe {
170                windows_sys::Win32::Foundation::CloseHandle(handle);
171            }
172            return Err(err);
173        }
174
175        Ok(Self { handle })
176    }
177
178    fn assign_process(&self, process: std::os::windows::io::RawHandle) -> io::Result<()> {
179        use windows_sys::Win32::System::JobObjects::AssignProcessToJobObject;
180
181        let ok = unsafe { AssignProcessToJobObject(self.handle, process.cast()) };
182        if ok == 0 { Err(io::Error::last_os_error()) } else { Ok(()) }
183    }
184
185    fn terminate(self) -> io::Result<()> {
186        use windows_sys::Win32::System::JobObjects::TerminateJobObject;
187
188        let ok = unsafe { TerminateJobObject(self.handle, 1) };
189        let err = if ok == 0 { Some(io::Error::last_os_error()) } else { None };
190        drop(self);
191        match err {
192            Some(err) => Err(err),
193            None => Ok(()),
194        }
195    }
196}
197
198#[cfg(windows)]
199impl Drop for WindowsJob {
200    fn drop(&mut self) {
201        unsafe {
202            windows_sys::Win32::Foundation::CloseHandle(self.handle);
203        }
204    }
205}
206
207#[cfg(not(any(unix, windows)))]
208struct PlatformProcessGroup;
209
210#[cfg(not(any(unix, windows)))]
211impl PlatformProcessGroup {
212    fn configure(_command: &mut TokioCommand) -> io::Result<Self> {
213        Ok(Self)
214    }
215
216    fn attach(self, _child: &Child) -> io::Result<Self> {
217        Ok(self)
218    }
219
220    fn terminate(&mut self) -> io::Result<bool> {
221        Ok(false)
222    }
223
224    fn kill(&mut self) -> io::Result<()> {
225        Ok(())
226    }
227}
228
229#[cfg(all(test, unix))]
230mod tests {
231    use super::*;
232
233    #[test]
234    fn cleanup_terminates_background_grandchild() {
235        let runtime = tokio::runtime::Runtime::new().unwrap();
236        runtime.block_on(async {
237            let tmp = tempfile::tempdir().unwrap();
238            let marker = tmp.path().join("session-child-leaked");
239            let mut command = Command::new("sh");
240            command.args([
241                "-c",
242                "(sleep 1; touch \"$1\") &",
243                "session-child",
244                &marker.to_string_lossy(),
245            ]);
246
247            let mut child = ManagedChild::spawn(command).unwrap();
248            child.wait().await.unwrap();
249            child.terminate_tree().await.unwrap();
250            tokio::time::sleep(Duration::from_millis(1200)).await;
251
252            assert!(!marker.exists(), "background grandchild escaped session cleanup");
253        });
254    }
255}