xref: /cloud-hypervisor/virtio-devices/src/vhost_user/fs.rs (revision 3ce0fef7fd546467398c914dbc74d8542e45cf6f)
1 // Copyright 2019 Intel Corporation. All Rights Reserved.
2 // SPDX-License-Identifier: Apache-2.0
3 
4 use super::vu_common_ctrl::VhostUserHandle;
5 use super::{Error, Result, DEFAULT_VIRTIO_FEATURES};
6 use crate::seccomp_filters::Thread;
7 use crate::thread_helper::spawn_virtio_thread;
8 use crate::vhost_user::VhostUserCommon;
9 use crate::{
10     ActivateError, ActivateResult, UserspaceMapping, VirtioCommon, VirtioDevice, VirtioDeviceType,
11     VirtioInterrupt, VirtioSharedMemoryList, VIRTIO_F_IOMMU_PLATFORM,
12 };
13 use crate::{GuestMemoryMmap, GuestRegionMmap, MmapRegion};
14 use libc::{self, c_void, off64_t, pread64, pwrite64};
15 use seccompiler::SeccompAction;
16 use std::io;
17 use std::os::unix::io::AsRawFd;
18 use std::result;
19 use std::sync::atomic::AtomicBool;
20 use std::sync::{Arc, Barrier, Mutex};
21 use std::thread;
22 use versionize::{VersionMap, Versionize, VersionizeResult};
23 use versionize_derive::Versionize;
24 use vhost::vhost_user::message::{
25     VhostUserFSBackendMsg, VhostUserFSBackendMsgFlags, VhostUserProtocolFeatures,
26     VhostUserVirtioFeatures, VHOST_USER_FS_BACKEND_ENTRIES,
27 };
28 use vhost::vhost_user::{
29     FrontendReqHandler, HandlerResult, VhostUserFrontend, VhostUserFrontendReqHandler,
30 };
31 use virtio_queue::Queue;
32 use vm_memory::{
33     Address, ByteValued, GuestAddress, GuestAddressSpace, GuestMemory, GuestMemoryAtomic,
34 };
35 use vm_migration::{
36     protocol::MemoryRangeTable, Migratable, MigratableError, Pausable, Snapshot, Snapshottable,
37     Transportable, VersionMapped,
38 };
39 use vmm_sys_util::eventfd::EventFd;
40 
41 const NUM_QUEUE_OFFSET: usize = 1;
42 const DEFAULT_QUEUE_NUMBER: usize = 2;
43 
44 #[derive(Versionize)]
45 pub struct State {
46     pub avail_features: u64,
47     pub acked_features: u64,
48     pub config: VirtioFsConfig,
49     pub acked_protocol_features: u64,
50     pub vu_num_queues: usize,
51     pub backend_req_support: bool,
52 }
53 
54 impl VersionMapped for State {}
55 
56 struct BackendReqHandler {
57     cache_offset: GuestAddress,
58     cache_size: u64,
59     mmap_cache_addr: u64,
60     mem: GuestMemoryAtomic<GuestMemoryMmap>,
61 }
62 
63 impl BackendReqHandler {
64     // Make sure request is within cache range
65     fn is_req_valid(&self, offset: u64, len: u64) -> bool {
66         let end = match offset.checked_add(len) {
67             Some(n) => n,
68             None => return false,
69         };
70 
71         !(offset >= self.cache_size || end > self.cache_size)
72     }
73 }
74 
75 impl VhostUserFrontendReqHandler for BackendReqHandler {
76     fn handle_config_change(&self) -> HandlerResult<u64> {
77         debug!("handle_config_change");
78         Ok(0)
79     }
80 
81     fn fs_backend_map(&self, fs: &VhostUserFSBackendMsg, fd: &dyn AsRawFd) -> HandlerResult<u64> {
82         debug!("fs_backend_map");
83 
84         for i in 0..VHOST_USER_FS_BACKEND_ENTRIES {
85             let offset = fs.cache_offset[i];
86             let len = fs.len[i];
87 
88             // Ignore if the length is 0.
89             if len == 0 {
90                 continue;
91             }
92 
93             if !self.is_req_valid(offset, len) {
94                 return Err(io::Error::from_raw_os_error(libc::EINVAL));
95             }
96 
97             let addr = self.mmap_cache_addr + offset;
98             let flags = fs.flags[i];
99             // SAFETY: FFI call with valid arguments
100             let ret = unsafe {
101                 libc::mmap(
102                     addr as *mut libc::c_void,
103                     len as usize,
104                     flags.bits() as i32,
105                     libc::MAP_SHARED | libc::MAP_FIXED,
106                     fd.as_raw_fd(),
107                     fs.fd_offset[i] as libc::off_t,
108                 )
109             };
110             if ret == libc::MAP_FAILED {
111                 return Err(io::Error::last_os_error());
112             }
113         }
114 
115         Ok(0)
116     }
117 
118     fn fs_backend_unmap(&self, fs: &VhostUserFSBackendMsg) -> HandlerResult<u64> {
119         debug!("fs_backend_unmap");
120 
121         for i in 0..VHOST_USER_FS_BACKEND_ENTRIES {
122             let mut len = fs.len[i];
123 
124             // Ignore if the length is 0.
125             if len == 0 {
126                 continue;
127             }
128 
129             // Need to handle a special case where the backend ask for the unmapping
130             // of the entire mapping.
131             let offset = if len == 0xffff_ffff_ffff_ffff {
132                 len = self.cache_size;
133                 0
134             } else {
135                 fs.cache_offset[i]
136             };
137 
138             if !self.is_req_valid(offset, len) {
139                 return Err(io::Error::from_raw_os_error(libc::EINVAL));
140             }
141 
142             let addr = self.mmap_cache_addr + offset;
143             // SAFETY: FFI call with valid arguments
144             let ret = unsafe {
145                 libc::mmap(
146                     addr as *mut libc::c_void,
147                     len as usize,
148                     libc::PROT_NONE,
149                     libc::MAP_ANONYMOUS | libc::MAP_PRIVATE | libc::MAP_FIXED,
150                     -1,
151                     0,
152                 )
153             };
154             if ret == libc::MAP_FAILED {
155                 return Err(io::Error::last_os_error());
156             }
157         }
158 
159         Ok(0)
160     }
161 
162     fn fs_backend_sync(&self, fs: &VhostUserFSBackendMsg) -> HandlerResult<u64> {
163         debug!("fs_backend_sync");
164 
165         for i in 0..VHOST_USER_FS_BACKEND_ENTRIES {
166             let offset = fs.cache_offset[i];
167             let len = fs.len[i];
168 
169             // Ignore if the length is 0.
170             if len == 0 {
171                 continue;
172             }
173 
174             if !self.is_req_valid(offset, len) {
175                 return Err(io::Error::from_raw_os_error(libc::EINVAL));
176             }
177 
178             let addr = self.mmap_cache_addr + offset;
179             let ret =
180                 // SAFETY: FFI call with valid arguments
181                 unsafe { libc::msync(addr as *mut libc::c_void, len as usize, libc::MS_SYNC) };
182             if ret == -1 {
183                 return Err(io::Error::last_os_error());
184             }
185         }
186 
187         Ok(0)
188     }
189 
190     fn fs_backend_io(&self, fs: &VhostUserFSBackendMsg, fd: &dyn AsRawFd) -> HandlerResult<u64> {
191         debug!("fs_backend_io");
192 
193         let mut done: u64 = 0;
194         for i in 0..VHOST_USER_FS_BACKEND_ENTRIES {
195             // Ignore if the length is 0.
196             if fs.len[i] == 0 {
197                 continue;
198             }
199 
200             let mut foffset = fs.fd_offset[i];
201             let mut len = fs.len[i] as usize;
202             let gpa = fs.cache_offset[i];
203             let cache_end = self.cache_offset.raw_value() + self.cache_size;
204             let efault = libc::EFAULT;
205 
206             let mut ptr = if gpa >= self.cache_offset.raw_value() && gpa < cache_end {
207                 let offset = gpa
208                     .checked_sub(self.cache_offset.raw_value())
209                     .ok_or_else(|| io::Error::from_raw_os_error(efault))?;
210                 let end = gpa
211                     .checked_add(fs.len[i])
212                     .ok_or_else(|| io::Error::from_raw_os_error(efault))?;
213 
214                 if end >= cache_end {
215                     return Err(io::Error::from_raw_os_error(efault));
216                 }
217 
218                 self.mmap_cache_addr + offset
219             } else {
220                 self.mem
221                     .memory()
222                     .get_host_address(GuestAddress(gpa))
223                     .map_err(|e| {
224                         error!(
225                             "Failed to find RAM region associated with guest physical address 0x{:x}: {:?}",
226                             gpa, e
227                         );
228                         io::Error::from_raw_os_error(efault)
229                     })? as u64
230             };
231 
232             while len > 0 {
233                 let ret = if (fs.flags[i] & VhostUserFSBackendMsgFlags::MAP_W)
234                     == VhostUserFSBackendMsgFlags::MAP_W
235                 {
236                     debug!("write: foffset={}, len={}", foffset, len);
237                     // SAFETY: FFI call with valid arguments
238                     unsafe {
239                         pwrite64(
240                             fd.as_raw_fd(),
241                             ptr as *const c_void,
242                             len,
243                             foffset as off64_t,
244                         )
245                     }
246                 } else {
247                     debug!("read: foffset={}, len={}", foffset, len);
248                     // SAFETY: FFI call with valid arguments
249                     unsafe { pread64(fd.as_raw_fd(), ptr as *mut c_void, len, foffset as off64_t) }
250                 };
251 
252                 if ret < 0 {
253                     return Err(io::Error::last_os_error());
254                 }
255 
256                 if ret == 0 {
257                     // EOF
258                     return Err(io::Error::new(
259                         io::ErrorKind::UnexpectedEof,
260                         "failed to access whole buffer",
261                     ));
262                 }
263                 len -= ret as usize;
264                 foffset += ret as u64;
265                 ptr += ret as u64;
266                 done += ret as u64;
267             }
268         }
269 
270         Ok(done)
271     }
272 }
273 
274 #[derive(Copy, Clone, Versionize)]
275 #[repr(C, packed)]
276 pub struct VirtioFsConfig {
277     pub tag: [u8; 36],
278     pub num_request_queues: u32,
279 }
280 
281 impl Default for VirtioFsConfig {
282     fn default() -> Self {
283         VirtioFsConfig {
284             tag: [0; 36],
285             num_request_queues: 0,
286         }
287     }
288 }
289 
290 // SAFETY: only a series of integers
291 unsafe impl ByteValued for VirtioFsConfig {}
292 
293 pub struct Fs {
294     common: VirtioCommon,
295     vu_common: VhostUserCommon,
296     id: String,
297     config: VirtioFsConfig,
298     // Hold ownership of the memory that is allocated for the device
299     // which will be automatically dropped when the device is dropped
300     cache: Option<(VirtioSharedMemoryList, MmapRegion)>,
301     backend_req_support: bool,
302     seccomp_action: SeccompAction,
303     guest_memory: Option<GuestMemoryAtomic<GuestMemoryMmap>>,
304     epoll_thread: Option<thread::JoinHandle<()>>,
305     exit_evt: EventFd,
306     iommu: bool,
307 }
308 
309 impl Fs {
310     /// Create a new virtio-fs device.
311     #[allow(clippy::too_many_arguments)]
312     pub fn new(
313         id: String,
314         path: &str,
315         tag: &str,
316         req_num_queues: usize,
317         queue_size: u16,
318         cache: Option<(VirtioSharedMemoryList, MmapRegion)>,
319         seccomp_action: SeccompAction,
320         exit_evt: EventFd,
321         iommu: bool,
322         state: Option<State>,
323     ) -> Result<Fs> {
324         let mut backend_req_support = false;
325 
326         // Calculate the actual number of queues needed.
327         let num_queues = NUM_QUEUE_OFFSET + req_num_queues;
328 
329         // Connect to the vhost-user socket.
330         let mut vu = VhostUserHandle::connect_vhost_user(false, path, num_queues as u64, false)?;
331 
332         let (
333             avail_features,
334             acked_features,
335             acked_protocol_features,
336             vu_num_queues,
337             config,
338             backend_req_support,
339             paused,
340         ) = if let Some(state) = state {
341             info!("Restoring vhost-user-fs {}", id);
342 
343             vu.set_protocol_features_vhost_user(
344                 state.acked_features,
345                 state.acked_protocol_features,
346             )?;
347 
348             (
349                 state.avail_features,
350                 state.acked_features,
351                 state.acked_protocol_features,
352                 state.vu_num_queues,
353                 state.config,
354                 state.backend_req_support,
355                 true,
356             )
357         } else {
358             // Filling device and vring features VMM supports.
359             let avail_features = DEFAULT_VIRTIO_FEATURES;
360 
361             let mut avail_protocol_features = VhostUserProtocolFeatures::MQ
362                 | VhostUserProtocolFeatures::CONFIGURE_MEM_SLOTS
363                 | VhostUserProtocolFeatures::REPLY_ACK
364                 | VhostUserProtocolFeatures::INFLIGHT_SHMFD
365                 | VhostUserProtocolFeatures::LOG_SHMFD;
366             let backend_protocol_features =
367                 VhostUserProtocolFeatures::BACKEND_REQ | VhostUserProtocolFeatures::BACKEND_SEND_FD;
368             if cache.is_some() {
369                 avail_protocol_features |= backend_protocol_features;
370             }
371 
372             let (acked_features, acked_protocol_features) =
373                 vu.negotiate_features_vhost_user(avail_features, avail_protocol_features)?;
374 
375             let backend_num_queues =
376                 if acked_protocol_features & VhostUserProtocolFeatures::MQ.bits() != 0 {
377                     vu.socket_handle()
378                         .get_queue_num()
379                         .map_err(Error::VhostUserGetQueueMaxNum)? as usize
380                 } else {
381                     DEFAULT_QUEUE_NUMBER
382                 };
383 
384             if num_queues > backend_num_queues {
385                 error!(
386                 "vhost-user-fs requested too many queues ({}) since the backend only supports {}\n",
387                 num_queues, backend_num_queues
388             );
389                 return Err(Error::BadQueueNum);
390             }
391 
392             if acked_protocol_features & backend_protocol_features.bits()
393                 == backend_protocol_features.bits()
394             {
395                 backend_req_support = true;
396             }
397 
398             // Create virtio-fs device configuration.
399             let mut config = VirtioFsConfig::default();
400             let tag_bytes_vec = tag.to_string().into_bytes();
401             config.tag[..tag_bytes_vec.len()].copy_from_slice(tag_bytes_vec.as_slice());
402             config.num_request_queues = req_num_queues as u32;
403 
404             (
405                 acked_features,
406                 // If part of the available features that have been acked, the
407                 // PROTOCOL_FEATURES bit must be already set through the VIRTIO
408                 // acked features as we know the guest would never ack it, thus
409                 // the feature would be lost.
410                 acked_features & VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits(),
411                 acked_protocol_features,
412                 num_queues,
413                 config,
414                 backend_req_support,
415                 false,
416             )
417         };
418 
419         Ok(Fs {
420             common: VirtioCommon {
421                 device_type: VirtioDeviceType::Fs as u32,
422                 avail_features,
423                 acked_features,
424                 queue_sizes: vec![queue_size; num_queues],
425                 paused_sync: Some(Arc::new(Barrier::new(2))),
426                 min_queues: 1,
427                 paused: Arc::new(AtomicBool::new(paused)),
428                 ..Default::default()
429             },
430             vu_common: VhostUserCommon {
431                 vu: Some(Arc::new(Mutex::new(vu))),
432                 acked_protocol_features,
433                 socket_path: path.to_string(),
434                 vu_num_queues,
435                 ..Default::default()
436             },
437             id,
438             config,
439             cache,
440             backend_req_support,
441             seccomp_action,
442             guest_memory: None,
443             epoll_thread: None,
444             exit_evt,
445             iommu,
446         })
447     }
448 
449     fn state(&self) -> State {
450         State {
451             avail_features: self.common.avail_features,
452             acked_features: self.common.acked_features,
453             config: self.config,
454             acked_protocol_features: self.vu_common.acked_protocol_features,
455             vu_num_queues: self.vu_common.vu_num_queues,
456             backend_req_support: self.backend_req_support,
457         }
458     }
459 }
460 
461 impl Drop for Fs {
462     fn drop(&mut self) {
463         if let Some(kill_evt) = self.common.kill_evt.take() {
464             // Ignore the result because there is nothing we can do about it.
465             let _ = kill_evt.write(1);
466         }
467         self.common.wait_for_epoll_threads();
468         if let Some(thread) = self.epoll_thread.take() {
469             if let Err(e) = thread.join() {
470                 error!("Error joining thread: {:?}", e);
471             }
472         }
473     }
474 }
475 
476 impl VirtioDevice for Fs {
477     fn device_type(&self) -> u32 {
478         self.common.device_type
479     }
480 
481     fn queue_max_sizes(&self) -> &[u16] {
482         &self.common.queue_sizes
483     }
484 
485     fn features(&self) -> u64 {
486         let mut features = self.common.avail_features;
487         if self.iommu {
488             features |= 1u64 << VIRTIO_F_IOMMU_PLATFORM;
489         }
490         features
491     }
492 
493     fn ack_features(&mut self, value: u64) {
494         self.common.ack_features(value)
495     }
496 
497     fn read_config(&self, offset: u64, data: &mut [u8]) {
498         self.read_config_from_slice(self.config.as_slice(), offset, data);
499     }
500 
501     fn activate(
502         &mut self,
503         mem: GuestMemoryAtomic<GuestMemoryMmap>,
504         interrupt_cb: Arc<dyn VirtioInterrupt>,
505         queues: Vec<(usize, Queue, EventFd)>,
506     ) -> ActivateResult {
507         self.common.activate(&queues, &interrupt_cb)?;
508         self.guest_memory = Some(mem.clone());
509 
510         // Initialize backend communication.
511         let backend_req_handler = if self.backend_req_support {
512             if let Some(cache) = self.cache.as_ref() {
513                 let vu_frontend_req_handler = Arc::new(BackendReqHandler {
514                     cache_offset: cache.0.addr,
515                     cache_size: cache.0.len,
516                     mmap_cache_addr: cache.0.host_addr,
517                     mem: mem.clone(),
518                 });
519 
520                 let mut req_handler =
521                     FrontendReqHandler::new(vu_frontend_req_handler).map_err(|e| {
522                         ActivateError::VhostUserFsSetup(Error::FrontendReqHandlerCreation(e))
523                     })?;
524 
525                 if self.vu_common.acked_protocol_features
526                     & VhostUserProtocolFeatures::REPLY_ACK.bits()
527                     != 0
528                 {
529                     req_handler.set_reply_ack_flag(true);
530                 }
531 
532                 Some(req_handler)
533             } else {
534                 None
535             }
536         } else {
537             None
538         };
539 
540         // Run a dedicated thread for handling potential reconnections with
541         // the backend.
542         let (kill_evt, pause_evt) = self.common.dup_eventfds();
543 
544         let mut handler = self.vu_common.activate(
545             mem,
546             queues,
547             interrupt_cb,
548             self.common.acked_features,
549             backend_req_handler,
550             kill_evt,
551             pause_evt,
552         )?;
553 
554         let paused = self.common.paused.clone();
555         let paused_sync = self.common.paused_sync.clone();
556 
557         let mut epoll_threads = Vec::new();
558         spawn_virtio_thread(
559             &self.id,
560             &self.seccomp_action,
561             Thread::VirtioVhostFs,
562             &mut epoll_threads,
563             &self.exit_evt,
564             move || handler.run(paused, paused_sync.unwrap()),
565         )?;
566         self.epoll_thread = Some(epoll_threads.remove(0));
567 
568         event!("virtio-device", "activated", "id", &self.id);
569         Ok(())
570     }
571 
572     fn reset(&mut self) -> Option<Arc<dyn VirtioInterrupt>> {
573         // We first must resume the virtio thread if it was paused.
574         if self.common.pause_evt.take().is_some() {
575             self.common.resume().ok()?;
576         }
577 
578         if let Some(vu) = &self.vu_common.vu {
579             if let Err(e) = vu.lock().unwrap().reset_vhost_user() {
580                 error!("Failed to reset vhost-user daemon: {:?}", e);
581                 return None;
582             }
583         }
584 
585         if let Some(kill_evt) = self.common.kill_evt.take() {
586             // Ignore the result because there is nothing we can do about it.
587             let _ = kill_evt.write(1);
588         }
589 
590         event!("virtio-device", "reset", "id", &self.id);
591 
592         // Return the interrupt
593         Some(self.common.interrupt_cb.take().unwrap())
594     }
595 
596     fn shutdown(&mut self) {
597         self.vu_common.shutdown()
598     }
599 
600     fn get_shm_regions(&self) -> Option<VirtioSharedMemoryList> {
601         self.cache.as_ref().map(|cache| cache.0.clone())
602     }
603 
604     fn set_shm_regions(
605         &mut self,
606         shm_regions: VirtioSharedMemoryList,
607     ) -> std::result::Result<(), crate::Error> {
608         if let Some(cache) = self.cache.as_mut() {
609             cache.0 = shm_regions;
610             Ok(())
611         } else {
612             Err(crate::Error::SetShmRegionsNotSupported)
613         }
614     }
615 
616     fn add_memory_region(
617         &mut self,
618         region: &Arc<GuestRegionMmap>,
619     ) -> std::result::Result<(), crate::Error> {
620         self.vu_common.add_memory_region(&self.guest_memory, region)
621     }
622 
623     fn userspace_mappings(&self) -> Vec<UserspaceMapping> {
624         let mut mappings = Vec::new();
625         if let Some(cache) = self.cache.as_ref() {
626             mappings.push(UserspaceMapping {
627                 host_addr: cache.0.host_addr,
628                 mem_slot: cache.0.mem_slot,
629                 addr: cache.0.addr,
630                 len: cache.0.len,
631                 mergeable: false,
632             })
633         }
634 
635         mappings
636     }
637 }
638 
639 impl Pausable for Fs {
640     fn pause(&mut self) -> result::Result<(), MigratableError> {
641         self.vu_common.pause()?;
642         self.common.pause()
643     }
644 
645     fn resume(&mut self) -> result::Result<(), MigratableError> {
646         self.common.resume()?;
647 
648         if let Some(epoll_thread) = &self.epoll_thread {
649             epoll_thread.thread().unpark();
650         }
651 
652         self.vu_common.resume()
653     }
654 }
655 
656 impl Snapshottable for Fs {
657     fn id(&self) -> String {
658         self.id.clone()
659     }
660 
661     fn snapshot(&mut self) -> std::result::Result<Snapshot, MigratableError> {
662         self.vu_common.snapshot(&self.state())
663     }
664 }
665 impl Transportable for Fs {}
666 
667 impl Migratable for Fs {
668     fn start_dirty_log(&mut self) -> std::result::Result<(), MigratableError> {
669         self.vu_common.start_dirty_log(&self.guest_memory)
670     }
671 
672     fn stop_dirty_log(&mut self) -> std::result::Result<(), MigratableError> {
673         self.vu_common.stop_dirty_log()
674     }
675 
676     fn dirty_log(&mut self) -> std::result::Result<MemoryRangeTable, MigratableError> {
677         self.vu_common.dirty_log(&self.guest_memory)
678     }
679 
680     fn start_migration(&mut self) -> std::result::Result<(), MigratableError> {
681         self.vu_common.start_migration()
682     }
683 
684     fn complete_migration(&mut self) -> std::result::Result<(), MigratableError> {
685         self.vu_common
686             .complete_migration(self.common.kill_evt.take())
687     }
688 }
689