xref: /cloud-hypervisor/virtio-devices/src/vhost_user/vu_common_ctrl.rs (revision 190d90196fff389b60b93b57acf958957b71b249)
1 // Copyright 2019 Intel Corporation. All Rights Reserved.
2 // SPDX-License-Identifier: Apache-2.0
3 
4 use std::ffi;
5 use std::fs::File;
6 use std::os::unix::io::{AsRawFd, FromRawFd, RawFd};
7 use std::os::unix::net::UnixListener;
8 use std::sync::atomic::Ordering;
9 use std::sync::Arc;
10 use std::thread::sleep;
11 use std::time::{Duration, Instant};
12 
13 use vhost::vhost_kern::vhost_binding::{VHOST_F_LOG_ALL, VHOST_VRING_F_LOG};
14 use vhost::vhost_user::message::{
15     VhostUserHeaderFlag, VhostUserInflight, VhostUserProtocolFeatures, VhostUserVirtioFeatures,
16 };
17 use vhost::vhost_user::{
18     Frontend, FrontendReqHandler, VhostUserFrontend, VhostUserFrontendReqHandler,
19 };
20 use vhost::{VhostBackend, VhostUserDirtyLogRegion, VhostUserMemoryRegionInfo, VringConfigData};
21 use virtio_queue::desc::RawDescriptor;
22 use virtio_queue::{Queue, QueueT};
23 use vm_memory::{
24     Address, Error as MmapError, FileOffset, GuestAddress, GuestMemory, GuestMemoryRegion,
25 };
26 use vm_migration::protocol::MemoryRangeTable;
27 use vmm_sys_util::eventfd::EventFd;
28 
29 use super::{Error, Result};
30 use crate::vhost_user::Inflight;
31 use crate::{
32     get_host_address_range, GuestMemoryMmap, GuestRegionMmap, MmapRegion, VirtioInterrupt,
33     VirtioInterruptType,
34 };
35 
36 // Size of a dirty page for vhost-user.
37 const VHOST_LOG_PAGE: u64 = 0x1000;
38 
39 #[derive(Debug, Clone)]
40 pub struct VhostUserConfig {
41     pub socket: String,
42     pub num_queues: usize,
43     pub queue_size: u16,
44 }
45 
46 #[derive(Clone)]
47 struct VringInfo {
48     config_data: VringConfigData,
49     used_guest_addr: u64,
50 }
51 
52 #[derive(Clone)]
53 pub struct VhostUserHandle {
54     vu: Frontend,
55     ready: bool,
56     supports_migration: bool,
57     shm_log: Option<Arc<MmapRegion>>,
58     acked_features: u64,
59     vrings_info: Option<Vec<VringInfo>>,
60     queue_indexes: Vec<usize>,
61 }
62 
63 impl VhostUserHandle {
update_mem_table(&mut self, mem: &GuestMemoryMmap) -> Result<()>64     pub fn update_mem_table(&mut self, mem: &GuestMemoryMmap) -> Result<()> {
65         let mut regions: Vec<VhostUserMemoryRegionInfo> = Vec::new();
66         for region in mem.iter() {
67             let (mmap_handle, mmap_offset) = match region.file_offset() {
68                 Some(_file_offset) => (_file_offset.file().as_raw_fd(), _file_offset.start()),
69                 None => return Err(Error::VhostUserMemoryRegion(MmapError::NoMemoryRegion)),
70             };
71 
72             let vhost_user_net_reg = VhostUserMemoryRegionInfo {
73                 guest_phys_addr: region.start_addr().raw_value(),
74                 memory_size: region.len(),
75                 userspace_addr: region.as_ptr() as u64,
76                 mmap_offset,
77                 mmap_handle,
78             };
79 
80             regions.push(vhost_user_net_reg);
81         }
82 
83         self.vu
84             .set_mem_table(regions.as_slice())
85             .map_err(Error::VhostUserSetMemTable)?;
86 
87         Ok(())
88     }
89 
add_memory_region(&mut self, region: &Arc<GuestRegionMmap>) -> Result<()>90     pub fn add_memory_region(&mut self, region: &Arc<GuestRegionMmap>) -> Result<()> {
91         let (mmap_handle, mmap_offset) = match region.file_offset() {
92             Some(file_offset) => (file_offset.file().as_raw_fd(), file_offset.start()),
93             None => return Err(Error::MissingRegionFd),
94         };
95 
96         let region = VhostUserMemoryRegionInfo {
97             guest_phys_addr: region.start_addr().raw_value(),
98             memory_size: region.len(),
99             userspace_addr: region.as_ptr() as u64,
100             mmap_offset,
101             mmap_handle,
102         };
103 
104         self.vu
105             .add_mem_region(&region)
106             .map_err(Error::VhostUserAddMemReg)
107     }
108 
negotiate_features_vhost_user( &mut self, avail_features: u64, avail_protocol_features: VhostUserProtocolFeatures, ) -> Result<(u64, u64)>109     pub fn negotiate_features_vhost_user(
110         &mut self,
111         avail_features: u64,
112         avail_protocol_features: VhostUserProtocolFeatures,
113     ) -> Result<(u64, u64)> {
114         // Set vhost-user owner.
115         self.vu.set_owner().map_err(Error::VhostUserSetOwner)?;
116 
117         // Get features from backend, do negotiation to get a feature collection which
118         // both VMM and backend support.
119         let backend_features = self
120             .vu
121             .get_features()
122             .map_err(Error::VhostUserGetFeatures)?;
123         let acked_features = avail_features & backend_features;
124 
125         let acked_protocol_features =
126             if acked_features & VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits() != 0 {
127                 let backend_protocol_features = self
128                     .vu
129                     .get_protocol_features()
130                     .map_err(Error::VhostUserGetProtocolFeatures)?;
131 
132                 let acked_protocol_features = avail_protocol_features & backend_protocol_features;
133 
134                 self.vu
135                     .set_protocol_features(acked_protocol_features)
136                     .map_err(Error::VhostUserSetProtocolFeatures)?;
137 
138                 acked_protocol_features
139             } else {
140                 VhostUserProtocolFeatures::empty()
141             };
142 
143         if avail_protocol_features.contains(VhostUserProtocolFeatures::REPLY_ACK)
144             && acked_protocol_features.contains(VhostUserProtocolFeatures::REPLY_ACK)
145         {
146             self.vu.set_hdr_flags(VhostUserHeaderFlag::NEED_REPLY);
147         }
148 
149         self.update_supports_migration(acked_features, acked_protocol_features.bits());
150 
151         Ok((acked_features, acked_protocol_features.bits()))
152     }
153 
154     #[allow(clippy::too_many_arguments)]
setup_vhost_user<S: VhostUserFrontendReqHandler>( &mut self, mem: &GuestMemoryMmap, queues: Vec<(usize, Queue, EventFd)>, virtio_interrupt: &Arc<dyn VirtioInterrupt>, acked_features: u64, backend_req_handler: &Option<FrontendReqHandler<S>>, inflight: Option<&mut Inflight>, ) -> Result<()>155     pub fn setup_vhost_user<S: VhostUserFrontendReqHandler>(
156         &mut self,
157         mem: &GuestMemoryMmap,
158         queues: Vec<(usize, Queue, EventFd)>,
159         virtio_interrupt: &Arc<dyn VirtioInterrupt>,
160         acked_features: u64,
161         backend_req_handler: &Option<FrontendReqHandler<S>>,
162         inflight: Option<&mut Inflight>,
163     ) -> Result<()> {
164         self.vu
165             .set_features(acked_features)
166             .map_err(Error::VhostUserSetFeatures)?;
167 
168         // Update internal value after it's been sent to the backend.
169         self.acked_features = acked_features;
170 
171         // Let's first provide the memory table to the backend.
172         self.update_mem_table(mem)?;
173 
174         // Send set_vring_num here, since it could tell backends, like SPDK,
175         // how many virt queues to be handled, which backend required to know
176         // at early stage.
177         for (queue_index, queue, _) in queues.iter() {
178             self.vu
179                 .set_vring_num(*queue_index, queue.size())
180                 .map_err(Error::VhostUserSetVringNum)?;
181         }
182 
183         // Setup for inflight I/O tracking shared memory.
184         if let Some(inflight) = inflight {
185             if inflight.fd.is_none() {
186                 let inflight_req_info = VhostUserInflight {
187                     mmap_size: 0,
188                     mmap_offset: 0,
189                     num_queues: queues.len() as u16,
190                     queue_size: queues[0].1.size(),
191                 };
192                 let (info, fd) = self
193                     .vu
194                     .get_inflight_fd(&inflight_req_info)
195                     .map_err(Error::VhostUserGetInflight)?;
196                 inflight.info = info;
197                 inflight.fd = Some(fd);
198             }
199             // Unwrapping the inflight fd is safe here since we know it can't be None.
200             self.vu
201                 .set_inflight_fd(&inflight.info, inflight.fd.as_ref().unwrap().as_raw_fd())
202                 .map_err(Error::VhostUserSetInflight)?;
203         }
204 
205         let mut vrings_info = Vec::new();
206         for (queue_index, queue, queue_evt) in queues.iter() {
207             let actual_size: usize = queue.size().into();
208 
209             let config_data = VringConfigData {
210                 queue_max_size: queue.max_size(),
211                 queue_size: queue.size(),
212                 flags: 0u32,
213                 desc_table_addr: get_host_address_range(
214                     mem,
215                     GuestAddress(queue.desc_table()),
216                     actual_size * std::mem::size_of::<RawDescriptor>(),
217                 )
218                 .ok_or(Error::DescriptorTableAddress)? as u64,
219                 // The used ring is {flags: u16; idx: u16; virtq_used_elem [{id: u16, len: u16}; actual_size]},
220                 // i.e. 4 + (4 + 4) * actual_size.
221                 used_ring_addr: get_host_address_range(
222                     mem,
223                     GuestAddress(queue.used_ring()),
224                     4 + actual_size * 8,
225                 )
226                 .ok_or(Error::UsedAddress)? as u64,
227                 // The used ring is {flags: u16; idx: u16; elem [u16; actual_size]},
228                 // i.e. 4 + (2) * actual_size.
229                 avail_ring_addr: get_host_address_range(
230                     mem,
231                     GuestAddress(queue.avail_ring()),
232                     4 + actual_size * 2,
233                 )
234                 .ok_or(Error::AvailAddress)? as u64,
235                 log_addr: None,
236             };
237 
238             vrings_info.push(VringInfo {
239                 config_data,
240                 used_guest_addr: queue.used_ring(),
241             });
242 
243             self.vu
244                 .set_vring_addr(*queue_index, &config_data)
245                 .map_err(Error::VhostUserSetVringAddr)?;
246             self.vu
247                 .set_vring_base(
248                     *queue_index,
249                     queue
250                         .avail_idx(mem, Ordering::Acquire)
251                         .map_err(Error::GetAvailableIndex)?
252                         .0,
253                 )
254                 .map_err(Error::VhostUserSetVringBase)?;
255 
256             if let Some(eventfd) =
257                 virtio_interrupt.notifier(VirtioInterruptType::Queue(*queue_index as u16))
258             {
259                 self.vu
260                     .set_vring_call(*queue_index, &eventfd)
261                     .map_err(Error::VhostUserSetVringCall)?;
262             }
263 
264             self.vu
265                 .set_vring_kick(*queue_index, queue_evt)
266                 .map_err(Error::VhostUserSetVringKick)?;
267 
268             self.queue_indexes.push(*queue_index);
269         }
270 
271         self.enable_vhost_user_vrings(self.queue_indexes.clone(), true)?;
272 
273         if let Some(backend_req_handler) = backend_req_handler {
274             self.vu
275                 .set_backend_request_fd(&backend_req_handler.get_tx_raw_fd())
276                 .map_err(Error::VhostUserSetBackendRequestFd)?;
277         }
278 
279         self.vrings_info = Some(vrings_info);
280         self.ready = true;
281 
282         Ok(())
283     }
284 
enable_vhost_user_vrings(&mut self, queue_indexes: Vec<usize>, enable: bool) -> Result<()>285     fn enable_vhost_user_vrings(&mut self, queue_indexes: Vec<usize>, enable: bool) -> Result<()> {
286         for queue_index in queue_indexes {
287             self.vu
288                 .set_vring_enable(queue_index, enable)
289                 .map_err(Error::VhostUserSetVringEnable)?;
290         }
291 
292         Ok(())
293     }
294 
reset_vhost_user(&mut self) -> Result<()>295     pub fn reset_vhost_user(&mut self) -> Result<()> {
296         for queue_index in self.queue_indexes.drain(..) {
297             self.vu
298                 .set_vring_enable(queue_index, false)
299                 .map_err(Error::VhostUserSetVringEnable)?;
300 
301             let _ = self
302                 .vu
303                 .get_vring_base(queue_index)
304                 .map_err(Error::VhostUserGetVringBase)?;
305         }
306 
307         Ok(())
308     }
309 
set_protocol_features_vhost_user( &mut self, acked_features: u64, acked_protocol_features: u64, ) -> Result<()>310     pub fn set_protocol_features_vhost_user(
311         &mut self,
312         acked_features: u64,
313         acked_protocol_features: u64,
314     ) -> Result<()> {
315         self.vu.set_owner().map_err(Error::VhostUserSetOwner)?;
316         self.vu
317             .get_features()
318             .map_err(Error::VhostUserGetFeatures)?;
319 
320         if acked_features & VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits() != 0 {
321             if let Some(acked_protocol_features) =
322                 VhostUserProtocolFeatures::from_bits(acked_protocol_features)
323             {
324                 self.vu
325                     .set_protocol_features(acked_protocol_features)
326                     .map_err(Error::VhostUserSetProtocolFeatures)?;
327 
328                 if acked_protocol_features.contains(VhostUserProtocolFeatures::REPLY_ACK) {
329                     self.vu.set_hdr_flags(VhostUserHeaderFlag::NEED_REPLY);
330                 }
331             }
332         }
333 
334         self.update_supports_migration(acked_features, acked_protocol_features);
335 
336         Ok(())
337     }
338 
339     #[allow(clippy::too_many_arguments)]
reinitialize_vhost_user<S: VhostUserFrontendReqHandler>( &mut self, mem: &GuestMemoryMmap, queues: Vec<(usize, Queue, EventFd)>, virtio_interrupt: &Arc<dyn VirtioInterrupt>, acked_features: u64, acked_protocol_features: u64, backend_req_handler: &Option<FrontendReqHandler<S>>, inflight: Option<&mut Inflight>, ) -> Result<()>340     pub fn reinitialize_vhost_user<S: VhostUserFrontendReqHandler>(
341         &mut self,
342         mem: &GuestMemoryMmap,
343         queues: Vec<(usize, Queue, EventFd)>,
344         virtio_interrupt: &Arc<dyn VirtioInterrupt>,
345         acked_features: u64,
346         acked_protocol_features: u64,
347         backend_req_handler: &Option<FrontendReqHandler<S>>,
348         inflight: Option<&mut Inflight>,
349     ) -> Result<()> {
350         self.set_protocol_features_vhost_user(acked_features, acked_protocol_features)?;
351 
352         self.setup_vhost_user(
353             mem,
354             queues,
355             virtio_interrupt,
356             acked_features,
357             backend_req_handler,
358             inflight,
359         )
360     }
361 
connect_vhost_user( server: bool, socket_path: &str, num_queues: u64, unlink_socket: bool, ) -> Result<Self>362     pub fn connect_vhost_user(
363         server: bool,
364         socket_path: &str,
365         num_queues: u64,
366         unlink_socket: bool,
367     ) -> Result<Self> {
368         if server {
369             if unlink_socket {
370                 std::fs::remove_file(socket_path).map_err(Error::RemoveSocketPath)?;
371             }
372 
373             info!("Binding vhost-user listener...");
374             let listener = UnixListener::bind(socket_path).map_err(Error::BindSocket)?;
375             info!("Waiting for incoming vhost-user connection...");
376             let (stream, _) = listener.accept().map_err(Error::AcceptConnection)?;
377 
378             Ok(VhostUserHandle {
379                 vu: Frontend::from_stream(stream, num_queues),
380                 ready: false,
381                 supports_migration: false,
382                 shm_log: None,
383                 acked_features: 0,
384                 vrings_info: None,
385                 queue_indexes: Vec::new(),
386             })
387         } else {
388             let now = Instant::now();
389 
390             // Retry connecting for a full minute
391             let err = loop {
392                 let err = match Frontend::connect(socket_path, num_queues) {
393                     Ok(m) => {
394                         return Ok(VhostUserHandle {
395                             vu: m,
396                             ready: false,
397                             supports_migration: false,
398                             shm_log: None,
399                             acked_features: 0,
400                             vrings_info: None,
401                             queue_indexes: Vec::new(),
402                         })
403                     }
404                     Err(e) => e,
405                 };
406                 sleep(Duration::from_millis(100));
407 
408                 if now.elapsed().as_secs() >= 60 {
409                     break err;
410                 }
411             };
412 
413             error!(
414                 "Failed connecting the backend after trying for 1 minute: {:?}",
415                 err
416             );
417             Err(Error::VhostUserConnect)
418         }
419     }
420 
socket_handle(&mut self) -> &mut Frontend421     pub fn socket_handle(&mut self) -> &mut Frontend {
422         &mut self.vu
423     }
424 
pause_vhost_user(&mut self) -> Result<()>425     pub fn pause_vhost_user(&mut self) -> Result<()> {
426         if self.ready {
427             self.enable_vhost_user_vrings(self.queue_indexes.clone(), false)?;
428         }
429 
430         Ok(())
431     }
432 
resume_vhost_user(&mut self) -> Result<()>433     pub fn resume_vhost_user(&mut self) -> Result<()> {
434         if self.ready {
435             self.enable_vhost_user_vrings(self.queue_indexes.clone(), true)?;
436         }
437 
438         Ok(())
439     }
440 
update_supports_migration(&mut self, acked_features: u64, acked_protocol_features: u64)441     fn update_supports_migration(&mut self, acked_features: u64, acked_protocol_features: u64) {
442         if (acked_features & u64::from(vhost::vhost_kern::vhost_binding::VHOST_F_LOG_ALL) != 0)
443             && (acked_protocol_features & VhostUserProtocolFeatures::LOG_SHMFD.bits() != 0)
444         {
445             self.supports_migration = true;
446         }
447     }
448 
update_log_base(&mut self, last_ram_addr: u64) -> Result<Option<Arc<MmapRegion>>>449     fn update_log_base(&mut self, last_ram_addr: u64) -> Result<Option<Arc<MmapRegion>>> {
450         // Create the memfd
451         let fd = memfd_create(
452             &ffi::CString::new("vhost_user_dirty_log").unwrap(),
453             libc::MFD_CLOEXEC | libc::MFD_ALLOW_SEALING,
454         )
455         .map_err(Error::MemfdCreate)?;
456 
457         // SAFETY: we checked the file descriptor is valid
458         let file = unsafe { File::from_raw_fd(fd) };
459         // The size of the memory mapping corresponds to the size of a bitmap
460         // covering all guest pages for addresses from 0 to the last physical
461         // address in guest RAM.
462         // A page is always 4kiB from a vhost-user perspective, and each bit is
463         // a page. That's how we can compute mmap_size from the last address.
464         let mmap_size = (last_ram_addr / (VHOST_LOG_PAGE * 8)) + 1;
465         let mmap_handle = file.as_raw_fd();
466 
467         // Set shm_log region size
468         file.set_len(mmap_size).map_err(Error::SetFileSize)?;
469 
470         // Set the seals
471         // SAFETY: FFI call with valid arguments
472         let res = unsafe {
473             libc::fcntl(
474                 file.as_raw_fd(),
475                 libc::F_ADD_SEALS,
476                 libc::F_SEAL_GROW | libc::F_SEAL_SHRINK | libc::F_SEAL_SEAL,
477             )
478         };
479         if res < 0 {
480             return Err(Error::SetSeals(std::io::Error::last_os_error()));
481         }
482 
483         // Mmap shm_log region
484         let region = MmapRegion::build(
485             Some(FileOffset::new(file, 0)),
486             mmap_size as usize,
487             libc::PROT_READ | libc::PROT_WRITE,
488             libc::MAP_SHARED,
489         )
490         .map_err(Error::NewMmapRegion)?;
491 
492         // Make sure we hold onto the region to prevent the mapping from being
493         // released.
494         let old_region = self.shm_log.replace(Arc::new(region));
495 
496         // Send the shm_log fd over to the backend
497         let log = VhostUserDirtyLogRegion {
498             mmap_size,
499             mmap_offset: 0,
500             mmap_handle,
501         };
502         self.vu
503             .set_log_base(0, Some(log))
504             .map_err(Error::VhostUserSetLogBase)?;
505 
506         Ok(old_region)
507     }
508 
set_vring_logging(&mut self, enable: bool) -> Result<()>509     fn set_vring_logging(&mut self, enable: bool) -> Result<()> {
510         if let Some(vrings_info) = &self.vrings_info {
511             for (i, vring_info) in vrings_info.iter().enumerate() {
512                 let mut config_data = vring_info.config_data;
513                 config_data.flags = if enable { 1 << VHOST_VRING_F_LOG } else { 0 };
514                 config_data.log_addr = if enable {
515                     Some(vring_info.used_guest_addr)
516                 } else {
517                     None
518                 };
519 
520                 self.vu
521                     .set_vring_addr(i, &config_data)
522                     .map_err(Error::VhostUserSetVringAddr)?;
523             }
524         }
525 
526         Ok(())
527     }
528 
start_dirty_log(&mut self, last_ram_addr: u64) -> Result<()>529     pub fn start_dirty_log(&mut self, last_ram_addr: u64) -> Result<()> {
530         if !self.supports_migration {
531             return Err(Error::MigrationNotSupported);
532         }
533 
534         // Set the shm log region
535         self.update_log_base(last_ram_addr)?;
536 
537         // Enable VHOST_F_LOG_ALL feature
538         let features = self.acked_features | (1 << VHOST_F_LOG_ALL);
539         self.vu
540             .set_features(features)
541             .map_err(Error::VhostUserSetFeatures)?;
542 
543         // Enable dirty page logging of used ring for all queues
544         self.set_vring_logging(true)
545     }
546 
stop_dirty_log(&mut self) -> Result<()>547     pub fn stop_dirty_log(&mut self) -> Result<()> {
548         if !self.supports_migration {
549             return Err(Error::MigrationNotSupported);
550         }
551 
552         // Disable dirty page logging of used ring for all queues
553         self.set_vring_logging(false)?;
554 
555         // Disable VHOST_F_LOG_ALL feature
556         self.vu
557             .set_features(self.acked_features)
558             .map_err(Error::VhostUserSetFeatures)?;
559 
560         // This is important here since the log region goes out of scope,
561         // invoking the Drop trait, hence unmapping the memory.
562         self.shm_log = None;
563 
564         Ok(())
565     }
566 
dirty_log(&mut self, last_ram_addr: u64) -> Result<MemoryRangeTable>567     pub fn dirty_log(&mut self, last_ram_addr: u64) -> Result<MemoryRangeTable> {
568         // The log region is updated by creating a new region that is sent to
569         // the backend. This ensures the backend stops logging to the previous
570         // region. The previous region is returned and processed to create the
571         // bitmap representing the dirty pages.
572         if let Some(region) = self.update_log_base(last_ram_addr)? {
573             // Be careful with the size, as it was based on u8, meaning we must
574             // divide it by 8.
575             let len = region.size() / 8;
576             // SAFETY: region is of size len
577             let bitmap = unsafe {
578                 // Cast the pointer to u64
579                 let ptr = region.as_ptr() as *const u64;
580                 std::slice::from_raw_parts(ptr, len).to_vec()
581             };
582             Ok(MemoryRangeTable::from_bitmap(bitmap, 0, 4096))
583         } else {
584             Err(Error::MissingShmLogRegion)
585         }
586     }
587 }
588 
memfd_create(name: &ffi::CStr, flags: u32) -> std::result::Result<RawFd, std::io::Error>589 fn memfd_create(name: &ffi::CStr, flags: u32) -> std::result::Result<RawFd, std::io::Error> {
590     // SAFETY: FFI call with valid arguments
591     let res = unsafe { libc::syscall(libc::SYS_memfd_create, name.as_ptr(), flags) };
592 
593     if res < 0 {
594         Err(std::io::Error::last_os_error())
595     } else {
596         Ok(res as RawFd)
597     }
598 }
599