1 // SPDX-License-Identifier: Apache-2.0 OR MIT
2
3 // inspired by <https://github.com/nbdd0121/pin-init/blob/trunk/examples/pthread_mutex.rs>
4 #![allow(clippy::undocumented_unsafe_blocks)]
5 #![cfg_attr(feature = "alloc", feature(allocator_api))]
6 #[cfg(not(windows))]
7 mod pthread_mtx {
8 #[cfg(feature = "alloc")]
9 use core::alloc::AllocError;
10 use core::{
11 cell::UnsafeCell,
12 marker::PhantomPinned,
13 mem::MaybeUninit,
14 ops::{Deref, DerefMut},
15 pin::Pin,
16 };
17 use pin_init::*;
18 use std::convert::Infallible;
19
20 #[pin_data(PinnedDrop)]
21 pub struct PThreadMutex<T> {
22 #[pin]
23 raw: UnsafeCell<libc::pthread_mutex_t>,
24 data: UnsafeCell<T>,
25 #[pin]
26 pin: PhantomPinned,
27 }
28
29 unsafe impl<T: Send> Send for PThreadMutex<T> {}
30 unsafe impl<T: Send> Sync for PThreadMutex<T> {}
31
32 #[pinned_drop]
33 impl<T> PinnedDrop for PThreadMutex<T> {
drop(self: Pin<&mut Self>)34 fn drop(self: Pin<&mut Self>) {
35 unsafe {
36 libc::pthread_mutex_destroy(self.raw.get());
37 }
38 }
39 }
40
41 #[derive(Debug)]
42 pub enum Error {
43 #[expect(dead_code)]
44 IO(std::io::Error),
45 Alloc,
46 }
47
48 impl From<Infallible> for Error {
from(e: Infallible) -> Self49 fn from(e: Infallible) -> Self {
50 match e {}
51 }
52 }
53
54 #[cfg(feature = "alloc")]
55 impl From<AllocError> for Error {
from(_: AllocError) -> Self56 fn from(_: AllocError) -> Self {
57 Self::Alloc
58 }
59 }
60
61 impl<T> PThreadMutex<T> {
new(data: T) -> impl PinInit<Self, Error>62 pub fn new(data: T) -> impl PinInit<Self, Error> {
63 fn init_raw() -> impl PinInit<UnsafeCell<libc::pthread_mutex_t>, Error> {
64 let init = |slot: *mut UnsafeCell<libc::pthread_mutex_t>| {
65 // we can cast, because `UnsafeCell` has the same layout as T.
66 let slot: *mut libc::pthread_mutex_t = slot.cast();
67 let mut attr = MaybeUninit::uninit();
68 let attr = attr.as_mut_ptr();
69 // SAFETY: ptr is valid
70 let ret = unsafe { libc::pthread_mutexattr_init(attr) };
71 if ret != 0 {
72 return Err(Error::IO(std::io::Error::from_raw_os_error(ret)));
73 }
74 // SAFETY: attr is initialized
75 let ret = unsafe {
76 libc::pthread_mutexattr_settype(attr, libc::PTHREAD_MUTEX_NORMAL)
77 };
78 if ret != 0 {
79 // SAFETY: attr is initialized
80 unsafe { libc::pthread_mutexattr_destroy(attr) };
81 return Err(Error::IO(std::io::Error::from_raw_os_error(ret)));
82 }
83 // SAFETY: slot is valid
84 unsafe { slot.write(libc::PTHREAD_MUTEX_INITIALIZER) };
85 // SAFETY: attr and slot are valid ptrs and attr is initialized
86 let ret = unsafe { libc::pthread_mutex_init(slot, attr) };
87 // SAFETY: attr was initialized
88 unsafe { libc::pthread_mutexattr_destroy(attr) };
89 if ret != 0 {
90 return Err(Error::IO(std::io::Error::from_raw_os_error(ret)));
91 }
92 Ok(())
93 };
94 // SAFETY: mutex has been initialized
95 unsafe { pin_init_from_closure(init) }
96 }
97 try_pin_init!(Self {
98 data: UnsafeCell::new(data),
99 raw <- init_raw(),
100 pin: PhantomPinned,
101 }? Error)
102 }
103
lock(&self) -> PThreadMutexGuard<'_, T>104 pub fn lock(&self) -> PThreadMutexGuard<'_, T> {
105 // SAFETY: raw is always initialized
106 unsafe { libc::pthread_mutex_lock(self.raw.get()) };
107 PThreadMutexGuard { mtx: self }
108 }
109 }
110
111 pub struct PThreadMutexGuard<'a, T> {
112 mtx: &'a PThreadMutex<T>,
113 }
114
115 impl<T> Drop for PThreadMutexGuard<'_, T> {
drop(&mut self)116 fn drop(&mut self) {
117 // SAFETY: raw is always initialized
118 unsafe { libc::pthread_mutex_unlock(self.mtx.raw.get()) };
119 }
120 }
121
122 impl<T> Deref for PThreadMutexGuard<'_, T> {
123 type Target = T;
124
deref(&self) -> &Self::Target125 fn deref(&self) -> &Self::Target {
126 unsafe { &*self.mtx.data.get() }
127 }
128 }
129
130 impl<T> DerefMut for PThreadMutexGuard<'_, T> {
deref_mut(&mut self) -> &mut Self::Target131 fn deref_mut(&mut self) -> &mut Self::Target {
132 unsafe { &mut *self.mtx.data.get() }
133 }
134 }
135 }
136
137 #[cfg_attr(test, test)]
main()138 fn main() {
139 #[cfg(all(any(feature = "std", feature = "alloc"), not(windows)))]
140 {
141 use core::pin::Pin;
142 use pin_init::*;
143 use pthread_mtx::*;
144 use std::{
145 sync::Arc,
146 thread::{sleep, Builder},
147 time::Duration,
148 };
149 let mtx: Pin<Arc<PThreadMutex<usize>>> = Arc::try_pin_init(PThreadMutex::new(0)).unwrap();
150 let mut handles = vec![];
151 let thread_count = 20;
152 let workload = 1_000_000;
153 for i in 0..thread_count {
154 let mtx = mtx.clone();
155 handles.push(
156 Builder::new()
157 .name(format!("worker #{i}"))
158 .spawn(move || {
159 for _ in 0..workload {
160 *mtx.lock() += 1;
161 }
162 println!("{i} halfway");
163 sleep(Duration::from_millis((i as u64) * 10));
164 for _ in 0..workload {
165 *mtx.lock() += 1;
166 }
167 println!("{i} finished");
168 })
169 .expect("should not fail"),
170 );
171 }
172 for h in handles {
173 h.join().expect("thread panicked");
174 }
175 println!("{:?}", &*mtx.lock());
176 assert_eq!(*mtx.lock(), workload * thread_count * 2);
177 }
178 }
179