xref: /cloud-hypervisor/virtio-devices/src/vhost_user/fs.rs (revision eeae63b4595fbf0cc69f62b6e9d9a79c543c4ac7)
1 // Copyright 2019 Intel Corporation. All Rights Reserved.
2 // SPDX-License-Identifier: Apache-2.0
3 
4 use std::sync::atomic::AtomicBool;
5 use std::sync::{Arc, Barrier, Mutex};
6 use std::{result, thread};
7 
8 use seccompiler::SeccompAction;
9 use serde::{Deserialize, Serialize};
10 use serde_with::{serde_as, Bytes};
11 use vhost::vhost_user::message::{VhostUserProtocolFeatures, VhostUserVirtioFeatures};
12 use vhost::vhost_user::{FrontendReqHandler, VhostUserFrontend, VhostUserFrontendReqHandler};
13 use virtio_queue::Queue;
14 use vm_memory::{ByteValued, GuestMemoryAtomic};
15 use vm_migration::protocol::MemoryRangeTable;
16 use vm_migration::{Migratable, MigratableError, Pausable, Snapshot, Snapshottable, Transportable};
17 use vmm_sys_util::eventfd::EventFd;
18 
19 use super::vu_common_ctrl::VhostUserHandle;
20 use super::{Error, Result, DEFAULT_VIRTIO_FEATURES};
21 use crate::seccomp_filters::Thread;
22 use crate::thread_helper::spawn_virtio_thread;
23 use crate::vhost_user::VhostUserCommon;
24 use crate::{
25     ActivateResult, GuestMemoryMmap, GuestRegionMmap, MmapRegion, UserspaceMapping, VirtioCommon,
26     VirtioDevice, VirtioDeviceType, VirtioInterrupt, VirtioSharedMemoryList,
27     VIRTIO_F_IOMMU_PLATFORM,
28 };
29 
30 const NUM_QUEUE_OFFSET: usize = 1;
31 const DEFAULT_QUEUE_NUMBER: usize = 2;
32 
33 #[derive(Serialize, Deserialize)]
34 pub struct State {
35     pub avail_features: u64,
36     pub acked_features: u64,
37     pub config: VirtioFsConfig,
38     pub acked_protocol_features: u64,
39     pub vu_num_queues: usize,
40     pub backend_req_support: bool,
41 }
42 
43 struct BackendReqHandler {}
44 impl VhostUserFrontendReqHandler for BackendReqHandler {}
45 
46 pub const VIRTIO_FS_TAG_LEN: usize = 36;
47 #[serde_as]
48 #[derive(Copy, Clone, Serialize, Deserialize)]
49 #[repr(C, packed)]
50 pub struct VirtioFsConfig {
51     #[serde_as(as = "Bytes")]
52     pub tag: [u8; VIRTIO_FS_TAG_LEN],
53     pub num_request_queues: u32,
54 }
55 
56 impl Default for VirtioFsConfig {
57     fn default() -> Self {
58         VirtioFsConfig {
59             tag: [0; VIRTIO_FS_TAG_LEN],
60             num_request_queues: 0,
61         }
62     }
63 }
64 
65 // SAFETY: only a series of integers
66 unsafe impl ByteValued for VirtioFsConfig {}
67 
68 pub struct Fs {
69     common: VirtioCommon,
70     vu_common: VhostUserCommon,
71     id: String,
72     config: VirtioFsConfig,
73     // Hold ownership of the memory that is allocated for the device
74     // which will be automatically dropped when the device is dropped
75     cache: Option<(VirtioSharedMemoryList, MmapRegion)>,
76     seccomp_action: SeccompAction,
77     guest_memory: Option<GuestMemoryAtomic<GuestMemoryMmap>>,
78     epoll_thread: Option<thread::JoinHandle<()>>,
79     exit_evt: EventFd,
80     iommu: bool,
81 }
82 
83 impl Fs {
84     /// Create a new virtio-fs device.
85     #[allow(clippy::too_many_arguments)]
86     pub fn new(
87         id: String,
88         path: &str,
89         tag: &str,
90         req_num_queues: usize,
91         queue_size: u16,
92         cache: Option<(VirtioSharedMemoryList, MmapRegion)>,
93         seccomp_action: SeccompAction,
94         exit_evt: EventFd,
95         iommu: bool,
96         state: Option<State>,
97     ) -> Result<Fs> {
98         // Calculate the actual number of queues needed.
99         let num_queues = NUM_QUEUE_OFFSET + req_num_queues;
100 
101         // Connect to the vhost-user socket.
102         let mut vu = VhostUserHandle::connect_vhost_user(false, path, num_queues as u64, false)?;
103 
104         let (
105             avail_features,
106             acked_features,
107             acked_protocol_features,
108             vu_num_queues,
109             config,
110             paused,
111         ) = if let Some(state) = state {
112             info!("Restoring vhost-user-fs {}", id);
113 
114             vu.set_protocol_features_vhost_user(
115                 state.acked_features,
116                 state.acked_protocol_features,
117             )?;
118 
119             (
120                 state.avail_features,
121                 state.acked_features,
122                 state.acked_protocol_features,
123                 state.vu_num_queues,
124                 state.config,
125                 true,
126             )
127         } else {
128             // Filling device and vring features VMM supports.
129             let avail_features = DEFAULT_VIRTIO_FEATURES;
130 
131             let avail_protocol_features = VhostUserProtocolFeatures::MQ
132                 | VhostUserProtocolFeatures::CONFIGURE_MEM_SLOTS
133                 | VhostUserProtocolFeatures::REPLY_ACK
134                 | VhostUserProtocolFeatures::INFLIGHT_SHMFD
135                 | VhostUserProtocolFeatures::LOG_SHMFD;
136 
137             let (acked_features, acked_protocol_features) =
138                 vu.negotiate_features_vhost_user(avail_features, avail_protocol_features)?;
139 
140             let backend_num_queues =
141                 if acked_protocol_features & VhostUserProtocolFeatures::MQ.bits() != 0 {
142                     vu.socket_handle()
143                         .get_queue_num()
144                         .map_err(Error::VhostUserGetQueueMaxNum)? as usize
145                 } else {
146                     DEFAULT_QUEUE_NUMBER
147                 };
148 
149             if num_queues > backend_num_queues {
150                 error!(
151                 "vhost-user-fs requested too many queues ({}) since the backend only supports {}\n",
152                 num_queues, backend_num_queues
153             );
154                 return Err(Error::BadQueueNum);
155             }
156 
157             // Create virtio-fs device configuration.
158             let mut config = VirtioFsConfig::default();
159             let tag_bytes_slice = tag.as_bytes();
160             let len = if tag_bytes_slice.len() < config.tag.len() {
161                 tag_bytes_slice.len()
162             } else {
163                 config.tag.len()
164             };
165             config.tag[..len].copy_from_slice(tag_bytes_slice[..len].as_ref());
166             config.num_request_queues = req_num_queues as u32;
167 
168             (
169                 acked_features,
170                 // If part of the available features that have been acked, the
171                 // PROTOCOL_FEATURES bit must be already set through the VIRTIO
172                 // acked features as we know the guest would never ack it, thus
173                 // the feature would be lost.
174                 acked_features & VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits(),
175                 acked_protocol_features,
176                 num_queues,
177                 config,
178                 false,
179             )
180         };
181 
182         Ok(Fs {
183             common: VirtioCommon {
184                 device_type: VirtioDeviceType::Fs as u32,
185                 avail_features,
186                 acked_features,
187                 queue_sizes: vec![queue_size; num_queues],
188                 paused_sync: Some(Arc::new(Barrier::new(2))),
189                 min_queues: 1,
190                 paused: Arc::new(AtomicBool::new(paused)),
191                 ..Default::default()
192             },
193             vu_common: VhostUserCommon {
194                 vu: Some(Arc::new(Mutex::new(vu))),
195                 acked_protocol_features,
196                 socket_path: path.to_string(),
197                 vu_num_queues,
198                 ..Default::default()
199             },
200             id,
201             config,
202             cache,
203             seccomp_action,
204             guest_memory: None,
205             epoll_thread: None,
206             exit_evt,
207             iommu,
208         })
209     }
210 
211     fn state(&self) -> State {
212         State {
213             avail_features: self.common.avail_features,
214             acked_features: self.common.acked_features,
215             config: self.config,
216             acked_protocol_features: self.vu_common.acked_protocol_features,
217             vu_num_queues: self.vu_common.vu_num_queues,
218             backend_req_support: false,
219         }
220     }
221 }
222 
223 impl Drop for Fs {
224     fn drop(&mut self) {
225         if let Some(kill_evt) = self.common.kill_evt.take() {
226             // Ignore the result because there is nothing we can do about it.
227             let _ = kill_evt.write(1);
228         }
229         self.common.wait_for_epoll_threads();
230         if let Some(thread) = self.epoll_thread.take() {
231             if let Err(e) = thread.join() {
232                 error!("Error joining thread: {:?}", e);
233             }
234         }
235     }
236 }
237 
238 impl VirtioDevice for Fs {
239     fn device_type(&self) -> u32 {
240         self.common.device_type
241     }
242 
243     fn queue_max_sizes(&self) -> &[u16] {
244         &self.common.queue_sizes
245     }
246 
247     fn features(&self) -> u64 {
248         let mut features = self.common.avail_features;
249         if self.iommu {
250             features |= 1u64 << VIRTIO_F_IOMMU_PLATFORM;
251         }
252         features
253     }
254 
255     fn ack_features(&mut self, value: u64) {
256         self.common.ack_features(value)
257     }
258 
259     fn read_config(&self, offset: u64, data: &mut [u8]) {
260         self.read_config_from_slice(self.config.as_slice(), offset, data);
261     }
262 
263     fn activate(
264         &mut self,
265         mem: GuestMemoryAtomic<GuestMemoryMmap>,
266         interrupt_cb: Arc<dyn VirtioInterrupt>,
267         queues: Vec<(usize, Queue, EventFd)>,
268     ) -> ActivateResult {
269         self.common.activate(&queues, &interrupt_cb)?;
270         self.guest_memory = Some(mem.clone());
271 
272         let backend_req_handler: Option<FrontendReqHandler<BackendReqHandler>> = None;
273         // Run a dedicated thread for handling potential reconnections with
274         // the backend.
275         let (kill_evt, pause_evt) = self.common.dup_eventfds();
276 
277         let mut handler = self.vu_common.activate(
278             mem,
279             queues,
280             interrupt_cb,
281             self.common.acked_features,
282             backend_req_handler,
283             kill_evt,
284             pause_evt,
285         )?;
286 
287         let paused = self.common.paused.clone();
288         let paused_sync = self.common.paused_sync.clone();
289 
290         let mut epoll_threads = Vec::new();
291         spawn_virtio_thread(
292             &self.id,
293             &self.seccomp_action,
294             Thread::VirtioVhostFs,
295             &mut epoll_threads,
296             &self.exit_evt,
297             move || handler.run(paused, paused_sync.unwrap()),
298         )?;
299         self.epoll_thread = Some(epoll_threads.remove(0));
300 
301         event!("virtio-device", "activated", "id", &self.id);
302         Ok(())
303     }
304 
305     fn reset(&mut self) -> Option<Arc<dyn VirtioInterrupt>> {
306         // We first must resume the virtio thread if it was paused.
307         if self.common.pause_evt.take().is_some() {
308             self.common.resume().ok()?;
309         }
310 
311         if let Some(vu) = &self.vu_common.vu {
312             if let Err(e) = vu.lock().unwrap().reset_vhost_user() {
313                 error!("Failed to reset vhost-user daemon: {:?}", e);
314                 return None;
315             }
316         }
317 
318         if let Some(kill_evt) = self.common.kill_evt.take() {
319             // Ignore the result because there is nothing we can do about it.
320             let _ = kill_evt.write(1);
321         }
322 
323         event!("virtio-device", "reset", "id", &self.id);
324 
325         // Return the interrupt
326         Some(self.common.interrupt_cb.take().unwrap())
327     }
328 
329     fn shutdown(&mut self) {
330         self.vu_common.shutdown()
331     }
332 
333     fn get_shm_regions(&self) -> Option<VirtioSharedMemoryList> {
334         self.cache.as_ref().map(|cache| cache.0.clone())
335     }
336 
337     fn set_shm_regions(
338         &mut self,
339         shm_regions: VirtioSharedMemoryList,
340     ) -> std::result::Result<(), crate::Error> {
341         if let Some(cache) = self.cache.as_mut() {
342             cache.0 = shm_regions;
343             Ok(())
344         } else {
345             Err(crate::Error::SetShmRegionsNotSupported)
346         }
347     }
348 
349     fn add_memory_region(
350         &mut self,
351         region: &Arc<GuestRegionMmap>,
352     ) -> std::result::Result<(), crate::Error> {
353         self.vu_common.add_memory_region(&self.guest_memory, region)
354     }
355 
356     fn userspace_mappings(&self) -> Vec<UserspaceMapping> {
357         let mut mappings = Vec::new();
358         if let Some(cache) = self.cache.as_ref() {
359             mappings.push(UserspaceMapping {
360                 host_addr: cache.0.host_addr,
361                 mem_slot: cache.0.mem_slot,
362                 addr: cache.0.addr,
363                 len: cache.0.len,
364                 mergeable: false,
365             })
366         }
367 
368         mappings
369     }
370 }
371 
372 impl Pausable for Fs {
373     fn pause(&mut self) -> result::Result<(), MigratableError> {
374         self.vu_common.pause()?;
375         self.common.pause()
376     }
377 
378     fn resume(&mut self) -> result::Result<(), MigratableError> {
379         self.common.resume()?;
380 
381         if let Some(epoll_thread) = &self.epoll_thread {
382             epoll_thread.thread().unpark();
383         }
384 
385         self.vu_common.resume()
386     }
387 }
388 
389 impl Snapshottable for Fs {
390     fn id(&self) -> String {
391         self.id.clone()
392     }
393 
394     fn snapshot(&mut self) -> std::result::Result<Snapshot, MigratableError> {
395         self.vu_common.snapshot(&self.state())
396     }
397 }
398 impl Transportable for Fs {}
399 
400 impl Migratable for Fs {
401     fn start_dirty_log(&mut self) -> std::result::Result<(), MigratableError> {
402         self.vu_common.start_dirty_log(&self.guest_memory)
403     }
404 
405     fn stop_dirty_log(&mut self) -> std::result::Result<(), MigratableError> {
406         self.vu_common.stop_dirty_log()
407     }
408 
409     fn dirty_log(&mut self) -> std::result::Result<MemoryRangeTable, MigratableError> {
410         self.vu_common.dirty_log(&self.guest_memory)
411     }
412 
413     fn start_migration(&mut self) -> std::result::Result<(), MigratableError> {
414         self.vu_common.start_migration()
415     }
416 
417     fn complete_migration(&mut self) -> std::result::Result<(), MigratableError> {
418         self.vu_common
419             .complete_migration(self.common.kill_evt.take())
420     }
421 }
422