1 // SPDX-License-Identifier: Apache-2.0 OR MIT
2
3 #![allow(clippy::undocumented_unsafe_blocks)]
4 #![cfg_attr(feature = "alloc", feature(allocator_api))]
5 #![cfg_attr(not(RUSTC_LINT_REASONS_IS_STABLE), feature(lint_reasons))]
6 #![allow(unused_imports)]
7
8 use core::{
9 cell::{Cell, UnsafeCell},
10 mem::MaybeUninit,
11 ops,
12 pin::Pin,
13 time::Duration,
14 };
15 use pin_init::*;
16 #[cfg(feature = "std")]
17 use std::{
18 sync::Arc,
19 thread::{sleep, Builder},
20 };
21
22 #[allow(unused_attributes)]
23 mod mutex;
24 use mutex::*;
25
26 pub struct StaticInit<T, I> {
27 cell: UnsafeCell<MaybeUninit<T>>,
28 init: Cell<Option<I>>,
29 lock: SpinLock,
30 present: Cell<bool>,
31 }
32
33 unsafe impl<T: Sync, I> Sync for StaticInit<T, I> {}
34 unsafe impl<T: Send, I> Send for StaticInit<T, I> {}
35
36 impl<T, I: PinInit<T>> StaticInit<T, I> {
new(init: I) -> Self37 pub const fn new(init: I) -> Self {
38 Self {
39 cell: UnsafeCell::new(MaybeUninit::uninit()),
40 init: Cell::new(Some(init)),
41 lock: SpinLock::new(),
42 present: Cell::new(false),
43 }
44 }
45 }
46
47 impl<T, I: PinInit<T>> ops::Deref for StaticInit<T, I> {
48 type Target = T;
deref(&self) -> &Self::Target49 fn deref(&self) -> &Self::Target {
50 if self.present.get() {
51 unsafe { (*self.cell.get()).assume_init_ref() }
52 } else {
53 println!("acquire spinlock on static init");
54 let _guard = self.lock.acquire();
55 println!("rechecking present...");
56 std::thread::sleep(std::time::Duration::from_millis(200));
57 if self.present.get() {
58 return unsafe { (*self.cell.get()).assume_init_ref() };
59 }
60 println!("doing init");
61 let ptr = self.cell.get().cast::<T>();
62 match self.init.take() {
63 Some(f) => unsafe { f.__pinned_init(ptr).unwrap() },
64 None => unsafe { core::hint::unreachable_unchecked() },
65 }
66 self.present.set(true);
67 unsafe { (*self.cell.get()).assume_init_ref() }
68 }
69 }
70 }
71
72 pub struct CountInit;
73
74 unsafe impl PinInit<CMutex<usize>> for CountInit {
__pinned_init( self, slot: *mut CMutex<usize>, ) -> Result<(), core::convert::Infallible>75 unsafe fn __pinned_init(
76 self,
77 slot: *mut CMutex<usize>,
78 ) -> Result<(), core::convert::Infallible> {
79 let init = CMutex::new(0);
80 std::thread::sleep(std::time::Duration::from_millis(1000));
81 unsafe { init.__pinned_init(slot) }
82 }
83 }
84
85 pub static COUNT: StaticInit<CMutex<usize>, CountInit> = StaticInit::new(CountInit);
86
main()87 fn main() {
88 #[cfg(feature = "std")]
89 {
90 let mtx: Pin<Arc<CMutex<usize>>> = Arc::pin_init(CMutex::new(0)).unwrap();
91 let mut handles = vec![];
92 let thread_count = 20;
93 let workload = 1_000;
94 for i in 0..thread_count {
95 let mtx = mtx.clone();
96 handles.push(
97 Builder::new()
98 .name(format!("worker #{i}"))
99 .spawn(move || {
100 for _ in 0..workload {
101 *COUNT.lock() += 1;
102 std::thread::sleep(std::time::Duration::from_millis(10));
103 *mtx.lock() += 1;
104 std::thread::sleep(std::time::Duration::from_millis(10));
105 *COUNT.lock() += 1;
106 }
107 println!("{i} halfway");
108 sleep(Duration::from_millis((i as u64) * 10));
109 for _ in 0..workload {
110 std::thread::sleep(std::time::Duration::from_millis(10));
111 *mtx.lock() += 1;
112 }
113 println!("{i} finished");
114 })
115 .expect("should not fail"),
116 );
117 }
118 for h in handles {
119 h.join().expect("thread panicked");
120 }
121 println!("{:?}, {:?}", &*mtx.lock(), &*COUNT.lock());
122 assert_eq!(*mtx.lock(), workload * thread_count * 2);
123 }
124 }
125