xref: /cloud-hypervisor/virtio-devices/src/vhost_user/fs.rs (revision f67b3f79ea19c9a66e04074cbbf5d292f6529e43)
1 // Copyright 2019 Intel Corporation. All Rights Reserved.
2 // SPDX-License-Identifier: Apache-2.0
3 
4 use super::vu_common_ctrl::VhostUserHandle;
5 use super::{Error, Result, DEFAULT_VIRTIO_FEATURES};
6 use crate::seccomp_filters::Thread;
7 use crate::thread_helper::spawn_virtio_thread;
8 use crate::vhost_user::VhostUserCommon;
9 use crate::{
10     ActivateError, ActivateResult, Queue, UserspaceMapping, VirtioCommon, VirtioDevice,
11     VirtioDeviceType, VirtioInterrupt, VirtioSharedMemoryList,
12 };
13 use crate::{GuestMemoryMmap, GuestRegionMmap, MmapRegion};
14 use libc::{self, c_void, off64_t, pread64, pwrite64};
15 use seccompiler::SeccompAction;
16 use std::io;
17 use std::os::unix::io::AsRawFd;
18 use std::result;
19 use std::sync::{Arc, Barrier, Mutex};
20 use std::thread;
21 use versionize::{VersionMap, Versionize, VersionizeResult};
22 use versionize_derive::Versionize;
23 use vhost::vhost_user::message::{
24     VhostUserFSSlaveMsg, VhostUserFSSlaveMsgFlags, VhostUserProtocolFeatures,
25     VhostUserVirtioFeatures, VHOST_USER_FS_SLAVE_ENTRIES,
26 };
27 use vhost::vhost_user::{
28     HandlerResult, MasterReqHandler, VhostUserMaster, VhostUserMasterReqHandler,
29 };
30 use vm_memory::{
31     Address, ByteValued, GuestAddress, GuestAddressSpace, GuestMemory, GuestMemoryAtomic,
32 };
33 use vm_migration::{
34     protocol::MemoryRangeTable, Migratable, MigratableError, Pausable, Snapshot, Snapshottable,
35     Transportable, VersionMapped,
36 };
37 use vmm_sys_util::eventfd::EventFd;
38 
39 const NUM_QUEUE_OFFSET: usize = 1;
40 const DEFAULT_QUEUE_NUMBER: usize = 2;
41 
42 #[derive(Versionize)]
43 pub struct State {
44     pub avail_features: u64,
45     pub acked_features: u64,
46     pub config: VirtioFsConfig,
47     pub acked_protocol_features: u64,
48     pub vu_num_queues: usize,
49     pub slave_req_support: bool,
50 }
51 
52 impl VersionMapped for State {}
53 
54 struct SlaveReqHandler {
55     cache_offset: GuestAddress,
56     cache_size: u64,
57     mmap_cache_addr: u64,
58     mem: GuestMemoryAtomic<GuestMemoryMmap>,
59 }
60 
61 impl SlaveReqHandler {
62     // Make sure request is within cache range
63     fn is_req_valid(&self, offset: u64, len: u64) -> bool {
64         let end = match offset.checked_add(len) {
65             Some(n) => n,
66             None => return false,
67         };
68 
69         !(offset >= self.cache_size || end > self.cache_size)
70     }
71 }
72 
73 impl VhostUserMasterReqHandler for SlaveReqHandler {
74     fn handle_config_change(&self) -> HandlerResult<u64> {
75         debug!("handle_config_change");
76         Ok(0)
77     }
78 
79     fn fs_slave_map(&self, fs: &VhostUserFSSlaveMsg, fd: &dyn AsRawFd) -> HandlerResult<u64> {
80         debug!("fs_slave_map");
81 
82         for i in 0..VHOST_USER_FS_SLAVE_ENTRIES {
83             let offset = fs.cache_offset[i];
84             let len = fs.len[i];
85 
86             // Ignore if the length is 0.
87             if len == 0 {
88                 continue;
89             }
90 
91             if !self.is_req_valid(offset, len) {
92                 return Err(io::Error::from_raw_os_error(libc::EINVAL));
93             }
94 
95             let addr = self.mmap_cache_addr + offset;
96             let flags = fs.flags[i];
97             let ret = unsafe {
98                 libc::mmap(
99                     addr as *mut libc::c_void,
100                     len as usize,
101                     flags.bits() as i32,
102                     libc::MAP_SHARED | libc::MAP_FIXED,
103                     fd.as_raw_fd(),
104                     fs.fd_offset[i] as libc::off_t,
105                 )
106             };
107             if ret == libc::MAP_FAILED {
108                 return Err(io::Error::last_os_error());
109             }
110 
111             let ret = unsafe { libc::close(fd.as_raw_fd()) };
112             if ret == -1 {
113                 return Err(io::Error::last_os_error());
114             }
115         }
116 
117         Ok(0)
118     }
119 
120     fn fs_slave_unmap(&self, fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64> {
121         debug!("fs_slave_unmap");
122 
123         for i in 0..VHOST_USER_FS_SLAVE_ENTRIES {
124             let offset = fs.cache_offset[i];
125             let mut len = fs.len[i];
126 
127             // Ignore if the length is 0.
128             if len == 0 {
129                 continue;
130             }
131 
132             // Need to handle a special case where the slave ask for the unmapping
133             // of the entire mapping.
134             if len == 0xffff_ffff_ffff_ffff {
135                 len = self.cache_size;
136             }
137 
138             if !self.is_req_valid(offset, len) {
139                 return Err(io::Error::from_raw_os_error(libc::EINVAL));
140             }
141 
142             let addr = self.mmap_cache_addr + offset;
143             let ret = unsafe {
144                 libc::mmap(
145                     addr as *mut libc::c_void,
146                     len as usize,
147                     libc::PROT_NONE,
148                     libc::MAP_ANONYMOUS | libc::MAP_PRIVATE | libc::MAP_FIXED,
149                     -1,
150                     0,
151                 )
152             };
153             if ret == libc::MAP_FAILED {
154                 return Err(io::Error::last_os_error());
155             }
156         }
157 
158         Ok(0)
159     }
160 
161     fn fs_slave_sync(&self, fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64> {
162         debug!("fs_slave_sync");
163 
164         for i in 0..VHOST_USER_FS_SLAVE_ENTRIES {
165             let offset = fs.cache_offset[i];
166             let len = fs.len[i];
167 
168             // Ignore if the length is 0.
169             if len == 0 {
170                 continue;
171             }
172 
173             if !self.is_req_valid(offset, len) {
174                 return Err(io::Error::from_raw_os_error(libc::EINVAL));
175             }
176 
177             let addr = self.mmap_cache_addr + offset;
178             let ret =
179                 unsafe { libc::msync(addr as *mut libc::c_void, len as usize, libc::MS_SYNC) };
180             if ret == -1 {
181                 return Err(io::Error::last_os_error());
182             }
183         }
184 
185         Ok(0)
186     }
187 
188     fn fs_slave_io(&self, fs: &VhostUserFSSlaveMsg, fd: &dyn AsRawFd) -> HandlerResult<u64> {
189         debug!("fs_slave_io");
190 
191         let mut done: u64 = 0;
192         for i in 0..VHOST_USER_FS_SLAVE_ENTRIES {
193             // Ignore if the length is 0.
194             if fs.len[i] == 0 {
195                 continue;
196             }
197 
198             let mut foffset = fs.fd_offset[i];
199             let mut len = fs.len[i] as usize;
200             let gpa = fs.cache_offset[i];
201             let cache_end = self.cache_offset.raw_value() + self.cache_size;
202             let efault = libc::EFAULT;
203 
204             let mut ptr = if gpa >= self.cache_offset.raw_value() && gpa < cache_end {
205                 let offset = gpa
206                     .checked_sub(self.cache_offset.raw_value())
207                     .ok_or_else(|| io::Error::from_raw_os_error(efault))?;
208                 let end = gpa
209                     .checked_add(fs.len[i])
210                     .ok_or_else(|| io::Error::from_raw_os_error(efault))?;
211 
212                 if end >= cache_end {
213                     return Err(io::Error::from_raw_os_error(efault));
214                 }
215 
216                 self.mmap_cache_addr + offset
217             } else {
218                 self.mem
219                     .memory()
220                     .get_host_address(GuestAddress(gpa))
221                     .map_err(|e| {
222                         error!(
223                             "Failed to find RAM region associated with guest physical address 0x{:x}: {:?}",
224                             gpa, e
225                         );
226                         io::Error::from_raw_os_error(efault)
227                     })? as u64
228             };
229 
230             while len > 0 {
231                 let ret = if (fs.flags[i] & VhostUserFSSlaveMsgFlags::MAP_W)
232                     == VhostUserFSSlaveMsgFlags::MAP_W
233                 {
234                     debug!("write: foffset={}, len={}", foffset, len);
235                     unsafe {
236                         pwrite64(
237                             fd.as_raw_fd(),
238                             ptr as *const c_void,
239                             len as usize,
240                             foffset as off64_t,
241                         )
242                     }
243                 } else {
244                     debug!("read: foffset={}, len={}", foffset, len);
245                     unsafe {
246                         pread64(
247                             fd.as_raw_fd(),
248                             ptr as *mut c_void,
249                             len as usize,
250                             foffset as off64_t,
251                         )
252                     }
253                 };
254 
255                 if ret < 0 {
256                     return Err(io::Error::last_os_error());
257                 }
258 
259                 if ret == 0 {
260                     // EOF
261                     return Err(io::Error::new(
262                         io::ErrorKind::UnexpectedEof,
263                         "failed to access whole buffer",
264                     ));
265                 }
266                 len -= ret as usize;
267                 foffset += ret as u64;
268                 ptr += ret as u64;
269                 done += ret as u64;
270             }
271         }
272 
273         let ret = unsafe { libc::close(fd.as_raw_fd()) };
274         if ret == -1 {
275             return Err(io::Error::last_os_error());
276         }
277 
278         Ok(done)
279     }
280 }
281 
282 #[derive(Copy, Clone, Versionize)]
283 #[repr(C, packed)]
284 pub struct VirtioFsConfig {
285     pub tag: [u8; 36],
286     pub num_request_queues: u32,
287 }
288 
289 impl Default for VirtioFsConfig {
290     fn default() -> Self {
291         VirtioFsConfig {
292             tag: [0; 36],
293             num_request_queues: 0,
294         }
295     }
296 }
297 
298 unsafe impl ByteValued for VirtioFsConfig {}
299 
300 pub struct Fs {
301     common: VirtioCommon,
302     vu_common: VhostUserCommon,
303     id: String,
304     config: VirtioFsConfig,
305     // Hold ownership of the memory that is allocated for the device
306     // which will be automatically dropped when the device is dropped
307     cache: Option<(VirtioSharedMemoryList, MmapRegion)>,
308     slave_req_support: bool,
309     seccomp_action: SeccompAction,
310     guest_memory: Option<GuestMemoryAtomic<GuestMemoryMmap>>,
311     epoll_thread: Option<thread::JoinHandle<()>>,
312     exit_evt: EventFd,
313 }
314 
315 impl Fs {
316     /// Create a new virtio-fs device.
317     #[allow(clippy::too_many_arguments)]
318     pub fn new(
319         id: String,
320         path: &str,
321         tag: &str,
322         req_num_queues: usize,
323         queue_size: u16,
324         cache: Option<(VirtioSharedMemoryList, MmapRegion)>,
325         seccomp_action: SeccompAction,
326         restoring: bool,
327         exit_evt: EventFd,
328     ) -> Result<Fs> {
329         let mut slave_req_support = false;
330 
331         // Calculate the actual number of queues needed.
332         let num_queues = NUM_QUEUE_OFFSET + req_num_queues;
333 
334         if restoring {
335             // We need 'queue_sizes' to report a number of queues that will be
336             // enough to handle all the potential queues. VirtioPciDevice::new()
337             // will create the actual queues based on this information.
338             return Ok(Fs {
339                 common: VirtioCommon {
340                     device_type: VirtioDeviceType::Fs as u32,
341                     queue_sizes: vec![queue_size; num_queues],
342                     paused_sync: Some(Arc::new(Barrier::new(2))),
343                     min_queues: DEFAULT_QUEUE_NUMBER as u16,
344                     ..Default::default()
345                 },
346                 vu_common: VhostUserCommon {
347                     socket_path: path.to_string(),
348                     vu_num_queues: num_queues,
349                     ..Default::default()
350                 },
351                 id,
352                 config: VirtioFsConfig::default(),
353                 cache,
354                 slave_req_support,
355                 seccomp_action,
356                 guest_memory: None,
357                 epoll_thread: None,
358                 exit_evt,
359             });
360         }
361 
362         // Connect to the vhost-user socket.
363         let mut vu = VhostUserHandle::connect_vhost_user(false, path, num_queues as u64, false)?;
364 
365         // Filling device and vring features VMM supports.
366         let avail_features = DEFAULT_VIRTIO_FEATURES;
367 
368         let mut avail_protocol_features = VhostUserProtocolFeatures::MQ
369             | VhostUserProtocolFeatures::CONFIGURE_MEM_SLOTS
370             | VhostUserProtocolFeatures::REPLY_ACK
371             | VhostUserProtocolFeatures::INFLIGHT_SHMFD
372             | VhostUserProtocolFeatures::LOG_SHMFD;
373         let slave_protocol_features =
374             VhostUserProtocolFeatures::SLAVE_REQ | VhostUserProtocolFeatures::SLAVE_SEND_FD;
375         if cache.is_some() {
376             avail_protocol_features |= slave_protocol_features;
377         }
378 
379         let (acked_features, acked_protocol_features) =
380             vu.negotiate_features_vhost_user(avail_features, avail_protocol_features)?;
381 
382         let backend_num_queues =
383             if acked_protocol_features & VhostUserProtocolFeatures::MQ.bits() != 0 {
384                 vu.socket_handle()
385                     .get_queue_num()
386                     .map_err(Error::VhostUserGetQueueMaxNum)? as usize
387             } else {
388                 DEFAULT_QUEUE_NUMBER
389             };
390 
391         if num_queues > backend_num_queues {
392             error!(
393                 "vhost-user-fs requested too many queues ({}) since the backend only supports {}\n",
394                 num_queues, backend_num_queues
395             );
396             return Err(Error::BadQueueNum);
397         }
398 
399         if acked_protocol_features & slave_protocol_features.bits()
400             == slave_protocol_features.bits()
401         {
402             slave_req_support = true;
403         }
404 
405         // Create virtio-fs device configuration.
406         let mut config = VirtioFsConfig::default();
407         let tag_bytes_vec = tag.to_string().into_bytes();
408         config.tag[..tag_bytes_vec.len()].copy_from_slice(tag_bytes_vec.as_slice());
409         config.num_request_queues = req_num_queues as u32;
410 
411         Ok(Fs {
412             common: VirtioCommon {
413                 device_type: VirtioDeviceType::Fs as u32,
414                 avail_features: acked_features,
415                 acked_features: 0,
416                 queue_sizes: vec![queue_size; num_queues],
417                 paused_sync: Some(Arc::new(Barrier::new(2))),
418                 min_queues: DEFAULT_QUEUE_NUMBER as u16,
419                 ..Default::default()
420             },
421             vu_common: VhostUserCommon {
422                 vu: Some(Arc::new(Mutex::new(vu))),
423                 acked_protocol_features,
424                 socket_path: path.to_string(),
425                 vu_num_queues: num_queues,
426                 ..Default::default()
427             },
428             id,
429             config,
430             cache,
431             slave_req_support,
432             seccomp_action,
433             guest_memory: None,
434             epoll_thread: None,
435             exit_evt,
436         })
437     }
438 
439     fn state(&self) -> State {
440         State {
441             avail_features: self.common.avail_features,
442             acked_features: self.common.acked_features,
443             config: self.config,
444             acked_protocol_features: self.vu_common.acked_protocol_features,
445             vu_num_queues: self.vu_common.vu_num_queues,
446             slave_req_support: self.slave_req_support,
447         }
448     }
449 
450     fn set_state(&mut self, state: &State) {
451         self.common.avail_features = state.avail_features;
452         self.common.acked_features = state.acked_features;
453         self.config = state.config;
454         self.vu_common.acked_protocol_features = state.acked_protocol_features;
455         self.vu_common.vu_num_queues = state.vu_num_queues;
456         self.slave_req_support = state.slave_req_support;
457 
458         if let Err(e) = self
459             .vu_common
460             .restore_backend_connection(self.common.acked_features)
461         {
462             error!(
463                 "Failed restoring connection with vhost-user backend: {:?}",
464                 e
465             );
466         }
467     }
468 }
469 
470 impl Drop for Fs {
471     fn drop(&mut self) {
472         if let Some(kill_evt) = self.common.kill_evt.take() {
473             // Ignore the result because there is nothing we can do about it.
474             let _ = kill_evt.write(1);
475         }
476     }
477 }
478 
479 impl VirtioDevice for Fs {
480     fn device_type(&self) -> u32 {
481         self.common.device_type
482     }
483 
484     fn queue_max_sizes(&self) -> &[u16] {
485         &self.common.queue_sizes
486     }
487 
488     fn features(&self) -> u64 {
489         self.common.avail_features
490     }
491 
492     fn ack_features(&mut self, value: u64) {
493         self.common.ack_features(value)
494     }
495 
496     fn read_config(&self, offset: u64, data: &mut [u8]) {
497         self.read_config_from_slice(self.config.as_slice(), offset, data);
498     }
499 
500     fn activate(
501         &mut self,
502         mem: GuestMemoryAtomic<GuestMemoryMmap>,
503         interrupt_cb: Arc<dyn VirtioInterrupt>,
504         queues: Vec<Queue>,
505         queue_evts: Vec<EventFd>,
506     ) -> ActivateResult {
507         self.common.activate(&queues, &queue_evts, &interrupt_cb)?;
508         self.guest_memory = Some(mem.clone());
509 
510         // Initialize slave communication.
511         let slave_req_handler = if self.slave_req_support {
512             if let Some(cache) = self.cache.as_ref() {
513                 let vu_master_req_handler = Arc::new(SlaveReqHandler {
514                     cache_offset: cache.0.addr,
515                     cache_size: cache.0.len,
516                     mmap_cache_addr: cache.0.host_addr,
517                     mem: mem.clone(),
518                 });
519 
520                 let mut req_handler =
521                     MasterReqHandler::new(vu_master_req_handler).map_err(|e| {
522                         ActivateError::VhostUserFsSetup(Error::MasterReqHandlerCreation(e))
523                     })?;
524                 req_handler.set_reply_ack_flag(true);
525                 Some(req_handler)
526             } else {
527                 None
528             }
529         } else {
530             None
531         };
532 
533         // The backend acknowledged features must contain the protocol feature
534         // bit in case it was initially set but lost through the features
535         // negotiation with the guest.
536         let backend_acked_features = self.common.acked_features
537             | (self.common.avail_features & VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits());
538 
539         // Run a dedicated thread for handling potential reconnections with
540         // the backend.
541         let (kill_evt, pause_evt) = self.common.dup_eventfds();
542 
543         let mut handler = self.vu_common.activate(
544             mem,
545             queues,
546             queue_evts,
547             interrupt_cb,
548             backend_acked_features,
549             slave_req_handler,
550             kill_evt,
551             pause_evt,
552         )?;
553 
554         let paused = self.common.paused.clone();
555         let paused_sync = self.common.paused_sync.clone();
556 
557         let mut epoll_threads = Vec::new();
558         spawn_virtio_thread(
559             &self.id,
560             &self.seccomp_action,
561             Thread::VirtioVhostFs,
562             &mut epoll_threads,
563             &self.exit_evt,
564             move || {
565                 if let Err(e) = handler.run(paused, paused_sync.unwrap()) {
566                     error!("Error running worker: {:?}", e);
567                 }
568             },
569         )?;
570         self.epoll_thread = Some(epoll_threads.remove(0));
571 
572         event!("virtio-device", "activated", "id", &self.id);
573         Ok(())
574     }
575 
576     fn reset(&mut self) -> Option<Arc<dyn VirtioInterrupt>> {
577         // We first must resume the virtio thread if it was paused.
578         if self.common.pause_evt.take().is_some() {
579             self.common.resume().ok()?;
580         }
581 
582         if let Some(vu) = &self.vu_common.vu {
583             if let Err(e) = vu
584                 .lock()
585                 .unwrap()
586                 .reset_vhost_user(self.common.queue_sizes.len())
587             {
588                 error!("Failed to reset vhost-user daemon: {:?}", e);
589                 return None;
590             }
591         }
592 
593         if let Some(kill_evt) = self.common.kill_evt.take() {
594             // Ignore the result because there is nothing we can do about it.
595             let _ = kill_evt.write(1);
596         }
597 
598         event!("virtio-device", "reset", "id", &self.id);
599 
600         // Return the interrupt
601         Some(self.common.interrupt_cb.take().unwrap())
602     }
603 
604     fn shutdown(&mut self) {
605         self.vu_common.shutdown()
606     }
607 
608     fn get_shm_regions(&self) -> Option<VirtioSharedMemoryList> {
609         self.cache.as_ref().map(|cache| cache.0.clone())
610     }
611 
612     fn set_shm_regions(
613         &mut self,
614         shm_regions: VirtioSharedMemoryList,
615     ) -> std::result::Result<(), crate::Error> {
616         if let Some(mut cache) = self.cache.as_mut() {
617             cache.0 = shm_regions;
618             Ok(())
619         } else {
620             Err(crate::Error::SetShmRegionsNotSupported)
621         }
622     }
623 
624     fn add_memory_region(
625         &mut self,
626         region: &Arc<GuestRegionMmap>,
627     ) -> std::result::Result<(), crate::Error> {
628         self.vu_common.add_memory_region(&self.guest_memory, region)
629     }
630 
631     fn userspace_mappings(&self) -> Vec<UserspaceMapping> {
632         let mut mappings = Vec::new();
633         if let Some(cache) = self.cache.as_ref() {
634             mappings.push(UserspaceMapping {
635                 host_addr: cache.0.host_addr,
636                 mem_slot: cache.0.mem_slot,
637                 addr: cache.0.addr,
638                 len: cache.0.len,
639                 mergeable: false,
640             })
641         }
642 
643         mappings
644     }
645 }
646 
647 impl Pausable for Fs {
648     fn pause(&mut self) -> result::Result<(), MigratableError> {
649         self.vu_common.pause()?;
650         self.common.pause()
651     }
652 
653     fn resume(&mut self) -> result::Result<(), MigratableError> {
654         self.common.resume()?;
655 
656         if let Some(epoll_thread) = &self.epoll_thread {
657             epoll_thread.thread().unpark();
658         }
659 
660         self.vu_common.resume()
661     }
662 }
663 
664 impl Snapshottable for Fs {
665     fn id(&self) -> String {
666         self.id.clone()
667     }
668 
669     fn snapshot(&mut self) -> std::result::Result<Snapshot, MigratableError> {
670         self.vu_common.snapshot(&self.id(), &self.state())
671     }
672 
673     fn restore(&mut self, snapshot: Snapshot) -> std::result::Result<(), MigratableError> {
674         self.set_state(&snapshot.to_versioned_state(&self.id)?);
675         Ok(())
676     }
677 }
678 impl Transportable for Fs {}
679 
680 impl Migratable for Fs {
681     fn start_dirty_log(&mut self) -> std::result::Result<(), MigratableError> {
682         self.vu_common.start_dirty_log(&self.guest_memory)
683     }
684 
685     fn stop_dirty_log(&mut self) -> std::result::Result<(), MigratableError> {
686         self.vu_common.stop_dirty_log()
687     }
688 
689     fn dirty_log(&mut self) -> std::result::Result<MemoryRangeTable, MigratableError> {
690         self.vu_common.dirty_log(&self.guest_memory)
691     }
692 
693     fn complete_migration(&mut self) -> std::result::Result<(), MigratableError> {
694         self.vu_common
695             .complete_migration(self.common.kill_evt.take())
696     }
697 }
698