1 // SPDX-License-Identifier: Apache-2.0 OR MIT
2 
3 #![allow(clippy::undocumented_unsafe_blocks)]
4 #![cfg_attr(feature = "alloc", feature(allocator_api))]
5 #![allow(clippy::missing_safety_doc)]
6 
7 use core::{
8     cell::{Cell, UnsafeCell},
9     marker::PhantomPinned,
10     ops::{Deref, DerefMut},
11     pin::Pin,
12     sync::atomic::{AtomicBool, Ordering},
13 };
14 use std::{
15     sync::Arc,
16     thread::{self, park, sleep, Builder, Thread},
17     time::Duration,
18 };
19 
20 use pin_init::*;
21 #[expect(unused_attributes)]
22 #[path = "./linked_list.rs"]
23 pub mod linked_list;
24 use linked_list::*;
25 
26 pub struct SpinLock {
27     inner: AtomicBool,
28 }
29 
30 impl SpinLock {
31     #[inline]
acquire(&self) -> SpinLockGuard<'_>32     pub fn acquire(&self) -> SpinLockGuard<'_> {
33         while self
34             .inner
35             .compare_exchange(false, true, Ordering::Acquire, Ordering::Relaxed)
36             .is_err()
37         {
38             while self.inner.load(Ordering::Relaxed) {
39                 thread::yield_now();
40             }
41         }
42         SpinLockGuard(self)
43     }
44 
45     #[inline]
46     #[allow(clippy::new_without_default)]
new() -> Self47     pub const fn new() -> Self {
48         Self {
49             inner: AtomicBool::new(false),
50         }
51     }
52 }
53 
54 pub struct SpinLockGuard<'a>(&'a SpinLock);
55 
56 impl Drop for SpinLockGuard<'_> {
57     #[inline]
drop(&mut self)58     fn drop(&mut self) {
59         self.0.inner.store(false, Ordering::Release);
60     }
61 }
62 
63 #[pin_data]
64 pub struct CMutex<T> {
65     #[pin]
66     wait_list: ListHead,
67     spin_lock: SpinLock,
68     locked: Cell<bool>,
69     #[pin]
70     data: UnsafeCell<T>,
71 }
72 
73 impl<T> CMutex<T> {
74     #[inline]
new(val: impl PinInit<T>) -> impl PinInit<Self>75     pub fn new(val: impl PinInit<T>) -> impl PinInit<Self> {
76         pin_init!(CMutex {
77             wait_list <- ListHead::new(),
78             spin_lock: SpinLock::new(),
79             locked: Cell::new(false),
80             data <- unsafe {
81                 pin_init_from_closure(|slot: *mut UnsafeCell<T>| {
82                     val.__pinned_init(slot.cast::<T>())
83                 })
84             },
85         })
86     }
87 
88     #[inline]
lock(&self) -> Pin<CMutexGuard<'_, T>>89     pub fn lock(&self) -> Pin<CMutexGuard<'_, T>> {
90         let mut sguard = self.spin_lock.acquire();
91         if self.locked.get() {
92             stack_pin_init!(let wait_entry = WaitEntry::insert_new(&self.wait_list));
93             // println!("wait list length: {}", self.wait_list.size());
94             while self.locked.get() {
95                 drop(sguard);
96                 park();
97                 sguard = self.spin_lock.acquire();
98             }
99             // This does have an effect, as the ListHead inside wait_entry implements Drop!
100             #[expect(clippy::drop_non_drop)]
101             drop(wait_entry);
102         }
103         self.locked.set(true);
104         unsafe {
105             Pin::new_unchecked(CMutexGuard {
106                 mtx: self,
107                 _pin: PhantomPinned,
108             })
109         }
110     }
111 
112     #[allow(dead_code)]
get_data_mut(self: Pin<&mut Self>) -> &mut T113     pub fn get_data_mut(self: Pin<&mut Self>) -> &mut T {
114         // SAFETY: we have an exclusive reference and thus nobody has access to data.
115         unsafe { &mut *self.data.get() }
116     }
117 }
118 
119 unsafe impl<T: Send> Send for CMutex<T> {}
120 unsafe impl<T: Send> Sync for CMutex<T> {}
121 
122 pub struct CMutexGuard<'a, T> {
123     mtx: &'a CMutex<T>,
124     _pin: PhantomPinned,
125 }
126 
127 impl<T> Drop for CMutexGuard<'_, T> {
128     #[inline]
drop(&mut self)129     fn drop(&mut self) {
130         let sguard = self.mtx.spin_lock.acquire();
131         self.mtx.locked.set(false);
132         if let Some(list_field) = self.mtx.wait_list.next() {
133             let wait_entry = list_field.as_ptr().cast::<WaitEntry>();
134             unsafe { (*wait_entry).thread.unpark() };
135         }
136         drop(sguard);
137     }
138 }
139 
140 impl<T> Deref for CMutexGuard<'_, T> {
141     type Target = T;
142 
143     #[inline]
deref(&self) -> &Self::Target144     fn deref(&self) -> &Self::Target {
145         unsafe { &*self.mtx.data.get() }
146     }
147 }
148 
149 impl<T> DerefMut for CMutexGuard<'_, T> {
150     #[inline]
deref_mut(&mut self) -> &mut Self::Target151     fn deref_mut(&mut self) -> &mut Self::Target {
152         unsafe { &mut *self.mtx.data.get() }
153     }
154 }
155 
156 #[pin_data]
157 #[repr(C)]
158 struct WaitEntry {
159     #[pin]
160     wait_list: ListHead,
161     thread: Thread,
162 }
163 
164 impl WaitEntry {
165     #[inline]
insert_new(list: &ListHead) -> impl PinInit<Self> + '_166     fn insert_new(list: &ListHead) -> impl PinInit<Self> + '_ {
167         pin_init!(Self {
168             thread: thread::current(),
169             wait_list <- ListHead::insert_prev(list),
170         })
171     }
172 }
173 
174 #[cfg(not(any(feature = "std", feature = "alloc")))]
main()175 fn main() {}
176 
177 #[allow(dead_code)]
178 #[cfg_attr(test, test)]
179 #[cfg(any(feature = "std", feature = "alloc"))]
main()180 fn main() {
181     let mtx: Pin<Arc<CMutex<usize>>> = Arc::pin_init(CMutex::new(0)).unwrap();
182     let mut handles = vec![];
183     let thread_count = 20;
184     let workload = if cfg!(miri) { 100 } else { 1_000 };
185     for i in 0..thread_count {
186         let mtx = mtx.clone();
187         handles.push(
188             Builder::new()
189                 .name(format!("worker #{i}"))
190                 .spawn(move || {
191                     for _ in 0..workload {
192                         *mtx.lock() += 1;
193                     }
194                     println!("{i} halfway");
195                     sleep(Duration::from_millis((i as u64) * 10));
196                     for _ in 0..workload {
197                         *mtx.lock() += 1;
198                     }
199                     println!("{i} finished");
200                 })
201                 .expect("should not fail"),
202         );
203     }
204     for h in handles {
205         h.join().expect("thread panicked");
206     }
207     println!("{:?}", &*mtx.lock());
208     assert_eq!(*mtx.lock(), workload * thread_count * 2);
209 }
210