1 // Copyright 2019 Intel Corporation. All Rights Reserved. 2 // SPDX-License-Identifier: Apache-2.0 3 4 use super::{Error, Result}; 5 use crate::vhost_user::Inflight; 6 use crate::{ 7 get_host_address_range, GuestMemoryMmap, GuestRegionMmap, MmapRegion, VirtioInterrupt, 8 VirtioInterruptType, 9 }; 10 use std::ffi; 11 use std::fs::File; 12 use std::os::unix::io::{AsRawFd, FromRawFd, RawFd}; 13 use std::os::unix::net::UnixListener; 14 use std::sync::atomic::Ordering; 15 use std::sync::Arc; 16 use std::thread::sleep; 17 use std::time::{Duration, Instant}; 18 use vhost::vhost_kern::vhost_binding::{VHOST_F_LOG_ALL, VHOST_VRING_F_LOG}; 19 use vhost::vhost_user::message::{ 20 VhostUserHeaderFlag, VhostUserInflight, VhostUserProtocolFeatures, VhostUserVirtioFeatures, 21 }; 22 use vhost::vhost_user::{ 23 Frontend, FrontendReqHandler, VhostUserFrontend, VhostUserFrontendReqHandler, 24 }; 25 use vhost::{VhostBackend, VhostUserDirtyLogRegion, VhostUserMemoryRegionInfo, VringConfigData}; 26 use virtio_queue::{Descriptor, Queue, QueueT}; 27 use vm_memory::{ 28 Address, Error as MmapError, FileOffset, GuestAddress, GuestMemory, GuestMemoryRegion, 29 }; 30 use vm_migration::protocol::MemoryRangeTable; 31 use vmm_sys_util::eventfd::EventFd; 32 33 // Size of a dirty page for vhost-user. 34 const VHOST_LOG_PAGE: u64 = 0x1000; 35 36 #[derive(Debug, Clone)] 37 pub struct VhostUserConfig { 38 pub socket: String, 39 pub num_queues: usize, 40 pub queue_size: u16, 41 } 42 43 #[derive(Clone)] 44 struct VringInfo { 45 config_data: VringConfigData, 46 used_guest_addr: u64, 47 } 48 49 #[derive(Clone)] 50 pub struct VhostUserHandle { 51 vu: Frontend, 52 ready: bool, 53 supports_migration: bool, 54 shm_log: Option<Arc<MmapRegion>>, 55 acked_features: u64, 56 vrings_info: Option<Vec<VringInfo>>, 57 queue_indexes: Vec<usize>, 58 } 59 60 impl VhostUserHandle { 61 pub fn update_mem_table(&mut self, mem: &GuestMemoryMmap) -> Result<()> { 62 let mut regions: Vec<VhostUserMemoryRegionInfo> = Vec::new(); 63 for region in mem.iter() { 64 let (mmap_handle, mmap_offset) = match region.file_offset() { 65 Some(_file_offset) => (_file_offset.file().as_raw_fd(), _file_offset.start()), 66 None => return Err(Error::VhostUserMemoryRegion(MmapError::NoMemoryRegion)), 67 }; 68 69 let vhost_user_net_reg = VhostUserMemoryRegionInfo { 70 guest_phys_addr: region.start_addr().raw_value(), 71 memory_size: region.len(), 72 userspace_addr: region.as_ptr() as u64, 73 mmap_offset, 74 mmap_handle, 75 }; 76 77 regions.push(vhost_user_net_reg); 78 } 79 80 self.vu 81 .set_mem_table(regions.as_slice()) 82 .map_err(Error::VhostUserSetMemTable)?; 83 84 Ok(()) 85 } 86 87 pub fn add_memory_region(&mut self, region: &Arc<GuestRegionMmap>) -> Result<()> { 88 let (mmap_handle, mmap_offset) = match region.file_offset() { 89 Some(file_offset) => (file_offset.file().as_raw_fd(), file_offset.start()), 90 None => return Err(Error::MissingRegionFd), 91 }; 92 93 let region = VhostUserMemoryRegionInfo { 94 guest_phys_addr: region.start_addr().raw_value(), 95 memory_size: region.len(), 96 userspace_addr: region.as_ptr() as u64, 97 mmap_offset, 98 mmap_handle, 99 }; 100 101 self.vu 102 .add_mem_region(®ion) 103 .map_err(Error::VhostUserAddMemReg) 104 } 105 106 pub fn negotiate_features_vhost_user( 107 &mut self, 108 avail_features: u64, 109 avail_protocol_features: VhostUserProtocolFeatures, 110 ) -> Result<(u64, u64)> { 111 // Set vhost-user owner. 112 self.vu.set_owner().map_err(Error::VhostUserSetOwner)?; 113 114 // Get features from backend, do negotiation to get a feature collection which 115 // both VMM and backend support. 116 let backend_features = self 117 .vu 118 .get_features() 119 .map_err(Error::VhostUserGetFeatures)?; 120 let acked_features = avail_features & backend_features; 121 122 let acked_protocol_features = 123 if acked_features & VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits() != 0 { 124 let backend_protocol_features = self 125 .vu 126 .get_protocol_features() 127 .map_err(Error::VhostUserGetProtocolFeatures)?; 128 129 let acked_protocol_features = avail_protocol_features & backend_protocol_features; 130 131 self.vu 132 .set_protocol_features(acked_protocol_features) 133 .map_err(Error::VhostUserSetProtocolFeatures)?; 134 135 acked_protocol_features 136 } else { 137 VhostUserProtocolFeatures::empty() 138 }; 139 140 if avail_protocol_features.contains(VhostUserProtocolFeatures::REPLY_ACK) 141 && acked_protocol_features.contains(VhostUserProtocolFeatures::REPLY_ACK) 142 { 143 self.vu.set_hdr_flags(VhostUserHeaderFlag::NEED_REPLY); 144 } 145 146 self.update_supports_migration(acked_features, acked_protocol_features.bits()); 147 148 Ok((acked_features, acked_protocol_features.bits())) 149 } 150 151 #[allow(clippy::too_many_arguments)] 152 pub fn setup_vhost_user<S: VhostUserFrontendReqHandler>( 153 &mut self, 154 mem: &GuestMemoryMmap, 155 queues: Vec<(usize, Queue, EventFd)>, 156 virtio_interrupt: &Arc<dyn VirtioInterrupt>, 157 acked_features: u64, 158 backend_req_handler: &Option<FrontendReqHandler<S>>, 159 inflight: Option<&mut Inflight>, 160 ) -> Result<()> { 161 self.vu 162 .set_features(acked_features) 163 .map_err(Error::VhostUserSetFeatures)?; 164 165 // Update internal value after it's been sent to the backend. 166 self.acked_features = acked_features; 167 168 // Let's first provide the memory table to the backend. 169 self.update_mem_table(mem)?; 170 171 // Send set_vring_num here, since it could tell backends, like SPDK, 172 // how many virt queues to be handled, which backend required to know 173 // at early stage. 174 for (queue_index, queue, _) in queues.iter() { 175 self.vu 176 .set_vring_num(*queue_index, queue.size()) 177 .map_err(Error::VhostUserSetVringNum)?; 178 } 179 180 // Setup for inflight I/O tracking shared memory. 181 if let Some(inflight) = inflight { 182 if inflight.fd.is_none() { 183 let inflight_req_info = VhostUserInflight { 184 mmap_size: 0, 185 mmap_offset: 0, 186 num_queues: queues.len() as u16, 187 queue_size: queues[0].1.size(), 188 }; 189 let (info, fd) = self 190 .vu 191 .get_inflight_fd(&inflight_req_info) 192 .map_err(Error::VhostUserGetInflight)?; 193 inflight.info = info; 194 inflight.fd = Some(fd); 195 } 196 // Unwrapping the inflight fd is safe here since we know it can't be None. 197 self.vu 198 .set_inflight_fd(&inflight.info, inflight.fd.as_ref().unwrap().as_raw_fd()) 199 .map_err(Error::VhostUserSetInflight)?; 200 } 201 202 let mut vrings_info = Vec::new(); 203 for (queue_index, queue, queue_evt) in queues.iter() { 204 let actual_size: usize = queue.size().into(); 205 206 let config_data = VringConfigData { 207 queue_max_size: queue.max_size(), 208 queue_size: queue.size(), 209 flags: 0u32, 210 desc_table_addr: get_host_address_range( 211 mem, 212 GuestAddress(queue.desc_table()), 213 actual_size * std::mem::size_of::<Descriptor>(), 214 ) 215 .ok_or(Error::DescriptorTableAddress)? as u64, 216 // The used ring is {flags: u16; idx: u16; virtq_used_elem [{id: u16, len: u16}; actual_size]}, 217 // i.e. 4 + (4 + 4) * actual_size. 218 used_ring_addr: get_host_address_range( 219 mem, 220 GuestAddress(queue.used_ring()), 221 4 + actual_size * 8, 222 ) 223 .ok_or(Error::UsedAddress)? as u64, 224 // The used ring is {flags: u16; idx: u16; elem [u16; actual_size]}, 225 // i.e. 4 + (2) * actual_size. 226 avail_ring_addr: get_host_address_range( 227 mem, 228 GuestAddress(queue.avail_ring()), 229 4 + actual_size * 2, 230 ) 231 .ok_or(Error::AvailAddress)? as u64, 232 log_addr: None, 233 }; 234 235 vrings_info.push(VringInfo { 236 config_data, 237 used_guest_addr: queue.used_ring(), 238 }); 239 240 self.vu 241 .set_vring_addr(*queue_index, &config_data) 242 .map_err(Error::VhostUserSetVringAddr)?; 243 self.vu 244 .set_vring_base( 245 *queue_index, 246 queue 247 .avail_idx(mem, Ordering::Acquire) 248 .map_err(Error::GetAvailableIndex)? 249 .0, 250 ) 251 .map_err(Error::VhostUserSetVringBase)?; 252 253 if let Some(eventfd) = 254 virtio_interrupt.notifier(VirtioInterruptType::Queue(*queue_index as u16)) 255 { 256 self.vu 257 .set_vring_call(*queue_index, &eventfd) 258 .map_err(Error::VhostUserSetVringCall)?; 259 } 260 261 self.vu 262 .set_vring_kick(*queue_index, queue_evt) 263 .map_err(Error::VhostUserSetVringKick)?; 264 265 self.queue_indexes.push(*queue_index); 266 } 267 268 self.enable_vhost_user_vrings(self.queue_indexes.clone(), true)?; 269 270 if let Some(backend_req_handler) = backend_req_handler { 271 self.vu 272 .set_backend_request_fd(&backend_req_handler.get_tx_raw_fd()) 273 .map_err(Error::VhostUserSetBackendRequestFd)?; 274 } 275 276 self.vrings_info = Some(vrings_info); 277 self.ready = true; 278 279 Ok(()) 280 } 281 282 fn enable_vhost_user_vrings(&mut self, queue_indexes: Vec<usize>, enable: bool) -> Result<()> { 283 for queue_index in queue_indexes { 284 self.vu 285 .set_vring_enable(queue_index, enable) 286 .map_err(Error::VhostUserSetVringEnable)?; 287 } 288 289 Ok(()) 290 } 291 292 pub fn reset_vhost_user(&mut self) -> Result<()> { 293 for queue_index in self.queue_indexes.drain(..) { 294 self.vu 295 .set_vring_enable(queue_index, false) 296 .map_err(Error::VhostUserSetVringEnable)?; 297 298 let _ = self 299 .vu 300 .get_vring_base(queue_index) 301 .map_err(Error::VhostUserGetVringBase)?; 302 } 303 304 Ok(()) 305 } 306 307 pub fn set_protocol_features_vhost_user( 308 &mut self, 309 acked_features: u64, 310 acked_protocol_features: u64, 311 ) -> Result<()> { 312 self.vu.set_owner().map_err(Error::VhostUserSetOwner)?; 313 self.vu 314 .get_features() 315 .map_err(Error::VhostUserGetFeatures)?; 316 317 if acked_features & VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits() != 0 { 318 if let Some(acked_protocol_features) = 319 VhostUserProtocolFeatures::from_bits(acked_protocol_features) 320 { 321 self.vu 322 .set_protocol_features(acked_protocol_features) 323 .map_err(Error::VhostUserSetProtocolFeatures)?; 324 325 if acked_protocol_features.contains(VhostUserProtocolFeatures::REPLY_ACK) { 326 self.vu.set_hdr_flags(VhostUserHeaderFlag::NEED_REPLY); 327 } 328 } 329 } 330 331 self.update_supports_migration(acked_features, acked_protocol_features); 332 333 Ok(()) 334 } 335 336 #[allow(clippy::too_many_arguments)] 337 pub fn reinitialize_vhost_user<S: VhostUserFrontendReqHandler>( 338 &mut self, 339 mem: &GuestMemoryMmap, 340 queues: Vec<(usize, Queue, EventFd)>, 341 virtio_interrupt: &Arc<dyn VirtioInterrupt>, 342 acked_features: u64, 343 acked_protocol_features: u64, 344 backend_req_handler: &Option<FrontendReqHandler<S>>, 345 inflight: Option<&mut Inflight>, 346 ) -> Result<()> { 347 self.set_protocol_features_vhost_user(acked_features, acked_protocol_features)?; 348 349 self.setup_vhost_user( 350 mem, 351 queues, 352 virtio_interrupt, 353 acked_features, 354 backend_req_handler, 355 inflight, 356 ) 357 } 358 359 pub fn connect_vhost_user( 360 server: bool, 361 socket_path: &str, 362 num_queues: u64, 363 unlink_socket: bool, 364 ) -> Result<Self> { 365 if server { 366 if unlink_socket { 367 std::fs::remove_file(socket_path).map_err(Error::RemoveSocketPath)?; 368 } 369 370 info!("Binding vhost-user listener..."); 371 let listener = UnixListener::bind(socket_path).map_err(Error::BindSocket)?; 372 info!("Waiting for incoming vhost-user connection..."); 373 let (stream, _) = listener.accept().map_err(Error::AcceptConnection)?; 374 375 Ok(VhostUserHandle { 376 vu: Frontend::from_stream(stream, num_queues), 377 ready: false, 378 supports_migration: false, 379 shm_log: None, 380 acked_features: 0, 381 vrings_info: None, 382 queue_indexes: Vec::new(), 383 }) 384 } else { 385 let now = Instant::now(); 386 387 // Retry connecting for a full minute 388 let err = loop { 389 let err = match Frontend::connect(socket_path, num_queues) { 390 Ok(m) => { 391 return Ok(VhostUserHandle { 392 vu: m, 393 ready: false, 394 supports_migration: false, 395 shm_log: None, 396 acked_features: 0, 397 vrings_info: None, 398 queue_indexes: Vec::new(), 399 }) 400 } 401 Err(e) => e, 402 }; 403 sleep(Duration::from_millis(100)); 404 405 if now.elapsed().as_secs() >= 60 { 406 break err; 407 } 408 }; 409 410 error!( 411 "Failed connecting the backend after trying for 1 minute: {:?}", 412 err 413 ); 414 Err(Error::VhostUserConnect) 415 } 416 } 417 418 pub fn socket_handle(&mut self) -> &mut Frontend { 419 &mut self.vu 420 } 421 422 pub fn pause_vhost_user(&mut self) -> Result<()> { 423 if self.ready { 424 self.enable_vhost_user_vrings(self.queue_indexes.clone(), false)?; 425 } 426 427 Ok(()) 428 } 429 430 pub fn resume_vhost_user(&mut self) -> Result<()> { 431 if self.ready { 432 self.enable_vhost_user_vrings(self.queue_indexes.clone(), true)?; 433 } 434 435 Ok(()) 436 } 437 438 fn update_supports_migration(&mut self, acked_features: u64, acked_protocol_features: u64) { 439 if (acked_features & u64::from(vhost::vhost_kern::vhost_binding::VHOST_F_LOG_ALL) != 0) 440 && (acked_protocol_features & VhostUserProtocolFeatures::LOG_SHMFD.bits() != 0) 441 { 442 self.supports_migration = true; 443 } 444 } 445 446 fn update_log_base(&mut self, last_ram_addr: u64) -> Result<Option<Arc<MmapRegion>>> { 447 // Create the memfd 448 let fd = memfd_create( 449 &ffi::CString::new("vhost_user_dirty_log").unwrap(), 450 libc::MFD_CLOEXEC | libc::MFD_ALLOW_SEALING, 451 ) 452 .map_err(Error::MemfdCreate)?; 453 454 // SAFETY: we checked the file descriptor is valid 455 let file = unsafe { File::from_raw_fd(fd) }; 456 // The size of the memory mapping corresponds to the size of a bitmap 457 // covering all guest pages for addresses from 0 to the last physical 458 // address in guest RAM. 459 // A page is always 4kiB from a vhost-user perspective, and each bit is 460 // a page. That's how we can compute mmap_size from the last address. 461 let mmap_size = (last_ram_addr / (VHOST_LOG_PAGE * 8)) + 1; 462 let mmap_handle = file.as_raw_fd(); 463 464 // Set shm_log region size 465 file.set_len(mmap_size).map_err(Error::SetFileSize)?; 466 467 // Set the seals 468 // SAFETY: FFI call with valid arguments 469 let res = unsafe { 470 libc::fcntl( 471 file.as_raw_fd(), 472 libc::F_ADD_SEALS, 473 libc::F_SEAL_GROW | libc::F_SEAL_SHRINK | libc::F_SEAL_SEAL, 474 ) 475 }; 476 if res < 0 { 477 return Err(Error::SetSeals(std::io::Error::last_os_error())); 478 } 479 480 // Mmap shm_log region 481 let region = MmapRegion::build( 482 Some(FileOffset::new(file, 0)), 483 mmap_size as usize, 484 libc::PROT_READ | libc::PROT_WRITE, 485 libc::MAP_SHARED, 486 ) 487 .map_err(Error::NewMmapRegion)?; 488 489 // Make sure we hold onto the region to prevent the mapping from being 490 // released. 491 let old_region = self.shm_log.replace(Arc::new(region)); 492 493 // Send the shm_log fd over to the backend 494 let log = VhostUserDirtyLogRegion { 495 mmap_size, 496 mmap_offset: 0, 497 mmap_handle, 498 }; 499 self.vu 500 .set_log_base(0, Some(log)) 501 .map_err(Error::VhostUserSetLogBase)?; 502 503 Ok(old_region) 504 } 505 506 fn set_vring_logging(&mut self, enable: bool) -> Result<()> { 507 if let Some(vrings_info) = &self.vrings_info { 508 for (i, vring_info) in vrings_info.iter().enumerate() { 509 let mut config_data = vring_info.config_data; 510 config_data.flags = if enable { 1 << VHOST_VRING_F_LOG } else { 0 }; 511 config_data.log_addr = if enable { 512 Some(vring_info.used_guest_addr) 513 } else { 514 None 515 }; 516 517 self.vu 518 .set_vring_addr(i, &config_data) 519 .map_err(Error::VhostUserSetVringAddr)?; 520 } 521 } 522 523 Ok(()) 524 } 525 526 pub fn start_dirty_log(&mut self, last_ram_addr: u64) -> Result<()> { 527 if !self.supports_migration { 528 return Err(Error::MigrationNotSupported); 529 } 530 531 // Set the shm log region 532 self.update_log_base(last_ram_addr)?; 533 534 // Enable VHOST_F_LOG_ALL feature 535 let features = self.acked_features | (1 << VHOST_F_LOG_ALL); 536 self.vu 537 .set_features(features) 538 .map_err(Error::VhostUserSetFeatures)?; 539 540 // Enable dirty page logging of used ring for all queues 541 self.set_vring_logging(true) 542 } 543 544 pub fn stop_dirty_log(&mut self) -> Result<()> { 545 if !self.supports_migration { 546 return Err(Error::MigrationNotSupported); 547 } 548 549 // Disable dirty page logging of used ring for all queues 550 self.set_vring_logging(false)?; 551 552 // Disable VHOST_F_LOG_ALL feature 553 self.vu 554 .set_features(self.acked_features) 555 .map_err(Error::VhostUserSetFeatures)?; 556 557 // This is important here since the log region goes out of scope, 558 // invoking the Drop trait, hence unmapping the memory. 559 self.shm_log = None; 560 561 Ok(()) 562 } 563 564 pub fn dirty_log(&mut self, last_ram_addr: u64) -> Result<MemoryRangeTable> { 565 // The log region is updated by creating a new region that is sent to 566 // the backend. This ensures the backend stops logging to the previous 567 // region. The previous region is returned and processed to create the 568 // bitmap representing the dirty pages. 569 if let Some(region) = self.update_log_base(last_ram_addr)? { 570 // Be careful with the size, as it was based on u8, meaning we must 571 // divide it by 8. 572 let len = region.size() / 8; 573 // SAFETY: region is of size len 574 let bitmap = unsafe { 575 // Cast the pointer to u64 576 let ptr = region.as_ptr() as *const u64; 577 std::slice::from_raw_parts(ptr, len).to_vec() 578 }; 579 Ok(MemoryRangeTable::from_bitmap(bitmap, 0, 4096)) 580 } else { 581 Err(Error::MissingShmLogRegion) 582 } 583 } 584 } 585 586 fn memfd_create(name: &ffi::CStr, flags: u32) -> std::result::Result<RawFd, std::io::Error> { 587 // SAFETY: FFI call with valid arguments 588 let res = unsafe { libc::syscall(libc::SYS_memfd_create, name.as_ptr(), flags) }; 589 590 if res < 0 { 591 Err(std::io::Error::last_os_error()) 592 } else { 593 Ok(res as RawFd) 594 } 595 } 596