1 // SPDX-License-Identifier: Apache-2.0 OR MIT
2
3 #![allow(clippy::undocumented_unsafe_blocks)]
4 #![cfg_attr(feature = "alloc", feature(allocator_api))]
5 #![allow(clippy::missing_safety_doc)]
6
7 use core::{
8 cell::{Cell, UnsafeCell},
9 marker::PhantomPinned,
10 ops::{Deref, DerefMut},
11 pin::Pin,
12 sync::atomic::{AtomicBool, Ordering},
13 };
14 use std::{
15 sync::Arc,
16 thread::{self, park, sleep, Builder, Thread},
17 time::Duration,
18 };
19
20 use pin_init::*;
21 #[expect(unused_attributes)]
22 #[path = "./linked_list.rs"]
23 pub mod linked_list;
24 use linked_list::*;
25
26 pub struct SpinLock {
27 inner: AtomicBool,
28 }
29
30 impl SpinLock {
31 #[inline]
acquire(&self) -> SpinLockGuard<'_>32 pub fn acquire(&self) -> SpinLockGuard<'_> {
33 while self
34 .inner
35 .compare_exchange(false, true, Ordering::Acquire, Ordering::Relaxed)
36 .is_err()
37 {
38 while self.inner.load(Ordering::Relaxed) {
39 thread::yield_now();
40 }
41 }
42 SpinLockGuard(self)
43 }
44
45 #[inline]
46 #[allow(clippy::new_without_default)]
new() -> Self47 pub const fn new() -> Self {
48 Self {
49 inner: AtomicBool::new(false),
50 }
51 }
52 }
53
54 pub struct SpinLockGuard<'a>(&'a SpinLock);
55
56 impl Drop for SpinLockGuard<'_> {
57 #[inline]
drop(&mut self)58 fn drop(&mut self) {
59 self.0.inner.store(false, Ordering::Release);
60 }
61 }
62
63 #[pin_data]
64 pub struct CMutex<T> {
65 #[pin]
66 wait_list: ListHead,
67 spin_lock: SpinLock,
68 locked: Cell<bool>,
69 #[pin]
70 data: UnsafeCell<T>,
71 }
72
73 impl<T> CMutex<T> {
74 #[inline]
new(val: impl PinInit<T>) -> impl PinInit<Self>75 pub fn new(val: impl PinInit<T>) -> impl PinInit<Self> {
76 pin_init!(CMutex {
77 wait_list <- ListHead::new(),
78 spin_lock: SpinLock::new(),
79 locked: Cell::new(false),
80 data <- unsafe {
81 pin_init_from_closure(|slot: *mut UnsafeCell<T>| {
82 val.__pinned_init(slot.cast::<T>())
83 })
84 },
85 })
86 }
87
88 #[inline]
lock(&self) -> Pin<CMutexGuard<'_, T>>89 pub fn lock(&self) -> Pin<CMutexGuard<'_, T>> {
90 let mut sguard = self.spin_lock.acquire();
91 if self.locked.get() {
92 stack_pin_init!(let wait_entry = WaitEntry::insert_new(&self.wait_list));
93 // println!("wait list length: {}", self.wait_list.size());
94 while self.locked.get() {
95 drop(sguard);
96 park();
97 sguard = self.spin_lock.acquire();
98 }
99 // This does have an effect, as the ListHead inside wait_entry implements Drop!
100 #[expect(clippy::drop_non_drop)]
101 drop(wait_entry);
102 }
103 self.locked.set(true);
104 unsafe {
105 Pin::new_unchecked(CMutexGuard {
106 mtx: self,
107 _pin: PhantomPinned,
108 })
109 }
110 }
111
112 #[allow(dead_code)]
get_data_mut(self: Pin<&mut Self>) -> &mut T113 pub fn get_data_mut(self: Pin<&mut Self>) -> &mut T {
114 // SAFETY: we have an exclusive reference and thus nobody has access to data.
115 unsafe { &mut *self.data.get() }
116 }
117 }
118
119 unsafe impl<T: Send> Send for CMutex<T> {}
120 unsafe impl<T: Send> Sync for CMutex<T> {}
121
122 pub struct CMutexGuard<'a, T> {
123 mtx: &'a CMutex<T>,
124 _pin: PhantomPinned,
125 }
126
127 impl<T> Drop for CMutexGuard<'_, T> {
128 #[inline]
drop(&mut self)129 fn drop(&mut self) {
130 let sguard = self.mtx.spin_lock.acquire();
131 self.mtx.locked.set(false);
132 if let Some(list_field) = self.mtx.wait_list.next() {
133 let wait_entry = list_field.as_ptr().cast::<WaitEntry>();
134 unsafe { (*wait_entry).thread.unpark() };
135 }
136 drop(sguard);
137 }
138 }
139
140 impl<T> Deref for CMutexGuard<'_, T> {
141 type Target = T;
142
143 #[inline]
deref(&self) -> &Self::Target144 fn deref(&self) -> &Self::Target {
145 unsafe { &*self.mtx.data.get() }
146 }
147 }
148
149 impl<T> DerefMut for CMutexGuard<'_, T> {
150 #[inline]
deref_mut(&mut self) -> &mut Self::Target151 fn deref_mut(&mut self) -> &mut Self::Target {
152 unsafe { &mut *self.mtx.data.get() }
153 }
154 }
155
156 #[pin_data]
157 #[repr(C)]
158 struct WaitEntry {
159 #[pin]
160 wait_list: ListHead,
161 thread: Thread,
162 }
163
164 impl WaitEntry {
165 #[inline]
insert_new(list: &ListHead) -> impl PinInit<Self> + '_166 fn insert_new(list: &ListHead) -> impl PinInit<Self> + '_ {
167 pin_init!(Self {
168 thread: thread::current(),
169 wait_list <- ListHead::insert_prev(list),
170 })
171 }
172 }
173
174 #[cfg(not(any(feature = "std", feature = "alloc")))]
main()175 fn main() {}
176
177 #[allow(dead_code)]
178 #[cfg_attr(test, test)]
179 #[cfg(any(feature = "std", feature = "alloc"))]
main()180 fn main() {
181 let mtx: Pin<Arc<CMutex<usize>>> = Arc::pin_init(CMutex::new(0)).unwrap();
182 let mut handles = vec![];
183 let thread_count = 20;
184 let workload = if cfg!(miri) { 100 } else { 1_000 };
185 for i in 0..thread_count {
186 let mtx = mtx.clone();
187 handles.push(
188 Builder::new()
189 .name(format!("worker #{i}"))
190 .spawn(move || {
191 for _ in 0..workload {
192 *mtx.lock() += 1;
193 }
194 println!("{i} halfway");
195 sleep(Duration::from_millis((i as u64) * 10));
196 for _ in 0..workload {
197 *mtx.lock() += 1;
198 }
199 println!("{i} finished");
200 })
201 .expect("should not fail"),
202 );
203 }
204 for h in handles {
205 h.join().expect("thread panicked");
206 }
207 println!("{:?}", &*mtx.lock());
208 assert_eq!(*mtx.lock(), workload * thread_count * 2);
209 }
210