1 // Copyright 2019 Intel Corporation. All Rights Reserved. 2 // SPDX-License-Identifier: Apache-2.0 3 4 use super::{Error, Result}; 5 use crate::vhost_user::Inflight; 6 use crate::{ 7 get_host_address_range, GuestMemoryMmap, GuestRegionMmap, MmapRegion, VirtioInterrupt, 8 VirtioInterruptType, 9 }; 10 use std::convert::TryInto; 11 use std::ffi; 12 use std::fs::File; 13 use std::os::unix::io::{AsRawFd, FromRawFd, RawFd}; 14 use std::os::unix::net::UnixListener; 15 use std::sync::atomic::Ordering; 16 use std::sync::Arc; 17 use std::thread::sleep; 18 use std::time::{Duration, Instant}; 19 use std::vec::Vec; 20 use vhost::vhost_kern::vhost_binding::{VHOST_F_LOG_ALL, VHOST_VRING_F_LOG}; 21 use vhost::vhost_user::message::{ 22 VhostUserHeaderFlag, VhostUserInflight, VhostUserProtocolFeatures, VhostUserVirtioFeatures, 23 }; 24 use vhost::vhost_user::{Master, MasterReqHandler, VhostUserMaster, VhostUserMasterReqHandler}; 25 use vhost::{VhostBackend, VhostUserDirtyLogRegion, VhostUserMemoryRegionInfo, VringConfigData}; 26 use virtio_queue::{Descriptor, Queue}; 27 use vm_memory::{ 28 Address, Error as MmapError, FileOffset, GuestMemory, GuestMemoryAtomic, GuestMemoryRegion, 29 }; 30 use vm_migration::protocol::MemoryRangeTable; 31 use vmm_sys_util::eventfd::EventFd; 32 33 // Size of a dirty page for vhost-user. 34 const VHOST_LOG_PAGE: u64 = 0x1000; 35 36 #[derive(Debug, Clone)] 37 pub struct VhostUserConfig { 38 pub socket: String, 39 pub num_queues: usize, 40 pub queue_size: u16, 41 } 42 43 #[derive(Clone)] 44 struct VringInfo { 45 config_data: VringConfigData, 46 used_guest_addr: u64, 47 } 48 49 #[derive(Clone)] 50 pub struct VhostUserHandle { 51 vu: Master, 52 ready: bool, 53 supports_migration: bool, 54 shm_log: Option<Arc<MmapRegion>>, 55 acked_features: u64, 56 vrings_info: Option<Vec<VringInfo>>, 57 } 58 59 impl VhostUserHandle { 60 pub fn update_mem_table(&mut self, mem: &GuestMemoryMmap) -> Result<()> { 61 let mut regions: Vec<VhostUserMemoryRegionInfo> = Vec::new(); 62 for region in mem.iter() { 63 let (mmap_handle, mmap_offset) = match region.file_offset() { 64 Some(_file_offset) => (_file_offset.file().as_raw_fd(), _file_offset.start()), 65 None => return Err(Error::VhostUserMemoryRegion(MmapError::NoMemoryRegion)), 66 }; 67 68 let vhost_user_net_reg = VhostUserMemoryRegionInfo { 69 guest_phys_addr: region.start_addr().raw_value(), 70 memory_size: region.len() as u64, 71 userspace_addr: region.as_ptr() as u64, 72 mmap_offset, 73 mmap_handle, 74 }; 75 76 regions.push(vhost_user_net_reg); 77 } 78 79 self.vu 80 .set_mem_table(regions.as_slice()) 81 .map_err(Error::VhostUserSetMemTable)?; 82 83 Ok(()) 84 } 85 86 pub fn add_memory_region(&mut self, region: &Arc<GuestRegionMmap>) -> Result<()> { 87 let (mmap_handle, mmap_offset) = match region.file_offset() { 88 Some(file_offset) => (file_offset.file().as_raw_fd(), file_offset.start()), 89 None => return Err(Error::MissingRegionFd), 90 }; 91 92 let region = VhostUserMemoryRegionInfo { 93 guest_phys_addr: region.start_addr().raw_value(), 94 memory_size: region.len() as u64, 95 userspace_addr: region.as_ptr() as u64, 96 mmap_offset, 97 mmap_handle, 98 }; 99 100 self.vu 101 .add_mem_region(®ion) 102 .map_err(Error::VhostUserAddMemReg) 103 } 104 105 pub fn negotiate_features_vhost_user( 106 &mut self, 107 avail_features: u64, 108 avail_protocol_features: VhostUserProtocolFeatures, 109 ) -> Result<(u64, u64)> { 110 // Set vhost-user owner. 111 self.vu.set_owner().map_err(Error::VhostUserSetOwner)?; 112 113 // Get features from backend, do negotiation to get a feature collection which 114 // both VMM and backend support. 115 let backend_features = self 116 .vu 117 .get_features() 118 .map_err(Error::VhostUserGetFeatures)?; 119 let acked_features = avail_features & backend_features; 120 121 let acked_protocol_features = 122 if acked_features & VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits() != 0 { 123 let backend_protocol_features = self 124 .vu 125 .get_protocol_features() 126 .map_err(Error::VhostUserGetProtocolFeatures)?; 127 128 let acked_protocol_features = avail_protocol_features & backend_protocol_features; 129 130 self.vu 131 .set_protocol_features(acked_protocol_features) 132 .map_err(Error::VhostUserSetProtocolFeatures)?; 133 134 acked_protocol_features 135 } else { 136 VhostUserProtocolFeatures::empty() 137 }; 138 139 if avail_protocol_features.contains(VhostUserProtocolFeatures::REPLY_ACK) 140 && acked_protocol_features.contains(VhostUserProtocolFeatures::REPLY_ACK) 141 { 142 self.vu.set_hdr_flags(VhostUserHeaderFlag::NEED_REPLY); 143 } 144 145 self.update_supports_migration(acked_features, acked_protocol_features.bits()); 146 147 Ok((acked_features, acked_protocol_features.bits())) 148 } 149 150 #[allow(clippy::too_many_arguments)] 151 pub fn setup_vhost_user<S: VhostUserMasterReqHandler>( 152 &mut self, 153 mem: &GuestMemoryMmap, 154 queues: Vec<Queue<GuestMemoryAtomic<GuestMemoryMmap>>>, 155 queue_evts: Vec<EventFd>, 156 virtio_interrupt: &Arc<dyn VirtioInterrupt>, 157 acked_features: u64, 158 slave_req_handler: &Option<MasterReqHandler<S>>, 159 inflight: Option<&mut Inflight>, 160 ) -> Result<()> { 161 self.vu 162 .set_features(acked_features) 163 .map_err(Error::VhostUserSetFeatures)?; 164 165 // Update internal value after it's been sent to the backend. 166 self.acked_features = acked_features; 167 168 // Let's first provide the memory table to the backend. 169 self.update_mem_table(mem)?; 170 171 // Send set_vring_num here, since it could tell backends, like SPDK, 172 // how many virt queues to be handled, which backend required to know 173 // at early stage. 174 for (queue_index, queue) in queues.iter().enumerate() { 175 self.vu 176 .set_vring_num(queue_index, queue.max_size()) 177 .map_err(Error::VhostUserSetVringNum)?; 178 } 179 180 // Setup for inflight I/O tracking shared memory. 181 if let Some(inflight) = inflight { 182 if inflight.fd.is_none() { 183 let inflight_req_info = VhostUserInflight { 184 mmap_size: 0, 185 mmap_offset: 0, 186 num_queues: queues.len() as u16, 187 queue_size: queues[0].max_size(), 188 }; 189 let (info, fd) = self 190 .vu 191 .get_inflight_fd(&inflight_req_info) 192 .map_err(Error::VhostUserGetInflight)?; 193 inflight.info = info; 194 inflight.fd = Some(fd); 195 } 196 // Unwrapping the inflight fd is safe here since we know it can't be None. 197 self.vu 198 .set_inflight_fd(&inflight.info, inflight.fd.as_ref().unwrap().as_raw_fd()) 199 .map_err(Error::VhostUserSetInflight)?; 200 } 201 202 let num_queues = queues.len() as usize; 203 204 let mut vrings_info = Vec::new(); 205 for (queue_index, queue) in queues.into_iter().enumerate() { 206 let actual_size: usize = queue.max_size().try_into().unwrap(); 207 208 let config_data = VringConfigData { 209 queue_max_size: queue.max_size(), 210 queue_size: queue.max_size(), 211 flags: 0u32, 212 desc_table_addr: get_host_address_range( 213 mem, 214 queue.state.desc_table, 215 actual_size * std::mem::size_of::<Descriptor>(), 216 ) 217 .ok_or(Error::DescriptorTableAddress)? as u64, 218 // The used ring is {flags: u16; idx: u16; virtq_used_elem [{id: u16, len: u16}; actual_size]}, 219 // i.e. 4 + (4 + 4) * actual_size. 220 used_ring_addr: get_host_address_range( 221 mem, 222 queue.state.used_ring, 223 4 + actual_size * 8, 224 ) 225 .ok_or(Error::UsedAddress)? as u64, 226 // The used ring is {flags: u16; idx: u16; elem [u16; actual_size]}, 227 // i.e. 4 + (2) * actual_size. 228 avail_ring_addr: get_host_address_range( 229 mem, 230 queue.state.avail_ring, 231 4 + actual_size * 2, 232 ) 233 .ok_or(Error::AvailAddress)? as u64, 234 log_addr: None, 235 }; 236 237 vrings_info.push(VringInfo { 238 config_data, 239 used_guest_addr: queue.state.used_ring.raw_value(), 240 }); 241 242 self.vu 243 .set_vring_addr(queue_index, &config_data) 244 .map_err(Error::VhostUserSetVringAddr)?; 245 self.vu 246 .set_vring_base( 247 queue_index, 248 queue 249 .avail_idx(Ordering::Acquire) 250 .map_err(Error::GetAvailableIndex)? 251 .0, 252 ) 253 .map_err(Error::VhostUserSetVringBase)?; 254 255 if let Some(eventfd) = 256 virtio_interrupt.notifier(VirtioInterruptType::Queue(queue_index as u16)) 257 { 258 self.vu 259 .set_vring_call(queue_index, &eventfd) 260 .map_err(Error::VhostUserSetVringCall)?; 261 } 262 263 self.vu 264 .set_vring_kick(queue_index, &queue_evts[queue_index]) 265 .map_err(Error::VhostUserSetVringKick)?; 266 } 267 268 self.enable_vhost_user_vrings(num_queues, true)?; 269 270 if let Some(slave_req_handler) = slave_req_handler { 271 self.vu 272 .set_slave_request_fd(&slave_req_handler.get_tx_raw_fd()) 273 .map_err(Error::VhostUserSetSlaveRequestFd)?; 274 } 275 276 self.vrings_info = Some(vrings_info); 277 self.ready = true; 278 279 Ok(()) 280 } 281 282 fn enable_vhost_user_vrings(&mut self, num_queues: usize, enable: bool) -> Result<()> { 283 for queue_index in 0..num_queues { 284 self.vu 285 .set_vring_enable(queue_index, enable) 286 .map_err(Error::VhostUserSetVringEnable)?; 287 } 288 289 Ok(()) 290 } 291 292 pub fn reset_vhost_user(&mut self, num_queues: usize) -> Result<()> { 293 self.enable_vhost_user_vrings(num_queues, false)?; 294 295 // Reset the owner. 296 self.vu.reset_owner().map_err(Error::VhostUserResetOwner) 297 } 298 299 pub fn set_protocol_features_vhost_user( 300 &mut self, 301 acked_features: u64, 302 acked_protocol_features: u64, 303 ) -> Result<()> { 304 self.vu.set_owner().map_err(Error::VhostUserSetOwner)?; 305 self.vu 306 .get_features() 307 .map_err(Error::VhostUserGetFeatures)?; 308 309 if acked_features & VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits() != 0 { 310 if let Some(acked_protocol_features) = 311 VhostUserProtocolFeatures::from_bits(acked_protocol_features) 312 { 313 self.vu 314 .set_protocol_features(acked_protocol_features) 315 .map_err(Error::VhostUserSetProtocolFeatures)?; 316 317 if acked_protocol_features.contains(VhostUserProtocolFeatures::REPLY_ACK) { 318 self.vu.set_hdr_flags(VhostUserHeaderFlag::NEED_REPLY); 319 } 320 } 321 } 322 323 self.update_supports_migration(acked_features, acked_protocol_features); 324 325 Ok(()) 326 } 327 328 #[allow(clippy::too_many_arguments)] 329 pub fn reinitialize_vhost_user<S: VhostUserMasterReqHandler>( 330 &mut self, 331 mem: &GuestMemoryMmap, 332 queues: Vec<Queue<GuestMemoryAtomic<GuestMemoryMmap>>>, 333 queue_evts: Vec<EventFd>, 334 virtio_interrupt: &Arc<dyn VirtioInterrupt>, 335 acked_features: u64, 336 acked_protocol_features: u64, 337 slave_req_handler: &Option<MasterReqHandler<S>>, 338 inflight: Option<&mut Inflight>, 339 ) -> Result<()> { 340 self.set_protocol_features_vhost_user(acked_features, acked_protocol_features)?; 341 342 self.setup_vhost_user( 343 mem, 344 queues, 345 queue_evts, 346 virtio_interrupt, 347 acked_features, 348 slave_req_handler, 349 inflight, 350 ) 351 } 352 353 pub fn connect_vhost_user( 354 server: bool, 355 socket_path: &str, 356 num_queues: u64, 357 unlink_socket: bool, 358 ) -> Result<Self> { 359 if server { 360 if unlink_socket { 361 std::fs::remove_file(socket_path).map_err(Error::RemoveSocketPath)?; 362 } 363 364 info!("Binding vhost-user listener..."); 365 let listener = UnixListener::bind(socket_path).map_err(Error::BindSocket)?; 366 info!("Waiting for incoming vhost-user connection..."); 367 let (stream, _) = listener.accept().map_err(Error::AcceptConnection)?; 368 369 Ok(VhostUserHandle { 370 vu: Master::from_stream(stream, num_queues), 371 ready: false, 372 supports_migration: false, 373 shm_log: None, 374 acked_features: 0, 375 vrings_info: None, 376 }) 377 } else { 378 let now = Instant::now(); 379 380 // Retry connecting for a full minute 381 let err = loop { 382 let err = match Master::connect(socket_path, num_queues) { 383 Ok(m) => { 384 return Ok(VhostUserHandle { 385 vu: m, 386 ready: false, 387 supports_migration: false, 388 shm_log: None, 389 acked_features: 0, 390 vrings_info: None, 391 }) 392 } 393 Err(e) => e, 394 }; 395 sleep(Duration::from_millis(100)); 396 397 if now.elapsed().as_secs() >= 60 { 398 break err; 399 } 400 }; 401 402 error!( 403 "Failed connecting the backend after trying for 1 minute: {:?}", 404 err 405 ); 406 Err(Error::VhostUserConnect) 407 } 408 } 409 410 pub fn socket_handle(&mut self) -> &mut Master { 411 &mut self.vu 412 } 413 414 pub fn pause_vhost_user(&mut self, num_queues: usize) -> Result<()> { 415 if self.ready { 416 self.enable_vhost_user_vrings(num_queues, false)?; 417 } 418 419 Ok(()) 420 } 421 422 pub fn resume_vhost_user(&mut self, num_queues: usize) -> Result<()> { 423 if self.ready { 424 self.enable_vhost_user_vrings(num_queues, true)?; 425 } 426 427 Ok(()) 428 } 429 430 fn update_supports_migration(&mut self, acked_features: u64, acked_protocol_features: u64) { 431 if (acked_features & u64::from(vhost::vhost_kern::vhost_binding::VHOST_F_LOG_ALL) != 0) 432 && (acked_protocol_features & VhostUserProtocolFeatures::LOG_SHMFD.bits() != 0) 433 { 434 self.supports_migration = true; 435 } 436 } 437 438 fn update_log_base(&mut self, last_ram_addr: u64) -> Result<Option<Arc<MmapRegion>>> { 439 // Create the memfd 440 let fd = memfd_create( 441 &ffi::CString::new("vhost_user_dirty_log").unwrap(), 442 libc::MFD_CLOEXEC | libc::MFD_ALLOW_SEALING, 443 ) 444 .map_err(Error::MemfdCreate)?; 445 446 // Safe because we checked the file descriptor is valid 447 let file = unsafe { File::from_raw_fd(fd) }; 448 // The size of the memory mapping corresponds to the size of a bitmap 449 // covering all guest pages for addresses from 0 to the last physical 450 // address in guest RAM. 451 // A page is always 4kiB from a vhost-user perspective, and each bit is 452 // a page. That's how we can compute mmap_size from the last address. 453 let mmap_size = (last_ram_addr / (VHOST_LOG_PAGE * 8)) + 1; 454 let mmap_handle = file.as_raw_fd(); 455 456 // Set shm_log region size 457 file.set_len(mmap_size).map_err(Error::SetFileSize)?; 458 459 // Set the seals 460 let res = unsafe { 461 libc::fcntl( 462 file.as_raw_fd(), 463 libc::F_ADD_SEALS, 464 libc::F_SEAL_GROW | libc::F_SEAL_SHRINK | libc::F_SEAL_SEAL, 465 ) 466 }; 467 if res < 0 { 468 return Err(Error::SetSeals(std::io::Error::last_os_error())); 469 } 470 471 // Mmap shm_log region 472 let region = MmapRegion::build( 473 Some(FileOffset::new(file, 0)), 474 mmap_size as usize, 475 libc::PROT_READ | libc::PROT_WRITE, 476 libc::MAP_SHARED, 477 ) 478 .map_err(Error::NewMmapRegion)?; 479 480 // Make sure we hold onto the region to prevent the mapping from being 481 // released. 482 let old_region = self.shm_log.replace(Arc::new(region)); 483 484 // Send the shm_log fd over to the backend 485 let log = VhostUserDirtyLogRegion { 486 mmap_size, 487 mmap_offset: 0, 488 mmap_handle, 489 }; 490 self.vu 491 .set_log_base(0, Some(log)) 492 .map_err(Error::VhostUserSetLogBase)?; 493 494 Ok(old_region) 495 } 496 497 fn set_vring_logging(&mut self, enable: bool) -> Result<()> { 498 if let Some(vrings_info) = &self.vrings_info { 499 for (i, vring_info) in vrings_info.iter().enumerate() { 500 let mut config_data = vring_info.config_data; 501 config_data.flags = if enable { 1 << VHOST_VRING_F_LOG } else { 0 }; 502 config_data.log_addr = if enable { 503 Some(vring_info.used_guest_addr) 504 } else { 505 None 506 }; 507 508 self.vu 509 .set_vring_addr(i, &config_data) 510 .map_err(Error::VhostUserSetVringAddr)?; 511 } 512 } 513 514 Ok(()) 515 } 516 517 pub fn start_dirty_log(&mut self, last_ram_addr: u64) -> Result<()> { 518 if !self.supports_migration { 519 return Err(Error::MigrationNotSupported); 520 } 521 522 // Set the shm log region 523 self.update_log_base(last_ram_addr)?; 524 525 // Enable VHOST_F_LOG_ALL feature 526 let features = self.acked_features | (1 << VHOST_F_LOG_ALL); 527 self.vu 528 .set_features(features) 529 .map_err(Error::VhostUserSetFeatures)?; 530 531 // Enable dirty page logging of used ring for all queues 532 self.set_vring_logging(true) 533 } 534 535 pub fn stop_dirty_log(&mut self) -> Result<()> { 536 if !self.supports_migration { 537 return Err(Error::MigrationNotSupported); 538 } 539 540 // Disable dirty page logging of used ring for all queues 541 self.set_vring_logging(false)?; 542 543 // Disable VHOST_F_LOG_ALL feature 544 self.vu 545 .set_features(self.acked_features) 546 .map_err(Error::VhostUserSetFeatures)?; 547 548 // This is important here since the log region goes out of scope, 549 // invoking the Drop trait, hence unmapping the memory. 550 self.shm_log = None; 551 552 Ok(()) 553 } 554 555 pub fn dirty_log(&mut self, last_ram_addr: u64) -> Result<MemoryRangeTable> { 556 // The log region is updated by creating a new region that is sent to 557 // the backend. This ensures the backend stops logging to the previous 558 // region. The previous region is returned and processed to create the 559 // bitmap representing the dirty pages. 560 if let Some(region) = self.update_log_base(last_ram_addr)? { 561 // Be careful with the size, as it was based on u8, meaning we must 562 // divide it by 8. 563 let len = region.size() / 8; 564 let bitmap = unsafe { 565 // Cast the pointer to u64 566 let ptr = region.as_ptr() as *const u64; 567 std::slice::from_raw_parts(ptr, len).to_vec() 568 }; 569 Ok(MemoryRangeTable::from_bitmap(bitmap, 0, 4096)) 570 } else { 571 Err(Error::MissingShmLogRegion) 572 } 573 } 574 } 575 576 fn memfd_create(name: &ffi::CStr, flags: u32) -> std::result::Result<RawFd, std::io::Error> { 577 let res = unsafe { libc::syscall(libc::SYS_memfd_create, name.as_ptr(), flags) }; 578 579 if res < 0 { 580 Err(std::io::Error::last_os_error()) 581 } else { 582 Ok(res as RawFd) 583 } 584 } 585