xref: /cloud-hypervisor/virtio-devices/src/vhost_user/fs.rs (revision eea9bcea38e0c5649f444c829f3a4f9c22aa486c)
1 // Copyright 2019 Intel Corporation. All Rights Reserved.
2 // SPDX-License-Identifier: Apache-2.0
3 
4 use super::vu_common_ctrl::VhostUserHandle;
5 use super::{Error, Result, DEFAULT_VIRTIO_FEATURES};
6 use crate::seccomp_filters::Thread;
7 use crate::thread_helper::spawn_virtio_thread;
8 use crate::vhost_user::VhostUserCommon;
9 use crate::{
10     ActivateError, ActivateResult, UserspaceMapping, VirtioCommon, VirtioDevice, VirtioDeviceType,
11     VirtioInterrupt, VirtioSharedMemoryList, VIRTIO_F_IOMMU_PLATFORM,
12 };
13 use crate::{GuestMemoryMmap, GuestRegionMmap, MmapRegion};
14 use libc::{self, c_void, off64_t, pread64, pwrite64};
15 use seccompiler::SeccompAction;
16 use std::io;
17 use std::os::unix::io::AsRawFd;
18 use std::result;
19 use std::sync::{Arc, Barrier, Mutex};
20 use std::thread;
21 use versionize::{VersionMap, Versionize, VersionizeResult};
22 use versionize_derive::Versionize;
23 use vhost::vhost_user::message::{
24     VhostUserFSSlaveMsg, VhostUserFSSlaveMsgFlags, VhostUserProtocolFeatures,
25     VhostUserVirtioFeatures, VHOST_USER_FS_SLAVE_ENTRIES,
26 };
27 use vhost::vhost_user::{
28     HandlerResult, MasterReqHandler, VhostUserMaster, VhostUserMasterReqHandler,
29 };
30 use virtio_queue::Queue;
31 use vm_memory::{
32     Address, ByteValued, GuestAddress, GuestAddressSpace, GuestMemory, GuestMemoryAtomic,
33 };
34 use vm_migration::{
35     protocol::MemoryRangeTable, Migratable, MigratableError, Pausable, Snapshot, Snapshottable,
36     Transportable, VersionMapped,
37 };
38 use vmm_sys_util::eventfd::EventFd;
39 
40 const NUM_QUEUE_OFFSET: usize = 1;
41 const DEFAULT_QUEUE_NUMBER: usize = 2;
42 
43 #[derive(Versionize)]
44 pub struct State {
45     pub avail_features: u64,
46     pub acked_features: u64,
47     pub config: VirtioFsConfig,
48     pub acked_protocol_features: u64,
49     pub vu_num_queues: usize,
50     pub slave_req_support: bool,
51 }
52 
53 impl VersionMapped for State {}
54 
55 struct SlaveReqHandler {
56     cache_offset: GuestAddress,
57     cache_size: u64,
58     mmap_cache_addr: u64,
59     mem: GuestMemoryAtomic<GuestMemoryMmap>,
60 }
61 
62 impl SlaveReqHandler {
63     // Make sure request is within cache range
64     fn is_req_valid(&self, offset: u64, len: u64) -> bool {
65         let end = match offset.checked_add(len) {
66             Some(n) => n,
67             None => return false,
68         };
69 
70         !(offset >= self.cache_size || end > self.cache_size)
71     }
72 }
73 
74 impl VhostUserMasterReqHandler for SlaveReqHandler {
75     fn handle_config_change(&self) -> HandlerResult<u64> {
76         debug!("handle_config_change");
77         Ok(0)
78     }
79 
80     fn fs_slave_map(&self, fs: &VhostUserFSSlaveMsg, fd: &dyn AsRawFd) -> HandlerResult<u64> {
81         debug!("fs_slave_map");
82 
83         for i in 0..VHOST_USER_FS_SLAVE_ENTRIES {
84             let offset = fs.cache_offset[i];
85             let len = fs.len[i];
86 
87             // Ignore if the length is 0.
88             if len == 0 {
89                 continue;
90             }
91 
92             if !self.is_req_valid(offset, len) {
93                 return Err(io::Error::from_raw_os_error(libc::EINVAL));
94             }
95 
96             let addr = self.mmap_cache_addr + offset;
97             let flags = fs.flags[i];
98             let ret = unsafe {
99                 libc::mmap(
100                     addr as *mut libc::c_void,
101                     len as usize,
102                     flags.bits() as i32,
103                     libc::MAP_SHARED | libc::MAP_FIXED,
104                     fd.as_raw_fd(),
105                     fs.fd_offset[i] as libc::off_t,
106                 )
107             };
108             if ret == libc::MAP_FAILED {
109                 return Err(io::Error::last_os_error());
110             }
111         }
112 
113         Ok(0)
114     }
115 
116     fn fs_slave_unmap(&self, fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64> {
117         debug!("fs_slave_unmap");
118 
119         for i in 0..VHOST_USER_FS_SLAVE_ENTRIES {
120             let offset = fs.cache_offset[i];
121             let mut len = fs.len[i];
122 
123             // Ignore if the length is 0.
124             if len == 0 {
125                 continue;
126             }
127 
128             // Need to handle a special case where the slave ask for the unmapping
129             // of the entire mapping.
130             if len == 0xffff_ffff_ffff_ffff {
131                 len = self.cache_size;
132             }
133 
134             if !self.is_req_valid(offset, len) {
135                 return Err(io::Error::from_raw_os_error(libc::EINVAL));
136             }
137 
138             let addr = self.mmap_cache_addr + offset;
139             let ret = unsafe {
140                 libc::mmap(
141                     addr as *mut libc::c_void,
142                     len as usize,
143                     libc::PROT_NONE,
144                     libc::MAP_ANONYMOUS | libc::MAP_PRIVATE | libc::MAP_FIXED,
145                     -1,
146                     0,
147                 )
148             };
149             if ret == libc::MAP_FAILED {
150                 return Err(io::Error::last_os_error());
151             }
152         }
153 
154         Ok(0)
155     }
156 
157     fn fs_slave_sync(&self, fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64> {
158         debug!("fs_slave_sync");
159 
160         for i in 0..VHOST_USER_FS_SLAVE_ENTRIES {
161             let offset = fs.cache_offset[i];
162             let len = fs.len[i];
163 
164             // Ignore if the length is 0.
165             if len == 0 {
166                 continue;
167             }
168 
169             if !self.is_req_valid(offset, len) {
170                 return Err(io::Error::from_raw_os_error(libc::EINVAL));
171             }
172 
173             let addr = self.mmap_cache_addr + offset;
174             let ret =
175                 unsafe { libc::msync(addr as *mut libc::c_void, len as usize, libc::MS_SYNC) };
176             if ret == -1 {
177                 return Err(io::Error::last_os_error());
178             }
179         }
180 
181         Ok(0)
182     }
183 
184     fn fs_slave_io(&self, fs: &VhostUserFSSlaveMsg, fd: &dyn AsRawFd) -> HandlerResult<u64> {
185         debug!("fs_slave_io");
186 
187         let mut done: u64 = 0;
188         for i in 0..VHOST_USER_FS_SLAVE_ENTRIES {
189             // Ignore if the length is 0.
190             if fs.len[i] == 0 {
191                 continue;
192             }
193 
194             let mut foffset = fs.fd_offset[i];
195             let mut len = fs.len[i] as usize;
196             let gpa = fs.cache_offset[i];
197             let cache_end = self.cache_offset.raw_value() + self.cache_size;
198             let efault = libc::EFAULT;
199 
200             let mut ptr = if gpa >= self.cache_offset.raw_value() && gpa < cache_end {
201                 let offset = gpa
202                     .checked_sub(self.cache_offset.raw_value())
203                     .ok_or_else(|| io::Error::from_raw_os_error(efault))?;
204                 let end = gpa
205                     .checked_add(fs.len[i])
206                     .ok_or_else(|| io::Error::from_raw_os_error(efault))?;
207 
208                 if end >= cache_end {
209                     return Err(io::Error::from_raw_os_error(efault));
210                 }
211 
212                 self.mmap_cache_addr + offset
213             } else {
214                 self.mem
215                     .memory()
216                     .get_host_address(GuestAddress(gpa))
217                     .map_err(|e| {
218                         error!(
219                             "Failed to find RAM region associated with guest physical address 0x{:x}: {:?}",
220                             gpa, e
221                         );
222                         io::Error::from_raw_os_error(efault)
223                     })? as u64
224             };
225 
226             while len > 0 {
227                 let ret = if (fs.flags[i] & VhostUserFSSlaveMsgFlags::MAP_W)
228                     == VhostUserFSSlaveMsgFlags::MAP_W
229                 {
230                     debug!("write: foffset={}, len={}", foffset, len);
231                     unsafe {
232                         pwrite64(
233                             fd.as_raw_fd(),
234                             ptr as *const c_void,
235                             len as usize,
236                             foffset as off64_t,
237                         )
238                     }
239                 } else {
240                     debug!("read: foffset={}, len={}", foffset, len);
241                     unsafe {
242                         pread64(
243                             fd.as_raw_fd(),
244                             ptr as *mut c_void,
245                             len as usize,
246                             foffset as off64_t,
247                         )
248                     }
249                 };
250 
251                 if ret < 0 {
252                     return Err(io::Error::last_os_error());
253                 }
254 
255                 if ret == 0 {
256                     // EOF
257                     return Err(io::Error::new(
258                         io::ErrorKind::UnexpectedEof,
259                         "failed to access whole buffer",
260                     ));
261                 }
262                 len -= ret as usize;
263                 foffset += ret as u64;
264                 ptr += ret as u64;
265                 done += ret as u64;
266             }
267         }
268 
269         Ok(done)
270     }
271 }
272 
273 #[derive(Copy, Clone, Versionize)]
274 #[repr(C, packed)]
275 pub struct VirtioFsConfig {
276     pub tag: [u8; 36],
277     pub num_request_queues: u32,
278 }
279 
280 impl Default for VirtioFsConfig {
281     fn default() -> Self {
282         VirtioFsConfig {
283             tag: [0; 36],
284             num_request_queues: 0,
285         }
286     }
287 }
288 
289 unsafe impl ByteValued for VirtioFsConfig {}
290 
291 pub struct Fs {
292     common: VirtioCommon,
293     vu_common: VhostUserCommon,
294     id: String,
295     config: VirtioFsConfig,
296     // Hold ownership of the memory that is allocated for the device
297     // which will be automatically dropped when the device is dropped
298     cache: Option<(VirtioSharedMemoryList, MmapRegion)>,
299     slave_req_support: bool,
300     seccomp_action: SeccompAction,
301     guest_memory: Option<GuestMemoryAtomic<GuestMemoryMmap>>,
302     epoll_thread: Option<thread::JoinHandle<()>>,
303     exit_evt: EventFd,
304     iommu: bool,
305 }
306 
307 impl Fs {
308     /// Create a new virtio-fs device.
309     #[allow(clippy::too_many_arguments)]
310     pub fn new(
311         id: String,
312         path: &str,
313         tag: &str,
314         req_num_queues: usize,
315         queue_size: u16,
316         cache: Option<(VirtioSharedMemoryList, MmapRegion)>,
317         seccomp_action: SeccompAction,
318         restoring: bool,
319         exit_evt: EventFd,
320         iommu: bool,
321     ) -> Result<Fs> {
322         let mut slave_req_support = false;
323 
324         // Calculate the actual number of queues needed.
325         let num_queues = NUM_QUEUE_OFFSET + req_num_queues;
326 
327         if restoring {
328             // We need 'queue_sizes' to report a number of queues that will be
329             // enough to handle all the potential queues. VirtioPciDevice::new()
330             // will create the actual queues based on this information.
331             return Ok(Fs {
332                 common: VirtioCommon {
333                     device_type: VirtioDeviceType::Fs as u32,
334                     queue_sizes: vec![queue_size; num_queues],
335                     paused_sync: Some(Arc::new(Barrier::new(2))),
336                     min_queues: 1,
337                     ..Default::default()
338                 },
339                 vu_common: VhostUserCommon {
340                     socket_path: path.to_string(),
341                     vu_num_queues: num_queues,
342                     ..Default::default()
343                 },
344                 id,
345                 config: VirtioFsConfig::default(),
346                 cache,
347                 slave_req_support,
348                 seccomp_action,
349                 guest_memory: None,
350                 epoll_thread: None,
351                 exit_evt,
352                 iommu,
353             });
354         }
355 
356         // Connect to the vhost-user socket.
357         let mut vu = VhostUserHandle::connect_vhost_user(false, path, num_queues as u64, false)?;
358 
359         // Filling device and vring features VMM supports.
360         let avail_features = DEFAULT_VIRTIO_FEATURES;
361 
362         let mut avail_protocol_features = VhostUserProtocolFeatures::MQ
363             | VhostUserProtocolFeatures::CONFIGURE_MEM_SLOTS
364             | VhostUserProtocolFeatures::REPLY_ACK
365             | VhostUserProtocolFeatures::INFLIGHT_SHMFD
366             | VhostUserProtocolFeatures::LOG_SHMFD;
367         let slave_protocol_features =
368             VhostUserProtocolFeatures::SLAVE_REQ | VhostUserProtocolFeatures::SLAVE_SEND_FD;
369         if cache.is_some() {
370             avail_protocol_features |= slave_protocol_features;
371         }
372 
373         let (acked_features, acked_protocol_features) =
374             vu.negotiate_features_vhost_user(avail_features, avail_protocol_features)?;
375 
376         let backend_num_queues =
377             if acked_protocol_features & VhostUserProtocolFeatures::MQ.bits() != 0 {
378                 vu.socket_handle()
379                     .get_queue_num()
380                     .map_err(Error::VhostUserGetQueueMaxNum)? as usize
381             } else {
382                 DEFAULT_QUEUE_NUMBER
383             };
384 
385         if num_queues > backend_num_queues {
386             error!(
387                 "vhost-user-fs requested too many queues ({}) since the backend only supports {}\n",
388                 num_queues, backend_num_queues
389             );
390             return Err(Error::BadQueueNum);
391         }
392 
393         if acked_protocol_features & slave_protocol_features.bits()
394             == slave_protocol_features.bits()
395         {
396             slave_req_support = true;
397         }
398 
399         // Create virtio-fs device configuration.
400         let mut config = VirtioFsConfig::default();
401         let tag_bytes_vec = tag.to_string().into_bytes();
402         config.tag[..tag_bytes_vec.len()].copy_from_slice(tag_bytes_vec.as_slice());
403         config.num_request_queues = req_num_queues as u32;
404 
405         Ok(Fs {
406             common: VirtioCommon {
407                 device_type: VirtioDeviceType::Fs as u32,
408                 avail_features: acked_features,
409                 // If part of the available features that have been acked, the
410                 // PROTOCOL_FEATURES bit must be already set through the VIRTIO
411                 // acked features as we know the guest would never ack it, thus
412                 // the feature would be lost.
413                 acked_features: acked_features & VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits(),
414                 queue_sizes: vec![queue_size; num_queues],
415                 paused_sync: Some(Arc::new(Barrier::new(2))),
416                 min_queues: 1,
417                 ..Default::default()
418             },
419             vu_common: VhostUserCommon {
420                 vu: Some(Arc::new(Mutex::new(vu))),
421                 acked_protocol_features,
422                 socket_path: path.to_string(),
423                 vu_num_queues: num_queues,
424                 ..Default::default()
425             },
426             id,
427             config,
428             cache,
429             slave_req_support,
430             seccomp_action,
431             guest_memory: None,
432             epoll_thread: None,
433             exit_evt,
434             iommu,
435         })
436     }
437 
438     fn state(&self) -> State {
439         State {
440             avail_features: self.common.avail_features,
441             acked_features: self.common.acked_features,
442             config: self.config,
443             acked_protocol_features: self.vu_common.acked_protocol_features,
444             vu_num_queues: self.vu_common.vu_num_queues,
445             slave_req_support: self.slave_req_support,
446         }
447     }
448 
449     fn set_state(&mut self, state: &State) {
450         self.common.avail_features = state.avail_features;
451         self.common.acked_features = state.acked_features;
452         self.config = state.config;
453         self.vu_common.acked_protocol_features = state.acked_protocol_features;
454         self.vu_common.vu_num_queues = state.vu_num_queues;
455         self.slave_req_support = state.slave_req_support;
456 
457         if let Err(e) = self
458             .vu_common
459             .restore_backend_connection(self.common.acked_features)
460         {
461             error!(
462                 "Failed restoring connection with vhost-user backend: {:?}",
463                 e
464             );
465         }
466     }
467 }
468 
469 impl Drop for Fs {
470     fn drop(&mut self) {
471         if let Some(kill_evt) = self.common.kill_evt.take() {
472             // Ignore the result because there is nothing we can do about it.
473             let _ = kill_evt.write(1);
474         }
475     }
476 }
477 
478 impl VirtioDevice for Fs {
479     fn device_type(&self) -> u32 {
480         self.common.device_type
481     }
482 
483     fn queue_max_sizes(&self) -> &[u16] {
484         &self.common.queue_sizes
485     }
486 
487     fn features(&self) -> u64 {
488         let mut features = self.common.avail_features;
489         if self.iommu {
490             features |= 1u64 << VIRTIO_F_IOMMU_PLATFORM;
491         }
492         features
493     }
494 
495     fn ack_features(&mut self, value: u64) {
496         self.common.ack_features(value)
497     }
498 
499     fn read_config(&self, offset: u64, data: &mut [u8]) {
500         self.read_config_from_slice(self.config.as_slice(), offset, data);
501     }
502 
503     fn activate(
504         &mut self,
505         mem: GuestMemoryAtomic<GuestMemoryMmap>,
506         interrupt_cb: Arc<dyn VirtioInterrupt>,
507         queues: Vec<(usize, Queue, EventFd)>,
508     ) -> ActivateResult {
509         self.common.activate(&queues, &interrupt_cb)?;
510         self.guest_memory = Some(mem.clone());
511 
512         // Initialize slave communication.
513         let slave_req_handler = if self.slave_req_support {
514             if let Some(cache) = self.cache.as_ref() {
515                 let vu_master_req_handler = Arc::new(SlaveReqHandler {
516                     cache_offset: cache.0.addr,
517                     cache_size: cache.0.len,
518                     mmap_cache_addr: cache.0.host_addr,
519                     mem: mem.clone(),
520                 });
521 
522                 let mut req_handler =
523                     MasterReqHandler::new(vu_master_req_handler).map_err(|e| {
524                         ActivateError::VhostUserFsSetup(Error::MasterReqHandlerCreation(e))
525                     })?;
526 
527                 if self.vu_common.acked_protocol_features
528                     & VhostUserProtocolFeatures::REPLY_ACK.bits()
529                     != 0
530                 {
531                     req_handler.set_reply_ack_flag(true);
532                 }
533 
534                 Some(req_handler)
535             } else {
536                 None
537             }
538         } else {
539             None
540         };
541 
542         // Run a dedicated thread for handling potential reconnections with
543         // the backend.
544         let (kill_evt, pause_evt) = self.common.dup_eventfds();
545 
546         let mut handler = self.vu_common.activate(
547             mem,
548             queues,
549             interrupt_cb,
550             self.common.acked_features,
551             slave_req_handler,
552             kill_evt,
553             pause_evt,
554         )?;
555 
556         let paused = self.common.paused.clone();
557         let paused_sync = self.common.paused_sync.clone();
558 
559         let mut epoll_threads = Vec::new();
560         spawn_virtio_thread(
561             &self.id,
562             &self.seccomp_action,
563             Thread::VirtioVhostFs,
564             &mut epoll_threads,
565             &self.exit_evt,
566             move || handler.run(paused, paused_sync.unwrap()),
567         )?;
568         self.epoll_thread = Some(epoll_threads.remove(0));
569 
570         event!("virtio-device", "activated", "id", &self.id);
571         Ok(())
572     }
573 
574     fn reset(&mut self) -> Option<Arc<dyn VirtioInterrupt>> {
575         // We first must resume the virtio thread if it was paused.
576         if self.common.pause_evt.take().is_some() {
577             self.common.resume().ok()?;
578         }
579 
580         if let Some(vu) = &self.vu_common.vu {
581             if let Err(e) = vu.lock().unwrap().reset_vhost_user() {
582                 error!("Failed to reset vhost-user daemon: {:?}", e);
583                 return None;
584             }
585         }
586 
587         if let Some(kill_evt) = self.common.kill_evt.take() {
588             // Ignore the result because there is nothing we can do about it.
589             let _ = kill_evt.write(1);
590         }
591 
592         event!("virtio-device", "reset", "id", &self.id);
593 
594         // Return the interrupt
595         Some(self.common.interrupt_cb.take().unwrap())
596     }
597 
598     fn shutdown(&mut self) {
599         self.vu_common.shutdown()
600     }
601 
602     fn get_shm_regions(&self) -> Option<VirtioSharedMemoryList> {
603         self.cache.as_ref().map(|cache| cache.0.clone())
604     }
605 
606     fn set_shm_regions(
607         &mut self,
608         shm_regions: VirtioSharedMemoryList,
609     ) -> std::result::Result<(), crate::Error> {
610         if let Some(mut cache) = self.cache.as_mut() {
611             cache.0 = shm_regions;
612             Ok(())
613         } else {
614             Err(crate::Error::SetShmRegionsNotSupported)
615         }
616     }
617 
618     fn add_memory_region(
619         &mut self,
620         region: &Arc<GuestRegionMmap>,
621     ) -> std::result::Result<(), crate::Error> {
622         self.vu_common.add_memory_region(&self.guest_memory, region)
623     }
624 
625     fn userspace_mappings(&self) -> Vec<UserspaceMapping> {
626         let mut mappings = Vec::new();
627         if let Some(cache) = self.cache.as_ref() {
628             mappings.push(UserspaceMapping {
629                 host_addr: cache.0.host_addr,
630                 mem_slot: cache.0.mem_slot,
631                 addr: cache.0.addr,
632                 len: cache.0.len,
633                 mergeable: false,
634             })
635         }
636 
637         mappings
638     }
639 }
640 
641 impl Pausable for Fs {
642     fn pause(&mut self) -> result::Result<(), MigratableError> {
643         self.vu_common.pause()?;
644         self.common.pause()
645     }
646 
647     fn resume(&mut self) -> result::Result<(), MigratableError> {
648         self.common.resume()?;
649 
650         if let Some(epoll_thread) = &self.epoll_thread {
651             epoll_thread.thread().unpark();
652         }
653 
654         self.vu_common.resume()
655     }
656 }
657 
658 impl Snapshottable for Fs {
659     fn id(&self) -> String {
660         self.id.clone()
661     }
662 
663     fn snapshot(&mut self) -> std::result::Result<Snapshot, MigratableError> {
664         self.vu_common.snapshot(&self.id(), &self.state())
665     }
666 
667     fn restore(&mut self, snapshot: Snapshot) -> std::result::Result<(), MigratableError> {
668         self.set_state(&snapshot.to_versioned_state(&self.id)?);
669         Ok(())
670     }
671 }
672 impl Transportable for Fs {}
673 
674 impl Migratable for Fs {
675     fn start_dirty_log(&mut self) -> std::result::Result<(), MigratableError> {
676         self.vu_common.start_dirty_log(&self.guest_memory)
677     }
678 
679     fn stop_dirty_log(&mut self) -> std::result::Result<(), MigratableError> {
680         self.vu_common.stop_dirty_log()
681     }
682 
683     fn dirty_log(&mut self) -> std::result::Result<MemoryRangeTable, MigratableError> {
684         self.vu_common.dirty_log(&self.guest_memory)
685     }
686 
687     fn start_migration(&mut self) -> std::result::Result<(), MigratableError> {
688         self.vu_common.start_migration()
689     }
690 
691     fn complete_migration(&mut self) -> std::result::Result<(), MigratableError> {
692         self.vu_common
693             .complete_migration(self.common.kill_evt.take())
694     }
695 }
696