1 // SPDX-License-Identifier: Apache-2.0 OR MIT 2 3 // inspired by <https://github.com/nbdd0121/pin-init/blob/trunk/examples/pthread_mutex.rs> 4 #![allow(clippy::undocumented_unsafe_blocks)] 5 #![cfg_attr(feature = "alloc", feature(allocator_api))] 6 #[cfg(not(windows))] 7 mod pthread_mtx { 8 #[cfg(feature = "alloc")] 9 use core::alloc::AllocError; 10 use core::{ 11 cell::UnsafeCell, 12 marker::PhantomPinned, 13 mem::MaybeUninit, 14 ops::{Deref, DerefMut}, 15 pin::Pin, 16 }; 17 use pin_init::*; 18 use std::convert::Infallible; 19 20 #[pin_data(PinnedDrop)] 21 pub struct PThreadMutex<T> { 22 #[pin] 23 raw: UnsafeCell<libc::pthread_mutex_t>, 24 data: UnsafeCell<T>, 25 #[pin] 26 pin: PhantomPinned, 27 } 28 29 unsafe impl<T: Send> Send for PThreadMutex<T> {} 30 unsafe impl<T: Send> Sync for PThreadMutex<T> {} 31 32 #[pinned_drop] 33 impl<T> PinnedDrop for PThreadMutex<T> { 34 fn drop(self: Pin<&mut Self>) { 35 unsafe { 36 libc::pthread_mutex_destroy(self.raw.get()); 37 } 38 } 39 } 40 41 #[derive(Debug)] 42 pub enum Error { 43 #[expect(dead_code)] 44 IO(std::io::Error), 45 Alloc, 46 } 47 48 impl From<Infallible> for Error { 49 fn from(e: Infallible) -> Self { 50 match e {} 51 } 52 } 53 54 #[cfg(feature = "alloc")] 55 impl From<AllocError> for Error { 56 fn from(_: AllocError) -> Self { 57 Self::Alloc 58 } 59 } 60 61 impl<T> PThreadMutex<T> { 62 pub fn new(data: T) -> impl PinInit<Self, Error> { 63 fn init_raw() -> impl PinInit<UnsafeCell<libc::pthread_mutex_t>, Error> { 64 let init = |slot: *mut UnsafeCell<libc::pthread_mutex_t>| { 65 // we can cast, because `UnsafeCell` has the same layout as T. 66 let slot: *mut libc::pthread_mutex_t = slot.cast(); 67 let mut attr = MaybeUninit::uninit(); 68 let attr = attr.as_mut_ptr(); 69 // SAFETY: ptr is valid 70 let ret = unsafe { libc::pthread_mutexattr_init(attr) }; 71 if ret != 0 { 72 return Err(Error::IO(std::io::Error::from_raw_os_error(ret))); 73 } 74 // SAFETY: attr is initialized 75 let ret = unsafe { 76 libc::pthread_mutexattr_settype(attr, libc::PTHREAD_MUTEX_NORMAL) 77 }; 78 if ret != 0 { 79 // SAFETY: attr is initialized 80 unsafe { libc::pthread_mutexattr_destroy(attr) }; 81 return Err(Error::IO(std::io::Error::from_raw_os_error(ret))); 82 } 83 // SAFETY: slot is valid 84 unsafe { slot.write(libc::PTHREAD_MUTEX_INITIALIZER) }; 85 // SAFETY: attr and slot are valid ptrs and attr is initialized 86 let ret = unsafe { libc::pthread_mutex_init(slot, attr) }; 87 // SAFETY: attr was initialized 88 unsafe { libc::pthread_mutexattr_destroy(attr) }; 89 if ret != 0 { 90 return Err(Error::IO(std::io::Error::from_raw_os_error(ret))); 91 } 92 Ok(()) 93 }; 94 // SAFETY: mutex has been initialized 95 unsafe { pin_init_from_closure(init) } 96 } 97 try_pin_init!(Self { 98 data: UnsafeCell::new(data), 99 raw <- init_raw(), 100 pin: PhantomPinned, 101 }? Error) 102 } 103 104 pub fn lock(&self) -> PThreadMutexGuard<'_, T> { 105 // SAFETY: raw is always initialized 106 unsafe { libc::pthread_mutex_lock(self.raw.get()) }; 107 PThreadMutexGuard { mtx: self } 108 } 109 } 110 111 pub struct PThreadMutexGuard<'a, T> { 112 mtx: &'a PThreadMutex<T>, 113 } 114 115 impl<T> Drop for PThreadMutexGuard<'_, T> { 116 fn drop(&mut self) { 117 // SAFETY: raw is always initialized 118 unsafe { libc::pthread_mutex_unlock(self.mtx.raw.get()) }; 119 } 120 } 121 122 impl<T> Deref for PThreadMutexGuard<'_, T> { 123 type Target = T; 124 125 fn deref(&self) -> &Self::Target { 126 unsafe { &*self.mtx.data.get() } 127 } 128 } 129 130 impl<T> DerefMut for PThreadMutexGuard<'_, T> { 131 fn deref_mut(&mut self) -> &mut Self::Target { 132 unsafe { &mut *self.mtx.data.get() } 133 } 134 } 135 } 136 137 #[cfg_attr(test, test)] 138 fn main() { 139 #[cfg(all(any(feature = "std", feature = "alloc"), not(windows)))] 140 { 141 use core::pin::Pin; 142 use pin_init::*; 143 use pthread_mtx::*; 144 use std::{ 145 sync::Arc, 146 thread::{sleep, Builder}, 147 time::Duration, 148 }; 149 let mtx: Pin<Arc<PThreadMutex<usize>>> = Arc::try_pin_init(PThreadMutex::new(0)).unwrap(); 150 let mut handles = vec![]; 151 let thread_count = 20; 152 let workload = 1_000_000; 153 for i in 0..thread_count { 154 let mtx = mtx.clone(); 155 handles.push( 156 Builder::new() 157 .name(format!("worker #{i}")) 158 .spawn(move || { 159 for _ in 0..workload { 160 *mtx.lock() += 1; 161 } 162 println!("{i} halfway"); 163 sleep(Duration::from_millis((i as u64) * 10)); 164 for _ in 0..workload { 165 *mtx.lock() += 1; 166 } 167 println!("{i} finished"); 168 }) 169 .expect("should not fail"), 170 ); 171 } 172 for h in handles { 173 h.join().expect("thread panicked"); 174 } 175 println!("{:?}", &*mtx.lock()); 176 assert_eq!(*mtx.lock(), workload * thread_count * 2); 177 } 178 } 179