1 // SPDX-License-Identifier: Apache-2.0 OR MIT
2 
3 // inspired by <https://github.com/nbdd0121/pin-init/blob/trunk/examples/pthread_mutex.rs>
4 #![allow(clippy::undocumented_unsafe_blocks)]
5 #![cfg_attr(feature = "alloc", feature(allocator_api))]
6 #[cfg(not(windows))]
7 mod pthread_mtx {
8     #[cfg(feature = "alloc")]
9     use core::alloc::AllocError;
10     use core::{
11         cell::UnsafeCell,
12         marker::PhantomPinned,
13         mem::MaybeUninit,
14         ops::{Deref, DerefMut},
15         pin::Pin,
16     };
17     use pin_init::*;
18     use std::convert::Infallible;
19 
20     #[pin_data(PinnedDrop)]
21     pub struct PThreadMutex<T> {
22         #[pin]
23         raw: UnsafeCell<libc::pthread_mutex_t>,
24         data: UnsafeCell<T>,
25         #[pin]
26         pin: PhantomPinned,
27     }
28 
29     unsafe impl<T: Send> Send for PThreadMutex<T> {}
30     unsafe impl<T: Send> Sync for PThreadMutex<T> {}
31 
32     #[pinned_drop]
33     impl<T> PinnedDrop for PThreadMutex<T> {
drop(self: Pin<&mut Self>)34         fn drop(self: Pin<&mut Self>) {
35             unsafe {
36                 libc::pthread_mutex_destroy(self.raw.get());
37             }
38         }
39     }
40 
41     #[derive(Debug)]
42     pub enum Error {
43         #[expect(dead_code)]
44         IO(std::io::Error),
45         Alloc,
46     }
47 
48     impl From<Infallible> for Error {
from(e: Infallible) -> Self49         fn from(e: Infallible) -> Self {
50             match e {}
51         }
52     }
53 
54     #[cfg(feature = "alloc")]
55     impl From<AllocError> for Error {
from(_: AllocError) -> Self56         fn from(_: AllocError) -> Self {
57             Self::Alloc
58         }
59     }
60 
61     impl<T> PThreadMutex<T> {
new(data: T) -> impl PinInit<Self, Error>62         pub fn new(data: T) -> impl PinInit<Self, Error> {
63             fn init_raw() -> impl PinInit<UnsafeCell<libc::pthread_mutex_t>, Error> {
64                 let init = |slot: *mut UnsafeCell<libc::pthread_mutex_t>| {
65                     // we can cast, because `UnsafeCell` has the same layout as T.
66                     let slot: *mut libc::pthread_mutex_t = slot.cast();
67                     let mut attr = MaybeUninit::uninit();
68                     let attr = attr.as_mut_ptr();
69                     // SAFETY: ptr is valid
70                     let ret = unsafe { libc::pthread_mutexattr_init(attr) };
71                     if ret != 0 {
72                         return Err(Error::IO(std::io::Error::from_raw_os_error(ret)));
73                     }
74                     // SAFETY: attr is initialized
75                     let ret = unsafe {
76                         libc::pthread_mutexattr_settype(attr, libc::PTHREAD_MUTEX_NORMAL)
77                     };
78                     if ret != 0 {
79                         // SAFETY: attr is initialized
80                         unsafe { libc::pthread_mutexattr_destroy(attr) };
81                         return Err(Error::IO(std::io::Error::from_raw_os_error(ret)));
82                     }
83                     // SAFETY: slot is valid
84                     unsafe { slot.write(libc::PTHREAD_MUTEX_INITIALIZER) };
85                     // SAFETY: attr and slot are valid ptrs and attr is initialized
86                     let ret = unsafe { libc::pthread_mutex_init(slot, attr) };
87                     // SAFETY: attr was initialized
88                     unsafe { libc::pthread_mutexattr_destroy(attr) };
89                     if ret != 0 {
90                         return Err(Error::IO(std::io::Error::from_raw_os_error(ret)));
91                     }
92                     Ok(())
93                 };
94                 // SAFETY: mutex has been initialized
95                 unsafe { pin_init_from_closure(init) }
96             }
97             try_pin_init!(Self {
98             data: UnsafeCell::new(data),
99             raw <- init_raw(),
100             pin: PhantomPinned,
101         }? Error)
102         }
103 
lock(&self) -> PThreadMutexGuard<'_, T>104         pub fn lock(&self) -> PThreadMutexGuard<'_, T> {
105             // SAFETY: raw is always initialized
106             unsafe { libc::pthread_mutex_lock(self.raw.get()) };
107             PThreadMutexGuard { mtx: self }
108         }
109     }
110 
111     pub struct PThreadMutexGuard<'a, T> {
112         mtx: &'a PThreadMutex<T>,
113     }
114 
115     impl<T> Drop for PThreadMutexGuard<'_, T> {
drop(&mut self)116         fn drop(&mut self) {
117             // SAFETY: raw is always initialized
118             unsafe { libc::pthread_mutex_unlock(self.mtx.raw.get()) };
119         }
120     }
121 
122     impl<T> Deref for PThreadMutexGuard<'_, T> {
123         type Target = T;
124 
deref(&self) -> &Self::Target125         fn deref(&self) -> &Self::Target {
126             unsafe { &*self.mtx.data.get() }
127         }
128     }
129 
130     impl<T> DerefMut for PThreadMutexGuard<'_, T> {
deref_mut(&mut self) -> &mut Self::Target131         fn deref_mut(&mut self) -> &mut Self::Target {
132             unsafe { &mut *self.mtx.data.get() }
133         }
134     }
135 }
136 
137 #[cfg_attr(test, test)]
main()138 fn main() {
139     #[cfg(all(any(feature = "std", feature = "alloc"), not(windows)))]
140     {
141         use core::pin::Pin;
142         use pin_init::*;
143         use pthread_mtx::*;
144         use std::{
145             sync::Arc,
146             thread::{sleep, Builder},
147             time::Duration,
148         };
149         let mtx: Pin<Arc<PThreadMutex<usize>>> = Arc::try_pin_init(PThreadMutex::new(0)).unwrap();
150         let mut handles = vec![];
151         let thread_count = 20;
152         let workload = 1_000_000;
153         for i in 0..thread_count {
154             let mtx = mtx.clone();
155             handles.push(
156                 Builder::new()
157                     .name(format!("worker #{i}"))
158                     .spawn(move || {
159                         for _ in 0..workload {
160                             *mtx.lock() += 1;
161                         }
162                         println!("{i} halfway");
163                         sleep(Duration::from_millis((i as u64) * 10));
164                         for _ in 0..workload {
165                             *mtx.lock() += 1;
166                         }
167                         println!("{i} finished");
168                     })
169                     .expect("should not fail"),
170             );
171         }
172         for h in handles {
173             h.join().expect("thread panicked");
174         }
175         println!("{:?}", &*mtx.lock());
176         assert_eq!(*mtx.lock(), workload * thread_count * 2);
177     }
178 }
179