1 // Copyright 2019 Intel Corporation. All Rights Reserved.
2 // SPDX-License-Identifier: Apache-2.0
3
4 use std::ffi;
5 use std::fs::File;
6 use std::os::unix::io::{AsRawFd, FromRawFd, RawFd};
7 use std::os::unix::net::UnixListener;
8 use std::sync::atomic::Ordering;
9 use std::sync::Arc;
10 use std::thread::sleep;
11 use std::time::{Duration, Instant};
12
13 use vhost::vhost_kern::vhost_binding::{VHOST_F_LOG_ALL, VHOST_VRING_F_LOG};
14 use vhost::vhost_user::message::{
15 VhostUserHeaderFlag, VhostUserInflight, VhostUserProtocolFeatures, VhostUserVirtioFeatures,
16 };
17 use vhost::vhost_user::{
18 Frontend, FrontendReqHandler, VhostUserFrontend, VhostUserFrontendReqHandler,
19 };
20 use vhost::{VhostBackend, VhostUserDirtyLogRegion, VhostUserMemoryRegionInfo, VringConfigData};
21 use virtio_queue::desc::RawDescriptor;
22 use virtio_queue::{Queue, QueueT};
23 use vm_memory::{
24 Address, Error as MmapError, FileOffset, GuestAddress, GuestMemory, GuestMemoryRegion,
25 };
26 use vm_migration::protocol::MemoryRangeTable;
27 use vmm_sys_util::eventfd::EventFd;
28
29 use super::{Error, Result};
30 use crate::vhost_user::Inflight;
31 use crate::{
32 get_host_address_range, GuestMemoryMmap, GuestRegionMmap, MmapRegion, VirtioInterrupt,
33 VirtioInterruptType,
34 };
35
36 // Size of a dirty page for vhost-user.
37 const VHOST_LOG_PAGE: u64 = 0x1000;
38
39 #[derive(Debug, Clone)]
40 pub struct VhostUserConfig {
41 pub socket: String,
42 pub num_queues: usize,
43 pub queue_size: u16,
44 }
45
46 #[derive(Clone)]
47 struct VringInfo {
48 config_data: VringConfigData,
49 used_guest_addr: u64,
50 }
51
52 #[derive(Clone)]
53 pub struct VhostUserHandle {
54 vu: Frontend,
55 ready: bool,
56 supports_migration: bool,
57 shm_log: Option<Arc<MmapRegion>>,
58 acked_features: u64,
59 vrings_info: Option<Vec<VringInfo>>,
60 queue_indexes: Vec<usize>,
61 }
62
63 impl VhostUserHandle {
update_mem_table(&mut self, mem: &GuestMemoryMmap) -> Result<()>64 pub fn update_mem_table(&mut self, mem: &GuestMemoryMmap) -> Result<()> {
65 let mut regions: Vec<VhostUserMemoryRegionInfo> = Vec::new();
66 for region in mem.iter() {
67 let (mmap_handle, mmap_offset) = match region.file_offset() {
68 Some(_file_offset) => (_file_offset.file().as_raw_fd(), _file_offset.start()),
69 None => return Err(Error::VhostUserMemoryRegion(MmapError::NoMemoryRegion)),
70 };
71
72 let vhost_user_net_reg = VhostUserMemoryRegionInfo {
73 guest_phys_addr: region.start_addr().raw_value(),
74 memory_size: region.len(),
75 userspace_addr: region.as_ptr() as u64,
76 mmap_offset,
77 mmap_handle,
78 };
79
80 regions.push(vhost_user_net_reg);
81 }
82
83 self.vu
84 .set_mem_table(regions.as_slice())
85 .map_err(Error::VhostUserSetMemTable)?;
86
87 Ok(())
88 }
89
add_memory_region(&mut self, region: &Arc<GuestRegionMmap>) -> Result<()>90 pub fn add_memory_region(&mut self, region: &Arc<GuestRegionMmap>) -> Result<()> {
91 let (mmap_handle, mmap_offset) = match region.file_offset() {
92 Some(file_offset) => (file_offset.file().as_raw_fd(), file_offset.start()),
93 None => return Err(Error::MissingRegionFd),
94 };
95
96 let region = VhostUserMemoryRegionInfo {
97 guest_phys_addr: region.start_addr().raw_value(),
98 memory_size: region.len(),
99 userspace_addr: region.as_ptr() as u64,
100 mmap_offset,
101 mmap_handle,
102 };
103
104 self.vu
105 .add_mem_region(®ion)
106 .map_err(Error::VhostUserAddMemReg)
107 }
108
negotiate_features_vhost_user( &mut self, avail_features: u64, avail_protocol_features: VhostUserProtocolFeatures, ) -> Result<(u64, u64)>109 pub fn negotiate_features_vhost_user(
110 &mut self,
111 avail_features: u64,
112 avail_protocol_features: VhostUserProtocolFeatures,
113 ) -> Result<(u64, u64)> {
114 // Set vhost-user owner.
115 self.vu.set_owner().map_err(Error::VhostUserSetOwner)?;
116
117 // Get features from backend, do negotiation to get a feature collection which
118 // both VMM and backend support.
119 let backend_features = self
120 .vu
121 .get_features()
122 .map_err(Error::VhostUserGetFeatures)?;
123 let acked_features = avail_features & backend_features;
124
125 let acked_protocol_features =
126 if acked_features & VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits() != 0 {
127 let backend_protocol_features = self
128 .vu
129 .get_protocol_features()
130 .map_err(Error::VhostUserGetProtocolFeatures)?;
131
132 let acked_protocol_features = avail_protocol_features & backend_protocol_features;
133
134 self.vu
135 .set_protocol_features(acked_protocol_features)
136 .map_err(Error::VhostUserSetProtocolFeatures)?;
137
138 acked_protocol_features
139 } else {
140 VhostUserProtocolFeatures::empty()
141 };
142
143 if avail_protocol_features.contains(VhostUserProtocolFeatures::REPLY_ACK)
144 && acked_protocol_features.contains(VhostUserProtocolFeatures::REPLY_ACK)
145 {
146 self.vu.set_hdr_flags(VhostUserHeaderFlag::NEED_REPLY);
147 }
148
149 self.update_supports_migration(acked_features, acked_protocol_features.bits());
150
151 Ok((acked_features, acked_protocol_features.bits()))
152 }
153
154 #[allow(clippy::too_many_arguments)]
setup_vhost_user<S: VhostUserFrontendReqHandler>( &mut self, mem: &GuestMemoryMmap, queues: Vec<(usize, Queue, EventFd)>, virtio_interrupt: &Arc<dyn VirtioInterrupt>, acked_features: u64, backend_req_handler: &Option<FrontendReqHandler<S>>, inflight: Option<&mut Inflight>, ) -> Result<()>155 pub fn setup_vhost_user<S: VhostUserFrontendReqHandler>(
156 &mut self,
157 mem: &GuestMemoryMmap,
158 queues: Vec<(usize, Queue, EventFd)>,
159 virtio_interrupt: &Arc<dyn VirtioInterrupt>,
160 acked_features: u64,
161 backend_req_handler: &Option<FrontendReqHandler<S>>,
162 inflight: Option<&mut Inflight>,
163 ) -> Result<()> {
164 self.vu
165 .set_features(acked_features)
166 .map_err(Error::VhostUserSetFeatures)?;
167
168 // Update internal value after it's been sent to the backend.
169 self.acked_features = acked_features;
170
171 // Let's first provide the memory table to the backend.
172 self.update_mem_table(mem)?;
173
174 // Send set_vring_num here, since it could tell backends, like SPDK,
175 // how many virt queues to be handled, which backend required to know
176 // at early stage.
177 for (queue_index, queue, _) in queues.iter() {
178 self.vu
179 .set_vring_num(*queue_index, queue.size())
180 .map_err(Error::VhostUserSetVringNum)?;
181 }
182
183 // Setup for inflight I/O tracking shared memory.
184 if let Some(inflight) = inflight {
185 if inflight.fd.is_none() {
186 let inflight_req_info = VhostUserInflight {
187 mmap_size: 0,
188 mmap_offset: 0,
189 num_queues: queues.len() as u16,
190 queue_size: queues[0].1.size(),
191 };
192 let (info, fd) = self
193 .vu
194 .get_inflight_fd(&inflight_req_info)
195 .map_err(Error::VhostUserGetInflight)?;
196 inflight.info = info;
197 inflight.fd = Some(fd);
198 }
199 // Unwrapping the inflight fd is safe here since we know it can't be None.
200 self.vu
201 .set_inflight_fd(&inflight.info, inflight.fd.as_ref().unwrap().as_raw_fd())
202 .map_err(Error::VhostUserSetInflight)?;
203 }
204
205 let mut vrings_info = Vec::new();
206 for (queue_index, queue, queue_evt) in queues.iter() {
207 let actual_size: usize = queue.size().into();
208
209 let config_data = VringConfigData {
210 queue_max_size: queue.max_size(),
211 queue_size: queue.size(),
212 flags: 0u32,
213 desc_table_addr: get_host_address_range(
214 mem,
215 GuestAddress(queue.desc_table()),
216 actual_size * std::mem::size_of::<RawDescriptor>(),
217 )
218 .ok_or(Error::DescriptorTableAddress)? as u64,
219 // The used ring is {flags: u16; idx: u16; virtq_used_elem [{id: u16, len: u16}; actual_size]},
220 // i.e. 4 + (4 + 4) * actual_size.
221 used_ring_addr: get_host_address_range(
222 mem,
223 GuestAddress(queue.used_ring()),
224 4 + actual_size * 8,
225 )
226 .ok_or(Error::UsedAddress)? as u64,
227 // The used ring is {flags: u16; idx: u16; elem [u16; actual_size]},
228 // i.e. 4 + (2) * actual_size.
229 avail_ring_addr: get_host_address_range(
230 mem,
231 GuestAddress(queue.avail_ring()),
232 4 + actual_size * 2,
233 )
234 .ok_or(Error::AvailAddress)? as u64,
235 log_addr: None,
236 };
237
238 vrings_info.push(VringInfo {
239 config_data,
240 used_guest_addr: queue.used_ring(),
241 });
242
243 self.vu
244 .set_vring_addr(*queue_index, &config_data)
245 .map_err(Error::VhostUserSetVringAddr)?;
246 self.vu
247 .set_vring_base(
248 *queue_index,
249 queue
250 .avail_idx(mem, Ordering::Acquire)
251 .map_err(Error::GetAvailableIndex)?
252 .0,
253 )
254 .map_err(Error::VhostUserSetVringBase)?;
255
256 if let Some(eventfd) =
257 virtio_interrupt.notifier(VirtioInterruptType::Queue(*queue_index as u16))
258 {
259 self.vu
260 .set_vring_call(*queue_index, &eventfd)
261 .map_err(Error::VhostUserSetVringCall)?;
262 }
263
264 self.vu
265 .set_vring_kick(*queue_index, queue_evt)
266 .map_err(Error::VhostUserSetVringKick)?;
267
268 self.queue_indexes.push(*queue_index);
269 }
270
271 self.enable_vhost_user_vrings(self.queue_indexes.clone(), true)?;
272
273 if let Some(backend_req_handler) = backend_req_handler {
274 self.vu
275 .set_backend_request_fd(&backend_req_handler.get_tx_raw_fd())
276 .map_err(Error::VhostUserSetBackendRequestFd)?;
277 }
278
279 self.vrings_info = Some(vrings_info);
280 self.ready = true;
281
282 Ok(())
283 }
284
enable_vhost_user_vrings(&mut self, queue_indexes: Vec<usize>, enable: bool) -> Result<()>285 fn enable_vhost_user_vrings(&mut self, queue_indexes: Vec<usize>, enable: bool) -> Result<()> {
286 for queue_index in queue_indexes {
287 self.vu
288 .set_vring_enable(queue_index, enable)
289 .map_err(Error::VhostUserSetVringEnable)?;
290 }
291
292 Ok(())
293 }
294
reset_vhost_user(&mut self) -> Result<()>295 pub fn reset_vhost_user(&mut self) -> Result<()> {
296 for queue_index in self.queue_indexes.drain(..) {
297 self.vu
298 .set_vring_enable(queue_index, false)
299 .map_err(Error::VhostUserSetVringEnable)?;
300
301 let _ = self
302 .vu
303 .get_vring_base(queue_index)
304 .map_err(Error::VhostUserGetVringBase)?;
305 }
306
307 Ok(())
308 }
309
set_protocol_features_vhost_user( &mut self, acked_features: u64, acked_protocol_features: u64, ) -> Result<()>310 pub fn set_protocol_features_vhost_user(
311 &mut self,
312 acked_features: u64,
313 acked_protocol_features: u64,
314 ) -> Result<()> {
315 self.vu.set_owner().map_err(Error::VhostUserSetOwner)?;
316 self.vu
317 .get_features()
318 .map_err(Error::VhostUserGetFeatures)?;
319
320 if acked_features & VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits() != 0 {
321 if let Some(acked_protocol_features) =
322 VhostUserProtocolFeatures::from_bits(acked_protocol_features)
323 {
324 self.vu
325 .set_protocol_features(acked_protocol_features)
326 .map_err(Error::VhostUserSetProtocolFeatures)?;
327
328 if acked_protocol_features.contains(VhostUserProtocolFeatures::REPLY_ACK) {
329 self.vu.set_hdr_flags(VhostUserHeaderFlag::NEED_REPLY);
330 }
331 }
332 }
333
334 self.update_supports_migration(acked_features, acked_protocol_features);
335
336 Ok(())
337 }
338
339 #[allow(clippy::too_many_arguments)]
reinitialize_vhost_user<S: VhostUserFrontendReqHandler>( &mut self, mem: &GuestMemoryMmap, queues: Vec<(usize, Queue, EventFd)>, virtio_interrupt: &Arc<dyn VirtioInterrupt>, acked_features: u64, acked_protocol_features: u64, backend_req_handler: &Option<FrontendReqHandler<S>>, inflight: Option<&mut Inflight>, ) -> Result<()>340 pub fn reinitialize_vhost_user<S: VhostUserFrontendReqHandler>(
341 &mut self,
342 mem: &GuestMemoryMmap,
343 queues: Vec<(usize, Queue, EventFd)>,
344 virtio_interrupt: &Arc<dyn VirtioInterrupt>,
345 acked_features: u64,
346 acked_protocol_features: u64,
347 backend_req_handler: &Option<FrontendReqHandler<S>>,
348 inflight: Option<&mut Inflight>,
349 ) -> Result<()> {
350 self.set_protocol_features_vhost_user(acked_features, acked_protocol_features)?;
351
352 self.setup_vhost_user(
353 mem,
354 queues,
355 virtio_interrupt,
356 acked_features,
357 backend_req_handler,
358 inflight,
359 )
360 }
361
connect_vhost_user( server: bool, socket_path: &str, num_queues: u64, unlink_socket: bool, ) -> Result<Self>362 pub fn connect_vhost_user(
363 server: bool,
364 socket_path: &str,
365 num_queues: u64,
366 unlink_socket: bool,
367 ) -> Result<Self> {
368 if server {
369 if unlink_socket {
370 std::fs::remove_file(socket_path).map_err(Error::RemoveSocketPath)?;
371 }
372
373 info!("Binding vhost-user listener...");
374 let listener = UnixListener::bind(socket_path).map_err(Error::BindSocket)?;
375 info!("Waiting for incoming vhost-user connection...");
376 let (stream, _) = listener.accept().map_err(Error::AcceptConnection)?;
377
378 Ok(VhostUserHandle {
379 vu: Frontend::from_stream(stream, num_queues),
380 ready: false,
381 supports_migration: false,
382 shm_log: None,
383 acked_features: 0,
384 vrings_info: None,
385 queue_indexes: Vec::new(),
386 })
387 } else {
388 let now = Instant::now();
389
390 // Retry connecting for a full minute
391 let err = loop {
392 let err = match Frontend::connect(socket_path, num_queues) {
393 Ok(m) => {
394 return Ok(VhostUserHandle {
395 vu: m,
396 ready: false,
397 supports_migration: false,
398 shm_log: None,
399 acked_features: 0,
400 vrings_info: None,
401 queue_indexes: Vec::new(),
402 })
403 }
404 Err(e) => e,
405 };
406 sleep(Duration::from_millis(100));
407
408 if now.elapsed().as_secs() >= 60 {
409 break err;
410 }
411 };
412
413 error!(
414 "Failed connecting the backend after trying for 1 minute: {:?}",
415 err
416 );
417 Err(Error::VhostUserConnect)
418 }
419 }
420
socket_handle(&mut self) -> &mut Frontend421 pub fn socket_handle(&mut self) -> &mut Frontend {
422 &mut self.vu
423 }
424
pause_vhost_user(&mut self) -> Result<()>425 pub fn pause_vhost_user(&mut self) -> Result<()> {
426 if self.ready {
427 self.enable_vhost_user_vrings(self.queue_indexes.clone(), false)?;
428 }
429
430 Ok(())
431 }
432
resume_vhost_user(&mut self) -> Result<()>433 pub fn resume_vhost_user(&mut self) -> Result<()> {
434 if self.ready {
435 self.enable_vhost_user_vrings(self.queue_indexes.clone(), true)?;
436 }
437
438 Ok(())
439 }
440
update_supports_migration(&mut self, acked_features: u64, acked_protocol_features: u64)441 fn update_supports_migration(&mut self, acked_features: u64, acked_protocol_features: u64) {
442 if (acked_features & u64::from(vhost::vhost_kern::vhost_binding::VHOST_F_LOG_ALL) != 0)
443 && (acked_protocol_features & VhostUserProtocolFeatures::LOG_SHMFD.bits() != 0)
444 {
445 self.supports_migration = true;
446 }
447 }
448
update_log_base(&mut self, last_ram_addr: u64) -> Result<Option<Arc<MmapRegion>>>449 fn update_log_base(&mut self, last_ram_addr: u64) -> Result<Option<Arc<MmapRegion>>> {
450 // Create the memfd
451 let fd = memfd_create(
452 &ffi::CString::new("vhost_user_dirty_log").unwrap(),
453 libc::MFD_CLOEXEC | libc::MFD_ALLOW_SEALING,
454 )
455 .map_err(Error::MemfdCreate)?;
456
457 // SAFETY: we checked the file descriptor is valid
458 let file = unsafe { File::from_raw_fd(fd) };
459 // The size of the memory mapping corresponds to the size of a bitmap
460 // covering all guest pages for addresses from 0 to the last physical
461 // address in guest RAM.
462 // A page is always 4kiB from a vhost-user perspective, and each bit is
463 // a page. That's how we can compute mmap_size from the last address.
464 let mmap_size = (last_ram_addr / (VHOST_LOG_PAGE * 8)) + 1;
465 let mmap_handle = file.as_raw_fd();
466
467 // Set shm_log region size
468 file.set_len(mmap_size).map_err(Error::SetFileSize)?;
469
470 // Set the seals
471 // SAFETY: FFI call with valid arguments
472 let res = unsafe {
473 libc::fcntl(
474 file.as_raw_fd(),
475 libc::F_ADD_SEALS,
476 libc::F_SEAL_GROW | libc::F_SEAL_SHRINK | libc::F_SEAL_SEAL,
477 )
478 };
479 if res < 0 {
480 return Err(Error::SetSeals(std::io::Error::last_os_error()));
481 }
482
483 // Mmap shm_log region
484 let region = MmapRegion::build(
485 Some(FileOffset::new(file, 0)),
486 mmap_size as usize,
487 libc::PROT_READ | libc::PROT_WRITE,
488 libc::MAP_SHARED,
489 )
490 .map_err(Error::NewMmapRegion)?;
491
492 // Make sure we hold onto the region to prevent the mapping from being
493 // released.
494 let old_region = self.shm_log.replace(Arc::new(region));
495
496 // Send the shm_log fd over to the backend
497 let log = VhostUserDirtyLogRegion {
498 mmap_size,
499 mmap_offset: 0,
500 mmap_handle,
501 };
502 self.vu
503 .set_log_base(0, Some(log))
504 .map_err(Error::VhostUserSetLogBase)?;
505
506 Ok(old_region)
507 }
508
set_vring_logging(&mut self, enable: bool) -> Result<()>509 fn set_vring_logging(&mut self, enable: bool) -> Result<()> {
510 if let Some(vrings_info) = &self.vrings_info {
511 for (i, vring_info) in vrings_info.iter().enumerate() {
512 let mut config_data = vring_info.config_data;
513 config_data.flags = if enable { 1 << VHOST_VRING_F_LOG } else { 0 };
514 config_data.log_addr = if enable {
515 Some(vring_info.used_guest_addr)
516 } else {
517 None
518 };
519
520 self.vu
521 .set_vring_addr(i, &config_data)
522 .map_err(Error::VhostUserSetVringAddr)?;
523 }
524 }
525
526 Ok(())
527 }
528
start_dirty_log(&mut self, last_ram_addr: u64) -> Result<()>529 pub fn start_dirty_log(&mut self, last_ram_addr: u64) -> Result<()> {
530 if !self.supports_migration {
531 return Err(Error::MigrationNotSupported);
532 }
533
534 // Set the shm log region
535 self.update_log_base(last_ram_addr)?;
536
537 // Enable VHOST_F_LOG_ALL feature
538 let features = self.acked_features | (1 << VHOST_F_LOG_ALL);
539 self.vu
540 .set_features(features)
541 .map_err(Error::VhostUserSetFeatures)?;
542
543 // Enable dirty page logging of used ring for all queues
544 self.set_vring_logging(true)
545 }
546
stop_dirty_log(&mut self) -> Result<()>547 pub fn stop_dirty_log(&mut self) -> Result<()> {
548 if !self.supports_migration {
549 return Err(Error::MigrationNotSupported);
550 }
551
552 // Disable dirty page logging of used ring for all queues
553 self.set_vring_logging(false)?;
554
555 // Disable VHOST_F_LOG_ALL feature
556 self.vu
557 .set_features(self.acked_features)
558 .map_err(Error::VhostUserSetFeatures)?;
559
560 // This is important here since the log region goes out of scope,
561 // invoking the Drop trait, hence unmapping the memory.
562 self.shm_log = None;
563
564 Ok(())
565 }
566
dirty_log(&mut self, last_ram_addr: u64) -> Result<MemoryRangeTable>567 pub fn dirty_log(&mut self, last_ram_addr: u64) -> Result<MemoryRangeTable> {
568 // The log region is updated by creating a new region that is sent to
569 // the backend. This ensures the backend stops logging to the previous
570 // region. The previous region is returned and processed to create the
571 // bitmap representing the dirty pages.
572 if let Some(region) = self.update_log_base(last_ram_addr)? {
573 // Be careful with the size, as it was based on u8, meaning we must
574 // divide it by 8.
575 let len = region.size() / 8;
576 // SAFETY: region is of size len
577 let bitmap = unsafe {
578 // Cast the pointer to u64
579 let ptr = region.as_ptr() as *const u64;
580 std::slice::from_raw_parts(ptr, len).to_vec()
581 };
582 Ok(MemoryRangeTable::from_bitmap(bitmap, 0, 4096))
583 } else {
584 Err(Error::MissingShmLogRegion)
585 }
586 }
587 }
588
memfd_create(name: &ffi::CStr, flags: u32) -> std::result::Result<RawFd, std::io::Error>589 fn memfd_create(name: &ffi::CStr, flags: u32) -> std::result::Result<RawFd, std::io::Error> {
590 // SAFETY: FFI call with valid arguments
591 let res = unsafe { libc::syscall(libc::SYS_memfd_create, name.as_ptr(), flags) };
592
593 if res < 0 {
594 Err(std::io::Error::last_os_error())
595 } else {
596 Ok(res as RawFd)
597 }
598 }
599