1 // Copyright © 2019 Intel Corporation 2 // 3 // SPDX-License-Identifier: Apache-2.0 AND BSD-3-Clause 4 5 use super::Error as DeviceError; 6 use super::{ 7 ActivateResult, EpollHelper, EpollHelperError, EpollHelperHandler, VirtioCommon, VirtioDevice, 8 VirtioDeviceType, EPOLL_HELPER_EVENT_LAST, VIRTIO_F_VERSION_1, 9 }; 10 use crate::seccomp_filters::Thread; 11 use crate::thread_helper::spawn_virtio_thread; 12 use crate::GuestMemoryMmap; 13 use crate::{DmaRemapping, VirtioInterrupt, VirtioInterruptType}; 14 use anyhow::anyhow; 15 use seccompiler::SeccompAction; 16 use std::collections::BTreeMap; 17 use std::io; 18 use std::mem::size_of; 19 use std::ops::Bound::Included; 20 use std::os::unix::io::AsRawFd; 21 use std::result; 22 use std::sync::atomic::{AtomicBool, Ordering}; 23 use std::sync::{Arc, Barrier, Mutex, RwLock}; 24 use thiserror::Error; 25 use versionize::{VersionMap, Versionize, VersionizeResult}; 26 use versionize_derive::Versionize; 27 use virtio_queue::{DescriptorChain, Queue, QueueT}; 28 use vm_device::dma_mapping::ExternalDmaMapping; 29 use vm_memory::{ 30 Address, ByteValued, Bytes, GuestAddress, GuestAddressSpace, GuestMemoryAtomic, 31 GuestMemoryError, GuestMemoryLoadGuard, 32 }; 33 use vm_migration::VersionMapped; 34 use vm_migration::{Migratable, MigratableError, Pausable, Snapshot, Snapshottable, Transportable}; 35 use vm_virtio::AccessPlatform; 36 use vmm_sys_util::eventfd::EventFd; 37 38 /// Queues sizes 39 const QUEUE_SIZE: u16 = 256; 40 const NUM_QUEUES: usize = 2; 41 const QUEUE_SIZES: &[u16] = &[QUEUE_SIZE; NUM_QUEUES]; 42 43 /// New descriptors are pending on the request queue. 44 /// "requestq" is meant to be used anytime an action is required to be 45 /// performed on behalf of the guest driver. 46 const REQUEST_Q_EVENT: u16 = EPOLL_HELPER_EVENT_LAST + 1; 47 /// New descriptors are pending on the event queue. 48 /// "eventq" lets the device report any fault or other asynchronous event to 49 /// the guest driver. 50 const _EVENT_Q_EVENT: u16 = EPOLL_HELPER_EVENT_LAST + 2; 51 52 /// PROBE properties size. 53 /// This is the minimal size to provide at least one RESV_MEM property. 54 /// Because virtio-iommu expects one MSI reserved region, we must provide it, 55 /// otherwise the driver in the guest will define a predefined one between 56 /// 0x8000000 and 0x80FFFFF, which is only relevant for ARM architecture, but 57 /// will conflict with x86. 58 const PROBE_PROP_SIZE: u32 = 59 (size_of::<VirtioIommuProbeProperty>() + size_of::<VirtioIommuProbeResvMem>()) as u32; 60 61 /// Virtio IOMMU features 62 #[allow(unused)] 63 const VIRTIO_IOMMU_F_INPUT_RANGE: u32 = 0; 64 #[allow(unused)] 65 const VIRTIO_IOMMU_F_DOMAIN_RANGE: u32 = 1; 66 #[allow(unused)] 67 const VIRTIO_IOMMU_F_MAP_UNMAP: u32 = 2; 68 #[allow(unused)] 69 const VIRTIO_IOMMU_F_BYPASS: u32 = 3; 70 const VIRTIO_IOMMU_F_PROBE: u32 = 4; 71 #[allow(unused)] 72 const VIRTIO_IOMMU_F_MMIO: u32 = 5; 73 const VIRTIO_IOMMU_F_BYPASS_CONFIG: u32 = 6; 74 75 // Support 2MiB and 4KiB page sizes. 76 const VIRTIO_IOMMU_PAGE_SIZE_MASK: u64 = (2 << 20) | (4 << 10); 77 78 #[derive(Copy, Clone, Debug, Default)] 79 #[repr(packed)] 80 #[allow(dead_code)] 81 struct VirtioIommuRange32 { 82 start: u32, 83 end: u32, 84 } 85 86 #[derive(Copy, Clone, Debug, Default)] 87 #[repr(packed)] 88 #[allow(dead_code)] 89 struct VirtioIommuRange64 { 90 start: u64, 91 end: u64, 92 } 93 94 #[derive(Copy, Clone, Debug, Default)] 95 #[repr(packed)] 96 #[allow(dead_code)] 97 struct VirtioIommuConfig { 98 page_size_mask: u64, 99 input_range: VirtioIommuRange64, 100 domain_range: VirtioIommuRange32, 101 probe_size: u32, 102 bypass: u8, 103 _reserved: [u8; 7], 104 } 105 106 /// Virtio IOMMU request type 107 const VIRTIO_IOMMU_T_ATTACH: u8 = 1; 108 const VIRTIO_IOMMU_T_DETACH: u8 = 2; 109 const VIRTIO_IOMMU_T_MAP: u8 = 3; 110 const VIRTIO_IOMMU_T_UNMAP: u8 = 4; 111 const VIRTIO_IOMMU_T_PROBE: u8 = 5; 112 113 #[derive(Copy, Clone, Debug, Default)] 114 #[repr(packed)] 115 struct VirtioIommuReqHead { 116 type_: u8, 117 _reserved: [u8; 3], 118 } 119 120 /// Virtio IOMMU request status 121 const VIRTIO_IOMMU_S_OK: u8 = 0; 122 #[allow(unused)] 123 const VIRTIO_IOMMU_S_IOERR: u8 = 1; 124 #[allow(unused)] 125 const VIRTIO_IOMMU_S_UNSUPP: u8 = 2; 126 #[allow(unused)] 127 const VIRTIO_IOMMU_S_DEVERR: u8 = 3; 128 #[allow(unused)] 129 const VIRTIO_IOMMU_S_INVAL: u8 = 4; 130 #[allow(unused)] 131 const VIRTIO_IOMMU_S_RANGE: u8 = 5; 132 #[allow(unused)] 133 const VIRTIO_IOMMU_S_NOENT: u8 = 6; 134 #[allow(unused)] 135 const VIRTIO_IOMMU_S_FAULT: u8 = 7; 136 #[allow(unused)] 137 const VIRTIO_IOMMU_S_NOMEM: u8 = 8; 138 139 #[derive(Copy, Clone, Debug, Default)] 140 #[repr(packed)] 141 #[allow(dead_code)] 142 struct VirtioIommuReqTail { 143 status: u8, 144 _reserved: [u8; 3], 145 } 146 147 /// ATTACH request 148 #[derive(Copy, Clone, Debug, Default)] 149 #[repr(packed)] 150 struct VirtioIommuReqAttach { 151 domain: u32, 152 endpoint: u32, 153 flags: u32, 154 _reserved: [u8; 4], 155 } 156 157 const VIRTIO_IOMMU_ATTACH_F_BYPASS: u32 = 1; 158 159 /// DETACH request 160 #[derive(Copy, Clone, Debug, Default)] 161 #[repr(packed)] 162 struct VirtioIommuReqDetach { 163 domain: u32, 164 endpoint: u32, 165 _reserved: [u8; 8], 166 } 167 168 /// Virtio IOMMU request MAP flags 169 #[allow(unused)] 170 const VIRTIO_IOMMU_MAP_F_READ: u32 = 1; 171 #[allow(unused)] 172 const VIRTIO_IOMMU_MAP_F_WRITE: u32 = 1 << 1; 173 #[allow(unused)] 174 const VIRTIO_IOMMU_MAP_F_MMIO: u32 = 1 << 2; 175 #[allow(unused)] 176 const VIRTIO_IOMMU_MAP_F_MASK: u32 = 177 VIRTIO_IOMMU_MAP_F_READ | VIRTIO_IOMMU_MAP_F_WRITE | VIRTIO_IOMMU_MAP_F_MMIO; 178 179 /// MAP request 180 #[derive(Copy, Clone, Debug, Default)] 181 #[repr(packed)] 182 struct VirtioIommuReqMap { 183 domain: u32, 184 virt_start: u64, 185 virt_end: u64, 186 phys_start: u64, 187 _flags: u32, 188 } 189 190 /// UNMAP request 191 #[derive(Copy, Clone, Debug, Default)] 192 #[repr(packed)] 193 struct VirtioIommuReqUnmap { 194 domain: u32, 195 virt_start: u64, 196 virt_end: u64, 197 _reserved: [u8; 4], 198 } 199 200 /// Virtio IOMMU request PROBE types 201 #[allow(unused)] 202 const VIRTIO_IOMMU_PROBE_T_NONE: u16 = 0; 203 const VIRTIO_IOMMU_PROBE_T_RESV_MEM: u16 = 1; 204 #[allow(unused)] 205 const VIRTIO_IOMMU_PROBE_T_MASK: u16 = 0xfff; 206 207 /// PROBE request 208 #[derive(Copy, Clone, Debug, Default)] 209 #[repr(packed)] 210 #[allow(dead_code)] 211 struct VirtioIommuReqProbe { 212 endpoint: u32, 213 _reserved: [u64; 8], 214 } 215 216 #[derive(Copy, Clone, Debug, Default)] 217 #[repr(packed)] 218 #[allow(dead_code)] 219 struct VirtioIommuProbeProperty { 220 type_: u16, 221 length: u16, 222 } 223 224 /// Virtio IOMMU request PROBE property RESV_MEM subtypes 225 #[allow(unused)] 226 const VIRTIO_IOMMU_RESV_MEM_T_RESERVED: u8 = 0; 227 const VIRTIO_IOMMU_RESV_MEM_T_MSI: u8 = 1; 228 229 #[derive(Copy, Clone, Debug, Default)] 230 #[repr(packed)] 231 #[allow(dead_code)] 232 struct VirtioIommuProbeResvMem { 233 subtype: u8, 234 _reserved: [u8; 3], 235 start: u64, 236 end: u64, 237 } 238 239 /// Virtio IOMMU fault flags 240 #[allow(unused)] 241 const VIRTIO_IOMMU_FAULT_F_READ: u32 = 1; 242 #[allow(unused)] 243 const VIRTIO_IOMMU_FAULT_F_WRITE: u32 = 1 << 1; 244 #[allow(unused)] 245 const VIRTIO_IOMMU_FAULT_F_EXEC: u32 = 1 << 2; 246 #[allow(unused)] 247 const VIRTIO_IOMMU_FAULT_F_ADDRESS: u32 = 1 << 8; 248 249 /// Virtio IOMMU fault reasons 250 #[allow(unused)] 251 const VIRTIO_IOMMU_FAULT_R_UNKNOWN: u32 = 0; 252 #[allow(unused)] 253 const VIRTIO_IOMMU_FAULT_R_DOMAIN: u32 = 1; 254 #[allow(unused)] 255 const VIRTIO_IOMMU_FAULT_R_MAPPING: u32 = 2; 256 257 /// Fault reporting through eventq 258 #[allow(unused)] 259 #[derive(Copy, Clone, Debug, Default)] 260 #[repr(packed)] 261 struct VirtioIommuFault { 262 reason: u8, 263 reserved: [u8; 3], 264 flags: u32, 265 endpoint: u32, 266 reserved2: [u8; 4], 267 address: u64, 268 } 269 270 // SAFETY: data structure only contain integers and have no implicit padding 271 unsafe impl ByteValued for VirtioIommuRange32 {} 272 // SAFETY: data structure only contain integers and have no implicit padding 273 unsafe impl ByteValued for VirtioIommuRange64 {} 274 // SAFETY: data structure only contain integers and have no implicit padding 275 unsafe impl ByteValued for VirtioIommuConfig {} 276 // SAFETY: data structure only contain integers and have no implicit padding 277 unsafe impl ByteValued for VirtioIommuReqHead {} 278 // SAFETY: data structure only contain integers and have no implicit padding 279 unsafe impl ByteValued for VirtioIommuReqTail {} 280 // SAFETY: data structure only contain integers and have no implicit padding 281 unsafe impl ByteValued for VirtioIommuReqAttach {} 282 // SAFETY: data structure only contain integers and have no implicit padding 283 unsafe impl ByteValued for VirtioIommuReqDetach {} 284 // SAFETY: data structure only contain integers and have no implicit padding 285 unsafe impl ByteValued for VirtioIommuReqMap {} 286 // SAFETY: data structure only contain integers and have no implicit padding 287 unsafe impl ByteValued for VirtioIommuReqUnmap {} 288 // SAFETY: data structure only contain integers and have no implicit padding 289 unsafe impl ByteValued for VirtioIommuReqProbe {} 290 // SAFETY: data structure only contain integers and have no implicit padding 291 unsafe impl ByteValued for VirtioIommuProbeProperty {} 292 // SAFETY: data structure only contain integers and have no implicit padding 293 unsafe impl ByteValued for VirtioIommuProbeResvMem {} 294 // SAFETY: data structure only contain integers and have no implicit padding 295 unsafe impl ByteValued for VirtioIommuFault {} 296 297 #[derive(Error, Debug)] 298 enum Error { 299 #[error("Guest gave us bad memory addresses: {0}")] 300 GuestMemory(GuestMemoryError), 301 #[error("Guest gave us a write only descriptor that protocol says to read from")] 302 UnexpectedWriteOnlyDescriptor, 303 #[error("Guest gave us a read only descriptor that protocol says to write to")] 304 UnexpectedReadOnlyDescriptor, 305 #[error("Guest gave us too few descriptors in a descriptor chain")] 306 DescriptorChainTooShort, 307 #[error("Guest gave us a buffer that was too short to use")] 308 BufferLengthTooSmall, 309 #[error("Guest sent us invalid request")] 310 InvalidRequest, 311 #[error("Guest sent us invalid ATTACH request")] 312 InvalidAttachRequest, 313 #[error("Guest sent us invalid DETACH request")] 314 InvalidDetachRequest, 315 #[error("Guest sent us invalid MAP request")] 316 InvalidMapRequest, 317 #[error("Invalid to map because the domain is in bypass mode")] 318 InvalidMapRequestBypassDomain, 319 #[error("Invalid to map because the domain is missing")] 320 InvalidMapRequestMissingDomain, 321 #[error("Guest sent us invalid UNMAP request")] 322 InvalidUnmapRequest, 323 #[error("Invalid to unmap because the domain is in bypass mode")] 324 InvalidUnmapRequestBypassDomain, 325 #[error("Invalid to unmap because the domain is missing")] 326 InvalidUnmapRequestMissingDomain, 327 #[error("Guest sent us invalid PROBE request")] 328 InvalidProbeRequest, 329 #[error("Failed to performing external mapping: {0}")] 330 ExternalMapping(io::Error), 331 #[error("Failed to performing external unmapping: {0}")] 332 ExternalUnmapping(io::Error), 333 #[error("Failed adding used index: {0}")] 334 QueueAddUsed(virtio_queue::Error), 335 } 336 337 struct Request {} 338 339 impl Request { 340 // Parse the available vring buffer. Based on the hashmap table of external 341 // mappings required from various devices such as VFIO or vhost-user ones, 342 // this function might update the hashmap table of external mappings per 343 // domain. 344 // Basically, the VMM knows about the device_id <=> mapping relationship 345 // before running the VM, but at runtime, a new domain <=> mapping hashmap 346 // is created based on the information provided from the guest driver for 347 // virtio-iommu (giving the link device_id <=> domain). 348 fn parse( 349 desc_chain: &mut DescriptorChain<GuestMemoryLoadGuard<GuestMemoryMmap>>, 350 mapping: &Arc<IommuMapping>, 351 ext_mapping: &BTreeMap<u32, Arc<dyn ExternalDmaMapping>>, 352 msi_iova_space: (u64, u64), 353 ) -> result::Result<usize, Error> { 354 let desc = desc_chain 355 .next() 356 .ok_or(Error::DescriptorChainTooShort) 357 .map_err(|e| { 358 error!("Missing head descriptor"); 359 e 360 })?; 361 362 // The descriptor contains the request type which MUST be readable. 363 if desc.is_write_only() { 364 return Err(Error::UnexpectedWriteOnlyDescriptor); 365 } 366 367 if (desc.len() as usize) < size_of::<VirtioIommuReqHead>() { 368 return Err(Error::InvalidRequest); 369 } 370 371 let req_head: VirtioIommuReqHead = desc_chain 372 .memory() 373 .read_obj(desc.addr()) 374 .map_err(Error::GuestMemory)?; 375 let req_offset = size_of::<VirtioIommuReqHead>(); 376 let desc_size_left = (desc.len() as usize) - req_offset; 377 let req_addr = if let Some(addr) = desc.addr().checked_add(req_offset as u64) { 378 addr 379 } else { 380 return Err(Error::InvalidRequest); 381 }; 382 383 let (msi_iova_start, msi_iova_end) = msi_iova_space; 384 385 // Create the reply 386 let mut reply: Vec<u8> = Vec::new(); 387 let mut status = VIRTIO_IOMMU_S_OK; 388 let mut hdr_len = 0; 389 390 let result = (|| { 391 match req_head.type_ { 392 VIRTIO_IOMMU_T_ATTACH => { 393 if desc_size_left != size_of::<VirtioIommuReqAttach>() { 394 status = VIRTIO_IOMMU_S_INVAL; 395 return Err(Error::InvalidAttachRequest); 396 } 397 398 let req: VirtioIommuReqAttach = desc_chain 399 .memory() 400 .read_obj(req_addr as GuestAddress) 401 .map_err(Error::GuestMemory)?; 402 debug!("Attach request {:?}", req); 403 404 // Copy the value to use it as a proper reference. 405 let domain_id = req.domain; 406 let endpoint = req.endpoint; 407 let bypass = 408 (req.flags & VIRTIO_IOMMU_ATTACH_F_BYPASS) == VIRTIO_IOMMU_ATTACH_F_BYPASS; 409 410 let mut old_domain_id = domain_id; 411 if let Some(&id) = mapping.endpoints.read().unwrap().get(&endpoint) { 412 old_domain_id = id; 413 } 414 415 if old_domain_id != domain_id { 416 detach_endpoint_from_domain(endpoint, old_domain_id, mapping, ext_mapping)?; 417 } 418 419 // Add endpoint associated with specific domain 420 mapping 421 .endpoints 422 .write() 423 .unwrap() 424 .insert(endpoint, domain_id); 425 426 // If any other mappings exist in the domain for other containers, 427 // make sure to issue these mappings for the new endpoint/container 428 if let Some(domain_mappings) = &mapping.domains.read().unwrap().get(&domain_id) 429 { 430 if let Some(ext_map) = ext_mapping.get(&endpoint) { 431 for (virt_start, addr_map) in &domain_mappings.mappings { 432 ext_map 433 .map(*virt_start, addr_map.gpa, addr_map.size) 434 .map_err(Error::ExternalUnmapping)?; 435 } 436 } 437 } 438 439 // Add new domain with no mapping if the entry didn't exist yet 440 let mut domains = mapping.domains.write().unwrap(); 441 let domain = Domain { 442 mappings: BTreeMap::new(), 443 bypass, 444 }; 445 domains.entry(domain_id).or_insert_with(|| domain); 446 } 447 VIRTIO_IOMMU_T_DETACH => { 448 if desc_size_left != size_of::<VirtioIommuReqDetach>() { 449 status = VIRTIO_IOMMU_S_INVAL; 450 return Err(Error::InvalidDetachRequest); 451 } 452 453 let req: VirtioIommuReqDetach = desc_chain 454 .memory() 455 .read_obj(req_addr as GuestAddress) 456 .map_err(Error::GuestMemory)?; 457 debug!("Detach request {:?}", req); 458 459 // Copy the value to use it as a proper reference. 460 let domain_id = req.domain; 461 let endpoint = req.endpoint; 462 463 // Remove endpoint associated with specific domain 464 detach_endpoint_from_domain(endpoint, domain_id, mapping, ext_mapping)?; 465 } 466 VIRTIO_IOMMU_T_MAP => { 467 if desc_size_left != size_of::<VirtioIommuReqMap>() { 468 status = VIRTIO_IOMMU_S_INVAL; 469 return Err(Error::InvalidMapRequest); 470 } 471 472 let req: VirtioIommuReqMap = desc_chain 473 .memory() 474 .read_obj(req_addr as GuestAddress) 475 .map_err(Error::GuestMemory)?; 476 debug!("Map request {:?}", req); 477 478 // Copy the value to use it as a proper reference. 479 let domain_id = req.domain; 480 481 if let Some(domain) = mapping.domains.read().unwrap().get(&domain_id) { 482 if domain.bypass { 483 status = VIRTIO_IOMMU_S_INVAL; 484 return Err(Error::InvalidMapRequestBypassDomain); 485 } 486 } else { 487 status = VIRTIO_IOMMU_S_INVAL; 488 return Err(Error::InvalidMapRequestMissingDomain); 489 } 490 491 // Find the list of endpoints attached to the given domain. 492 let endpoints: Vec<u32> = mapping 493 .endpoints 494 .write() 495 .unwrap() 496 .iter() 497 .filter(|(_, &d)| d == domain_id) 498 .map(|(&e, _)| e) 499 .collect(); 500 501 // For viommu all endpoints receive their own VFIO container, as a result 502 // Each endpoint within the domain needs to be separately mapped, as the 503 // mapping is done on a per-container level, not a per-domain level 504 for endpoint in endpoints { 505 if let Some(ext_map) = ext_mapping.get(&endpoint) { 506 let size = req.virt_end - req.virt_start + 1; 507 ext_map 508 .map(req.virt_start, req.phys_start, size) 509 .map_err(Error::ExternalMapping)?; 510 } 511 } 512 513 // Add new mapping associated with the domain 514 mapping 515 .domains 516 .write() 517 .unwrap() 518 .get_mut(&domain_id) 519 .unwrap() 520 .mappings 521 .insert( 522 req.virt_start, 523 Mapping { 524 gpa: req.phys_start, 525 size: req.virt_end - req.virt_start + 1, 526 }, 527 ); 528 } 529 VIRTIO_IOMMU_T_UNMAP => { 530 if desc_size_left != size_of::<VirtioIommuReqUnmap>() { 531 status = VIRTIO_IOMMU_S_INVAL; 532 return Err(Error::InvalidUnmapRequest); 533 } 534 535 let req: VirtioIommuReqUnmap = desc_chain 536 .memory() 537 .read_obj(req_addr as GuestAddress) 538 .map_err(Error::GuestMemory)?; 539 debug!("Unmap request {:?}", req); 540 541 // Copy the value to use it as a proper reference. 542 let domain_id = req.domain; 543 let virt_start = req.virt_start; 544 545 if let Some(domain) = mapping.domains.read().unwrap().get(&domain_id) { 546 if domain.bypass { 547 status = VIRTIO_IOMMU_S_INVAL; 548 return Err(Error::InvalidUnmapRequestBypassDomain); 549 } 550 } else { 551 status = VIRTIO_IOMMU_S_INVAL; 552 return Err(Error::InvalidUnmapRequestMissingDomain); 553 } 554 555 // Find the list of endpoints attached to the given domain. 556 let endpoints: Vec<u32> = mapping 557 .endpoints 558 .write() 559 .unwrap() 560 .iter() 561 .filter(|(_, &d)| d == domain_id) 562 .map(|(&e, _)| e) 563 .collect(); 564 565 // Trigger external unmapping if necessary. 566 for endpoint in endpoints { 567 if let Some(ext_map) = ext_mapping.get(&endpoint) { 568 let size = req.virt_end - virt_start + 1; 569 ext_map 570 .unmap(virt_start, size) 571 .map_err(Error::ExternalUnmapping)?; 572 } 573 } 574 575 // Remove all mappings associated with the domain within the requested range 576 mapping 577 .domains 578 .write() 579 .unwrap() 580 .get_mut(&domain_id) 581 .unwrap() 582 .mappings 583 .retain(|&x, _| (x < req.virt_start || x > req.virt_end)); 584 } 585 VIRTIO_IOMMU_T_PROBE => { 586 if desc_size_left != size_of::<VirtioIommuReqProbe>() { 587 status = VIRTIO_IOMMU_S_INVAL; 588 return Err(Error::InvalidProbeRequest); 589 } 590 591 let req: VirtioIommuReqProbe = desc_chain 592 .memory() 593 .read_obj(req_addr as GuestAddress) 594 .map_err(Error::GuestMemory)?; 595 debug!("Probe request {:?}", req); 596 597 let probe_prop = VirtioIommuProbeProperty { 598 type_: VIRTIO_IOMMU_PROBE_T_RESV_MEM, 599 length: size_of::<VirtioIommuProbeResvMem>() as u16, 600 }; 601 reply.extend_from_slice(probe_prop.as_slice()); 602 603 let resv_mem = VirtioIommuProbeResvMem { 604 subtype: VIRTIO_IOMMU_RESV_MEM_T_MSI, 605 start: msi_iova_start, 606 end: msi_iova_end, 607 ..Default::default() 608 }; 609 reply.extend_from_slice(resv_mem.as_slice()); 610 611 hdr_len = PROBE_PROP_SIZE; 612 } 613 _ => { 614 status = VIRTIO_IOMMU_S_INVAL; 615 return Err(Error::InvalidRequest); 616 } 617 } 618 Ok(()) 619 })(); 620 621 let status_desc = desc_chain.next().ok_or(Error::DescriptorChainTooShort)?; 622 623 // The status MUST always be writable 624 if !status_desc.is_write_only() { 625 return Err(Error::UnexpectedReadOnlyDescriptor); 626 } 627 628 if status_desc.len() < hdr_len + size_of::<VirtioIommuReqTail>() as u32 { 629 return Err(Error::BufferLengthTooSmall); 630 } 631 632 let tail = VirtioIommuReqTail { 633 status, 634 ..Default::default() 635 }; 636 reply.extend_from_slice(tail.as_slice()); 637 638 // Make sure we return the result of the request to the guest before 639 // we return a potential error internally. 640 desc_chain 641 .memory() 642 .write_slice(reply.as_slice(), status_desc.addr()) 643 .map_err(Error::GuestMemory)?; 644 645 // Return the error if the result was not Ok(). 646 result?; 647 648 Ok((hdr_len as usize) + size_of::<VirtioIommuReqTail>()) 649 } 650 } 651 652 fn detach_endpoint_from_domain( 653 endpoint: u32, 654 domain_id: u32, 655 mapping: &Arc<IommuMapping>, 656 ext_mapping: &BTreeMap<u32, Arc<dyn ExternalDmaMapping>>, 657 ) -> result::Result<(), Error> { 658 // Remove endpoint associated with specific domain 659 mapping.endpoints.write().unwrap().remove(&endpoint); 660 661 // Trigger external unmapping for the endpoint if necessary. 662 if let Some(domain_mappings) = &mapping.domains.read().unwrap().get(&domain_id) { 663 if let Some(ext_map) = ext_mapping.get(&endpoint) { 664 for (virt_start, addr_map) in &domain_mappings.mappings { 665 ext_map 666 .unmap(*virt_start, addr_map.size) 667 .map_err(Error::ExternalUnmapping)?; 668 } 669 } 670 } 671 672 if mapping 673 .endpoints 674 .write() 675 .unwrap() 676 .iter() 677 .filter(|(_, &d)| d == domain_id) 678 .count() 679 == 0 680 { 681 mapping.domains.write().unwrap().remove(&domain_id); 682 } 683 684 Ok(()) 685 } 686 687 struct IommuEpollHandler { 688 mem: GuestMemoryAtomic<GuestMemoryMmap>, 689 request_queue: Queue, 690 _event_queue: Queue, 691 interrupt_cb: Arc<dyn VirtioInterrupt>, 692 request_queue_evt: EventFd, 693 _event_queue_evt: EventFd, 694 kill_evt: EventFd, 695 pause_evt: EventFd, 696 mapping: Arc<IommuMapping>, 697 ext_mapping: Arc<Mutex<BTreeMap<u32, Arc<dyn ExternalDmaMapping>>>>, 698 msi_iova_space: (u64, u64), 699 } 700 701 impl IommuEpollHandler { 702 fn request_queue(&mut self) -> Result<bool, Error> { 703 let mut used_descs = false; 704 while let Some(mut desc_chain) = self.request_queue.pop_descriptor_chain(self.mem.memory()) 705 { 706 let len = Request::parse( 707 &mut desc_chain, 708 &self.mapping, 709 &self.ext_mapping.lock().unwrap(), 710 self.msi_iova_space, 711 )?; 712 713 self.request_queue 714 .add_used(desc_chain.memory(), desc_chain.head_index(), len as u32) 715 .map_err(Error::QueueAddUsed)?; 716 717 used_descs = true; 718 } 719 720 Ok(used_descs) 721 } 722 723 fn signal_used_queue(&self, queue_index: u16) -> result::Result<(), DeviceError> { 724 self.interrupt_cb 725 .trigger(VirtioInterruptType::Queue(queue_index)) 726 .map_err(|e| { 727 error!("Failed to signal used queue: {:?}", e); 728 DeviceError::FailedSignalingUsedQueue(e) 729 }) 730 } 731 732 fn run( 733 &mut self, 734 paused: Arc<AtomicBool>, 735 paused_sync: Arc<Barrier>, 736 ) -> result::Result<(), EpollHelperError> { 737 let mut helper = EpollHelper::new(&self.kill_evt, &self.pause_evt)?; 738 helper.add_event(self.request_queue_evt.as_raw_fd(), REQUEST_Q_EVENT)?; 739 helper.run(paused, paused_sync, self)?; 740 741 Ok(()) 742 } 743 } 744 745 impl EpollHelperHandler for IommuEpollHandler { 746 fn handle_event( 747 &mut self, 748 _helper: &mut EpollHelper, 749 event: &epoll::Event, 750 ) -> result::Result<(), EpollHelperError> { 751 let ev_type = event.data as u16; 752 match ev_type { 753 REQUEST_Q_EVENT => { 754 self.request_queue_evt.read().map_err(|e| { 755 EpollHelperError::HandleEvent(anyhow!("Failed to get queue event: {:?}", e)) 756 })?; 757 758 let needs_notification = self.request_queue().map_err(|e| { 759 EpollHelperError::HandleEvent(anyhow!( 760 "Failed to process request queue : {:?}", 761 e 762 )) 763 })?; 764 if needs_notification { 765 self.signal_used_queue(0).map_err(|e| { 766 EpollHelperError::HandleEvent(anyhow!( 767 "Failed to signal used queue: {:?}", 768 e 769 )) 770 })?; 771 } 772 } 773 _ => { 774 return Err(EpollHelperError::HandleEvent(anyhow!( 775 "Unexpected event: {}", 776 ev_type 777 ))); 778 } 779 } 780 Ok(()) 781 } 782 } 783 784 #[derive(Clone, Copy, Debug, Versionize)] 785 struct Mapping { 786 gpa: u64, 787 size: u64, 788 } 789 790 #[derive(Clone, Debug)] 791 struct Domain { 792 mappings: BTreeMap<u64, Mapping>, 793 bypass: bool, 794 } 795 796 #[derive(Debug)] 797 pub struct IommuMapping { 798 // Domain related to an endpoint. 799 endpoints: Arc<RwLock<BTreeMap<u32, u32>>>, 800 // Information related to each domain. 801 domains: Arc<RwLock<BTreeMap<u32, Domain>>>, 802 // Global flag indicating if endpoints that are not attached to any domain 803 // are in bypass mode. 804 bypass: AtomicBool, 805 } 806 807 impl DmaRemapping for IommuMapping { 808 fn translate_gva(&self, id: u32, addr: u64) -> std::result::Result<u64, std::io::Error> { 809 debug!("Translate GVA addr 0x{:x}", addr); 810 if let Some(domain_id) = self.endpoints.read().unwrap().get(&id) { 811 if let Some(domain) = self.domains.read().unwrap().get(domain_id) { 812 // Directly return identity mapping in case the domain is in 813 // bypass mode. 814 if domain.bypass { 815 return Ok(addr); 816 } 817 818 let range_start = if VIRTIO_IOMMU_PAGE_SIZE_MASK > addr { 819 0 820 } else { 821 addr - VIRTIO_IOMMU_PAGE_SIZE_MASK 822 }; 823 for (&key, &value) in domain 824 .mappings 825 .range((Included(&range_start), Included(&addr))) 826 { 827 if addr >= key && addr < key + value.size { 828 let new_addr = addr - key + value.gpa; 829 debug!("Into GPA addr 0x{:x}", new_addr); 830 return Ok(new_addr); 831 } 832 } 833 } 834 } else if self.bypass.load(Ordering::Acquire) { 835 return Ok(addr); 836 } 837 838 Err(io::Error::new( 839 io::ErrorKind::Other, 840 format!("failed to translate GVA addr 0x{addr:x}"), 841 )) 842 } 843 844 fn translate_gpa(&self, id: u32, addr: u64) -> std::result::Result<u64, std::io::Error> { 845 debug!("Translate GPA addr 0x{:x}", addr); 846 if let Some(domain_id) = self.endpoints.read().unwrap().get(&id) { 847 if let Some(domain) = self.domains.read().unwrap().get(domain_id) { 848 // Directly return identity mapping in case the domain is in 849 // bypass mode. 850 if domain.bypass { 851 return Ok(addr); 852 } 853 854 for (&key, &value) in domain.mappings.iter() { 855 if addr >= value.gpa && addr < value.gpa + value.size { 856 let new_addr = addr - value.gpa + key; 857 debug!("Into GVA addr 0x{:x}", new_addr); 858 return Ok(new_addr); 859 } 860 } 861 } 862 } else if self.bypass.load(Ordering::Acquire) { 863 return Ok(addr); 864 } 865 866 Err(io::Error::new( 867 io::ErrorKind::Other, 868 format!("failed to translate GPA addr 0x{addr:x}"), 869 )) 870 } 871 } 872 873 #[derive(Debug)] 874 pub struct AccessPlatformMapping { 875 id: u32, 876 mapping: Arc<IommuMapping>, 877 } 878 879 impl AccessPlatformMapping { 880 pub fn new(id: u32, mapping: Arc<IommuMapping>) -> Self { 881 AccessPlatformMapping { id, mapping } 882 } 883 } 884 885 impl AccessPlatform for AccessPlatformMapping { 886 fn translate_gva(&self, base: u64, _size: u64) -> std::result::Result<u64, std::io::Error> { 887 self.mapping.translate_gva(self.id, base) 888 } 889 fn translate_gpa(&self, base: u64, _size: u64) -> std::result::Result<u64, std::io::Error> { 890 self.mapping.translate_gpa(self.id, base) 891 } 892 } 893 894 pub struct Iommu { 895 common: VirtioCommon, 896 id: String, 897 config: VirtioIommuConfig, 898 mapping: Arc<IommuMapping>, 899 ext_mapping: Arc<Mutex<BTreeMap<u32, Arc<dyn ExternalDmaMapping>>>>, 900 seccomp_action: SeccompAction, 901 exit_evt: EventFd, 902 msi_iova_space: (u64, u64), 903 } 904 905 type EndpointsState = Vec<(u32, u32)>; 906 type DomainsState = Vec<(u32, (Vec<(u64, Mapping)>, bool))>; 907 908 #[derive(Versionize)] 909 pub struct IommuState { 910 avail_features: u64, 911 acked_features: u64, 912 endpoints: EndpointsState, 913 domains: DomainsState, 914 } 915 916 impl VersionMapped for IommuState {} 917 918 impl Iommu { 919 pub fn new( 920 id: String, 921 seccomp_action: SeccompAction, 922 exit_evt: EventFd, 923 msi_iova_space: (u64, u64), 924 state: Option<IommuState>, 925 ) -> io::Result<(Self, Arc<IommuMapping>)> { 926 let (avail_features, acked_features, endpoints, domains, paused) = 927 if let Some(state) = state { 928 info!("Restoring virtio-iommu {}", id); 929 ( 930 state.avail_features, 931 state.acked_features, 932 state.endpoints.into_iter().collect(), 933 state 934 .domains 935 .into_iter() 936 .map(|(k, v)| { 937 ( 938 k, 939 Domain { 940 mappings: v.0.into_iter().collect(), 941 bypass: v.1, 942 }, 943 ) 944 }) 945 .collect(), 946 true, 947 ) 948 } else { 949 let avail_features = 1u64 << VIRTIO_F_VERSION_1 950 | 1u64 << VIRTIO_IOMMU_F_MAP_UNMAP 951 | 1u64 << VIRTIO_IOMMU_F_PROBE 952 | 1u64 << VIRTIO_IOMMU_F_BYPASS_CONFIG; 953 954 (avail_features, 0, BTreeMap::new(), BTreeMap::new(), false) 955 }; 956 957 let config = VirtioIommuConfig { 958 page_size_mask: VIRTIO_IOMMU_PAGE_SIZE_MASK, 959 probe_size: PROBE_PROP_SIZE, 960 ..Default::default() 961 }; 962 963 let mapping = Arc::new(IommuMapping { 964 endpoints: Arc::new(RwLock::new(endpoints)), 965 domains: Arc::new(RwLock::new(domains)), 966 bypass: AtomicBool::new(true), 967 }); 968 969 Ok(( 970 Iommu { 971 id, 972 common: VirtioCommon { 973 device_type: VirtioDeviceType::Iommu as u32, 974 queue_sizes: QUEUE_SIZES.to_vec(), 975 avail_features, 976 acked_features, 977 paused_sync: Some(Arc::new(Barrier::new(2))), 978 min_queues: NUM_QUEUES as u16, 979 paused: Arc::new(AtomicBool::new(paused)), 980 ..Default::default() 981 }, 982 config, 983 mapping: mapping.clone(), 984 ext_mapping: Arc::new(Mutex::new(BTreeMap::new())), 985 seccomp_action, 986 exit_evt, 987 msi_iova_space, 988 }, 989 mapping, 990 )) 991 } 992 993 fn state(&self) -> IommuState { 994 IommuState { 995 avail_features: self.common.avail_features, 996 acked_features: self.common.acked_features, 997 endpoints: self 998 .mapping 999 .endpoints 1000 .read() 1001 .unwrap() 1002 .clone() 1003 .into_iter() 1004 .collect(), 1005 domains: self 1006 .mapping 1007 .domains 1008 .read() 1009 .unwrap() 1010 .clone() 1011 .into_iter() 1012 .map(|(k, v)| (k, (v.mappings.into_iter().collect(), v.bypass))) 1013 .collect(), 1014 } 1015 } 1016 1017 fn update_bypass(&mut self) { 1018 // Use bypass from config if VIRTIO_IOMMU_F_BYPASS_CONFIG has been negotiated 1019 if !self 1020 .common 1021 .feature_acked(VIRTIO_IOMMU_F_BYPASS_CONFIG.into()) 1022 { 1023 return; 1024 } 1025 1026 let bypass = self.config.bypass == 1; 1027 info!("Updating bypass mode to {}", bypass); 1028 self.mapping.bypass.store(bypass, Ordering::Release); 1029 } 1030 1031 pub fn add_external_mapping(&mut self, device_id: u32, mapping: Arc<dyn ExternalDmaMapping>) { 1032 self.ext_mapping.lock().unwrap().insert(device_id, mapping); 1033 } 1034 1035 #[cfg(fuzzing)] 1036 pub fn wait_for_epoll_threads(&mut self) { 1037 self.common.wait_for_epoll_threads(); 1038 } 1039 } 1040 1041 impl Drop for Iommu { 1042 fn drop(&mut self) { 1043 if let Some(kill_evt) = self.common.kill_evt.take() { 1044 // Ignore the result because there is nothing we can do about it. 1045 let _ = kill_evt.write(1); 1046 } 1047 self.common.wait_for_epoll_threads(); 1048 } 1049 } 1050 1051 impl VirtioDevice for Iommu { 1052 fn device_type(&self) -> u32 { 1053 self.common.device_type 1054 } 1055 1056 fn queue_max_sizes(&self) -> &[u16] { 1057 &self.common.queue_sizes 1058 } 1059 1060 fn features(&self) -> u64 { 1061 self.common.avail_features 1062 } 1063 1064 fn ack_features(&mut self, value: u64) { 1065 self.common.ack_features(value) 1066 } 1067 1068 fn read_config(&self, offset: u64, data: &mut [u8]) { 1069 self.read_config_from_slice(self.config.as_slice(), offset, data); 1070 } 1071 1072 fn write_config(&mut self, offset: u64, data: &[u8]) { 1073 // The "bypass" field is the only mutable field 1074 let bypass_offset = 1075 (&self.config.bypass as *const _ as u64) - (&self.config as *const _ as u64); 1076 if offset != bypass_offset || data.len() != std::mem::size_of_val(&self.config.bypass) { 1077 error!( 1078 "Attempt to write to read-only field: offset {:x} length {}", 1079 offset, 1080 data.len() 1081 ); 1082 return; 1083 } 1084 1085 self.config.bypass = data[0]; 1086 1087 self.update_bypass(); 1088 } 1089 1090 fn activate( 1091 &mut self, 1092 mem: GuestMemoryAtomic<GuestMemoryMmap>, 1093 interrupt_cb: Arc<dyn VirtioInterrupt>, 1094 mut queues: Vec<(usize, Queue, EventFd)>, 1095 ) -> ActivateResult { 1096 self.common.activate(&queues, &interrupt_cb)?; 1097 let (kill_evt, pause_evt) = self.common.dup_eventfds(); 1098 1099 let (_, request_queue, request_queue_evt) = queues.remove(0); 1100 let (_, _event_queue, _event_queue_evt) = queues.remove(0); 1101 1102 let mut handler = IommuEpollHandler { 1103 mem, 1104 request_queue, 1105 _event_queue, 1106 interrupt_cb, 1107 request_queue_evt, 1108 _event_queue_evt, 1109 kill_evt, 1110 pause_evt, 1111 mapping: self.mapping.clone(), 1112 ext_mapping: self.ext_mapping.clone(), 1113 msi_iova_space: self.msi_iova_space, 1114 }; 1115 1116 let paused = self.common.paused.clone(); 1117 let paused_sync = self.common.paused_sync.clone(); 1118 let mut epoll_threads = Vec::new(); 1119 spawn_virtio_thread( 1120 &self.id, 1121 &self.seccomp_action, 1122 Thread::VirtioIommu, 1123 &mut epoll_threads, 1124 &self.exit_evt, 1125 move || handler.run(paused, paused_sync.unwrap()), 1126 )?; 1127 1128 self.common.epoll_threads = Some(epoll_threads); 1129 1130 event!("virtio-device", "activated", "id", &self.id); 1131 Ok(()) 1132 } 1133 1134 fn reset(&mut self) -> Option<Arc<dyn VirtioInterrupt>> { 1135 let result = self.common.reset(); 1136 event!("virtio-device", "reset", "id", &self.id); 1137 result 1138 } 1139 } 1140 1141 impl Pausable for Iommu { 1142 fn pause(&mut self) -> result::Result<(), MigratableError> { 1143 self.common.pause() 1144 } 1145 1146 fn resume(&mut self) -> result::Result<(), MigratableError> { 1147 self.common.resume() 1148 } 1149 } 1150 1151 impl Snapshottable for Iommu { 1152 fn id(&self) -> String { 1153 self.id.clone() 1154 } 1155 1156 fn snapshot(&mut self) -> std::result::Result<Snapshot, MigratableError> { 1157 Snapshot::new_from_versioned_state(&self.state()) 1158 } 1159 } 1160 impl Transportable for Iommu {} 1161 impl Migratable for Iommu {} 1162