1 // Copyright 2019 Intel Corporation. All Rights Reserved. 2 // SPDX-License-Identifier: Apache-2.0 3 4 use super::super::{Descriptor, Queue}; 5 use super::{Error, Result}; 6 use crate::vhost_user::Inflight; 7 use crate::{ 8 get_host_address_range, GuestMemoryMmap, GuestRegionMmap, MmapRegion, VirtioInterrupt, 9 VirtioInterruptType, 10 }; 11 use std::convert::TryInto; 12 use std::ffi; 13 use std::fs::File; 14 use std::os::unix::io::{AsRawFd, FromRawFd, RawFd}; 15 use std::os::unix::net::UnixListener; 16 use std::sync::Arc; 17 use std::thread::sleep; 18 use std::time::{Duration, Instant}; 19 use std::vec::Vec; 20 use vhost::vhost_kern::vhost_binding::{VHOST_F_LOG_ALL, VHOST_VRING_F_LOG}; 21 use vhost::vhost_user::message::{ 22 VhostUserHeaderFlag, VhostUserInflight, VhostUserProtocolFeatures, VhostUserVirtioFeatures, 23 }; 24 use vhost::vhost_user::{Master, MasterReqHandler, VhostUserMaster, VhostUserMasterReqHandler}; 25 use vhost::{VhostBackend, VhostUserDirtyLogRegion, VhostUserMemoryRegionInfo, VringConfigData}; 26 use vm_memory::{Address, Error as MmapError, FileOffset, GuestMemory, GuestMemoryRegion}; 27 use vm_migration::protocol::MemoryRangeTable; 28 use vmm_sys_util::eventfd::EventFd; 29 30 // Size of a dirty page for vhost-user. 31 const VHOST_LOG_PAGE: u64 = 0x1000; 32 33 #[derive(Debug, Clone)] 34 pub struct VhostUserConfig { 35 pub socket: String, 36 pub num_queues: usize, 37 pub queue_size: u16, 38 } 39 40 #[derive(Clone)] 41 struct VringInfo { 42 config_data: VringConfigData, 43 used_guest_addr: u64, 44 } 45 46 #[derive(Clone)] 47 pub struct VhostUserHandle { 48 vu: Master, 49 ready: bool, 50 supports_migration: bool, 51 shm_log: Option<Arc<MmapRegion>>, 52 acked_features: u64, 53 vrings_info: Option<Vec<VringInfo>>, 54 } 55 56 impl VhostUserHandle { 57 pub fn update_mem_table(&mut self, mem: &GuestMemoryMmap) -> Result<()> { 58 let mut regions: Vec<VhostUserMemoryRegionInfo> = Vec::new(); 59 for region in mem.iter() { 60 let (mmap_handle, mmap_offset) = match region.file_offset() { 61 Some(_file_offset) => (_file_offset.file().as_raw_fd(), _file_offset.start()), 62 None => return Err(Error::VhostUserMemoryRegion(MmapError::NoMemoryRegion)), 63 }; 64 65 let vhost_user_net_reg = VhostUserMemoryRegionInfo { 66 guest_phys_addr: region.start_addr().raw_value(), 67 memory_size: region.len() as u64, 68 userspace_addr: region.as_ptr() as u64, 69 mmap_offset, 70 mmap_handle, 71 }; 72 73 regions.push(vhost_user_net_reg); 74 } 75 76 self.vu 77 .set_mem_table(regions.as_slice()) 78 .map_err(Error::VhostUserSetMemTable)?; 79 80 Ok(()) 81 } 82 83 pub fn add_memory_region(&mut self, region: &Arc<GuestRegionMmap>) -> Result<()> { 84 let (mmap_handle, mmap_offset) = match region.file_offset() { 85 Some(file_offset) => (file_offset.file().as_raw_fd(), file_offset.start()), 86 None => return Err(Error::MissingRegionFd), 87 }; 88 89 let region = VhostUserMemoryRegionInfo { 90 guest_phys_addr: region.start_addr().raw_value(), 91 memory_size: region.len() as u64, 92 userspace_addr: region.as_ptr() as u64, 93 mmap_offset, 94 mmap_handle, 95 }; 96 97 self.vu 98 .add_mem_region(®ion) 99 .map_err(Error::VhostUserAddMemReg) 100 } 101 102 pub fn negotiate_features_vhost_user( 103 &mut self, 104 avail_features: u64, 105 avail_protocol_features: VhostUserProtocolFeatures, 106 ) -> Result<(u64, u64)> { 107 // Set vhost-user owner. 108 self.vu.set_owner().map_err(Error::VhostUserSetOwner)?; 109 110 // Get features from backend, do negotiation to get a feature collection which 111 // both VMM and backend support. 112 let backend_features = self 113 .vu 114 .get_features() 115 .map_err(Error::VhostUserGetFeatures)?; 116 let acked_features = avail_features & backend_features; 117 118 let acked_protocol_features = 119 if acked_features & VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits() != 0 { 120 let backend_protocol_features = self 121 .vu 122 .get_protocol_features() 123 .map_err(Error::VhostUserGetProtocolFeatures)?; 124 125 let acked_protocol_features = avail_protocol_features & backend_protocol_features; 126 127 self.vu 128 .set_protocol_features(acked_protocol_features) 129 .map_err(Error::VhostUserSetProtocolFeatures)?; 130 131 acked_protocol_features 132 } else { 133 VhostUserProtocolFeatures::empty() 134 }; 135 136 if avail_protocol_features.contains(VhostUserProtocolFeatures::REPLY_ACK) 137 && acked_protocol_features.contains(VhostUserProtocolFeatures::REPLY_ACK) 138 { 139 self.vu.set_hdr_flags(VhostUserHeaderFlag::NEED_REPLY); 140 } 141 142 self.update_supports_migration(acked_features, acked_protocol_features.bits()); 143 144 Ok((acked_features, acked_protocol_features.bits())) 145 } 146 147 #[allow(clippy::too_many_arguments)] 148 pub fn setup_vhost_user<S: VhostUserMasterReqHandler>( 149 &mut self, 150 mem: &GuestMemoryMmap, 151 queues: Vec<Queue>, 152 queue_evts: Vec<EventFd>, 153 virtio_interrupt: &Arc<dyn VirtioInterrupt>, 154 acked_features: u64, 155 slave_req_handler: &Option<MasterReqHandler<S>>, 156 inflight: Option<&mut Inflight>, 157 ) -> Result<()> { 158 self.vu 159 .set_features(acked_features) 160 .map_err(Error::VhostUserSetFeatures)?; 161 162 // Update internal value after it's been sent to the backend. 163 self.acked_features = acked_features; 164 165 // Let's first provide the memory table to the backend. 166 self.update_mem_table(mem)?; 167 168 // Send set_vring_num here, since it could tell backends, like SPDK, 169 // how many virt queues to be handled, which backend required to know 170 // at early stage. 171 for (queue_index, queue) in queues.iter().enumerate() { 172 self.vu 173 .set_vring_num(queue_index, queue.actual_size()) 174 .map_err(Error::VhostUserSetVringNum)?; 175 } 176 177 // Setup for inflight I/O tracking shared memory. 178 if let Some(inflight) = inflight { 179 if inflight.fd.is_none() { 180 let inflight_req_info = VhostUserInflight { 181 mmap_size: 0, 182 mmap_offset: 0, 183 num_queues: queues.len() as u16, 184 queue_size: queues[0].actual_size(), 185 }; 186 let (info, fd) = self 187 .vu 188 .get_inflight_fd(&inflight_req_info) 189 .map_err(Error::VhostUserGetInflight)?; 190 inflight.info = info; 191 inflight.fd = Some(fd); 192 } 193 // Unwrapping the inflight fd is safe here since we know it can't be None. 194 self.vu 195 .set_inflight_fd(&inflight.info, inflight.fd.as_ref().unwrap().as_raw_fd()) 196 .map_err(Error::VhostUserSetInflight)?; 197 } 198 199 let num_queues = queues.len() as usize; 200 201 let mut vrings_info = Vec::new(); 202 for (queue_index, queue) in queues.into_iter().enumerate() { 203 let actual_size: usize = queue.actual_size().try_into().unwrap(); 204 205 let config_data = VringConfigData { 206 queue_max_size: queue.get_max_size(), 207 queue_size: queue.actual_size(), 208 flags: 0u32, 209 desc_table_addr: get_host_address_range( 210 mem, 211 queue.desc_table, 212 actual_size * std::mem::size_of::<Descriptor>(), 213 ) 214 .ok_or(Error::DescriptorTableAddress)? as u64, 215 // The used ring is {flags: u16; idx: u16; virtq_used_elem [{id: u16, len: u16}; actual_size]}, 216 // i.e. 4 + (4 + 4) * actual_size. 217 used_ring_addr: get_host_address_range(mem, queue.used_ring, 4 + actual_size * 8) 218 .ok_or(Error::UsedAddress)? as u64, 219 // The used ring is {flags: u16; idx: u16; elem [u16; actual_size]}, 220 // i.e. 4 + (2) * actual_size. 221 avail_ring_addr: get_host_address_range(mem, queue.avail_ring, 4 + actual_size * 2) 222 .ok_or(Error::AvailAddress)? as u64, 223 log_addr: None, 224 }; 225 226 vrings_info.push(VringInfo { 227 config_data, 228 used_guest_addr: queue.used_ring.raw_value(), 229 }); 230 231 self.vu 232 .set_vring_addr(queue_index, &config_data) 233 .map_err(Error::VhostUserSetVringAddr)?; 234 self.vu 235 .set_vring_base( 236 queue_index, 237 queue 238 .used_index_from_memory(mem) 239 .map_err(Error::GetAvailableIndex)?, 240 ) 241 .map_err(Error::VhostUserSetVringBase)?; 242 243 if let Some(eventfd) = 244 virtio_interrupt.notifier(&VirtioInterruptType::Queue, Some(&queue)) 245 { 246 self.vu 247 .set_vring_call(queue_index, &eventfd) 248 .map_err(Error::VhostUserSetVringCall)?; 249 } 250 251 self.vu 252 .set_vring_kick(queue_index, &queue_evts[queue_index]) 253 .map_err(Error::VhostUserSetVringKick)?; 254 } 255 256 self.enable_vhost_user_vrings(num_queues, true)?; 257 258 if let Some(slave_req_handler) = slave_req_handler { 259 self.vu 260 .set_slave_request_fd(&slave_req_handler.get_tx_raw_fd()) 261 .map_err(Error::VhostUserSetSlaveRequestFd)?; 262 } 263 264 self.vrings_info = Some(vrings_info); 265 self.ready = true; 266 267 Ok(()) 268 } 269 270 fn enable_vhost_user_vrings(&mut self, num_queues: usize, enable: bool) -> Result<()> { 271 for queue_index in 0..num_queues { 272 self.vu 273 .set_vring_enable(queue_index, enable) 274 .map_err(Error::VhostUserSetVringEnable)?; 275 } 276 277 Ok(()) 278 } 279 280 pub fn reset_vhost_user(&mut self, num_queues: usize) -> Result<()> { 281 self.enable_vhost_user_vrings(num_queues, false)?; 282 283 // Reset the owner. 284 self.vu.reset_owner().map_err(Error::VhostUserResetOwner) 285 } 286 287 #[allow(clippy::too_many_arguments)] 288 pub fn reinitialize_vhost_user<S: VhostUserMasterReqHandler>( 289 &mut self, 290 mem: &GuestMemoryMmap, 291 queues: Vec<Queue>, 292 queue_evts: Vec<EventFd>, 293 virtio_interrupt: &Arc<dyn VirtioInterrupt>, 294 acked_features: u64, 295 acked_protocol_features: u64, 296 slave_req_handler: &Option<MasterReqHandler<S>>, 297 inflight: Option<&mut Inflight>, 298 ) -> Result<()> { 299 self.vu.set_owner().map_err(Error::VhostUserSetOwner)?; 300 self.vu 301 .get_features() 302 .map_err(Error::VhostUserGetFeatures)?; 303 304 if acked_features & VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits() != 0 { 305 if let Some(acked_protocol_features) = 306 VhostUserProtocolFeatures::from_bits(acked_protocol_features) 307 { 308 self.vu 309 .set_protocol_features(acked_protocol_features) 310 .map_err(Error::VhostUserSetProtocolFeatures)?; 311 312 if acked_protocol_features.contains(VhostUserProtocolFeatures::REPLY_ACK) { 313 self.vu.set_hdr_flags(VhostUserHeaderFlag::NEED_REPLY); 314 } 315 } 316 } 317 318 self.update_supports_migration(acked_features, acked_protocol_features); 319 320 self.setup_vhost_user( 321 mem, 322 queues, 323 queue_evts, 324 virtio_interrupt, 325 acked_features, 326 slave_req_handler, 327 inflight, 328 ) 329 } 330 331 pub fn connect_vhost_user( 332 server: bool, 333 socket_path: &str, 334 num_queues: u64, 335 unlink_socket: bool, 336 ) -> Result<Self> { 337 if server { 338 if unlink_socket { 339 std::fs::remove_file(socket_path).map_err(Error::RemoveSocketPath)?; 340 } 341 342 info!("Binding vhost-user listener..."); 343 let listener = UnixListener::bind(socket_path).map_err(Error::BindSocket)?; 344 info!("Waiting for incoming vhost-user connection..."); 345 let (stream, _) = listener.accept().map_err(Error::AcceptConnection)?; 346 347 Ok(VhostUserHandle { 348 vu: Master::from_stream(stream, num_queues), 349 ready: false, 350 supports_migration: false, 351 shm_log: None, 352 acked_features: 0, 353 vrings_info: None, 354 }) 355 } else { 356 let now = Instant::now(); 357 358 // Retry connecting for a full minute 359 let err = loop { 360 let err = match Master::connect(socket_path, num_queues) { 361 Ok(m) => { 362 return Ok(VhostUserHandle { 363 vu: m, 364 ready: false, 365 supports_migration: false, 366 shm_log: None, 367 acked_features: 0, 368 vrings_info: None, 369 }) 370 } 371 Err(e) => e, 372 }; 373 sleep(Duration::from_millis(100)); 374 375 if now.elapsed().as_secs() >= 60 { 376 break err; 377 } 378 }; 379 380 error!( 381 "Failed connecting the backend after trying for 1 minute: {:?}", 382 err 383 ); 384 Err(Error::VhostUserConnect) 385 } 386 } 387 388 pub fn socket_handle(&mut self) -> &mut Master { 389 &mut self.vu 390 } 391 392 pub fn pause_vhost_user(&mut self, num_queues: usize) -> Result<()> { 393 if self.ready { 394 self.enable_vhost_user_vrings(num_queues, false)?; 395 } 396 397 Ok(()) 398 } 399 400 pub fn resume_vhost_user(&mut self, num_queues: usize) -> Result<()> { 401 if self.ready { 402 self.enable_vhost_user_vrings(num_queues, true)?; 403 } 404 405 Ok(()) 406 } 407 408 fn update_supports_migration(&mut self, acked_features: u64, acked_protocol_features: u64) { 409 if (acked_features & u64::from(vhost::vhost_kern::vhost_binding::VHOST_F_LOG_ALL) != 0) 410 && (acked_protocol_features & VhostUserProtocolFeatures::LOG_SHMFD.bits() != 0) 411 { 412 self.supports_migration = true; 413 } 414 } 415 416 fn update_log_base(&mut self, last_ram_addr: u64) -> Result<Option<Arc<MmapRegion>>> { 417 // Create the memfd 418 let fd = memfd_create( 419 &ffi::CString::new("vhost_user_dirty_log").unwrap(), 420 libc::MFD_CLOEXEC | libc::MFD_ALLOW_SEALING, 421 ) 422 .map_err(Error::MemfdCreate)?; 423 424 // Safe because we checked the file descriptor is valid 425 let file = unsafe { File::from_raw_fd(fd) }; 426 // The size of the memory mapping corresponds to the size of a bitmap 427 // covering all guest pages for addresses from 0 to the last physical 428 // address in guest RAM. 429 // A page is always 4kiB from a vhost-user perspective, and each bit is 430 // a page. That's how we can compute mmap_size from the last address. 431 let mmap_size = (last_ram_addr / (VHOST_LOG_PAGE * 8)) + 1; 432 let mmap_handle = file.as_raw_fd(); 433 434 // Set shm_log region size 435 file.set_len(mmap_size).map_err(Error::SetFileSize)?; 436 437 // Set the seals 438 let res = unsafe { 439 libc::fcntl( 440 file.as_raw_fd(), 441 libc::F_ADD_SEALS, 442 libc::F_SEAL_GROW | libc::F_SEAL_SHRINK | libc::F_SEAL_SEAL, 443 ) 444 }; 445 if res < 0 { 446 return Err(Error::SetSeals(std::io::Error::last_os_error())); 447 } 448 449 // Mmap shm_log region 450 let region = MmapRegion::build( 451 Some(FileOffset::new(file, 0)), 452 mmap_size as usize, 453 libc::PROT_READ | libc::PROT_WRITE, 454 libc::MAP_SHARED, 455 ) 456 .map_err(Error::NewMmapRegion)?; 457 458 // Make sure we hold onto the region to prevent the mapping from being 459 // released. 460 let old_region = self.shm_log.take(); 461 self.shm_log = Some(Arc::new(region)); 462 463 // Send the shm_log fd over to the backend 464 let log = VhostUserDirtyLogRegion { 465 mmap_size, 466 mmap_offset: 0, 467 mmap_handle, 468 }; 469 self.vu 470 .set_log_base(0, Some(log)) 471 .map_err(Error::VhostUserSetLogBase)?; 472 473 Ok(old_region) 474 } 475 476 fn set_vring_logging(&mut self, enable: bool) -> Result<()> { 477 if let Some(vrings_info) = &self.vrings_info { 478 for (i, vring_info) in vrings_info.iter().enumerate() { 479 let mut config_data = vring_info.config_data; 480 config_data.flags = if enable { 1 << VHOST_VRING_F_LOG } else { 0 }; 481 config_data.log_addr = if enable { 482 Some(vring_info.used_guest_addr) 483 } else { 484 None 485 }; 486 487 self.vu 488 .set_vring_addr(i, &config_data) 489 .map_err(Error::VhostUserSetVringAddr)?; 490 } 491 } 492 493 Ok(()) 494 } 495 496 pub fn start_dirty_log(&mut self, last_ram_addr: u64) -> Result<()> { 497 if !self.supports_migration { 498 return Err(Error::MigrationNotSupported); 499 } 500 501 // Set the shm log region 502 self.update_log_base(last_ram_addr)?; 503 504 // Enable VHOST_F_LOG_ALL feature 505 let features = self.acked_features | (1 << VHOST_F_LOG_ALL); 506 self.vu 507 .set_features(features) 508 .map_err(Error::VhostUserSetFeatures)?; 509 510 // Enable dirty page logging of used ring for all queues 511 self.set_vring_logging(true) 512 } 513 514 pub fn stop_dirty_log(&mut self) -> Result<()> { 515 if !self.supports_migration { 516 return Err(Error::MigrationNotSupported); 517 } 518 519 // Disable dirty page logging of used ring for all queues 520 self.set_vring_logging(false)?; 521 522 // Disable VHOST_F_LOG_ALL feature 523 self.vu 524 .set_features(self.acked_features) 525 .map_err(Error::VhostUserSetFeatures)?; 526 527 // This is important here since the log region goes out of scope, 528 // invoking the Drop trait, hence unmapping the memory. 529 self.shm_log = None; 530 531 Ok(()) 532 } 533 534 pub fn dirty_log(&mut self, last_ram_addr: u64) -> Result<MemoryRangeTable> { 535 // The log region is updated by creating a new region that is sent to 536 // the backend. This ensures the backend stops logging to the previous 537 // region. The previous region is returned and processed to create the 538 // bitmap representing the dirty pages. 539 if let Some(region) = self.update_log_base(last_ram_addr)? { 540 // Cast the pointer to u64 541 let ptr = region.as_ptr() as *mut u64; 542 // Be careful with the size, as it was based on u8, meaning we must 543 // divide it by 8. 544 let len = region.size() / 8; 545 let bitmap = unsafe { Vec::from_raw_parts(ptr, len, len) }; 546 Ok(MemoryRangeTable::from_bitmap(bitmap, 0)) 547 } else { 548 Err(Error::MissingShmLogRegion) 549 } 550 } 551 } 552 553 fn memfd_create(name: &ffi::CStr, flags: u32) -> std::result::Result<RawFd, std::io::Error> { 554 let res = unsafe { libc::syscall(libc::SYS_memfd_create, name.as_ptr(), flags) }; 555 556 if res < 0 { 557 Err(std::io::Error::last_os_error()) 558 } else { 559 Ok(res as RawFd) 560 } 561 } 562