1 // Copyright 2019 Intel Corporation. All Rights Reserved. 2 // SPDX-License-Identifier: Apache-2.0 3 4 use super::{Error, Result}; 5 use crate::vhost_user::Inflight; 6 use crate::{ 7 get_host_address_range, GuestMemoryMmap, GuestRegionMmap, MmapRegion, VirtioInterrupt, 8 VirtioInterruptType, 9 }; 10 use std::ffi; 11 use std::fs::File; 12 use std::os::unix::io::{AsRawFd, FromRawFd, RawFd}; 13 use std::os::unix::net::UnixListener; 14 use std::sync::atomic::Ordering; 15 use std::sync::Arc; 16 use std::thread::sleep; 17 use std::time::{Duration, Instant}; 18 use std::vec::Vec; 19 use vhost::vhost_kern::vhost_binding::{VHOST_F_LOG_ALL, VHOST_VRING_F_LOG}; 20 use vhost::vhost_user::message::{ 21 VhostUserHeaderFlag, VhostUserInflight, VhostUserProtocolFeatures, VhostUserVirtioFeatures, 22 }; 23 use vhost::vhost_user::{ 24 Frontend, FrontendReqHandler, VhostUserFrontend, VhostUserFrontendReqHandler, 25 }; 26 use vhost::{VhostBackend, VhostUserDirtyLogRegion, VhostUserMemoryRegionInfo, VringConfigData}; 27 use virtio_queue::{Descriptor, Queue, QueueT}; 28 use vm_memory::{ 29 Address, Error as MmapError, FileOffset, GuestAddress, GuestMemory, GuestMemoryRegion, 30 }; 31 use vm_migration::protocol::MemoryRangeTable; 32 use vmm_sys_util::eventfd::EventFd; 33 34 // Size of a dirty page for vhost-user. 35 const VHOST_LOG_PAGE: u64 = 0x1000; 36 37 #[derive(Debug, Clone)] 38 pub struct VhostUserConfig { 39 pub socket: String, 40 pub num_queues: usize, 41 pub queue_size: u16, 42 } 43 44 #[derive(Clone)] 45 struct VringInfo { 46 config_data: VringConfigData, 47 used_guest_addr: u64, 48 } 49 50 #[derive(Clone)] 51 pub struct VhostUserHandle { 52 vu: Frontend, 53 ready: bool, 54 supports_migration: bool, 55 shm_log: Option<Arc<MmapRegion>>, 56 acked_features: u64, 57 vrings_info: Option<Vec<VringInfo>>, 58 queue_indexes: Vec<usize>, 59 } 60 61 impl VhostUserHandle { 62 pub fn update_mem_table(&mut self, mem: &GuestMemoryMmap) -> Result<()> { 63 let mut regions: Vec<VhostUserMemoryRegionInfo> = Vec::new(); 64 for region in mem.iter() { 65 let (mmap_handle, mmap_offset) = match region.file_offset() { 66 Some(_file_offset) => (_file_offset.file().as_raw_fd(), _file_offset.start()), 67 None => return Err(Error::VhostUserMemoryRegion(MmapError::NoMemoryRegion)), 68 }; 69 70 let vhost_user_net_reg = VhostUserMemoryRegionInfo { 71 guest_phys_addr: region.start_addr().raw_value(), 72 memory_size: region.len(), 73 userspace_addr: region.as_ptr() as u64, 74 mmap_offset, 75 mmap_handle, 76 }; 77 78 regions.push(vhost_user_net_reg); 79 } 80 81 self.vu 82 .set_mem_table(regions.as_slice()) 83 .map_err(Error::VhostUserSetMemTable)?; 84 85 Ok(()) 86 } 87 88 pub fn add_memory_region(&mut self, region: &Arc<GuestRegionMmap>) -> Result<()> { 89 let (mmap_handle, mmap_offset) = match region.file_offset() { 90 Some(file_offset) => (file_offset.file().as_raw_fd(), file_offset.start()), 91 None => return Err(Error::MissingRegionFd), 92 }; 93 94 let region = VhostUserMemoryRegionInfo { 95 guest_phys_addr: region.start_addr().raw_value(), 96 memory_size: region.len(), 97 userspace_addr: region.as_ptr() as u64, 98 mmap_offset, 99 mmap_handle, 100 }; 101 102 self.vu 103 .add_mem_region(®ion) 104 .map_err(Error::VhostUserAddMemReg) 105 } 106 107 pub fn negotiate_features_vhost_user( 108 &mut self, 109 avail_features: u64, 110 avail_protocol_features: VhostUserProtocolFeatures, 111 ) -> Result<(u64, u64)> { 112 // Set vhost-user owner. 113 self.vu.set_owner().map_err(Error::VhostUserSetOwner)?; 114 115 // Get features from backend, do negotiation to get a feature collection which 116 // both VMM and backend support. 117 let backend_features = self 118 .vu 119 .get_features() 120 .map_err(Error::VhostUserGetFeatures)?; 121 let acked_features = avail_features & backend_features; 122 123 let acked_protocol_features = 124 if acked_features & VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits() != 0 { 125 let backend_protocol_features = self 126 .vu 127 .get_protocol_features() 128 .map_err(Error::VhostUserGetProtocolFeatures)?; 129 130 let acked_protocol_features = avail_protocol_features & backend_protocol_features; 131 132 self.vu 133 .set_protocol_features(acked_protocol_features) 134 .map_err(Error::VhostUserSetProtocolFeatures)?; 135 136 acked_protocol_features 137 } else { 138 VhostUserProtocolFeatures::empty() 139 }; 140 141 if avail_protocol_features.contains(VhostUserProtocolFeatures::REPLY_ACK) 142 && acked_protocol_features.contains(VhostUserProtocolFeatures::REPLY_ACK) 143 { 144 self.vu.set_hdr_flags(VhostUserHeaderFlag::NEED_REPLY); 145 } 146 147 self.update_supports_migration(acked_features, acked_protocol_features.bits()); 148 149 Ok((acked_features, acked_protocol_features.bits())) 150 } 151 152 #[allow(clippy::too_many_arguments)] 153 pub fn setup_vhost_user<S: VhostUserFrontendReqHandler>( 154 &mut self, 155 mem: &GuestMemoryMmap, 156 queues: Vec<(usize, Queue, EventFd)>, 157 virtio_interrupt: &Arc<dyn VirtioInterrupt>, 158 acked_features: u64, 159 backend_req_handler: &Option<FrontendReqHandler<S>>, 160 inflight: Option<&mut Inflight>, 161 ) -> Result<()> { 162 self.vu 163 .set_features(acked_features) 164 .map_err(Error::VhostUserSetFeatures)?; 165 166 // Update internal value after it's been sent to the backend. 167 self.acked_features = acked_features; 168 169 // Let's first provide the memory table to the backend. 170 self.update_mem_table(mem)?; 171 172 // Send set_vring_num here, since it could tell backends, like SPDK, 173 // how many virt queues to be handled, which backend required to know 174 // at early stage. 175 for (queue_index, queue, _) in queues.iter() { 176 self.vu 177 .set_vring_num(*queue_index, queue.size()) 178 .map_err(Error::VhostUserSetVringNum)?; 179 } 180 181 // Setup for inflight I/O tracking shared memory. 182 if let Some(inflight) = inflight { 183 if inflight.fd.is_none() { 184 let inflight_req_info = VhostUserInflight { 185 mmap_size: 0, 186 mmap_offset: 0, 187 num_queues: queues.len() as u16, 188 queue_size: queues[0].1.size(), 189 }; 190 let (info, fd) = self 191 .vu 192 .get_inflight_fd(&inflight_req_info) 193 .map_err(Error::VhostUserGetInflight)?; 194 inflight.info = info; 195 inflight.fd = Some(fd); 196 } 197 // Unwrapping the inflight fd is safe here since we know it can't be None. 198 self.vu 199 .set_inflight_fd(&inflight.info, inflight.fd.as_ref().unwrap().as_raw_fd()) 200 .map_err(Error::VhostUserSetInflight)?; 201 } 202 203 let mut vrings_info = Vec::new(); 204 for (queue_index, queue, queue_evt) in queues.iter() { 205 let actual_size: usize = queue.size().into(); 206 207 let config_data = VringConfigData { 208 queue_max_size: queue.max_size(), 209 queue_size: queue.size(), 210 flags: 0u32, 211 desc_table_addr: get_host_address_range( 212 mem, 213 GuestAddress(queue.desc_table()), 214 actual_size * std::mem::size_of::<Descriptor>(), 215 ) 216 .ok_or(Error::DescriptorTableAddress)? as u64, 217 // The used ring is {flags: u16; idx: u16; virtq_used_elem [{id: u16, len: u16}; actual_size]}, 218 // i.e. 4 + (4 + 4) * actual_size. 219 used_ring_addr: get_host_address_range( 220 mem, 221 GuestAddress(queue.used_ring()), 222 4 + actual_size * 8, 223 ) 224 .ok_or(Error::UsedAddress)? as u64, 225 // The used ring is {flags: u16; idx: u16; elem [u16; actual_size]}, 226 // i.e. 4 + (2) * actual_size. 227 avail_ring_addr: get_host_address_range( 228 mem, 229 GuestAddress(queue.avail_ring()), 230 4 + actual_size * 2, 231 ) 232 .ok_or(Error::AvailAddress)? as u64, 233 log_addr: None, 234 }; 235 236 vrings_info.push(VringInfo { 237 config_data, 238 used_guest_addr: queue.used_ring(), 239 }); 240 241 self.vu 242 .set_vring_addr(*queue_index, &config_data) 243 .map_err(Error::VhostUserSetVringAddr)?; 244 self.vu 245 .set_vring_base( 246 *queue_index, 247 queue 248 .avail_idx(mem, Ordering::Acquire) 249 .map_err(Error::GetAvailableIndex)? 250 .0, 251 ) 252 .map_err(Error::VhostUserSetVringBase)?; 253 254 if let Some(eventfd) = 255 virtio_interrupt.notifier(VirtioInterruptType::Queue(*queue_index as u16)) 256 { 257 self.vu 258 .set_vring_call(*queue_index, &eventfd) 259 .map_err(Error::VhostUserSetVringCall)?; 260 } 261 262 self.vu 263 .set_vring_kick(*queue_index, queue_evt) 264 .map_err(Error::VhostUserSetVringKick)?; 265 266 self.queue_indexes.push(*queue_index); 267 } 268 269 self.enable_vhost_user_vrings(self.queue_indexes.clone(), true)?; 270 271 if let Some(backend_req_handler) = backend_req_handler { 272 self.vu 273 .set_backend_request_fd(&backend_req_handler.get_tx_raw_fd()) 274 .map_err(Error::VhostUserSetBackendRequestFd)?; 275 } 276 277 self.vrings_info = Some(vrings_info); 278 self.ready = true; 279 280 Ok(()) 281 } 282 283 fn enable_vhost_user_vrings(&mut self, queue_indexes: Vec<usize>, enable: bool) -> Result<()> { 284 for queue_index in queue_indexes { 285 self.vu 286 .set_vring_enable(queue_index, enable) 287 .map_err(Error::VhostUserSetVringEnable)?; 288 } 289 290 Ok(()) 291 } 292 293 pub fn reset_vhost_user(&mut self) -> Result<()> { 294 for queue_index in self.queue_indexes.drain(..) { 295 self.vu 296 .set_vring_enable(queue_index, false) 297 .map_err(Error::VhostUserSetVringEnable)?; 298 299 let _ = self 300 .vu 301 .get_vring_base(queue_index) 302 .map_err(Error::VhostUserGetVringBase)?; 303 } 304 305 Ok(()) 306 } 307 308 pub fn set_protocol_features_vhost_user( 309 &mut self, 310 acked_features: u64, 311 acked_protocol_features: u64, 312 ) -> Result<()> { 313 self.vu.set_owner().map_err(Error::VhostUserSetOwner)?; 314 self.vu 315 .get_features() 316 .map_err(Error::VhostUserGetFeatures)?; 317 318 if acked_features & VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits() != 0 { 319 if let Some(acked_protocol_features) = 320 VhostUserProtocolFeatures::from_bits(acked_protocol_features) 321 { 322 self.vu 323 .set_protocol_features(acked_protocol_features) 324 .map_err(Error::VhostUserSetProtocolFeatures)?; 325 326 if acked_protocol_features.contains(VhostUserProtocolFeatures::REPLY_ACK) { 327 self.vu.set_hdr_flags(VhostUserHeaderFlag::NEED_REPLY); 328 } 329 } 330 } 331 332 self.update_supports_migration(acked_features, acked_protocol_features); 333 334 Ok(()) 335 } 336 337 #[allow(clippy::too_many_arguments)] 338 pub fn reinitialize_vhost_user<S: VhostUserFrontendReqHandler>( 339 &mut self, 340 mem: &GuestMemoryMmap, 341 queues: Vec<(usize, Queue, EventFd)>, 342 virtio_interrupt: &Arc<dyn VirtioInterrupt>, 343 acked_features: u64, 344 acked_protocol_features: u64, 345 backend_req_handler: &Option<FrontendReqHandler<S>>, 346 inflight: Option<&mut Inflight>, 347 ) -> Result<()> { 348 self.set_protocol_features_vhost_user(acked_features, acked_protocol_features)?; 349 350 self.setup_vhost_user( 351 mem, 352 queues, 353 virtio_interrupt, 354 acked_features, 355 backend_req_handler, 356 inflight, 357 ) 358 } 359 360 pub fn connect_vhost_user( 361 server: bool, 362 socket_path: &str, 363 num_queues: u64, 364 unlink_socket: bool, 365 ) -> Result<Self> { 366 if server { 367 if unlink_socket { 368 std::fs::remove_file(socket_path).map_err(Error::RemoveSocketPath)?; 369 } 370 371 info!("Binding vhost-user listener..."); 372 let listener = UnixListener::bind(socket_path).map_err(Error::BindSocket)?; 373 info!("Waiting for incoming vhost-user connection..."); 374 let (stream, _) = listener.accept().map_err(Error::AcceptConnection)?; 375 376 Ok(VhostUserHandle { 377 vu: Frontend::from_stream(stream, num_queues), 378 ready: false, 379 supports_migration: false, 380 shm_log: None, 381 acked_features: 0, 382 vrings_info: None, 383 queue_indexes: Vec::new(), 384 }) 385 } else { 386 let now = Instant::now(); 387 388 // Retry connecting for a full minute 389 let err = loop { 390 let err = match Frontend::connect(socket_path, num_queues) { 391 Ok(m) => { 392 return Ok(VhostUserHandle { 393 vu: m, 394 ready: false, 395 supports_migration: false, 396 shm_log: None, 397 acked_features: 0, 398 vrings_info: None, 399 queue_indexes: Vec::new(), 400 }) 401 } 402 Err(e) => e, 403 }; 404 sleep(Duration::from_millis(100)); 405 406 if now.elapsed().as_secs() >= 60 { 407 break err; 408 } 409 }; 410 411 error!( 412 "Failed connecting the backend after trying for 1 minute: {:?}", 413 err 414 ); 415 Err(Error::VhostUserConnect) 416 } 417 } 418 419 pub fn socket_handle(&mut self) -> &mut Frontend { 420 &mut self.vu 421 } 422 423 pub fn pause_vhost_user(&mut self) -> Result<()> { 424 if self.ready { 425 self.enable_vhost_user_vrings(self.queue_indexes.clone(), false)?; 426 } 427 428 Ok(()) 429 } 430 431 pub fn resume_vhost_user(&mut self) -> Result<()> { 432 if self.ready { 433 self.enable_vhost_user_vrings(self.queue_indexes.clone(), true)?; 434 } 435 436 Ok(()) 437 } 438 439 fn update_supports_migration(&mut self, acked_features: u64, acked_protocol_features: u64) { 440 if (acked_features & u64::from(vhost::vhost_kern::vhost_binding::VHOST_F_LOG_ALL) != 0) 441 && (acked_protocol_features & VhostUserProtocolFeatures::LOG_SHMFD.bits() != 0) 442 { 443 self.supports_migration = true; 444 } 445 } 446 447 fn update_log_base(&mut self, last_ram_addr: u64) -> Result<Option<Arc<MmapRegion>>> { 448 // Create the memfd 449 let fd = memfd_create( 450 &ffi::CString::new("vhost_user_dirty_log").unwrap(), 451 libc::MFD_CLOEXEC | libc::MFD_ALLOW_SEALING, 452 ) 453 .map_err(Error::MemfdCreate)?; 454 455 // SAFETY: we checked the file descriptor is valid 456 let file = unsafe { File::from_raw_fd(fd) }; 457 // The size of the memory mapping corresponds to the size of a bitmap 458 // covering all guest pages for addresses from 0 to the last physical 459 // address in guest RAM. 460 // A page is always 4kiB from a vhost-user perspective, and each bit is 461 // a page. That's how we can compute mmap_size from the last address. 462 let mmap_size = (last_ram_addr / (VHOST_LOG_PAGE * 8)) + 1; 463 let mmap_handle = file.as_raw_fd(); 464 465 // Set shm_log region size 466 file.set_len(mmap_size).map_err(Error::SetFileSize)?; 467 468 // Set the seals 469 // SAFETY: FFI call with valid arguments 470 let res = unsafe { 471 libc::fcntl( 472 file.as_raw_fd(), 473 libc::F_ADD_SEALS, 474 libc::F_SEAL_GROW | libc::F_SEAL_SHRINK | libc::F_SEAL_SEAL, 475 ) 476 }; 477 if res < 0 { 478 return Err(Error::SetSeals(std::io::Error::last_os_error())); 479 } 480 481 // Mmap shm_log region 482 let region = MmapRegion::build( 483 Some(FileOffset::new(file, 0)), 484 mmap_size as usize, 485 libc::PROT_READ | libc::PROT_WRITE, 486 libc::MAP_SHARED, 487 ) 488 .map_err(Error::NewMmapRegion)?; 489 490 // Make sure we hold onto the region to prevent the mapping from being 491 // released. 492 let old_region = self.shm_log.replace(Arc::new(region)); 493 494 // Send the shm_log fd over to the backend 495 let log = VhostUserDirtyLogRegion { 496 mmap_size, 497 mmap_offset: 0, 498 mmap_handle, 499 }; 500 self.vu 501 .set_log_base(0, Some(log)) 502 .map_err(Error::VhostUserSetLogBase)?; 503 504 Ok(old_region) 505 } 506 507 fn set_vring_logging(&mut self, enable: bool) -> Result<()> { 508 if let Some(vrings_info) = &self.vrings_info { 509 for (i, vring_info) in vrings_info.iter().enumerate() { 510 let mut config_data = vring_info.config_data; 511 config_data.flags = if enable { 1 << VHOST_VRING_F_LOG } else { 0 }; 512 config_data.log_addr = if enable { 513 Some(vring_info.used_guest_addr) 514 } else { 515 None 516 }; 517 518 self.vu 519 .set_vring_addr(i, &config_data) 520 .map_err(Error::VhostUserSetVringAddr)?; 521 } 522 } 523 524 Ok(()) 525 } 526 527 pub fn start_dirty_log(&mut self, last_ram_addr: u64) -> Result<()> { 528 if !self.supports_migration { 529 return Err(Error::MigrationNotSupported); 530 } 531 532 // Set the shm log region 533 self.update_log_base(last_ram_addr)?; 534 535 // Enable VHOST_F_LOG_ALL feature 536 let features = self.acked_features | (1 << VHOST_F_LOG_ALL); 537 self.vu 538 .set_features(features) 539 .map_err(Error::VhostUserSetFeatures)?; 540 541 // Enable dirty page logging of used ring for all queues 542 self.set_vring_logging(true) 543 } 544 545 pub fn stop_dirty_log(&mut self) -> Result<()> { 546 if !self.supports_migration { 547 return Err(Error::MigrationNotSupported); 548 } 549 550 // Disable dirty page logging of used ring for all queues 551 self.set_vring_logging(false)?; 552 553 // Disable VHOST_F_LOG_ALL feature 554 self.vu 555 .set_features(self.acked_features) 556 .map_err(Error::VhostUserSetFeatures)?; 557 558 // This is important here since the log region goes out of scope, 559 // invoking the Drop trait, hence unmapping the memory. 560 self.shm_log = None; 561 562 Ok(()) 563 } 564 565 pub fn dirty_log(&mut self, last_ram_addr: u64) -> Result<MemoryRangeTable> { 566 // The log region is updated by creating a new region that is sent to 567 // the backend. This ensures the backend stops logging to the previous 568 // region. The previous region is returned and processed to create the 569 // bitmap representing the dirty pages. 570 if let Some(region) = self.update_log_base(last_ram_addr)? { 571 // Be careful with the size, as it was based on u8, meaning we must 572 // divide it by 8. 573 let len = region.size() / 8; 574 // SAFETY: region is of size len 575 let bitmap = unsafe { 576 // Cast the pointer to u64 577 let ptr = region.as_ptr() as *const u64; 578 std::slice::from_raw_parts(ptr, len).to_vec() 579 }; 580 Ok(MemoryRangeTable::from_bitmap(bitmap, 0, 4096)) 581 } else { 582 Err(Error::MissingShmLogRegion) 583 } 584 } 585 } 586 587 fn memfd_create(name: &ffi::CStr, flags: u32) -> std::result::Result<RawFd, std::io::Error> { 588 // SAFETY: FFI call with valid arguments 589 let res = unsafe { libc::syscall(libc::SYS_memfd_create, name.as_ptr(), flags) }; 590 591 if res < 0 { 592 Err(std::io::Error::last_os_error()) 593 } else { 594 Ok(res as RawFd) 595 } 596 } 597