xref: /linux/rust/pin-init/examples/static_init.rs (revision 0074281bb6316108e0cff094bd4db78ab3eee236)
1 // SPDX-License-Identifier: Apache-2.0 OR MIT
2 
3 #![allow(clippy::undocumented_unsafe_blocks)]
4 #![cfg_attr(feature = "alloc", feature(allocator_api))]
5 #![cfg_attr(not(RUSTC_LINT_REASONS_IS_STABLE), feature(lint_reasons))]
6 #![allow(unused_imports)]
7 
8 use core::{
9     cell::{Cell, UnsafeCell},
10     mem::MaybeUninit,
11     ops,
12     pin::Pin,
13     time::Duration,
14 };
15 use pin_init::*;
16 #[cfg(feature = "std")]
17 use std::{
18     sync::Arc,
19     thread::{sleep, Builder},
20 };
21 
22 #[allow(unused_attributes)]
23 mod mutex;
24 use mutex::*;
25 
26 pub struct StaticInit<T, I> {
27     cell: UnsafeCell<MaybeUninit<T>>,
28     init: Cell<Option<I>>,
29     lock: SpinLock,
30     present: Cell<bool>,
31 }
32 
33 unsafe impl<T: Sync, I> Sync for StaticInit<T, I> {}
34 unsafe impl<T: Send, I> Send for StaticInit<T, I> {}
35 
36 impl<T, I: PinInit<T>> StaticInit<T, I> {
new(init: I) -> Self37     pub const fn new(init: I) -> Self {
38         Self {
39             cell: UnsafeCell::new(MaybeUninit::uninit()),
40             init: Cell::new(Some(init)),
41             lock: SpinLock::new(),
42             present: Cell::new(false),
43         }
44     }
45 }
46 
47 impl<T, I: PinInit<T>> ops::Deref for StaticInit<T, I> {
48     type Target = T;
deref(&self) -> &Self::Target49     fn deref(&self) -> &Self::Target {
50         if self.present.get() {
51             unsafe { (*self.cell.get()).assume_init_ref() }
52         } else {
53             println!("acquire spinlock on static init");
54             let _guard = self.lock.acquire();
55             println!("rechecking present...");
56             std::thread::sleep(std::time::Duration::from_millis(200));
57             if self.present.get() {
58                 return unsafe { (*self.cell.get()).assume_init_ref() };
59             }
60             println!("doing init");
61             let ptr = self.cell.get().cast::<T>();
62             match self.init.take() {
63                 Some(f) => unsafe { f.__pinned_init(ptr).unwrap() },
64                 None => unsafe { core::hint::unreachable_unchecked() },
65             }
66             self.present.set(true);
67             unsafe { (*self.cell.get()).assume_init_ref() }
68         }
69     }
70 }
71 
72 pub struct CountInit;
73 
74 unsafe impl PinInit<CMutex<usize>> for CountInit {
__pinned_init( self, slot: *mut CMutex<usize>, ) -> Result<(), core::convert::Infallible>75     unsafe fn __pinned_init(
76         self,
77         slot: *mut CMutex<usize>,
78     ) -> Result<(), core::convert::Infallible> {
79         let init = CMutex::new(0);
80         std::thread::sleep(std::time::Duration::from_millis(1000));
81         unsafe { init.__pinned_init(slot) }
82     }
83 }
84 
85 pub static COUNT: StaticInit<CMutex<usize>, CountInit> = StaticInit::new(CountInit);
86 
main()87 fn main() {
88     #[cfg(feature = "std")]
89     {
90         let mtx: Pin<Arc<CMutex<usize>>> = Arc::pin_init(CMutex::new(0)).unwrap();
91         let mut handles = vec![];
92         let thread_count = 20;
93         let workload = 1_000;
94         for i in 0..thread_count {
95             let mtx = mtx.clone();
96             handles.push(
97                 Builder::new()
98                     .name(format!("worker #{i}"))
99                     .spawn(move || {
100                         for _ in 0..workload {
101                             *COUNT.lock() += 1;
102                             std::thread::sleep(std::time::Duration::from_millis(10));
103                             *mtx.lock() += 1;
104                             std::thread::sleep(std::time::Duration::from_millis(10));
105                             *COUNT.lock() += 1;
106                         }
107                         println!("{i} halfway");
108                         sleep(Duration::from_millis((i as u64) * 10));
109                         for _ in 0..workload {
110                             std::thread::sleep(std::time::Duration::from_millis(10));
111                             *mtx.lock() += 1;
112                         }
113                         println!("{i} finished");
114                     })
115                     .expect("should not fail"),
116             );
117         }
118         for h in handles {
119             h.join().expect("thread panicked");
120         }
121         println!("{:?}, {:?}", &*mtx.lock(), &*COUNT.lock());
122         assert_eq!(*mtx.lock(), workload * thread_count * 2);
123     }
124 }
125