1 // Copyright 2019 Intel Corporation. All Rights Reserved. 2 // SPDX-License-Identifier: Apache-2.0 3 4 use std::ffi; 5 use std::fs::File; 6 use std::os::unix::io::{AsRawFd, FromRawFd, RawFd}; 7 use std::os::unix::net::UnixListener; 8 use std::sync::atomic::Ordering; 9 use std::sync::Arc; 10 use std::thread::sleep; 11 use std::time::{Duration, Instant}; 12 13 use vhost::vhost_kern::vhost_binding::{VHOST_F_LOG_ALL, VHOST_VRING_F_LOG}; 14 use vhost::vhost_user::message::{ 15 VhostUserHeaderFlag, VhostUserInflight, VhostUserProtocolFeatures, VhostUserVirtioFeatures, 16 }; 17 use vhost::vhost_user::{ 18 Frontend, FrontendReqHandler, VhostUserFrontend, VhostUserFrontendReqHandler, 19 }; 20 use vhost::{VhostBackend, VhostUserDirtyLogRegion, VhostUserMemoryRegionInfo, VringConfigData}; 21 use virtio_queue::{Descriptor, Queue, QueueT}; 22 use vm_memory::{ 23 Address, Error as MmapError, FileOffset, GuestAddress, GuestMemory, GuestMemoryRegion, 24 }; 25 use vm_migration::protocol::MemoryRangeTable; 26 use vmm_sys_util::eventfd::EventFd; 27 28 use super::{Error, Result}; 29 use crate::vhost_user::Inflight; 30 use crate::{ 31 get_host_address_range, GuestMemoryMmap, GuestRegionMmap, MmapRegion, VirtioInterrupt, 32 VirtioInterruptType, 33 }; 34 35 // Size of a dirty page for vhost-user. 36 const VHOST_LOG_PAGE: u64 = 0x1000; 37 38 #[derive(Debug, Clone)] 39 pub struct VhostUserConfig { 40 pub socket: String, 41 pub num_queues: usize, 42 pub queue_size: u16, 43 } 44 45 #[derive(Clone)] 46 struct VringInfo { 47 config_data: VringConfigData, 48 used_guest_addr: u64, 49 } 50 51 #[derive(Clone)] 52 pub struct VhostUserHandle { 53 vu: Frontend, 54 ready: bool, 55 supports_migration: bool, 56 shm_log: Option<Arc<MmapRegion>>, 57 acked_features: u64, 58 vrings_info: Option<Vec<VringInfo>>, 59 queue_indexes: Vec<usize>, 60 } 61 62 impl VhostUserHandle { 63 pub fn update_mem_table(&mut self, mem: &GuestMemoryMmap) -> Result<()> { 64 let mut regions: Vec<VhostUserMemoryRegionInfo> = Vec::new(); 65 for region in mem.iter() { 66 let (mmap_handle, mmap_offset) = match region.file_offset() { 67 Some(_file_offset) => (_file_offset.file().as_raw_fd(), _file_offset.start()), 68 None => return Err(Error::VhostUserMemoryRegion(MmapError::NoMemoryRegion)), 69 }; 70 71 let vhost_user_net_reg = VhostUserMemoryRegionInfo { 72 guest_phys_addr: region.start_addr().raw_value(), 73 memory_size: region.len(), 74 userspace_addr: region.as_ptr() as u64, 75 mmap_offset, 76 mmap_handle, 77 }; 78 79 regions.push(vhost_user_net_reg); 80 } 81 82 self.vu 83 .set_mem_table(regions.as_slice()) 84 .map_err(Error::VhostUserSetMemTable)?; 85 86 Ok(()) 87 } 88 89 pub fn add_memory_region(&mut self, region: &Arc<GuestRegionMmap>) -> Result<()> { 90 let (mmap_handle, mmap_offset) = match region.file_offset() { 91 Some(file_offset) => (file_offset.file().as_raw_fd(), file_offset.start()), 92 None => return Err(Error::MissingRegionFd), 93 }; 94 95 let region = VhostUserMemoryRegionInfo { 96 guest_phys_addr: region.start_addr().raw_value(), 97 memory_size: region.len(), 98 userspace_addr: region.as_ptr() as u64, 99 mmap_offset, 100 mmap_handle, 101 }; 102 103 self.vu 104 .add_mem_region(®ion) 105 .map_err(Error::VhostUserAddMemReg) 106 } 107 108 pub fn negotiate_features_vhost_user( 109 &mut self, 110 avail_features: u64, 111 avail_protocol_features: VhostUserProtocolFeatures, 112 ) -> Result<(u64, u64)> { 113 // Set vhost-user owner. 114 self.vu.set_owner().map_err(Error::VhostUserSetOwner)?; 115 116 // Get features from backend, do negotiation to get a feature collection which 117 // both VMM and backend support. 118 let backend_features = self 119 .vu 120 .get_features() 121 .map_err(Error::VhostUserGetFeatures)?; 122 let acked_features = avail_features & backend_features; 123 124 let acked_protocol_features = 125 if acked_features & VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits() != 0 { 126 let backend_protocol_features = self 127 .vu 128 .get_protocol_features() 129 .map_err(Error::VhostUserGetProtocolFeatures)?; 130 131 let acked_protocol_features = avail_protocol_features & backend_protocol_features; 132 133 self.vu 134 .set_protocol_features(acked_protocol_features) 135 .map_err(Error::VhostUserSetProtocolFeatures)?; 136 137 acked_protocol_features 138 } else { 139 VhostUserProtocolFeatures::empty() 140 }; 141 142 if avail_protocol_features.contains(VhostUserProtocolFeatures::REPLY_ACK) 143 && acked_protocol_features.contains(VhostUserProtocolFeatures::REPLY_ACK) 144 { 145 self.vu.set_hdr_flags(VhostUserHeaderFlag::NEED_REPLY); 146 } 147 148 self.update_supports_migration(acked_features, acked_protocol_features.bits()); 149 150 Ok((acked_features, acked_protocol_features.bits())) 151 } 152 153 #[allow(clippy::too_many_arguments)] 154 pub fn setup_vhost_user<S: VhostUserFrontendReqHandler>( 155 &mut self, 156 mem: &GuestMemoryMmap, 157 queues: Vec<(usize, Queue, EventFd)>, 158 virtio_interrupt: &Arc<dyn VirtioInterrupt>, 159 acked_features: u64, 160 backend_req_handler: &Option<FrontendReqHandler<S>>, 161 inflight: Option<&mut Inflight>, 162 ) -> Result<()> { 163 self.vu 164 .set_features(acked_features) 165 .map_err(Error::VhostUserSetFeatures)?; 166 167 // Update internal value after it's been sent to the backend. 168 self.acked_features = acked_features; 169 170 // Let's first provide the memory table to the backend. 171 self.update_mem_table(mem)?; 172 173 // Send set_vring_num here, since it could tell backends, like SPDK, 174 // how many virt queues to be handled, which backend required to know 175 // at early stage. 176 for (queue_index, queue, _) in queues.iter() { 177 self.vu 178 .set_vring_num(*queue_index, queue.size()) 179 .map_err(Error::VhostUserSetVringNum)?; 180 } 181 182 // Setup for inflight I/O tracking shared memory. 183 if let Some(inflight) = inflight { 184 if inflight.fd.is_none() { 185 let inflight_req_info = VhostUserInflight { 186 mmap_size: 0, 187 mmap_offset: 0, 188 num_queues: queues.len() as u16, 189 queue_size: queues[0].1.size(), 190 }; 191 let (info, fd) = self 192 .vu 193 .get_inflight_fd(&inflight_req_info) 194 .map_err(Error::VhostUserGetInflight)?; 195 inflight.info = info; 196 inflight.fd = Some(fd); 197 } 198 // Unwrapping the inflight fd is safe here since we know it can't be None. 199 self.vu 200 .set_inflight_fd(&inflight.info, inflight.fd.as_ref().unwrap().as_raw_fd()) 201 .map_err(Error::VhostUserSetInflight)?; 202 } 203 204 let mut vrings_info = Vec::new(); 205 for (queue_index, queue, queue_evt) in queues.iter() { 206 let actual_size: usize = queue.size().into(); 207 208 let config_data = VringConfigData { 209 queue_max_size: queue.max_size(), 210 queue_size: queue.size(), 211 flags: 0u32, 212 desc_table_addr: get_host_address_range( 213 mem, 214 GuestAddress(queue.desc_table()), 215 actual_size * std::mem::size_of::<Descriptor>(), 216 ) 217 .ok_or(Error::DescriptorTableAddress)? as u64, 218 // The used ring is {flags: u16; idx: u16; virtq_used_elem [{id: u16, len: u16}; actual_size]}, 219 // i.e. 4 + (4 + 4) * actual_size. 220 used_ring_addr: get_host_address_range( 221 mem, 222 GuestAddress(queue.used_ring()), 223 4 + actual_size * 8, 224 ) 225 .ok_or(Error::UsedAddress)? as u64, 226 // The used ring is {flags: u16; idx: u16; elem [u16; actual_size]}, 227 // i.e. 4 + (2) * actual_size. 228 avail_ring_addr: get_host_address_range( 229 mem, 230 GuestAddress(queue.avail_ring()), 231 4 + actual_size * 2, 232 ) 233 .ok_or(Error::AvailAddress)? as u64, 234 log_addr: None, 235 }; 236 237 vrings_info.push(VringInfo { 238 config_data, 239 used_guest_addr: queue.used_ring(), 240 }); 241 242 self.vu 243 .set_vring_addr(*queue_index, &config_data) 244 .map_err(Error::VhostUserSetVringAddr)?; 245 self.vu 246 .set_vring_base( 247 *queue_index, 248 queue 249 .avail_idx(mem, Ordering::Acquire) 250 .map_err(Error::GetAvailableIndex)? 251 .0, 252 ) 253 .map_err(Error::VhostUserSetVringBase)?; 254 255 if let Some(eventfd) = 256 virtio_interrupt.notifier(VirtioInterruptType::Queue(*queue_index as u16)) 257 { 258 self.vu 259 .set_vring_call(*queue_index, &eventfd) 260 .map_err(Error::VhostUserSetVringCall)?; 261 } 262 263 self.vu 264 .set_vring_kick(*queue_index, queue_evt) 265 .map_err(Error::VhostUserSetVringKick)?; 266 267 self.queue_indexes.push(*queue_index); 268 } 269 270 self.enable_vhost_user_vrings(self.queue_indexes.clone(), true)?; 271 272 if let Some(backend_req_handler) = backend_req_handler { 273 self.vu 274 .set_backend_request_fd(&backend_req_handler.get_tx_raw_fd()) 275 .map_err(Error::VhostUserSetBackendRequestFd)?; 276 } 277 278 self.vrings_info = Some(vrings_info); 279 self.ready = true; 280 281 Ok(()) 282 } 283 284 fn enable_vhost_user_vrings(&mut self, queue_indexes: Vec<usize>, enable: bool) -> Result<()> { 285 for queue_index in queue_indexes { 286 self.vu 287 .set_vring_enable(queue_index, enable) 288 .map_err(Error::VhostUserSetVringEnable)?; 289 } 290 291 Ok(()) 292 } 293 294 pub fn reset_vhost_user(&mut self) -> Result<()> { 295 for queue_index in self.queue_indexes.drain(..) { 296 self.vu 297 .set_vring_enable(queue_index, false) 298 .map_err(Error::VhostUserSetVringEnable)?; 299 300 let _ = self 301 .vu 302 .get_vring_base(queue_index) 303 .map_err(Error::VhostUserGetVringBase)?; 304 } 305 306 Ok(()) 307 } 308 309 pub fn set_protocol_features_vhost_user( 310 &mut self, 311 acked_features: u64, 312 acked_protocol_features: u64, 313 ) -> Result<()> { 314 self.vu.set_owner().map_err(Error::VhostUserSetOwner)?; 315 self.vu 316 .get_features() 317 .map_err(Error::VhostUserGetFeatures)?; 318 319 if acked_features & VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits() != 0 { 320 if let Some(acked_protocol_features) = 321 VhostUserProtocolFeatures::from_bits(acked_protocol_features) 322 { 323 self.vu 324 .set_protocol_features(acked_protocol_features) 325 .map_err(Error::VhostUserSetProtocolFeatures)?; 326 327 if acked_protocol_features.contains(VhostUserProtocolFeatures::REPLY_ACK) { 328 self.vu.set_hdr_flags(VhostUserHeaderFlag::NEED_REPLY); 329 } 330 } 331 } 332 333 self.update_supports_migration(acked_features, acked_protocol_features); 334 335 Ok(()) 336 } 337 338 #[allow(clippy::too_many_arguments)] 339 pub fn reinitialize_vhost_user<S: VhostUserFrontendReqHandler>( 340 &mut self, 341 mem: &GuestMemoryMmap, 342 queues: Vec<(usize, Queue, EventFd)>, 343 virtio_interrupt: &Arc<dyn VirtioInterrupt>, 344 acked_features: u64, 345 acked_protocol_features: u64, 346 backend_req_handler: &Option<FrontendReqHandler<S>>, 347 inflight: Option<&mut Inflight>, 348 ) -> Result<()> { 349 self.set_protocol_features_vhost_user(acked_features, acked_protocol_features)?; 350 351 self.setup_vhost_user( 352 mem, 353 queues, 354 virtio_interrupt, 355 acked_features, 356 backend_req_handler, 357 inflight, 358 ) 359 } 360 361 pub fn connect_vhost_user( 362 server: bool, 363 socket_path: &str, 364 num_queues: u64, 365 unlink_socket: bool, 366 ) -> Result<Self> { 367 if server { 368 if unlink_socket { 369 std::fs::remove_file(socket_path).map_err(Error::RemoveSocketPath)?; 370 } 371 372 info!("Binding vhost-user listener..."); 373 let listener = UnixListener::bind(socket_path).map_err(Error::BindSocket)?; 374 info!("Waiting for incoming vhost-user connection..."); 375 let (stream, _) = listener.accept().map_err(Error::AcceptConnection)?; 376 377 Ok(VhostUserHandle { 378 vu: Frontend::from_stream(stream, num_queues), 379 ready: false, 380 supports_migration: false, 381 shm_log: None, 382 acked_features: 0, 383 vrings_info: None, 384 queue_indexes: Vec::new(), 385 }) 386 } else { 387 let now = Instant::now(); 388 389 // Retry connecting for a full minute 390 let err = loop { 391 let err = match Frontend::connect(socket_path, num_queues) { 392 Ok(m) => { 393 return Ok(VhostUserHandle { 394 vu: m, 395 ready: false, 396 supports_migration: false, 397 shm_log: None, 398 acked_features: 0, 399 vrings_info: None, 400 queue_indexes: Vec::new(), 401 }) 402 } 403 Err(e) => e, 404 }; 405 sleep(Duration::from_millis(100)); 406 407 if now.elapsed().as_secs() >= 60 { 408 break err; 409 } 410 }; 411 412 error!( 413 "Failed connecting the backend after trying for 1 minute: {:?}", 414 err 415 ); 416 Err(Error::VhostUserConnect) 417 } 418 } 419 420 pub fn socket_handle(&mut self) -> &mut Frontend { 421 &mut self.vu 422 } 423 424 pub fn pause_vhost_user(&mut self) -> Result<()> { 425 if self.ready { 426 self.enable_vhost_user_vrings(self.queue_indexes.clone(), false)?; 427 } 428 429 Ok(()) 430 } 431 432 pub fn resume_vhost_user(&mut self) -> Result<()> { 433 if self.ready { 434 self.enable_vhost_user_vrings(self.queue_indexes.clone(), true)?; 435 } 436 437 Ok(()) 438 } 439 440 fn update_supports_migration(&mut self, acked_features: u64, acked_protocol_features: u64) { 441 if (acked_features & u64::from(vhost::vhost_kern::vhost_binding::VHOST_F_LOG_ALL) != 0) 442 && (acked_protocol_features & VhostUserProtocolFeatures::LOG_SHMFD.bits() != 0) 443 { 444 self.supports_migration = true; 445 } 446 } 447 448 fn update_log_base(&mut self, last_ram_addr: u64) -> Result<Option<Arc<MmapRegion>>> { 449 // Create the memfd 450 let fd = memfd_create( 451 &ffi::CString::new("vhost_user_dirty_log").unwrap(), 452 libc::MFD_CLOEXEC | libc::MFD_ALLOW_SEALING, 453 ) 454 .map_err(Error::MemfdCreate)?; 455 456 // SAFETY: we checked the file descriptor is valid 457 let file = unsafe { File::from_raw_fd(fd) }; 458 // The size of the memory mapping corresponds to the size of a bitmap 459 // covering all guest pages for addresses from 0 to the last physical 460 // address in guest RAM. 461 // A page is always 4kiB from a vhost-user perspective, and each bit is 462 // a page. That's how we can compute mmap_size from the last address. 463 let mmap_size = (last_ram_addr / (VHOST_LOG_PAGE * 8)) + 1; 464 let mmap_handle = file.as_raw_fd(); 465 466 // Set shm_log region size 467 file.set_len(mmap_size).map_err(Error::SetFileSize)?; 468 469 // Set the seals 470 // SAFETY: FFI call with valid arguments 471 let res = unsafe { 472 libc::fcntl( 473 file.as_raw_fd(), 474 libc::F_ADD_SEALS, 475 libc::F_SEAL_GROW | libc::F_SEAL_SHRINK | libc::F_SEAL_SEAL, 476 ) 477 }; 478 if res < 0 { 479 return Err(Error::SetSeals(std::io::Error::last_os_error())); 480 } 481 482 // Mmap shm_log region 483 let region = MmapRegion::build( 484 Some(FileOffset::new(file, 0)), 485 mmap_size as usize, 486 libc::PROT_READ | libc::PROT_WRITE, 487 libc::MAP_SHARED, 488 ) 489 .map_err(Error::NewMmapRegion)?; 490 491 // Make sure we hold onto the region to prevent the mapping from being 492 // released. 493 let old_region = self.shm_log.replace(Arc::new(region)); 494 495 // Send the shm_log fd over to the backend 496 let log = VhostUserDirtyLogRegion { 497 mmap_size, 498 mmap_offset: 0, 499 mmap_handle, 500 }; 501 self.vu 502 .set_log_base(0, Some(log)) 503 .map_err(Error::VhostUserSetLogBase)?; 504 505 Ok(old_region) 506 } 507 508 fn set_vring_logging(&mut self, enable: bool) -> Result<()> { 509 if let Some(vrings_info) = &self.vrings_info { 510 for (i, vring_info) in vrings_info.iter().enumerate() { 511 let mut config_data = vring_info.config_data; 512 config_data.flags = if enable { 1 << VHOST_VRING_F_LOG } else { 0 }; 513 config_data.log_addr = if enable { 514 Some(vring_info.used_guest_addr) 515 } else { 516 None 517 }; 518 519 self.vu 520 .set_vring_addr(i, &config_data) 521 .map_err(Error::VhostUserSetVringAddr)?; 522 } 523 } 524 525 Ok(()) 526 } 527 528 pub fn start_dirty_log(&mut self, last_ram_addr: u64) -> Result<()> { 529 if !self.supports_migration { 530 return Err(Error::MigrationNotSupported); 531 } 532 533 // Set the shm log region 534 self.update_log_base(last_ram_addr)?; 535 536 // Enable VHOST_F_LOG_ALL feature 537 let features = self.acked_features | (1 << VHOST_F_LOG_ALL); 538 self.vu 539 .set_features(features) 540 .map_err(Error::VhostUserSetFeatures)?; 541 542 // Enable dirty page logging of used ring for all queues 543 self.set_vring_logging(true) 544 } 545 546 pub fn stop_dirty_log(&mut self) -> Result<()> { 547 if !self.supports_migration { 548 return Err(Error::MigrationNotSupported); 549 } 550 551 // Disable dirty page logging of used ring for all queues 552 self.set_vring_logging(false)?; 553 554 // Disable VHOST_F_LOG_ALL feature 555 self.vu 556 .set_features(self.acked_features) 557 .map_err(Error::VhostUserSetFeatures)?; 558 559 // This is important here since the log region goes out of scope, 560 // invoking the Drop trait, hence unmapping the memory. 561 self.shm_log = None; 562 563 Ok(()) 564 } 565 566 pub fn dirty_log(&mut self, last_ram_addr: u64) -> Result<MemoryRangeTable> { 567 // The log region is updated by creating a new region that is sent to 568 // the backend. This ensures the backend stops logging to the previous 569 // region. The previous region is returned and processed to create the 570 // bitmap representing the dirty pages. 571 if let Some(region) = self.update_log_base(last_ram_addr)? { 572 // Be careful with the size, as it was based on u8, meaning we must 573 // divide it by 8. 574 let len = region.size() / 8; 575 // SAFETY: region is of size len 576 let bitmap = unsafe { 577 // Cast the pointer to u64 578 let ptr = region.as_ptr() as *const u64; 579 std::slice::from_raw_parts(ptr, len).to_vec() 580 }; 581 Ok(MemoryRangeTable::from_bitmap(bitmap, 0, 4096)) 582 } else { 583 Err(Error::MissingShmLogRegion) 584 } 585 } 586 } 587 588 fn memfd_create(name: &ffi::CStr, flags: u32) -> std::result::Result<RawFd, std::io::Error> { 589 // SAFETY: FFI call with valid arguments 590 let res = unsafe { libc::syscall(libc::SYS_memfd_create, name.as_ptr(), flags) }; 591 592 if res < 0 { 593 Err(std::io::Error::last_os_error()) 594 } else { 595 Ok(res as RawFd) 596 } 597 } 598