1 // SPDX-License-Identifier: Apache-2.0 OR MIT
2
3 #![allow(clippy::undocumented_unsafe_blocks)]
4 #![cfg_attr(feature = "alloc", feature(allocator_api))]
5
6 use core::{
7 cell::{Cell, UnsafeCell},
8 mem::MaybeUninit,
9 ops,
10 pin::Pin,
11 time::Duration,
12 };
13 use pin_init::*;
14 use std::{
15 sync::Arc,
16 thread::{sleep, Builder},
17 };
18
19 #[expect(unused_attributes)]
20 mod mutex;
21 use mutex::*;
22
23 pub struct StaticInit<T, I> {
24 cell: UnsafeCell<MaybeUninit<T>>,
25 init: Cell<Option<I>>,
26 lock: SpinLock,
27 present: Cell<bool>,
28 }
29
30 unsafe impl<T: Sync, I> Sync for StaticInit<T, I> {}
31 unsafe impl<T: Send, I> Send for StaticInit<T, I> {}
32
33 impl<T, I: PinInit<T>> StaticInit<T, I> {
new(init: I) -> Self34 pub const fn new(init: I) -> Self {
35 Self {
36 cell: UnsafeCell::new(MaybeUninit::uninit()),
37 init: Cell::new(Some(init)),
38 lock: SpinLock::new(),
39 present: Cell::new(false),
40 }
41 }
42 }
43
44 impl<T, I: PinInit<T>> ops::Deref for StaticInit<T, I> {
45 type Target = T;
deref(&self) -> &Self::Target46 fn deref(&self) -> &Self::Target {
47 if self.present.get() {
48 unsafe { (*self.cell.get()).assume_init_ref() }
49 } else {
50 println!("acquire spinlock on static init");
51 let _guard = self.lock.acquire();
52 println!("rechecking present...");
53 std::thread::sleep(std::time::Duration::from_millis(200));
54 if self.present.get() {
55 return unsafe { (*self.cell.get()).assume_init_ref() };
56 }
57 println!("doing init");
58 let ptr = self.cell.get().cast::<T>();
59 match self.init.take() {
60 Some(f) => unsafe { f.__pinned_init(ptr).unwrap() },
61 None => unsafe { core::hint::unreachable_unchecked() },
62 }
63 self.present.set(true);
64 unsafe { (*self.cell.get()).assume_init_ref() }
65 }
66 }
67 }
68
69 pub struct CountInit;
70
71 unsafe impl PinInit<CMutex<usize>> for CountInit {
__pinned_init( self, slot: *mut CMutex<usize>, ) -> Result<(), core::convert::Infallible>72 unsafe fn __pinned_init(
73 self,
74 slot: *mut CMutex<usize>,
75 ) -> Result<(), core::convert::Infallible> {
76 let init = CMutex::new(0);
77 std::thread::sleep(std::time::Duration::from_millis(1000));
78 unsafe { init.__pinned_init(slot) }
79 }
80 }
81
82 pub static COUNT: StaticInit<CMutex<usize>, CountInit> = StaticInit::new(CountInit);
83
84 #[cfg(not(any(feature = "std", feature = "alloc")))]
main()85 fn main() {}
86
87 #[cfg(any(feature = "std", feature = "alloc"))]
main()88 fn main() {
89 let mtx: Pin<Arc<CMutex<usize>>> = Arc::pin_init(CMutex::new(0)).unwrap();
90 let mut handles = vec![];
91 let thread_count = 20;
92 let workload = 1_000;
93 for i in 0..thread_count {
94 let mtx = mtx.clone();
95 handles.push(
96 Builder::new()
97 .name(format!("worker #{i}"))
98 .spawn(move || {
99 for _ in 0..workload {
100 *COUNT.lock() += 1;
101 std::thread::sleep(std::time::Duration::from_millis(10));
102 *mtx.lock() += 1;
103 std::thread::sleep(std::time::Duration::from_millis(10));
104 *COUNT.lock() += 1;
105 }
106 println!("{i} halfway");
107 sleep(Duration::from_millis((i as u64) * 10));
108 for _ in 0..workload {
109 std::thread::sleep(std::time::Duration::from_millis(10));
110 *mtx.lock() += 1;
111 }
112 println!("{i} finished");
113 })
114 .expect("should not fail"),
115 );
116 }
117 for h in handles {
118 h.join().expect("thread panicked");
119 }
120 println!("{:?}, {:?}", &*mtx.lock(), &*COUNT.lock());
121 assert_eq!(*mtx.lock(), workload * thread_count * 2);
122 }
123