1 // Copyright 2019 Intel Corporation. All Rights Reserved. 2 // SPDX-License-Identifier: Apache-2.0 3 4 use super::super::{Descriptor, Queue}; 5 use super::{Error, Result}; 6 use crate::vhost_user::Inflight; 7 use crate::{ 8 get_host_address_range, GuestMemoryMmap, GuestRegionMmap, MmapRegion, VirtioInterrupt, 9 VirtioInterruptType, 10 }; 11 use std::convert::TryInto; 12 use std::ffi; 13 use std::fs::File; 14 use std::os::unix::io::{AsRawFd, FromRawFd, RawFd}; 15 use std::os::unix::net::UnixListener; 16 use std::sync::Arc; 17 use std::thread::sleep; 18 use std::time::{Duration, Instant}; 19 use std::vec::Vec; 20 use vhost::vhost_kern::vhost_binding::{VHOST_F_LOG_ALL, VHOST_VRING_F_LOG}; 21 use vhost::vhost_user::message::{ 22 VhostUserHeaderFlag, VhostUserInflight, VhostUserProtocolFeatures, VhostUserVirtioFeatures, 23 }; 24 use vhost::vhost_user::{Master, MasterReqHandler, VhostUserMaster, VhostUserMasterReqHandler}; 25 use vhost::{VhostBackend, VhostUserDirtyLogRegion, VhostUserMemoryRegionInfo, VringConfigData}; 26 use vm_memory::{Address, Error as MmapError, FileOffset, GuestMemory, GuestMemoryRegion}; 27 use vm_migration::protocol::MemoryRangeTable; 28 use vmm_sys_util::eventfd::EventFd; 29 30 // Size of a dirty page for vhost-user. 31 const VHOST_LOG_PAGE: u64 = 0x1000; 32 33 #[derive(Debug, Clone)] 34 pub struct VhostUserConfig { 35 pub socket: String, 36 pub num_queues: usize, 37 pub queue_size: u16, 38 } 39 40 #[derive(Clone)] 41 struct VringInfo { 42 config_data: VringConfigData, 43 used_guest_addr: u64, 44 } 45 46 #[derive(Clone)] 47 pub struct VhostUserHandle { 48 vu: Master, 49 ready: bool, 50 supports_migration: bool, 51 shm_log: Option<Arc<MmapRegion>>, 52 acked_features: u64, 53 vrings_info: Option<Vec<VringInfo>>, 54 } 55 56 impl VhostUserHandle { 57 pub fn update_mem_table(&mut self, mem: &GuestMemoryMmap) -> Result<()> { 58 let mut regions: Vec<VhostUserMemoryRegionInfo> = Vec::new(); 59 for region in mem.iter() { 60 let (mmap_handle, mmap_offset) = match region.file_offset() { 61 Some(_file_offset) => (_file_offset.file().as_raw_fd(), _file_offset.start()), 62 None => return Err(Error::VhostUserMemoryRegion(MmapError::NoMemoryRegion)), 63 }; 64 65 let vhost_user_net_reg = VhostUserMemoryRegionInfo { 66 guest_phys_addr: region.start_addr().raw_value(), 67 memory_size: region.len() as u64, 68 userspace_addr: region.as_ptr() as u64, 69 mmap_offset, 70 mmap_handle, 71 }; 72 73 regions.push(vhost_user_net_reg); 74 } 75 76 self.vu 77 .set_mem_table(regions.as_slice()) 78 .map_err(Error::VhostUserSetMemTable)?; 79 80 Ok(()) 81 } 82 83 pub fn add_memory_region(&mut self, region: &Arc<GuestRegionMmap>) -> Result<()> { 84 let (mmap_handle, mmap_offset) = match region.file_offset() { 85 Some(file_offset) => (file_offset.file().as_raw_fd(), file_offset.start()), 86 None => return Err(Error::MissingRegionFd), 87 }; 88 89 let region = VhostUserMemoryRegionInfo { 90 guest_phys_addr: region.start_addr().raw_value(), 91 memory_size: region.len() as u64, 92 userspace_addr: region.as_ptr() as u64, 93 mmap_offset, 94 mmap_handle, 95 }; 96 97 self.vu 98 .add_mem_region(®ion) 99 .map_err(Error::VhostUserAddMemReg) 100 } 101 102 pub fn negotiate_features_vhost_user( 103 &mut self, 104 avail_features: u64, 105 avail_protocol_features: VhostUserProtocolFeatures, 106 ) -> Result<(u64, u64)> { 107 // Set vhost-user owner. 108 self.vu.set_owner().map_err(Error::VhostUserSetOwner)?; 109 110 // Get features from backend, do negotiation to get a feature collection which 111 // both VMM and backend support. 112 let backend_features = self 113 .vu 114 .get_features() 115 .map_err(Error::VhostUserGetFeatures)?; 116 let acked_features = avail_features & backend_features; 117 118 let acked_protocol_features = 119 if acked_features & VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits() != 0 { 120 let backend_protocol_features = self 121 .vu 122 .get_protocol_features() 123 .map_err(Error::VhostUserGetProtocolFeatures)?; 124 125 let acked_protocol_features = avail_protocol_features & backend_protocol_features; 126 127 self.vu 128 .set_protocol_features(acked_protocol_features) 129 .map_err(Error::VhostUserSetProtocolFeatures)?; 130 131 acked_protocol_features 132 } else { 133 VhostUserProtocolFeatures::empty() 134 }; 135 136 if avail_protocol_features.contains(VhostUserProtocolFeatures::REPLY_ACK) 137 && acked_protocol_features.contains(VhostUserProtocolFeatures::REPLY_ACK) 138 { 139 self.vu.set_hdr_flags(VhostUserHeaderFlag::NEED_REPLY); 140 } 141 142 self.update_supports_migration(acked_features, acked_protocol_features.bits()); 143 144 Ok((acked_features, acked_protocol_features.bits())) 145 } 146 147 #[allow(clippy::too_many_arguments)] 148 pub fn setup_vhost_user<S: VhostUserMasterReqHandler>( 149 &mut self, 150 mem: &GuestMemoryMmap, 151 queues: Vec<Queue>, 152 queue_evts: Vec<EventFd>, 153 virtio_interrupt: &Arc<dyn VirtioInterrupt>, 154 acked_features: u64, 155 slave_req_handler: &Option<MasterReqHandler<S>>, 156 inflight: Option<&mut Inflight>, 157 ) -> Result<()> { 158 self.vu 159 .set_features(acked_features) 160 .map_err(Error::VhostUserSetFeatures)?; 161 162 // Update internal value after it's been sent to the backend. 163 self.acked_features = acked_features; 164 165 // Let's first provide the memory table to the backend. 166 self.update_mem_table(mem)?; 167 168 // Send set_vring_num here, since it could tell backends, like SPDK, 169 // how many virt queues to be handled, which backend required to know 170 // at early stage. 171 for (queue_index, queue) in queues.iter().enumerate() { 172 self.vu 173 .set_vring_num(queue_index, queue.actual_size()) 174 .map_err(Error::VhostUserSetVringNum)?; 175 } 176 177 // Setup for inflight I/O tracking shared memory. 178 if let Some(inflight) = inflight { 179 if inflight.fd.is_none() { 180 let inflight_req_info = VhostUserInflight { 181 mmap_size: 0, 182 mmap_offset: 0, 183 num_queues: queues.len() as u16, 184 queue_size: queues[0].actual_size(), 185 }; 186 let (info, fd) = self 187 .vu 188 .get_inflight_fd(&inflight_req_info) 189 .map_err(Error::VhostUserGetInflight)?; 190 inflight.info = info; 191 inflight.fd = Some(fd); 192 } 193 // Unwrapping the inflight fd is safe here since we know it can't be None. 194 self.vu 195 .set_inflight_fd(&inflight.info, inflight.fd.as_ref().unwrap().as_raw_fd()) 196 .map_err(Error::VhostUserSetInflight)?; 197 } 198 199 let num_queues = queues.len() as usize; 200 201 let mut vrings_info = Vec::new(); 202 for (queue_index, queue) in queues.into_iter().enumerate() { 203 let actual_size: usize = queue.actual_size().try_into().unwrap(); 204 205 let config_data = VringConfigData { 206 queue_max_size: queue.get_max_size(), 207 queue_size: queue.actual_size(), 208 flags: 0u32, 209 desc_table_addr: get_host_address_range( 210 mem, 211 queue.desc_table, 212 actual_size * std::mem::size_of::<Descriptor>(), 213 ) 214 .ok_or(Error::DescriptorTableAddress)? as u64, 215 // The used ring is {flags: u16; idx: u16; virtq_used_elem [{id: u16, len: u16}; actual_size]}, 216 // i.e. 4 + (4 + 4) * actual_size. 217 used_ring_addr: get_host_address_range(mem, queue.used_ring, 4 + actual_size * 8) 218 .ok_or(Error::UsedAddress)? as u64, 219 // The used ring is {flags: u16; idx: u16; elem [u16; actual_size]}, 220 // i.e. 4 + (2) * actual_size. 221 avail_ring_addr: get_host_address_range(mem, queue.avail_ring, 4 + actual_size * 2) 222 .ok_or(Error::AvailAddress)? as u64, 223 log_addr: None, 224 }; 225 226 vrings_info.push(VringInfo { 227 config_data, 228 used_guest_addr: queue.used_ring.raw_value(), 229 }); 230 231 self.vu 232 .set_vring_addr(queue_index, &config_data) 233 .map_err(Error::VhostUserSetVringAddr)?; 234 self.vu 235 .set_vring_base( 236 queue_index, 237 queue 238 .used_index_from_memory(mem) 239 .map_err(Error::GetAvailableIndex)?, 240 ) 241 .map_err(Error::VhostUserSetVringBase)?; 242 243 if let Some(eventfd) = 244 virtio_interrupt.notifier(&VirtioInterruptType::Queue, Some(&queue)) 245 { 246 self.vu 247 .set_vring_call(queue_index, &eventfd) 248 .map_err(Error::VhostUserSetVringCall)?; 249 } 250 251 self.vu 252 .set_vring_kick(queue_index, &queue_evts[queue_index]) 253 .map_err(Error::VhostUserSetVringKick)?; 254 } 255 256 self.enable_vhost_user_vrings(num_queues, true)?; 257 258 if let Some(slave_req_handler) = slave_req_handler { 259 self.vu 260 .set_slave_request_fd(&slave_req_handler.get_tx_raw_fd()) 261 .map_err(Error::VhostUserSetSlaveRequestFd)?; 262 } 263 264 self.vrings_info = Some(vrings_info); 265 self.ready = true; 266 267 Ok(()) 268 } 269 270 fn enable_vhost_user_vrings(&mut self, num_queues: usize, enable: bool) -> Result<()> { 271 for queue_index in 0..num_queues { 272 self.vu 273 .set_vring_enable(queue_index, enable) 274 .map_err(Error::VhostUserSetVringEnable)?; 275 } 276 277 Ok(()) 278 } 279 280 pub fn reset_vhost_user(&mut self, num_queues: usize) -> Result<()> { 281 self.enable_vhost_user_vrings(num_queues, false)?; 282 283 // Reset the owner. 284 self.vu.reset_owner().map_err(Error::VhostUserResetOwner) 285 } 286 287 pub fn set_protocol_features_vhost_user( 288 &mut self, 289 acked_features: u64, 290 acked_protocol_features: u64, 291 ) -> Result<()> { 292 self.vu.set_owner().map_err(Error::VhostUserSetOwner)?; 293 self.vu 294 .get_features() 295 .map_err(Error::VhostUserGetFeatures)?; 296 297 if acked_features & VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits() != 0 { 298 if let Some(acked_protocol_features) = 299 VhostUserProtocolFeatures::from_bits(acked_protocol_features) 300 { 301 self.vu 302 .set_protocol_features(acked_protocol_features) 303 .map_err(Error::VhostUserSetProtocolFeatures)?; 304 305 if acked_protocol_features.contains(VhostUserProtocolFeatures::REPLY_ACK) { 306 self.vu.set_hdr_flags(VhostUserHeaderFlag::NEED_REPLY); 307 } 308 } 309 } 310 311 self.update_supports_migration(acked_features, acked_protocol_features); 312 313 Ok(()) 314 } 315 316 #[allow(clippy::too_many_arguments)] 317 pub fn reinitialize_vhost_user<S: VhostUserMasterReqHandler>( 318 &mut self, 319 mem: &GuestMemoryMmap, 320 queues: Vec<Queue>, 321 queue_evts: Vec<EventFd>, 322 virtio_interrupt: &Arc<dyn VirtioInterrupt>, 323 acked_features: u64, 324 acked_protocol_features: u64, 325 slave_req_handler: &Option<MasterReqHandler<S>>, 326 inflight: Option<&mut Inflight>, 327 ) -> Result<()> { 328 self.set_protocol_features_vhost_user(acked_features, acked_protocol_features)?; 329 330 self.setup_vhost_user( 331 mem, 332 queues, 333 queue_evts, 334 virtio_interrupt, 335 acked_features, 336 slave_req_handler, 337 inflight, 338 ) 339 } 340 341 pub fn connect_vhost_user( 342 server: bool, 343 socket_path: &str, 344 num_queues: u64, 345 unlink_socket: bool, 346 ) -> Result<Self> { 347 if server { 348 if unlink_socket { 349 std::fs::remove_file(socket_path).map_err(Error::RemoveSocketPath)?; 350 } 351 352 info!("Binding vhost-user listener..."); 353 let listener = UnixListener::bind(socket_path).map_err(Error::BindSocket)?; 354 info!("Waiting for incoming vhost-user connection..."); 355 let (stream, _) = listener.accept().map_err(Error::AcceptConnection)?; 356 357 Ok(VhostUserHandle { 358 vu: Master::from_stream(stream, num_queues), 359 ready: false, 360 supports_migration: false, 361 shm_log: None, 362 acked_features: 0, 363 vrings_info: None, 364 }) 365 } else { 366 let now = Instant::now(); 367 368 // Retry connecting for a full minute 369 let err = loop { 370 let err = match Master::connect(socket_path, num_queues) { 371 Ok(m) => { 372 return Ok(VhostUserHandle { 373 vu: m, 374 ready: false, 375 supports_migration: false, 376 shm_log: None, 377 acked_features: 0, 378 vrings_info: None, 379 }) 380 } 381 Err(e) => e, 382 }; 383 sleep(Duration::from_millis(100)); 384 385 if now.elapsed().as_secs() >= 60 { 386 break err; 387 } 388 }; 389 390 error!( 391 "Failed connecting the backend after trying for 1 minute: {:?}", 392 err 393 ); 394 Err(Error::VhostUserConnect) 395 } 396 } 397 398 pub fn socket_handle(&mut self) -> &mut Master { 399 &mut self.vu 400 } 401 402 pub fn pause_vhost_user(&mut self, num_queues: usize) -> Result<()> { 403 if self.ready { 404 self.enable_vhost_user_vrings(num_queues, false)?; 405 } 406 407 Ok(()) 408 } 409 410 pub fn resume_vhost_user(&mut self, num_queues: usize) -> Result<()> { 411 if self.ready { 412 self.enable_vhost_user_vrings(num_queues, true)?; 413 } 414 415 Ok(()) 416 } 417 418 fn update_supports_migration(&mut self, acked_features: u64, acked_protocol_features: u64) { 419 if (acked_features & u64::from(vhost::vhost_kern::vhost_binding::VHOST_F_LOG_ALL) != 0) 420 && (acked_protocol_features & VhostUserProtocolFeatures::LOG_SHMFD.bits() != 0) 421 { 422 self.supports_migration = true; 423 } 424 } 425 426 fn update_log_base(&mut self, last_ram_addr: u64) -> Result<Option<Arc<MmapRegion>>> { 427 // Create the memfd 428 let fd = memfd_create( 429 &ffi::CString::new("vhost_user_dirty_log").unwrap(), 430 libc::MFD_CLOEXEC | libc::MFD_ALLOW_SEALING, 431 ) 432 .map_err(Error::MemfdCreate)?; 433 434 // Safe because we checked the file descriptor is valid 435 let file = unsafe { File::from_raw_fd(fd) }; 436 // The size of the memory mapping corresponds to the size of a bitmap 437 // covering all guest pages for addresses from 0 to the last physical 438 // address in guest RAM. 439 // A page is always 4kiB from a vhost-user perspective, and each bit is 440 // a page. That's how we can compute mmap_size from the last address. 441 let mmap_size = (last_ram_addr / (VHOST_LOG_PAGE * 8)) + 1; 442 let mmap_handle = file.as_raw_fd(); 443 444 // Set shm_log region size 445 file.set_len(mmap_size).map_err(Error::SetFileSize)?; 446 447 // Set the seals 448 let res = unsafe { 449 libc::fcntl( 450 file.as_raw_fd(), 451 libc::F_ADD_SEALS, 452 libc::F_SEAL_GROW | libc::F_SEAL_SHRINK | libc::F_SEAL_SEAL, 453 ) 454 }; 455 if res < 0 { 456 return Err(Error::SetSeals(std::io::Error::last_os_error())); 457 } 458 459 // Mmap shm_log region 460 let region = MmapRegion::build( 461 Some(FileOffset::new(file, 0)), 462 mmap_size as usize, 463 libc::PROT_READ | libc::PROT_WRITE, 464 libc::MAP_SHARED, 465 ) 466 .map_err(Error::NewMmapRegion)?; 467 468 // Make sure we hold onto the region to prevent the mapping from being 469 // released. 470 let old_region = self.shm_log.replace(Arc::new(region)); 471 472 // Send the shm_log fd over to the backend 473 let log = VhostUserDirtyLogRegion { 474 mmap_size, 475 mmap_offset: 0, 476 mmap_handle, 477 }; 478 self.vu 479 .set_log_base(0, Some(log)) 480 .map_err(Error::VhostUserSetLogBase)?; 481 482 Ok(old_region) 483 } 484 485 fn set_vring_logging(&mut self, enable: bool) -> Result<()> { 486 if let Some(vrings_info) = &self.vrings_info { 487 for (i, vring_info) in vrings_info.iter().enumerate() { 488 let mut config_data = vring_info.config_data; 489 config_data.flags = if enable { 1 << VHOST_VRING_F_LOG } else { 0 }; 490 config_data.log_addr = if enable { 491 Some(vring_info.used_guest_addr) 492 } else { 493 None 494 }; 495 496 self.vu 497 .set_vring_addr(i, &config_data) 498 .map_err(Error::VhostUserSetVringAddr)?; 499 } 500 } 501 502 Ok(()) 503 } 504 505 pub fn start_dirty_log(&mut self, last_ram_addr: u64) -> Result<()> { 506 if !self.supports_migration { 507 return Err(Error::MigrationNotSupported); 508 } 509 510 // Set the shm log region 511 self.update_log_base(last_ram_addr)?; 512 513 // Enable VHOST_F_LOG_ALL feature 514 let features = self.acked_features | (1 << VHOST_F_LOG_ALL); 515 self.vu 516 .set_features(features) 517 .map_err(Error::VhostUserSetFeatures)?; 518 519 // Enable dirty page logging of used ring for all queues 520 self.set_vring_logging(true) 521 } 522 523 pub fn stop_dirty_log(&mut self) -> Result<()> { 524 if !self.supports_migration { 525 return Err(Error::MigrationNotSupported); 526 } 527 528 // Disable dirty page logging of used ring for all queues 529 self.set_vring_logging(false)?; 530 531 // Disable VHOST_F_LOG_ALL feature 532 self.vu 533 .set_features(self.acked_features) 534 .map_err(Error::VhostUserSetFeatures)?; 535 536 // This is important here since the log region goes out of scope, 537 // invoking the Drop trait, hence unmapping the memory. 538 self.shm_log = None; 539 540 Ok(()) 541 } 542 543 pub fn dirty_log(&mut self, last_ram_addr: u64) -> Result<MemoryRangeTable> { 544 // The log region is updated by creating a new region that is sent to 545 // the backend. This ensures the backend stops logging to the previous 546 // region. The previous region is returned and processed to create the 547 // bitmap representing the dirty pages. 548 if let Some(region) = self.update_log_base(last_ram_addr)? { 549 // Be careful with the size, as it was based on u8, meaning we must 550 // divide it by 8. 551 let len = region.size() / 8; 552 let bitmap = unsafe { 553 // Cast the pointer to u64 554 let ptr = region.as_ptr() as *const u64; 555 std::slice::from_raw_parts(ptr, len).to_vec() 556 }; 557 Ok(MemoryRangeTable::from_bitmap(bitmap, 0)) 558 } else { 559 Err(Error::MissingShmLogRegion) 560 } 561 } 562 } 563 564 fn memfd_create(name: &ffi::CStr, flags: u32) -> std::result::Result<RawFd, std::io::Error> { 565 let res = unsafe { libc::syscall(libc::SYS_memfd_create, name.as_ptr(), flags) }; 566 567 if res < 0 { 568 Err(std::io::Error::last_os_error()) 569 } else { 570 Ok(res as RawFd) 571 } 572 } 573