1 // SPDX-License-Identifier: Apache-2.0 OR MIT
2 
3 #![allow(clippy::undocumented_unsafe_blocks)]
4 #![cfg_attr(feature = "alloc", feature(allocator_api))]
5 
6 use core::{
7     cell::{Cell, UnsafeCell},
8     mem::MaybeUninit,
9     ops,
10     pin::Pin,
11     time::Duration,
12 };
13 use pin_init::*;
14 use std::{
15     sync::Arc,
16     thread::{sleep, Builder},
17 };
18 
19 #[expect(unused_attributes)]
20 mod mutex;
21 use mutex::*;
22 
23 pub struct StaticInit<T, I> {
24     cell: UnsafeCell<MaybeUninit<T>>,
25     init: Cell<Option<I>>,
26     lock: SpinLock,
27     present: Cell<bool>,
28 }
29 
30 unsafe impl<T: Sync, I> Sync for StaticInit<T, I> {}
31 unsafe impl<T: Send, I> Send for StaticInit<T, I> {}
32 
33 impl<T, I: PinInit<T>> StaticInit<T, I> {
new(init: I) -> Self34     pub const fn new(init: I) -> Self {
35         Self {
36             cell: UnsafeCell::new(MaybeUninit::uninit()),
37             init: Cell::new(Some(init)),
38             lock: SpinLock::new(),
39             present: Cell::new(false),
40         }
41     }
42 }
43 
44 impl<T, I: PinInit<T>> ops::Deref for StaticInit<T, I> {
45     type Target = T;
deref(&self) -> &Self::Target46     fn deref(&self) -> &Self::Target {
47         if self.present.get() {
48             unsafe { (*self.cell.get()).assume_init_ref() }
49         } else {
50             println!("acquire spinlock on static init");
51             let _guard = self.lock.acquire();
52             println!("rechecking present...");
53             std::thread::sleep(std::time::Duration::from_millis(200));
54             if self.present.get() {
55                 return unsafe { (*self.cell.get()).assume_init_ref() };
56             }
57             println!("doing init");
58             let ptr = self.cell.get().cast::<T>();
59             match self.init.take() {
60                 Some(f) => unsafe { f.__pinned_init(ptr).unwrap() },
61                 None => unsafe { core::hint::unreachable_unchecked() },
62             }
63             self.present.set(true);
64             unsafe { (*self.cell.get()).assume_init_ref() }
65         }
66     }
67 }
68 
69 pub struct CountInit;
70 
71 unsafe impl PinInit<CMutex<usize>> for CountInit {
__pinned_init( self, slot: *mut CMutex<usize>, ) -> Result<(), core::convert::Infallible>72     unsafe fn __pinned_init(
73         self,
74         slot: *mut CMutex<usize>,
75     ) -> Result<(), core::convert::Infallible> {
76         let init = CMutex::new(0);
77         std::thread::sleep(std::time::Duration::from_millis(1000));
78         unsafe { init.__pinned_init(slot) }
79     }
80 }
81 
82 pub static COUNT: StaticInit<CMutex<usize>, CountInit> = StaticInit::new(CountInit);
83 
84 #[cfg(not(any(feature = "std", feature = "alloc")))]
main()85 fn main() {}
86 
87 #[cfg(any(feature = "std", feature = "alloc"))]
main()88 fn main() {
89     let mtx: Pin<Arc<CMutex<usize>>> = Arc::pin_init(CMutex::new(0)).unwrap();
90     let mut handles = vec![];
91     let thread_count = 20;
92     let workload = 1_000;
93     for i in 0..thread_count {
94         let mtx = mtx.clone();
95         handles.push(
96             Builder::new()
97                 .name(format!("worker #{i}"))
98                 .spawn(move || {
99                     for _ in 0..workload {
100                         *COUNT.lock() += 1;
101                         std::thread::sleep(std::time::Duration::from_millis(10));
102                         *mtx.lock() += 1;
103                         std::thread::sleep(std::time::Duration::from_millis(10));
104                         *COUNT.lock() += 1;
105                     }
106                     println!("{i} halfway");
107                     sleep(Duration::from_millis((i as u64) * 10));
108                     for _ in 0..workload {
109                         std::thread::sleep(std::time::Duration::from_millis(10));
110                         *mtx.lock() += 1;
111                     }
112                     println!("{i} finished");
113                 })
114                 .expect("should not fail"),
115         );
116     }
117     for h in handles {
118         h.join().expect("thread panicked");
119     }
120     println!("{:?}, {:?}", &*mtx.lock(), &*COUNT.lock());
121     assert_eq!(*mtx.lock(), workload * thread_count * 2);
122 }
123