xref: /cloud-hypervisor/virtio-devices/src/vhost_user/fs.rs (revision 7d7bfb2034001d4cb15df2ddc56d2d350c8da30f)
1 // Copyright 2019 Intel Corporation. All Rights Reserved.
2 // SPDX-License-Identifier: Apache-2.0
3 
4 use super::vu_common_ctrl::VhostUserHandle;
5 use super::{Error, Result, DEFAULT_VIRTIO_FEATURES};
6 use crate::seccomp_filters::Thread;
7 use crate::thread_helper::spawn_virtio_thread;
8 use crate::vhost_user::VhostUserCommon;
9 use crate::{
10     ActivateError, ActivateResult, UserspaceMapping, VirtioCommon, VirtioDevice, VirtioDeviceType,
11     VirtioInterrupt, VirtioSharedMemoryList, VIRTIO_F_IOMMU_PLATFORM,
12 };
13 use crate::{GuestMemoryMmap, GuestRegionMmap, MmapRegion};
14 use libc::{self, c_void, off64_t, pread64, pwrite64};
15 use seccompiler::SeccompAction;
16 use std::io;
17 use std::os::unix::io::AsRawFd;
18 use std::result;
19 use std::sync::{Arc, Barrier, Mutex};
20 use std::thread;
21 use versionize::{VersionMap, Versionize, VersionizeResult};
22 use versionize_derive::Versionize;
23 use vhost::vhost_user::message::{
24     VhostUserFSSlaveMsg, VhostUserFSSlaveMsgFlags, VhostUserProtocolFeatures,
25     VhostUserVirtioFeatures, VHOST_USER_FS_SLAVE_ENTRIES,
26 };
27 use vhost::vhost_user::{
28     HandlerResult, MasterReqHandler, VhostUserMaster, VhostUserMasterReqHandler,
29 };
30 use virtio_queue::Queue;
31 use vm_memory::{
32     Address, ByteValued, GuestAddress, GuestAddressSpace, GuestMemory, GuestMemoryAtomic,
33 };
34 use vm_migration::{
35     protocol::MemoryRangeTable, Migratable, MigratableError, Pausable, Snapshot, Snapshottable,
36     Transportable, VersionMapped,
37 };
38 use vmm_sys_util::eventfd::EventFd;
39 
40 const NUM_QUEUE_OFFSET: usize = 1;
41 const DEFAULT_QUEUE_NUMBER: usize = 2;
42 
43 #[derive(Versionize)]
44 pub struct State {
45     pub avail_features: u64,
46     pub acked_features: u64,
47     pub config: VirtioFsConfig,
48     pub acked_protocol_features: u64,
49     pub vu_num_queues: usize,
50     pub slave_req_support: bool,
51 }
52 
53 impl VersionMapped for State {}
54 
55 struct SlaveReqHandler {
56     cache_offset: GuestAddress,
57     cache_size: u64,
58     mmap_cache_addr: u64,
59     mem: GuestMemoryAtomic<GuestMemoryMmap>,
60 }
61 
62 impl SlaveReqHandler {
63     // Make sure request is within cache range
64     fn is_req_valid(&self, offset: u64, len: u64) -> bool {
65         let end = match offset.checked_add(len) {
66             Some(n) => n,
67             None => return false,
68         };
69 
70         !(offset >= self.cache_size || end > self.cache_size)
71     }
72 }
73 
74 impl VhostUserMasterReqHandler for SlaveReqHandler {
75     fn handle_config_change(&self) -> HandlerResult<u64> {
76         debug!("handle_config_change");
77         Ok(0)
78     }
79 
80     fn fs_slave_map(&self, fs: &VhostUserFSSlaveMsg, fd: &dyn AsRawFd) -> HandlerResult<u64> {
81         debug!("fs_slave_map");
82 
83         for i in 0..VHOST_USER_FS_SLAVE_ENTRIES {
84             let offset = fs.cache_offset[i];
85             let len = fs.len[i];
86 
87             // Ignore if the length is 0.
88             if len == 0 {
89                 continue;
90             }
91 
92             if !self.is_req_valid(offset, len) {
93                 return Err(io::Error::from_raw_os_error(libc::EINVAL));
94             }
95 
96             let addr = self.mmap_cache_addr + offset;
97             let flags = fs.flags[i];
98             let ret = unsafe {
99                 libc::mmap(
100                     addr as *mut libc::c_void,
101                     len as usize,
102                     flags.bits() as i32,
103                     libc::MAP_SHARED | libc::MAP_FIXED,
104                     fd.as_raw_fd(),
105                     fs.fd_offset[i] as libc::off_t,
106                 )
107             };
108             if ret == libc::MAP_FAILED {
109                 return Err(io::Error::last_os_error());
110             }
111         }
112 
113         Ok(0)
114     }
115 
116     fn fs_slave_unmap(&self, fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64> {
117         debug!("fs_slave_unmap");
118 
119         for i in 0..VHOST_USER_FS_SLAVE_ENTRIES {
120             let offset = fs.cache_offset[i];
121             let mut len = fs.len[i];
122 
123             // Ignore if the length is 0.
124             if len == 0 {
125                 continue;
126             }
127 
128             // Need to handle a special case where the slave ask for the unmapping
129             // of the entire mapping.
130             if len == 0xffff_ffff_ffff_ffff {
131                 len = self.cache_size;
132             }
133 
134             if !self.is_req_valid(offset, len) {
135                 return Err(io::Error::from_raw_os_error(libc::EINVAL));
136             }
137 
138             let addr = self.mmap_cache_addr + offset;
139             let ret = unsafe {
140                 libc::mmap(
141                     addr as *mut libc::c_void,
142                     len as usize,
143                     libc::PROT_NONE,
144                     libc::MAP_ANONYMOUS | libc::MAP_PRIVATE | libc::MAP_FIXED,
145                     -1,
146                     0,
147                 )
148             };
149             if ret == libc::MAP_FAILED {
150                 return Err(io::Error::last_os_error());
151             }
152         }
153 
154         Ok(0)
155     }
156 
157     fn fs_slave_sync(&self, fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64> {
158         debug!("fs_slave_sync");
159 
160         for i in 0..VHOST_USER_FS_SLAVE_ENTRIES {
161             let offset = fs.cache_offset[i];
162             let len = fs.len[i];
163 
164             // Ignore if the length is 0.
165             if len == 0 {
166                 continue;
167             }
168 
169             if !self.is_req_valid(offset, len) {
170                 return Err(io::Error::from_raw_os_error(libc::EINVAL));
171             }
172 
173             let addr = self.mmap_cache_addr + offset;
174             let ret =
175                 unsafe { libc::msync(addr as *mut libc::c_void, len as usize, libc::MS_SYNC) };
176             if ret == -1 {
177                 return Err(io::Error::last_os_error());
178             }
179         }
180 
181         Ok(0)
182     }
183 
184     fn fs_slave_io(&self, fs: &VhostUserFSSlaveMsg, fd: &dyn AsRawFd) -> HandlerResult<u64> {
185         debug!("fs_slave_io");
186 
187         let mut done: u64 = 0;
188         for i in 0..VHOST_USER_FS_SLAVE_ENTRIES {
189             // Ignore if the length is 0.
190             if fs.len[i] == 0 {
191                 continue;
192             }
193 
194             let mut foffset = fs.fd_offset[i];
195             let mut len = fs.len[i] as usize;
196             let gpa = fs.cache_offset[i];
197             let cache_end = self.cache_offset.raw_value() + self.cache_size;
198             let efault = libc::EFAULT;
199 
200             let mut ptr = if gpa >= self.cache_offset.raw_value() && gpa < cache_end {
201                 let offset = gpa
202                     .checked_sub(self.cache_offset.raw_value())
203                     .ok_or_else(|| io::Error::from_raw_os_error(efault))?;
204                 let end = gpa
205                     .checked_add(fs.len[i])
206                     .ok_or_else(|| io::Error::from_raw_os_error(efault))?;
207 
208                 if end >= cache_end {
209                     return Err(io::Error::from_raw_os_error(efault));
210                 }
211 
212                 self.mmap_cache_addr + offset
213             } else {
214                 self.mem
215                     .memory()
216                     .get_host_address(GuestAddress(gpa))
217                     .map_err(|e| {
218                         error!(
219                             "Failed to find RAM region associated with guest physical address 0x{:x}: {:?}",
220                             gpa, e
221                         );
222                         io::Error::from_raw_os_error(efault)
223                     })? as u64
224             };
225 
226             while len > 0 {
227                 let ret = if (fs.flags[i] & VhostUserFSSlaveMsgFlags::MAP_W)
228                     == VhostUserFSSlaveMsgFlags::MAP_W
229                 {
230                     debug!("write: foffset={}, len={}", foffset, len);
231                     unsafe {
232                         pwrite64(
233                             fd.as_raw_fd(),
234                             ptr as *const c_void,
235                             len as usize,
236                             foffset as off64_t,
237                         )
238                     }
239                 } else {
240                     debug!("read: foffset={}, len={}", foffset, len);
241                     unsafe {
242                         pread64(
243                             fd.as_raw_fd(),
244                             ptr as *mut c_void,
245                             len as usize,
246                             foffset as off64_t,
247                         )
248                     }
249                 };
250 
251                 if ret < 0 {
252                     return Err(io::Error::last_os_error());
253                 }
254 
255                 if ret == 0 {
256                     // EOF
257                     return Err(io::Error::new(
258                         io::ErrorKind::UnexpectedEof,
259                         "failed to access whole buffer",
260                     ));
261                 }
262                 len -= ret as usize;
263                 foffset += ret as u64;
264                 ptr += ret as u64;
265                 done += ret as u64;
266             }
267         }
268 
269         Ok(done)
270     }
271 }
272 
273 #[derive(Copy, Clone, Versionize)]
274 #[repr(C, packed)]
275 pub struct VirtioFsConfig {
276     pub tag: [u8; 36],
277     pub num_request_queues: u32,
278 }
279 
280 impl Default for VirtioFsConfig {
281     fn default() -> Self {
282         VirtioFsConfig {
283             tag: [0; 36],
284             num_request_queues: 0,
285         }
286     }
287 }
288 
289 unsafe impl ByteValued for VirtioFsConfig {}
290 
291 pub struct Fs {
292     common: VirtioCommon,
293     vu_common: VhostUserCommon,
294     id: String,
295     config: VirtioFsConfig,
296     // Hold ownership of the memory that is allocated for the device
297     // which will be automatically dropped when the device is dropped
298     cache: Option<(VirtioSharedMemoryList, MmapRegion)>,
299     slave_req_support: bool,
300     seccomp_action: SeccompAction,
301     guest_memory: Option<GuestMemoryAtomic<GuestMemoryMmap>>,
302     epoll_thread: Option<thread::JoinHandle<()>>,
303     exit_evt: EventFd,
304     iommu: bool,
305 }
306 
307 impl Fs {
308     /// Create a new virtio-fs device.
309     #[allow(clippy::too_many_arguments)]
310     pub fn new(
311         id: String,
312         path: &str,
313         tag: &str,
314         req_num_queues: usize,
315         queue_size: u16,
316         cache: Option<(VirtioSharedMemoryList, MmapRegion)>,
317         seccomp_action: SeccompAction,
318         restoring: bool,
319         exit_evt: EventFd,
320         iommu: bool,
321     ) -> Result<Fs> {
322         let mut slave_req_support = false;
323 
324         // Calculate the actual number of queues needed.
325         let num_queues = NUM_QUEUE_OFFSET + req_num_queues;
326 
327         if restoring {
328             // We need 'queue_sizes' to report a number of queues that will be
329             // enough to handle all the potential queues. VirtioPciDevice::new()
330             // will create the actual queues based on this information.
331             return Ok(Fs {
332                 common: VirtioCommon {
333                     device_type: VirtioDeviceType::Fs as u32,
334                     queue_sizes: vec![queue_size; num_queues],
335                     paused_sync: Some(Arc::new(Barrier::new(2))),
336                     min_queues: DEFAULT_QUEUE_NUMBER as u16,
337                     ..Default::default()
338                 },
339                 vu_common: VhostUserCommon {
340                     socket_path: path.to_string(),
341                     vu_num_queues: num_queues,
342                     ..Default::default()
343                 },
344                 id,
345                 config: VirtioFsConfig::default(),
346                 cache,
347                 slave_req_support,
348                 seccomp_action,
349                 guest_memory: None,
350                 epoll_thread: None,
351                 exit_evt,
352                 iommu,
353             });
354         }
355 
356         // Connect to the vhost-user socket.
357         let mut vu = VhostUserHandle::connect_vhost_user(false, path, num_queues as u64, false)?;
358 
359         // Filling device and vring features VMM supports.
360         let avail_features = DEFAULT_VIRTIO_FEATURES;
361 
362         let mut avail_protocol_features = VhostUserProtocolFeatures::MQ
363             | VhostUserProtocolFeatures::CONFIGURE_MEM_SLOTS
364             | VhostUserProtocolFeatures::REPLY_ACK
365             | VhostUserProtocolFeatures::INFLIGHT_SHMFD
366             | VhostUserProtocolFeatures::LOG_SHMFD;
367         let slave_protocol_features =
368             VhostUserProtocolFeatures::SLAVE_REQ | VhostUserProtocolFeatures::SLAVE_SEND_FD;
369         if cache.is_some() {
370             avail_protocol_features |= slave_protocol_features;
371         }
372 
373         let (acked_features, acked_protocol_features) =
374             vu.negotiate_features_vhost_user(avail_features, avail_protocol_features)?;
375 
376         let backend_num_queues =
377             if acked_protocol_features & VhostUserProtocolFeatures::MQ.bits() != 0 {
378                 vu.socket_handle()
379                     .get_queue_num()
380                     .map_err(Error::VhostUserGetQueueMaxNum)? as usize
381             } else {
382                 DEFAULT_QUEUE_NUMBER
383             };
384 
385         if num_queues > backend_num_queues {
386             error!(
387                 "vhost-user-fs requested too many queues ({}) since the backend only supports {}\n",
388                 num_queues, backend_num_queues
389             );
390             return Err(Error::BadQueueNum);
391         }
392 
393         if acked_protocol_features & slave_protocol_features.bits()
394             == slave_protocol_features.bits()
395         {
396             slave_req_support = true;
397         }
398 
399         // Create virtio-fs device configuration.
400         let mut config = VirtioFsConfig::default();
401         let tag_bytes_vec = tag.to_string().into_bytes();
402         config.tag[..tag_bytes_vec.len()].copy_from_slice(tag_bytes_vec.as_slice());
403         config.num_request_queues = req_num_queues as u32;
404 
405         Ok(Fs {
406             common: VirtioCommon {
407                 device_type: VirtioDeviceType::Fs as u32,
408                 avail_features: acked_features,
409                 // If part of the available features that have been acked, the
410                 // PROTOCOL_FEATURES bit must be already set through the VIRTIO
411                 // acked features as we know the guest would never ack it, thus
412                 // the feature would be lost.
413                 acked_features: acked_features & VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits(),
414                 queue_sizes: vec![queue_size; num_queues],
415                 paused_sync: Some(Arc::new(Barrier::new(2))),
416                 min_queues: DEFAULT_QUEUE_NUMBER as u16,
417                 ..Default::default()
418             },
419             vu_common: VhostUserCommon {
420                 vu: Some(Arc::new(Mutex::new(vu))),
421                 acked_protocol_features,
422                 socket_path: path.to_string(),
423                 vu_num_queues: num_queues,
424                 ..Default::default()
425             },
426             id,
427             config,
428             cache,
429             slave_req_support,
430             seccomp_action,
431             guest_memory: None,
432             epoll_thread: None,
433             exit_evt,
434             iommu,
435         })
436     }
437 
438     fn state(&self) -> State {
439         State {
440             avail_features: self.common.avail_features,
441             acked_features: self.common.acked_features,
442             config: self.config,
443             acked_protocol_features: self.vu_common.acked_protocol_features,
444             vu_num_queues: self.vu_common.vu_num_queues,
445             slave_req_support: self.slave_req_support,
446         }
447     }
448 
449     fn set_state(&mut self, state: &State) {
450         self.common.avail_features = state.avail_features;
451         self.common.acked_features = state.acked_features;
452         self.config = state.config;
453         self.vu_common.acked_protocol_features = state.acked_protocol_features;
454         self.vu_common.vu_num_queues = state.vu_num_queues;
455         self.slave_req_support = state.slave_req_support;
456 
457         if let Err(e) = self
458             .vu_common
459             .restore_backend_connection(self.common.acked_features)
460         {
461             error!(
462                 "Failed restoring connection with vhost-user backend: {:?}",
463                 e
464             );
465         }
466     }
467 }
468 
469 impl Drop for Fs {
470     fn drop(&mut self) {
471         if let Some(kill_evt) = self.common.kill_evt.take() {
472             // Ignore the result because there is nothing we can do about it.
473             let _ = kill_evt.write(1);
474         }
475     }
476 }
477 
478 impl VirtioDevice for Fs {
479     fn device_type(&self) -> u32 {
480         self.common.device_type
481     }
482 
483     fn queue_max_sizes(&self) -> &[u16] {
484         &self.common.queue_sizes
485     }
486 
487     fn features(&self) -> u64 {
488         let mut features = self.common.avail_features;
489         if self.iommu {
490             features |= 1u64 << VIRTIO_F_IOMMU_PLATFORM;
491         }
492         features
493     }
494 
495     fn ack_features(&mut self, value: u64) {
496         self.common.ack_features(value)
497     }
498 
499     fn read_config(&self, offset: u64, data: &mut [u8]) {
500         self.read_config_from_slice(self.config.as_slice(), offset, data);
501     }
502 
503     fn activate(
504         &mut self,
505         mem: GuestMemoryAtomic<GuestMemoryMmap>,
506         interrupt_cb: Arc<dyn VirtioInterrupt>,
507         queues: Vec<Queue<GuestMemoryAtomic<GuestMemoryMmap>>>,
508         queue_evts: Vec<EventFd>,
509     ) -> ActivateResult {
510         self.common.activate(&queues, &queue_evts, &interrupt_cb)?;
511         self.guest_memory = Some(mem.clone());
512 
513         // Initialize slave communication.
514         let slave_req_handler = if self.slave_req_support {
515             if let Some(cache) = self.cache.as_ref() {
516                 let vu_master_req_handler = Arc::new(SlaveReqHandler {
517                     cache_offset: cache.0.addr,
518                     cache_size: cache.0.len,
519                     mmap_cache_addr: cache.0.host_addr,
520                     mem: mem.clone(),
521                 });
522 
523                 let mut req_handler =
524                     MasterReqHandler::new(vu_master_req_handler).map_err(|e| {
525                         ActivateError::VhostUserFsSetup(Error::MasterReqHandlerCreation(e))
526                     })?;
527 
528                 if self.vu_common.acked_protocol_features
529                     & VhostUserProtocolFeatures::REPLY_ACK.bits()
530                     != 0
531                 {
532                     req_handler.set_reply_ack_flag(true);
533                 }
534 
535                 Some(req_handler)
536             } else {
537                 None
538             }
539         } else {
540             None
541         };
542 
543         // Run a dedicated thread for handling potential reconnections with
544         // the backend.
545         let (kill_evt, pause_evt) = self.common.dup_eventfds();
546 
547         let mut handler = self.vu_common.activate(
548             mem,
549             queues,
550             queue_evts,
551             interrupt_cb,
552             self.common.acked_features,
553             slave_req_handler,
554             kill_evt,
555             pause_evt,
556         )?;
557 
558         let paused = self.common.paused.clone();
559         let paused_sync = self.common.paused_sync.clone();
560 
561         let mut epoll_threads = Vec::new();
562         spawn_virtio_thread(
563             &self.id,
564             &self.seccomp_action,
565             Thread::VirtioVhostFs,
566             &mut epoll_threads,
567             &self.exit_evt,
568             move || {
569                 if let Err(e) = handler.run(paused, paused_sync.unwrap()) {
570                     error!("Error running worker: {:?}", e);
571                 }
572             },
573         )?;
574         self.epoll_thread = Some(epoll_threads.remove(0));
575 
576         event!("virtio-device", "activated", "id", &self.id);
577         Ok(())
578     }
579 
580     fn reset(&mut self) -> Option<Arc<dyn VirtioInterrupt>> {
581         // We first must resume the virtio thread if it was paused.
582         if self.common.pause_evt.take().is_some() {
583             self.common.resume().ok()?;
584         }
585 
586         if let Some(vu) = &self.vu_common.vu {
587             if let Err(e) = vu
588                 .lock()
589                 .unwrap()
590                 .reset_vhost_user(self.common.queue_sizes.len())
591             {
592                 error!("Failed to reset vhost-user daemon: {:?}", e);
593                 return None;
594             }
595         }
596 
597         if let Some(kill_evt) = self.common.kill_evt.take() {
598             // Ignore the result because there is nothing we can do about it.
599             let _ = kill_evt.write(1);
600         }
601 
602         event!("virtio-device", "reset", "id", &self.id);
603 
604         // Return the interrupt
605         Some(self.common.interrupt_cb.take().unwrap())
606     }
607 
608     fn shutdown(&mut self) {
609         self.vu_common.shutdown()
610     }
611 
612     fn get_shm_regions(&self) -> Option<VirtioSharedMemoryList> {
613         self.cache.as_ref().map(|cache| cache.0.clone())
614     }
615 
616     fn set_shm_regions(
617         &mut self,
618         shm_regions: VirtioSharedMemoryList,
619     ) -> std::result::Result<(), crate::Error> {
620         if let Some(mut cache) = self.cache.as_mut() {
621             cache.0 = shm_regions;
622             Ok(())
623         } else {
624             Err(crate::Error::SetShmRegionsNotSupported)
625         }
626     }
627 
628     fn add_memory_region(
629         &mut self,
630         region: &Arc<GuestRegionMmap>,
631     ) -> std::result::Result<(), crate::Error> {
632         self.vu_common.add_memory_region(&self.guest_memory, region)
633     }
634 
635     fn userspace_mappings(&self) -> Vec<UserspaceMapping> {
636         let mut mappings = Vec::new();
637         if let Some(cache) = self.cache.as_ref() {
638             mappings.push(UserspaceMapping {
639                 host_addr: cache.0.host_addr,
640                 mem_slot: cache.0.mem_slot,
641                 addr: cache.0.addr,
642                 len: cache.0.len,
643                 mergeable: false,
644             })
645         }
646 
647         mappings
648     }
649 }
650 
651 impl Pausable for Fs {
652     fn pause(&mut self) -> result::Result<(), MigratableError> {
653         self.vu_common.pause()?;
654         self.common.pause()
655     }
656 
657     fn resume(&mut self) -> result::Result<(), MigratableError> {
658         self.common.resume()?;
659 
660         if let Some(epoll_thread) = &self.epoll_thread {
661             epoll_thread.thread().unpark();
662         }
663 
664         self.vu_common.resume()
665     }
666 }
667 
668 impl Snapshottable for Fs {
669     fn id(&self) -> String {
670         self.id.clone()
671     }
672 
673     fn snapshot(&mut self) -> std::result::Result<Snapshot, MigratableError> {
674         self.vu_common.snapshot(&self.id(), &self.state())
675     }
676 
677     fn restore(&mut self, snapshot: Snapshot) -> std::result::Result<(), MigratableError> {
678         self.set_state(&snapshot.to_versioned_state(&self.id)?);
679         Ok(())
680     }
681 }
682 impl Transportable for Fs {}
683 
684 impl Migratable for Fs {
685     fn start_dirty_log(&mut self) -> std::result::Result<(), MigratableError> {
686         self.vu_common.start_dirty_log(&self.guest_memory)
687     }
688 
689     fn stop_dirty_log(&mut self) -> std::result::Result<(), MigratableError> {
690         self.vu_common.stop_dirty_log()
691     }
692 
693     fn dirty_log(&mut self) -> std::result::Result<MemoryRangeTable, MigratableError> {
694         self.vu_common.dirty_log(&self.guest_memory)
695     }
696 
697     fn start_migration(&mut self) -> std::result::Result<(), MigratableError> {
698         self.vu_common.start_migration()
699     }
700 
701     fn complete_migration(&mut self) -> std::result::Result<(), MigratableError> {
702         self.vu_common
703             .complete_migration(self.common.kill_evt.take())
704     }
705 }
706