1 // Copyright 2019 Intel Corporation. All Rights Reserved. 2 // SPDX-License-Identifier: Apache-2.0 3 4 use super::{Error, Result}; 5 use crate::vhost_user::Inflight; 6 use crate::{ 7 get_host_address_range, GuestMemoryMmap, GuestRegionMmap, MmapRegion, VirtioInterrupt, 8 VirtioInterruptType, 9 }; 10 use std::convert::TryInto; 11 use std::ffi; 12 use std::fs::File; 13 use std::os::unix::io::{AsRawFd, FromRawFd, RawFd}; 14 use std::os::unix::net::UnixListener; 15 use std::sync::atomic::Ordering; 16 use std::sync::Arc; 17 use std::thread::sleep; 18 use std::time::{Duration, Instant}; 19 use std::vec::Vec; 20 use vhost::vhost_kern::vhost_binding::{VHOST_F_LOG_ALL, VHOST_VRING_F_LOG}; 21 use vhost::vhost_user::message::{ 22 VhostUserHeaderFlag, VhostUserInflight, VhostUserProtocolFeatures, VhostUserVirtioFeatures, 23 }; 24 use vhost::vhost_user::{Master, MasterReqHandler, VhostUserMaster, VhostUserMasterReqHandler}; 25 use vhost::{VhostBackend, VhostUserDirtyLogRegion, VhostUserMemoryRegionInfo, VringConfigData}; 26 use virtio_queue::{Descriptor, Queue}; 27 use vm_memory::{ 28 Address, Error as MmapError, FileOffset, GuestMemory, GuestMemoryAtomic, GuestMemoryRegion, 29 }; 30 use vm_migration::protocol::MemoryRangeTable; 31 use vmm_sys_util::eventfd::EventFd; 32 33 // Size of a dirty page for vhost-user. 34 const VHOST_LOG_PAGE: u64 = 0x1000; 35 36 #[derive(Debug, Clone)] 37 pub struct VhostUserConfig { 38 pub socket: String, 39 pub num_queues: usize, 40 pub queue_size: u16, 41 } 42 43 #[derive(Clone)] 44 struct VringInfo { 45 config_data: VringConfigData, 46 used_guest_addr: u64, 47 } 48 49 #[derive(Clone)] 50 pub struct VhostUserHandle { 51 vu: Master, 52 ready: bool, 53 supports_migration: bool, 54 shm_log: Option<Arc<MmapRegion>>, 55 acked_features: u64, 56 vrings_info: Option<Vec<VringInfo>>, 57 queue_indexes: Vec<usize>, 58 } 59 60 impl VhostUserHandle { 61 pub fn update_mem_table(&mut self, mem: &GuestMemoryMmap) -> Result<()> { 62 let mut regions: Vec<VhostUserMemoryRegionInfo> = Vec::new(); 63 for region in mem.iter() { 64 let (mmap_handle, mmap_offset) = match region.file_offset() { 65 Some(_file_offset) => (_file_offset.file().as_raw_fd(), _file_offset.start()), 66 None => return Err(Error::VhostUserMemoryRegion(MmapError::NoMemoryRegion)), 67 }; 68 69 let vhost_user_net_reg = VhostUserMemoryRegionInfo { 70 guest_phys_addr: region.start_addr().raw_value(), 71 memory_size: region.len() as u64, 72 userspace_addr: region.as_ptr() as u64, 73 mmap_offset, 74 mmap_handle, 75 }; 76 77 regions.push(vhost_user_net_reg); 78 } 79 80 self.vu 81 .set_mem_table(regions.as_slice()) 82 .map_err(Error::VhostUserSetMemTable)?; 83 84 Ok(()) 85 } 86 87 pub fn add_memory_region(&mut self, region: &Arc<GuestRegionMmap>) -> Result<()> { 88 let (mmap_handle, mmap_offset) = match region.file_offset() { 89 Some(file_offset) => (file_offset.file().as_raw_fd(), file_offset.start()), 90 None => return Err(Error::MissingRegionFd), 91 }; 92 93 let region = VhostUserMemoryRegionInfo { 94 guest_phys_addr: region.start_addr().raw_value(), 95 memory_size: region.len() as u64, 96 userspace_addr: region.as_ptr() as u64, 97 mmap_offset, 98 mmap_handle, 99 }; 100 101 self.vu 102 .add_mem_region(®ion) 103 .map_err(Error::VhostUserAddMemReg) 104 } 105 106 pub fn negotiate_features_vhost_user( 107 &mut self, 108 avail_features: u64, 109 avail_protocol_features: VhostUserProtocolFeatures, 110 ) -> Result<(u64, u64)> { 111 // Set vhost-user owner. 112 self.vu.set_owner().map_err(Error::VhostUserSetOwner)?; 113 114 // Get features from backend, do negotiation to get a feature collection which 115 // both VMM and backend support. 116 let backend_features = self 117 .vu 118 .get_features() 119 .map_err(Error::VhostUserGetFeatures)?; 120 let acked_features = avail_features & backend_features; 121 122 let acked_protocol_features = 123 if acked_features & VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits() != 0 { 124 let backend_protocol_features = self 125 .vu 126 .get_protocol_features() 127 .map_err(Error::VhostUserGetProtocolFeatures)?; 128 129 let acked_protocol_features = avail_protocol_features & backend_protocol_features; 130 131 self.vu 132 .set_protocol_features(acked_protocol_features) 133 .map_err(Error::VhostUserSetProtocolFeatures)?; 134 135 acked_protocol_features 136 } else { 137 VhostUserProtocolFeatures::empty() 138 }; 139 140 if avail_protocol_features.contains(VhostUserProtocolFeatures::REPLY_ACK) 141 && acked_protocol_features.contains(VhostUserProtocolFeatures::REPLY_ACK) 142 { 143 self.vu.set_hdr_flags(VhostUserHeaderFlag::NEED_REPLY); 144 } 145 146 self.update_supports_migration(acked_features, acked_protocol_features.bits()); 147 148 Ok((acked_features, acked_protocol_features.bits())) 149 } 150 151 #[allow(clippy::too_many_arguments)] 152 pub fn setup_vhost_user<S: VhostUserMasterReqHandler>( 153 &mut self, 154 mem: &GuestMemoryMmap, 155 queues: Vec<(usize, Queue<GuestMemoryAtomic<GuestMemoryMmap>>, EventFd)>, 156 virtio_interrupt: &Arc<dyn VirtioInterrupt>, 157 acked_features: u64, 158 slave_req_handler: &Option<MasterReqHandler<S>>, 159 inflight: Option<&mut Inflight>, 160 ) -> Result<()> { 161 self.vu 162 .set_features(acked_features) 163 .map_err(Error::VhostUserSetFeatures)?; 164 165 // Update internal value after it's been sent to the backend. 166 self.acked_features = acked_features; 167 168 // Let's first provide the memory table to the backend. 169 self.update_mem_table(mem)?; 170 171 // Send set_vring_num here, since it could tell backends, like SPDK, 172 // how many virt queues to be handled, which backend required to know 173 // at early stage. 174 for (queue_index, queue, _) in queues.iter() { 175 self.vu 176 .set_vring_num(*queue_index, queue.state.size) 177 .map_err(Error::VhostUserSetVringNum)?; 178 } 179 180 // Setup for inflight I/O tracking shared memory. 181 if let Some(inflight) = inflight { 182 if inflight.fd.is_none() { 183 let inflight_req_info = VhostUserInflight { 184 mmap_size: 0, 185 mmap_offset: 0, 186 num_queues: queues.len() as u16, 187 queue_size: queues[0].1.state.size, 188 }; 189 let (info, fd) = self 190 .vu 191 .get_inflight_fd(&inflight_req_info) 192 .map_err(Error::VhostUserGetInflight)?; 193 inflight.info = info; 194 inflight.fd = Some(fd); 195 } 196 // Unwrapping the inflight fd is safe here since we know it can't be None. 197 self.vu 198 .set_inflight_fd(&inflight.info, inflight.fd.as_ref().unwrap().as_raw_fd()) 199 .map_err(Error::VhostUserSetInflight)?; 200 } 201 202 let mut vrings_info = Vec::new(); 203 for (queue_index, queue, queue_evt) in queues.iter() { 204 let actual_size: usize = queue.state.size.try_into().unwrap(); 205 206 let config_data = VringConfigData { 207 queue_max_size: queue.max_size(), 208 queue_size: queue.state.size, 209 flags: 0u32, 210 desc_table_addr: get_host_address_range( 211 mem, 212 queue.state.desc_table, 213 actual_size * std::mem::size_of::<Descriptor>(), 214 ) 215 .ok_or(Error::DescriptorTableAddress)? as u64, 216 // The used ring is {flags: u16; idx: u16; virtq_used_elem [{id: u16, len: u16}; actual_size]}, 217 // i.e. 4 + (4 + 4) * actual_size. 218 used_ring_addr: get_host_address_range( 219 mem, 220 queue.state.used_ring, 221 4 + actual_size * 8, 222 ) 223 .ok_or(Error::UsedAddress)? as u64, 224 // The used ring is {flags: u16; idx: u16; elem [u16; actual_size]}, 225 // i.e. 4 + (2) * actual_size. 226 avail_ring_addr: get_host_address_range( 227 mem, 228 queue.state.avail_ring, 229 4 + actual_size * 2, 230 ) 231 .ok_or(Error::AvailAddress)? as u64, 232 log_addr: None, 233 }; 234 235 vrings_info.push(VringInfo { 236 config_data, 237 used_guest_addr: queue.state.used_ring.raw_value(), 238 }); 239 240 self.vu 241 .set_vring_addr(*queue_index, &config_data) 242 .map_err(Error::VhostUserSetVringAddr)?; 243 self.vu 244 .set_vring_base( 245 *queue_index, 246 queue 247 .avail_idx(Ordering::Acquire) 248 .map_err(Error::GetAvailableIndex)? 249 .0, 250 ) 251 .map_err(Error::VhostUserSetVringBase)?; 252 253 if let Some(eventfd) = 254 virtio_interrupt.notifier(VirtioInterruptType::Queue(*queue_index as u16)) 255 { 256 self.vu 257 .set_vring_call(*queue_index, &eventfd) 258 .map_err(Error::VhostUserSetVringCall)?; 259 } 260 261 self.vu 262 .set_vring_kick(*queue_index, queue_evt) 263 .map_err(Error::VhostUserSetVringKick)?; 264 265 self.queue_indexes.push(*queue_index); 266 } 267 268 self.enable_vhost_user_vrings(self.queue_indexes.clone(), true)?; 269 270 if let Some(slave_req_handler) = slave_req_handler { 271 self.vu 272 .set_slave_request_fd(&slave_req_handler.get_tx_raw_fd()) 273 .map_err(Error::VhostUserSetSlaveRequestFd)?; 274 } 275 276 self.vrings_info = Some(vrings_info); 277 self.ready = true; 278 279 Ok(()) 280 } 281 282 fn enable_vhost_user_vrings(&mut self, queue_indexes: Vec<usize>, enable: bool) -> Result<()> { 283 for queue_index in queue_indexes { 284 self.vu 285 .set_vring_enable(queue_index, enable) 286 .map_err(Error::VhostUserSetVringEnable)?; 287 } 288 289 Ok(()) 290 } 291 292 pub fn reset_vhost_user(&mut self) -> Result<()> { 293 for queue_index in self.queue_indexes.drain(..) { 294 self.vu 295 .set_vring_enable(queue_index, false) 296 .map_err(Error::VhostUserSetVringEnable)?; 297 298 let _ = self 299 .vu 300 .get_vring_base(queue_index) 301 .map_err(Error::VhostUserGetVringBase)?; 302 } 303 304 Ok(()) 305 } 306 307 pub fn set_protocol_features_vhost_user( 308 &mut self, 309 acked_features: u64, 310 acked_protocol_features: u64, 311 ) -> Result<()> { 312 self.vu.set_owner().map_err(Error::VhostUserSetOwner)?; 313 self.vu 314 .get_features() 315 .map_err(Error::VhostUserGetFeatures)?; 316 317 if acked_features & VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits() != 0 { 318 if let Some(acked_protocol_features) = 319 VhostUserProtocolFeatures::from_bits(acked_protocol_features) 320 { 321 self.vu 322 .set_protocol_features(acked_protocol_features) 323 .map_err(Error::VhostUserSetProtocolFeatures)?; 324 325 if acked_protocol_features.contains(VhostUserProtocolFeatures::REPLY_ACK) { 326 self.vu.set_hdr_flags(VhostUserHeaderFlag::NEED_REPLY); 327 } 328 } 329 } 330 331 self.update_supports_migration(acked_features, acked_protocol_features); 332 333 Ok(()) 334 } 335 336 #[allow(clippy::too_many_arguments)] 337 pub fn reinitialize_vhost_user<S: VhostUserMasterReqHandler>( 338 &mut self, 339 mem: &GuestMemoryMmap, 340 queues: Vec<(usize, Queue<GuestMemoryAtomic<GuestMemoryMmap>>, EventFd)>, 341 virtio_interrupt: &Arc<dyn VirtioInterrupt>, 342 acked_features: u64, 343 acked_protocol_features: u64, 344 slave_req_handler: &Option<MasterReqHandler<S>>, 345 inflight: Option<&mut Inflight>, 346 ) -> Result<()> { 347 self.set_protocol_features_vhost_user(acked_features, acked_protocol_features)?; 348 349 self.setup_vhost_user( 350 mem, 351 queues, 352 virtio_interrupt, 353 acked_features, 354 slave_req_handler, 355 inflight, 356 ) 357 } 358 359 pub fn connect_vhost_user( 360 server: bool, 361 socket_path: &str, 362 num_queues: u64, 363 unlink_socket: bool, 364 ) -> Result<Self> { 365 if server { 366 if unlink_socket { 367 std::fs::remove_file(socket_path).map_err(Error::RemoveSocketPath)?; 368 } 369 370 info!("Binding vhost-user listener..."); 371 let listener = UnixListener::bind(socket_path).map_err(Error::BindSocket)?; 372 info!("Waiting for incoming vhost-user connection..."); 373 let (stream, _) = listener.accept().map_err(Error::AcceptConnection)?; 374 375 Ok(VhostUserHandle { 376 vu: Master::from_stream(stream, num_queues), 377 ready: false, 378 supports_migration: false, 379 shm_log: None, 380 acked_features: 0, 381 vrings_info: None, 382 queue_indexes: Vec::new(), 383 }) 384 } else { 385 let now = Instant::now(); 386 387 // Retry connecting for a full minute 388 let err = loop { 389 let err = match Master::connect(socket_path, num_queues) { 390 Ok(m) => { 391 return Ok(VhostUserHandle { 392 vu: m, 393 ready: false, 394 supports_migration: false, 395 shm_log: None, 396 acked_features: 0, 397 vrings_info: None, 398 queue_indexes: Vec::new(), 399 }) 400 } 401 Err(e) => e, 402 }; 403 sleep(Duration::from_millis(100)); 404 405 if now.elapsed().as_secs() >= 60 { 406 break err; 407 } 408 }; 409 410 error!( 411 "Failed connecting the backend after trying for 1 minute: {:?}", 412 err 413 ); 414 Err(Error::VhostUserConnect) 415 } 416 } 417 418 pub fn socket_handle(&mut self) -> &mut Master { 419 &mut self.vu 420 } 421 422 pub fn pause_vhost_user(&mut self) -> Result<()> { 423 if self.ready { 424 self.enable_vhost_user_vrings(self.queue_indexes.clone(), false)?; 425 } 426 427 Ok(()) 428 } 429 430 pub fn resume_vhost_user(&mut self) -> Result<()> { 431 if self.ready { 432 self.enable_vhost_user_vrings(self.queue_indexes.clone(), true)?; 433 } 434 435 Ok(()) 436 } 437 438 fn update_supports_migration(&mut self, acked_features: u64, acked_protocol_features: u64) { 439 if (acked_features & u64::from(vhost::vhost_kern::vhost_binding::VHOST_F_LOG_ALL) != 0) 440 && (acked_protocol_features & VhostUserProtocolFeatures::LOG_SHMFD.bits() != 0) 441 { 442 self.supports_migration = true; 443 } 444 } 445 446 fn update_log_base(&mut self, last_ram_addr: u64) -> Result<Option<Arc<MmapRegion>>> { 447 // Create the memfd 448 let fd = memfd_create( 449 &ffi::CString::new("vhost_user_dirty_log").unwrap(), 450 libc::MFD_CLOEXEC | libc::MFD_ALLOW_SEALING, 451 ) 452 .map_err(Error::MemfdCreate)?; 453 454 // Safe because we checked the file descriptor is valid 455 let file = unsafe { File::from_raw_fd(fd) }; 456 // The size of the memory mapping corresponds to the size of a bitmap 457 // covering all guest pages for addresses from 0 to the last physical 458 // address in guest RAM. 459 // A page is always 4kiB from a vhost-user perspective, and each bit is 460 // a page. That's how we can compute mmap_size from the last address. 461 let mmap_size = (last_ram_addr / (VHOST_LOG_PAGE * 8)) + 1; 462 let mmap_handle = file.as_raw_fd(); 463 464 // Set shm_log region size 465 file.set_len(mmap_size).map_err(Error::SetFileSize)?; 466 467 // Set the seals 468 let res = unsafe { 469 libc::fcntl( 470 file.as_raw_fd(), 471 libc::F_ADD_SEALS, 472 libc::F_SEAL_GROW | libc::F_SEAL_SHRINK | libc::F_SEAL_SEAL, 473 ) 474 }; 475 if res < 0 { 476 return Err(Error::SetSeals(std::io::Error::last_os_error())); 477 } 478 479 // Mmap shm_log region 480 let region = MmapRegion::build( 481 Some(FileOffset::new(file, 0)), 482 mmap_size as usize, 483 libc::PROT_READ | libc::PROT_WRITE, 484 libc::MAP_SHARED, 485 ) 486 .map_err(Error::NewMmapRegion)?; 487 488 // Make sure we hold onto the region to prevent the mapping from being 489 // released. 490 let old_region = self.shm_log.replace(Arc::new(region)); 491 492 // Send the shm_log fd over to the backend 493 let log = VhostUserDirtyLogRegion { 494 mmap_size, 495 mmap_offset: 0, 496 mmap_handle, 497 }; 498 self.vu 499 .set_log_base(0, Some(log)) 500 .map_err(Error::VhostUserSetLogBase)?; 501 502 Ok(old_region) 503 } 504 505 fn set_vring_logging(&mut self, enable: bool) -> Result<()> { 506 if let Some(vrings_info) = &self.vrings_info { 507 for (i, vring_info) in vrings_info.iter().enumerate() { 508 let mut config_data = vring_info.config_data; 509 config_data.flags = if enable { 1 << VHOST_VRING_F_LOG } else { 0 }; 510 config_data.log_addr = if enable { 511 Some(vring_info.used_guest_addr) 512 } else { 513 None 514 }; 515 516 self.vu 517 .set_vring_addr(i, &config_data) 518 .map_err(Error::VhostUserSetVringAddr)?; 519 } 520 } 521 522 Ok(()) 523 } 524 525 pub fn start_dirty_log(&mut self, last_ram_addr: u64) -> Result<()> { 526 if !self.supports_migration { 527 return Err(Error::MigrationNotSupported); 528 } 529 530 // Set the shm log region 531 self.update_log_base(last_ram_addr)?; 532 533 // Enable VHOST_F_LOG_ALL feature 534 let features = self.acked_features | (1 << VHOST_F_LOG_ALL); 535 self.vu 536 .set_features(features) 537 .map_err(Error::VhostUserSetFeatures)?; 538 539 // Enable dirty page logging of used ring for all queues 540 self.set_vring_logging(true) 541 } 542 543 pub fn stop_dirty_log(&mut self) -> Result<()> { 544 if !self.supports_migration { 545 return Err(Error::MigrationNotSupported); 546 } 547 548 // Disable dirty page logging of used ring for all queues 549 self.set_vring_logging(false)?; 550 551 // Disable VHOST_F_LOG_ALL feature 552 self.vu 553 .set_features(self.acked_features) 554 .map_err(Error::VhostUserSetFeatures)?; 555 556 // This is important here since the log region goes out of scope, 557 // invoking the Drop trait, hence unmapping the memory. 558 self.shm_log = None; 559 560 Ok(()) 561 } 562 563 pub fn dirty_log(&mut self, last_ram_addr: u64) -> Result<MemoryRangeTable> { 564 // The log region is updated by creating a new region that is sent to 565 // the backend. This ensures the backend stops logging to the previous 566 // region. The previous region is returned and processed to create the 567 // bitmap representing the dirty pages. 568 if let Some(region) = self.update_log_base(last_ram_addr)? { 569 // Be careful with the size, as it was based on u8, meaning we must 570 // divide it by 8. 571 let len = region.size() / 8; 572 let bitmap = unsafe { 573 // Cast the pointer to u64 574 let ptr = region.as_ptr() as *const u64; 575 std::slice::from_raw_parts(ptr, len).to_vec() 576 }; 577 Ok(MemoryRangeTable::from_bitmap(bitmap, 0, 4096)) 578 } else { 579 Err(Error::MissingShmLogRegion) 580 } 581 } 582 } 583 584 fn memfd_create(name: &ffi::CStr, flags: u32) -> std::result::Result<RawFd, std::io::Error> { 585 let res = unsafe { libc::syscall(libc::SYS_memfd_create, name.as_ptr(), flags) }; 586 587 if res < 0 { 588 Err(std::io::Error::last_os_error()) 589 } else { 590 Ok(res as RawFd) 591 } 592 } 593