1 // Copyright © 2019 Intel Corporation 2 // 3 // SPDX-License-Identifier: Apache-2.0 AND BSD-3-Clause 4 5 use std::collections::BTreeMap; 6 use std::mem::size_of; 7 use std::os::unix::io::AsRawFd; 8 use std::sync::atomic::{AtomicBool, Ordering}; 9 use std::sync::{Arc, Barrier, Mutex, RwLock}; 10 use std::{io, result}; 11 12 use anyhow::anyhow; 13 use seccompiler::SeccompAction; 14 use serde::{Deserialize, Serialize}; 15 use thiserror::Error; 16 use virtio_queue::{DescriptorChain, Queue, QueueT}; 17 use vm_device::dma_mapping::ExternalDmaMapping; 18 use vm_memory::{ 19 Address, ByteValued, Bytes, GuestAddress, GuestAddressSpace, GuestMemoryAtomic, 20 GuestMemoryError, GuestMemoryLoadGuard, 21 }; 22 use vm_migration::{Migratable, MigratableError, Pausable, Snapshot, Snapshottable, Transportable}; 23 use vm_virtio::AccessPlatform; 24 use vmm_sys_util::eventfd::EventFd; 25 26 use super::{ 27 ActivateResult, EpollHelper, EpollHelperError, EpollHelperHandler, Error as DeviceError, 28 VirtioCommon, VirtioDevice, VirtioDeviceType, EPOLL_HELPER_EVENT_LAST, VIRTIO_F_VERSION_1, 29 }; 30 use crate::seccomp_filters::Thread; 31 use crate::thread_helper::spawn_virtio_thread; 32 use crate::{DmaRemapping, GuestMemoryMmap, VirtioInterrupt, VirtioInterruptType}; 33 34 /// Queues sizes 35 const QUEUE_SIZE: u16 = 256; 36 const NUM_QUEUES: usize = 2; 37 const QUEUE_SIZES: &[u16] = &[QUEUE_SIZE; NUM_QUEUES]; 38 39 /// New descriptors are pending on the request queue. 40 /// "requestq" is meant to be used anytime an action is required to be 41 /// performed on behalf of the guest driver. 42 const REQUEST_Q_EVENT: u16 = EPOLL_HELPER_EVENT_LAST + 1; 43 /// New descriptors are pending on the event queue. 44 /// "eventq" lets the device report any fault or other asynchronous event to 45 /// the guest driver. 46 const _EVENT_Q_EVENT: u16 = EPOLL_HELPER_EVENT_LAST + 2; 47 48 /// PROBE properties size. 49 /// This is the minimal size to provide at least one RESV_MEM property. 50 /// Because virtio-iommu expects one MSI reserved region, we must provide it, 51 /// otherwise the driver in the guest will define a predefined one between 52 /// 0x8000000 and 0x80FFFFF, which is only relevant for ARM architecture, but 53 /// will conflict with x86. 54 const PROBE_PROP_SIZE: u32 = 55 (size_of::<VirtioIommuProbeProperty>() + size_of::<VirtioIommuProbeResvMem>()) as u32; 56 57 /// Virtio IOMMU features 58 #[allow(unused)] 59 const VIRTIO_IOMMU_F_INPUT_RANGE: u32 = 0; 60 #[allow(unused)] 61 const VIRTIO_IOMMU_F_DOMAIN_RANGE: u32 = 1; 62 #[allow(unused)] 63 const VIRTIO_IOMMU_F_MAP_UNMAP: u32 = 2; 64 #[allow(unused)] 65 const VIRTIO_IOMMU_F_BYPASS: u32 = 3; 66 const VIRTIO_IOMMU_F_PROBE: u32 = 4; 67 #[allow(unused)] 68 const VIRTIO_IOMMU_F_MMIO: u32 = 5; 69 const VIRTIO_IOMMU_F_BYPASS_CONFIG: u32 = 6; 70 71 // Support 2MiB and 4KiB page sizes. 72 const VIRTIO_IOMMU_PAGE_SIZE_MASK: u64 = (2 << 20) | (4 << 10); 73 74 #[derive(Copy, Clone, Debug, Default)] 75 #[repr(C, packed)] 76 #[allow(dead_code)] 77 struct VirtioIommuRange32 { 78 start: u32, 79 end: u32, 80 } 81 82 #[derive(Copy, Clone, Debug, Default)] 83 #[repr(C, packed)] 84 #[allow(dead_code)] 85 struct VirtioIommuRange64 { 86 start: u64, 87 end: u64, 88 } 89 90 #[derive(Copy, Clone, Debug, Default)] 91 #[repr(C, packed)] 92 #[allow(dead_code)] 93 struct VirtioIommuConfig { 94 page_size_mask: u64, 95 input_range: VirtioIommuRange64, 96 domain_range: VirtioIommuRange32, 97 probe_size: u32, 98 bypass: u8, 99 _reserved: [u8; 7], 100 } 101 102 /// Virtio IOMMU request type 103 const VIRTIO_IOMMU_T_ATTACH: u8 = 1; 104 const VIRTIO_IOMMU_T_DETACH: u8 = 2; 105 const VIRTIO_IOMMU_T_MAP: u8 = 3; 106 const VIRTIO_IOMMU_T_UNMAP: u8 = 4; 107 const VIRTIO_IOMMU_T_PROBE: u8 = 5; 108 109 #[derive(Copy, Clone, Debug, Default)] 110 #[repr(C, packed)] 111 struct VirtioIommuReqHead { 112 type_: u8, 113 _reserved: [u8; 3], 114 } 115 116 /// Virtio IOMMU request status 117 const VIRTIO_IOMMU_S_OK: u8 = 0; 118 #[allow(unused)] 119 const VIRTIO_IOMMU_S_IOERR: u8 = 1; 120 #[allow(unused)] 121 const VIRTIO_IOMMU_S_UNSUPP: u8 = 2; 122 #[allow(unused)] 123 const VIRTIO_IOMMU_S_DEVERR: u8 = 3; 124 #[allow(unused)] 125 const VIRTIO_IOMMU_S_INVAL: u8 = 4; 126 #[allow(unused)] 127 const VIRTIO_IOMMU_S_RANGE: u8 = 5; 128 #[allow(unused)] 129 const VIRTIO_IOMMU_S_NOENT: u8 = 6; 130 #[allow(unused)] 131 const VIRTIO_IOMMU_S_FAULT: u8 = 7; 132 #[allow(unused)] 133 const VIRTIO_IOMMU_S_NOMEM: u8 = 8; 134 135 #[derive(Copy, Clone, Debug, Default)] 136 #[repr(C, packed)] 137 #[allow(dead_code)] 138 struct VirtioIommuReqTail { 139 status: u8, 140 _reserved: [u8; 3], 141 } 142 143 /// ATTACH request 144 #[derive(Copy, Clone, Debug, Default)] 145 #[repr(C, packed)] 146 struct VirtioIommuReqAttach { 147 domain: u32, 148 endpoint: u32, 149 flags: u32, 150 _reserved: [u8; 4], 151 } 152 153 const VIRTIO_IOMMU_ATTACH_F_BYPASS: u32 = 1; 154 155 /// DETACH request 156 #[derive(Copy, Clone, Debug, Default)] 157 #[repr(C, packed)] 158 struct VirtioIommuReqDetach { 159 domain: u32, 160 endpoint: u32, 161 _reserved: [u8; 8], 162 } 163 164 /// Virtio IOMMU request MAP flags 165 #[allow(unused)] 166 const VIRTIO_IOMMU_MAP_F_READ: u32 = 1; 167 #[allow(unused)] 168 const VIRTIO_IOMMU_MAP_F_WRITE: u32 = 1 << 1; 169 #[allow(unused)] 170 const VIRTIO_IOMMU_MAP_F_MMIO: u32 = 1 << 2; 171 #[allow(unused)] 172 const VIRTIO_IOMMU_MAP_F_MASK: u32 = 173 VIRTIO_IOMMU_MAP_F_READ | VIRTIO_IOMMU_MAP_F_WRITE | VIRTIO_IOMMU_MAP_F_MMIO; 174 175 /// MAP request 176 #[derive(Copy, Clone, Debug, Default)] 177 #[repr(C, packed)] 178 struct VirtioIommuReqMap { 179 domain: u32, 180 virt_start: u64, 181 virt_end: u64, 182 phys_start: u64, 183 _flags: u32, 184 } 185 186 /// UNMAP request 187 #[derive(Copy, Clone, Debug, Default)] 188 #[repr(C, packed)] 189 struct VirtioIommuReqUnmap { 190 domain: u32, 191 virt_start: u64, 192 virt_end: u64, 193 _reserved: [u8; 4], 194 } 195 196 /// Virtio IOMMU request PROBE types 197 #[allow(unused)] 198 const VIRTIO_IOMMU_PROBE_T_NONE: u16 = 0; 199 const VIRTIO_IOMMU_PROBE_T_RESV_MEM: u16 = 1; 200 #[allow(unused)] 201 const VIRTIO_IOMMU_PROBE_T_MASK: u16 = 0xfff; 202 203 /// PROBE request 204 #[derive(Copy, Clone, Debug, Default)] 205 #[repr(C, packed)] 206 #[allow(dead_code)] 207 struct VirtioIommuReqProbe { 208 endpoint: u32, 209 _reserved: [u64; 8], 210 } 211 212 #[derive(Copy, Clone, Debug, Default)] 213 #[repr(C, packed)] 214 #[allow(dead_code)] 215 struct VirtioIommuProbeProperty { 216 type_: u16, 217 length: u16, 218 } 219 220 /// Virtio IOMMU request PROBE property RESV_MEM subtypes 221 #[allow(unused)] 222 const VIRTIO_IOMMU_RESV_MEM_T_RESERVED: u8 = 0; 223 const VIRTIO_IOMMU_RESV_MEM_T_MSI: u8 = 1; 224 225 #[derive(Copy, Clone, Debug, Default)] 226 #[repr(C, packed)] 227 #[allow(dead_code)] 228 struct VirtioIommuProbeResvMem { 229 subtype: u8, 230 _reserved: [u8; 3], 231 start: u64, 232 end: u64, 233 } 234 235 /// Virtio IOMMU fault flags 236 #[allow(unused)] 237 const VIRTIO_IOMMU_FAULT_F_READ: u32 = 1; 238 #[allow(unused)] 239 const VIRTIO_IOMMU_FAULT_F_WRITE: u32 = 1 << 1; 240 #[allow(unused)] 241 const VIRTIO_IOMMU_FAULT_F_EXEC: u32 = 1 << 2; 242 #[allow(unused)] 243 const VIRTIO_IOMMU_FAULT_F_ADDRESS: u32 = 1 << 8; 244 245 /// Virtio IOMMU fault reasons 246 #[allow(unused)] 247 const VIRTIO_IOMMU_FAULT_R_UNKNOWN: u32 = 0; 248 #[allow(unused)] 249 const VIRTIO_IOMMU_FAULT_R_DOMAIN: u32 = 1; 250 #[allow(unused)] 251 const VIRTIO_IOMMU_FAULT_R_MAPPING: u32 = 2; 252 253 /// Fault reporting through eventq 254 #[allow(unused)] 255 #[derive(Copy, Clone, Debug, Default)] 256 #[repr(C, packed)] 257 struct VirtioIommuFault { 258 reason: u8, 259 reserved: [u8; 3], 260 flags: u32, 261 endpoint: u32, 262 reserved2: [u8; 4], 263 address: u64, 264 } 265 266 // SAFETY: data structure only contain integers and have no implicit padding 267 unsafe impl ByteValued for VirtioIommuRange32 {} 268 // SAFETY: data structure only contain integers and have no implicit padding 269 unsafe impl ByteValued for VirtioIommuRange64 {} 270 // SAFETY: data structure only contain integers and have no implicit padding 271 unsafe impl ByteValued for VirtioIommuConfig {} 272 // SAFETY: data structure only contain integers and have no implicit padding 273 unsafe impl ByteValued for VirtioIommuReqHead {} 274 // SAFETY: data structure only contain integers and have no implicit padding 275 unsafe impl ByteValued for VirtioIommuReqTail {} 276 // SAFETY: data structure only contain integers and have no implicit padding 277 unsafe impl ByteValued for VirtioIommuReqAttach {} 278 // SAFETY: data structure only contain integers and have no implicit padding 279 unsafe impl ByteValued for VirtioIommuReqDetach {} 280 // SAFETY: data structure only contain integers and have no implicit padding 281 unsafe impl ByteValued for VirtioIommuReqMap {} 282 // SAFETY: data structure only contain integers and have no implicit padding 283 unsafe impl ByteValued for VirtioIommuReqUnmap {} 284 // SAFETY: data structure only contain integers and have no implicit padding 285 unsafe impl ByteValued for VirtioIommuReqProbe {} 286 // SAFETY: data structure only contain integers and have no implicit padding 287 unsafe impl ByteValued for VirtioIommuProbeProperty {} 288 // SAFETY: data structure only contain integers and have no implicit padding 289 unsafe impl ByteValued for VirtioIommuProbeResvMem {} 290 // SAFETY: data structure only contain integers and have no implicit padding 291 unsafe impl ByteValued for VirtioIommuFault {} 292 293 #[derive(Error, Debug)] 294 enum Error { 295 #[error("Guest gave us bad memory addresses: {0}")] 296 GuestMemory(GuestMemoryError), 297 #[error("Guest gave us a write only descriptor that protocol says to read from")] 298 UnexpectedWriteOnlyDescriptor, 299 #[error("Guest gave us a read only descriptor that protocol says to write to")] 300 UnexpectedReadOnlyDescriptor, 301 #[error("Guest gave us too few descriptors in a descriptor chain")] 302 DescriptorChainTooShort, 303 #[error("Guest gave us a buffer that was too short to use")] 304 BufferLengthTooSmall, 305 #[error("Guest sent us invalid request")] 306 InvalidRequest, 307 #[error("Guest sent us invalid ATTACH request")] 308 InvalidAttachRequest, 309 #[error("Guest sent us invalid DETACH request")] 310 InvalidDetachRequest, 311 #[error("Guest sent us invalid MAP request")] 312 InvalidMapRequest, 313 #[error("Invalid to map because the domain is in bypass mode")] 314 InvalidMapRequestBypassDomain, 315 #[error("Invalid to map because the domain is missing")] 316 InvalidMapRequestMissingDomain, 317 #[error("Guest sent us invalid UNMAP request")] 318 InvalidUnmapRequest, 319 #[error("Invalid to unmap because the domain is in bypass mode")] 320 InvalidUnmapRequestBypassDomain, 321 #[error("Invalid to unmap because the domain is missing")] 322 InvalidUnmapRequestMissingDomain, 323 #[error("Guest sent us invalid PROBE request")] 324 InvalidProbeRequest, 325 #[error("Failed to performing external mapping: {0}")] 326 ExternalMapping(io::Error), 327 #[error("Failed to performing external unmapping: {0}")] 328 ExternalUnmapping(io::Error), 329 #[error("Failed adding used index: {0}")] 330 QueueAddUsed(virtio_queue::Error), 331 } 332 333 struct Request {} 334 335 impl Request { 336 // Parse the available vring buffer. Based on the hashmap table of external 337 // mappings required from various devices such as VFIO or vhost-user ones, 338 // this function might update the hashmap table of external mappings per 339 // domain. 340 // Basically, the VMM knows about the device_id <=> mapping relationship 341 // before running the VM, but at runtime, a new domain <=> mapping hashmap 342 // is created based on the information provided from the guest driver for 343 // virtio-iommu (giving the link device_id <=> domain). 344 fn parse( 345 desc_chain: &mut DescriptorChain<GuestMemoryLoadGuard<GuestMemoryMmap>>, 346 mapping: &Arc<IommuMapping>, 347 ext_mapping: &BTreeMap<u32, Arc<dyn ExternalDmaMapping>>, 348 msi_iova_space: (u64, u64), 349 ) -> result::Result<usize, Error> { 350 let desc = desc_chain 351 .next() 352 .ok_or(Error::DescriptorChainTooShort) 353 .inspect_err(|_| { 354 error!("Missing head descriptor"); 355 })?; 356 357 // The descriptor contains the request type which MUST be readable. 358 if desc.is_write_only() { 359 return Err(Error::UnexpectedWriteOnlyDescriptor); 360 } 361 362 if (desc.len() as usize) < size_of::<VirtioIommuReqHead>() { 363 return Err(Error::InvalidRequest); 364 } 365 366 let req_head: VirtioIommuReqHead = desc_chain 367 .memory() 368 .read_obj(desc.addr()) 369 .map_err(Error::GuestMemory)?; 370 let req_offset = size_of::<VirtioIommuReqHead>(); 371 let desc_size_left = (desc.len() as usize) - req_offset; 372 let req_addr = if let Some(addr) = desc.addr().checked_add(req_offset as u64) { 373 addr 374 } else { 375 return Err(Error::InvalidRequest); 376 }; 377 378 let (msi_iova_start, msi_iova_end) = msi_iova_space; 379 380 // Create the reply 381 let mut reply: Vec<u8> = Vec::new(); 382 let mut status = VIRTIO_IOMMU_S_OK; 383 let mut hdr_len = 0; 384 385 let result = (|| { 386 match req_head.type_ { 387 VIRTIO_IOMMU_T_ATTACH => { 388 if desc_size_left != size_of::<VirtioIommuReqAttach>() { 389 status = VIRTIO_IOMMU_S_INVAL; 390 return Err(Error::InvalidAttachRequest); 391 } 392 393 let req: VirtioIommuReqAttach = desc_chain 394 .memory() 395 .read_obj(req_addr as GuestAddress) 396 .map_err(Error::GuestMemory)?; 397 debug!("Attach request 0x{:x?}", req); 398 399 // Copy the value to use it as a proper reference. 400 let domain_id = req.domain; 401 let endpoint = req.endpoint; 402 let bypass = 403 (req.flags & VIRTIO_IOMMU_ATTACH_F_BYPASS) == VIRTIO_IOMMU_ATTACH_F_BYPASS; 404 405 let mut old_domain_id = domain_id; 406 if let Some(&id) = mapping.endpoints.read().unwrap().get(&endpoint) { 407 old_domain_id = id; 408 } 409 410 if old_domain_id != domain_id { 411 detach_endpoint_from_domain(endpoint, old_domain_id, mapping, ext_mapping)?; 412 } 413 414 // Add endpoint associated with specific domain 415 mapping 416 .endpoints 417 .write() 418 .unwrap() 419 .insert(endpoint, domain_id); 420 421 // If any other mappings exist in the domain for other containers, 422 // make sure to issue these mappings for the new endpoint/container 423 if let Some(domain_mappings) = &mapping.domains.read().unwrap().get(&domain_id) 424 { 425 if let Some(ext_map) = ext_mapping.get(&endpoint) { 426 for (virt_start, addr_map) in &domain_mappings.mappings { 427 ext_map 428 .map(*virt_start, addr_map.gpa, addr_map.size) 429 .map_err(Error::ExternalUnmapping)?; 430 } 431 } 432 } 433 434 // Add new domain with no mapping if the entry didn't exist yet 435 let mut domains = mapping.domains.write().unwrap(); 436 let domain = Domain { 437 mappings: BTreeMap::new(), 438 bypass, 439 }; 440 domains.entry(domain_id).or_insert_with(|| domain); 441 } 442 VIRTIO_IOMMU_T_DETACH => { 443 if desc_size_left != size_of::<VirtioIommuReqDetach>() { 444 status = VIRTIO_IOMMU_S_INVAL; 445 return Err(Error::InvalidDetachRequest); 446 } 447 448 let req: VirtioIommuReqDetach = desc_chain 449 .memory() 450 .read_obj(req_addr as GuestAddress) 451 .map_err(Error::GuestMemory)?; 452 debug!("Detach request 0x{:x?}", req); 453 454 // Copy the value to use it as a proper reference. 455 let domain_id = req.domain; 456 let endpoint = req.endpoint; 457 458 // Remove endpoint associated with specific domain 459 detach_endpoint_from_domain(endpoint, domain_id, mapping, ext_mapping)?; 460 } 461 VIRTIO_IOMMU_T_MAP => { 462 if desc_size_left != size_of::<VirtioIommuReqMap>() { 463 status = VIRTIO_IOMMU_S_INVAL; 464 return Err(Error::InvalidMapRequest); 465 } 466 467 let req: VirtioIommuReqMap = desc_chain 468 .memory() 469 .read_obj(req_addr as GuestAddress) 470 .map_err(Error::GuestMemory)?; 471 debug!("Map request 0x{:x?}", req); 472 473 // Copy the value to use it as a proper reference. 474 let domain_id = req.domain; 475 476 if let Some(domain) = mapping.domains.read().unwrap().get(&domain_id) { 477 if domain.bypass { 478 status = VIRTIO_IOMMU_S_INVAL; 479 return Err(Error::InvalidMapRequestBypassDomain); 480 } 481 } else { 482 status = VIRTIO_IOMMU_S_INVAL; 483 return Err(Error::InvalidMapRequestMissingDomain); 484 } 485 486 // Find the list of endpoints attached to the given domain. 487 let endpoints: Vec<u32> = mapping 488 .endpoints 489 .write() 490 .unwrap() 491 .iter() 492 .filter(|(_, &d)| d == domain_id) 493 .map(|(&e, _)| e) 494 .collect(); 495 496 // For viommu all endpoints receive their own VFIO container, as a result 497 // Each endpoint within the domain needs to be separately mapped, as the 498 // mapping is done on a per-container level, not a per-domain level 499 for endpoint in endpoints { 500 if let Some(ext_map) = ext_mapping.get(&endpoint) { 501 let size = req.virt_end - req.virt_start + 1; 502 ext_map 503 .map(req.virt_start, req.phys_start, size) 504 .map_err(Error::ExternalMapping)?; 505 } 506 } 507 508 // Add new mapping associated with the domain 509 mapping 510 .domains 511 .write() 512 .unwrap() 513 .get_mut(&domain_id) 514 .unwrap() 515 .mappings 516 .insert( 517 req.virt_start, 518 Mapping { 519 gpa: req.phys_start, 520 size: req.virt_end - req.virt_start + 1, 521 }, 522 ); 523 } 524 VIRTIO_IOMMU_T_UNMAP => { 525 if desc_size_left != size_of::<VirtioIommuReqUnmap>() { 526 status = VIRTIO_IOMMU_S_INVAL; 527 return Err(Error::InvalidUnmapRequest); 528 } 529 530 let req: VirtioIommuReqUnmap = desc_chain 531 .memory() 532 .read_obj(req_addr as GuestAddress) 533 .map_err(Error::GuestMemory)?; 534 debug!("Unmap request 0x{:x?}", req); 535 536 // Copy the value to use it as a proper reference. 537 let domain_id = req.domain; 538 let virt_start = req.virt_start; 539 540 if let Some(domain) = mapping.domains.read().unwrap().get(&domain_id) { 541 if domain.bypass { 542 status = VIRTIO_IOMMU_S_INVAL; 543 return Err(Error::InvalidUnmapRequestBypassDomain); 544 } 545 } else { 546 status = VIRTIO_IOMMU_S_INVAL; 547 return Err(Error::InvalidUnmapRequestMissingDomain); 548 } 549 550 // Find the list of endpoints attached to the given domain. 551 let endpoints: Vec<u32> = mapping 552 .endpoints 553 .write() 554 .unwrap() 555 .iter() 556 .filter(|(_, &d)| d == domain_id) 557 .map(|(&e, _)| e) 558 .collect(); 559 560 // Trigger external unmapping if necessary. 561 for endpoint in endpoints { 562 if let Some(ext_map) = ext_mapping.get(&endpoint) { 563 let size = req.virt_end - virt_start + 1; 564 ext_map 565 .unmap(virt_start, size) 566 .map_err(Error::ExternalUnmapping)?; 567 } 568 } 569 570 // Remove all mappings associated with the domain within the requested range 571 mapping 572 .domains 573 .write() 574 .unwrap() 575 .get_mut(&domain_id) 576 .unwrap() 577 .mappings 578 .retain(|&x, _| (x < req.virt_start || x > req.virt_end)); 579 } 580 VIRTIO_IOMMU_T_PROBE => { 581 if desc_size_left != size_of::<VirtioIommuReqProbe>() { 582 status = VIRTIO_IOMMU_S_INVAL; 583 return Err(Error::InvalidProbeRequest); 584 } 585 586 let req: VirtioIommuReqProbe = desc_chain 587 .memory() 588 .read_obj(req_addr as GuestAddress) 589 .map_err(Error::GuestMemory)?; 590 debug!("Probe request 0x{:x?}", req); 591 592 let probe_prop = VirtioIommuProbeProperty { 593 type_: VIRTIO_IOMMU_PROBE_T_RESV_MEM, 594 length: size_of::<VirtioIommuProbeResvMem>() as u16, 595 }; 596 reply.extend_from_slice(probe_prop.as_slice()); 597 598 let resv_mem = VirtioIommuProbeResvMem { 599 subtype: VIRTIO_IOMMU_RESV_MEM_T_MSI, 600 start: msi_iova_start, 601 end: msi_iova_end, 602 ..Default::default() 603 }; 604 reply.extend_from_slice(resv_mem.as_slice()); 605 606 hdr_len = PROBE_PROP_SIZE; 607 } 608 _ => { 609 status = VIRTIO_IOMMU_S_INVAL; 610 return Err(Error::InvalidRequest); 611 } 612 } 613 Ok(()) 614 })(); 615 616 let status_desc = desc_chain.next().ok_or(Error::DescriptorChainTooShort)?; 617 618 // The status MUST always be writable 619 if !status_desc.is_write_only() { 620 return Err(Error::UnexpectedReadOnlyDescriptor); 621 } 622 623 if status_desc.len() < hdr_len + size_of::<VirtioIommuReqTail>() as u32 { 624 return Err(Error::BufferLengthTooSmall); 625 } 626 627 let tail = VirtioIommuReqTail { 628 status, 629 ..Default::default() 630 }; 631 reply.extend_from_slice(tail.as_slice()); 632 633 // Make sure we return the result of the request to the guest before 634 // we return a potential error internally. 635 desc_chain 636 .memory() 637 .write_slice(reply.as_slice(), status_desc.addr()) 638 .map_err(Error::GuestMemory)?; 639 640 // Return the error if the result was not Ok(). 641 result?; 642 643 Ok((hdr_len as usize) + size_of::<VirtioIommuReqTail>()) 644 } 645 } 646 647 fn detach_endpoint_from_domain( 648 endpoint: u32, 649 domain_id: u32, 650 mapping: &Arc<IommuMapping>, 651 ext_mapping: &BTreeMap<u32, Arc<dyn ExternalDmaMapping>>, 652 ) -> result::Result<(), Error> { 653 // Remove endpoint associated with specific domain 654 mapping.endpoints.write().unwrap().remove(&endpoint); 655 656 // Trigger external unmapping for the endpoint if necessary. 657 if let Some(domain_mappings) = &mapping.domains.read().unwrap().get(&domain_id) { 658 if let Some(ext_map) = ext_mapping.get(&endpoint) { 659 for (virt_start, addr_map) in &domain_mappings.mappings { 660 ext_map 661 .unmap(*virt_start, addr_map.size) 662 .map_err(Error::ExternalUnmapping)?; 663 } 664 } 665 } 666 667 if mapping 668 .endpoints 669 .write() 670 .unwrap() 671 .iter() 672 .filter(|(_, &d)| d == domain_id) 673 .count() 674 == 0 675 { 676 mapping.domains.write().unwrap().remove(&domain_id); 677 } 678 679 Ok(()) 680 } 681 682 struct IommuEpollHandler { 683 mem: GuestMemoryAtomic<GuestMemoryMmap>, 684 request_queue: Queue, 685 _event_queue: Queue, 686 interrupt_cb: Arc<dyn VirtioInterrupt>, 687 request_queue_evt: EventFd, 688 _event_queue_evt: EventFd, 689 kill_evt: EventFd, 690 pause_evt: EventFd, 691 mapping: Arc<IommuMapping>, 692 ext_mapping: Arc<Mutex<BTreeMap<u32, Arc<dyn ExternalDmaMapping>>>>, 693 msi_iova_space: (u64, u64), 694 } 695 696 impl IommuEpollHandler { 697 fn request_queue(&mut self) -> Result<bool, Error> { 698 let mut used_descs = false; 699 while let Some(mut desc_chain) = self.request_queue.pop_descriptor_chain(self.mem.memory()) 700 { 701 let len = Request::parse( 702 &mut desc_chain, 703 &self.mapping, 704 &self.ext_mapping.lock().unwrap(), 705 self.msi_iova_space, 706 )?; 707 708 self.request_queue 709 .add_used(desc_chain.memory(), desc_chain.head_index(), len as u32) 710 .map_err(Error::QueueAddUsed)?; 711 712 used_descs = true; 713 } 714 715 Ok(used_descs) 716 } 717 718 fn signal_used_queue(&self, queue_index: u16) -> result::Result<(), DeviceError> { 719 self.interrupt_cb 720 .trigger(VirtioInterruptType::Queue(queue_index)) 721 .map_err(|e| { 722 error!("Failed to signal used queue: {:?}", e); 723 DeviceError::FailedSignalingUsedQueue(e) 724 }) 725 } 726 727 fn run( 728 &mut self, 729 paused: Arc<AtomicBool>, 730 paused_sync: Arc<Barrier>, 731 ) -> result::Result<(), EpollHelperError> { 732 let mut helper = EpollHelper::new(&self.kill_evt, &self.pause_evt)?; 733 helper.add_event(self.request_queue_evt.as_raw_fd(), REQUEST_Q_EVENT)?; 734 helper.run(paused, paused_sync, self)?; 735 736 Ok(()) 737 } 738 } 739 740 impl EpollHelperHandler for IommuEpollHandler { 741 fn handle_event( 742 &mut self, 743 _helper: &mut EpollHelper, 744 event: &epoll::Event, 745 ) -> result::Result<(), EpollHelperError> { 746 let ev_type = event.data as u16; 747 match ev_type { 748 REQUEST_Q_EVENT => { 749 self.request_queue_evt.read().map_err(|e| { 750 EpollHelperError::HandleEvent(anyhow!("Failed to get queue event: {:?}", e)) 751 })?; 752 753 let needs_notification = self.request_queue().map_err(|e| { 754 EpollHelperError::HandleEvent(anyhow!( 755 "Failed to process request queue : {:?}", 756 e 757 )) 758 })?; 759 if needs_notification { 760 self.signal_used_queue(0).map_err(|e| { 761 EpollHelperError::HandleEvent(anyhow!( 762 "Failed to signal used queue: {:?}", 763 e 764 )) 765 })?; 766 } 767 } 768 _ => { 769 return Err(EpollHelperError::HandleEvent(anyhow!( 770 "Unexpected event: {}", 771 ev_type 772 ))); 773 } 774 } 775 Ok(()) 776 } 777 } 778 779 #[derive(Clone, Copy, Debug, Serialize, Deserialize)] 780 struct Mapping { 781 gpa: u64, 782 size: u64, 783 } 784 785 #[derive(Clone, Debug)] 786 struct Domain { 787 mappings: BTreeMap<u64, Mapping>, 788 bypass: bool, 789 } 790 791 #[derive(Debug)] 792 pub struct IommuMapping { 793 // Domain related to an endpoint. 794 endpoints: Arc<RwLock<BTreeMap<u32, u32>>>, 795 // Information related to each domain. 796 domains: Arc<RwLock<BTreeMap<u32, Domain>>>, 797 // Global flag indicating if endpoints that are not attached to any domain 798 // are in bypass mode. 799 bypass: AtomicBool, 800 } 801 802 impl DmaRemapping for IommuMapping { 803 fn translate_gva(&self, id: u32, addr: u64) -> std::result::Result<u64, std::io::Error> { 804 debug!("Translate GVA addr 0x{:x}", addr); 805 if let Some(domain_id) = self.endpoints.read().unwrap().get(&id) { 806 if let Some(domain) = self.domains.read().unwrap().get(domain_id) { 807 // Directly return identity mapping in case the domain is in 808 // bypass mode. 809 if domain.bypass { 810 return Ok(addr); 811 } 812 813 for (&key, &value) in domain.mappings.iter() { 814 if addr >= key && addr < key + value.size { 815 let new_addr = addr - key + value.gpa; 816 debug!("Into GPA addr 0x{:x}", new_addr); 817 return Ok(new_addr); 818 } 819 } 820 } 821 } else if self.bypass.load(Ordering::Acquire) { 822 return Ok(addr); 823 } 824 825 Err(io::Error::new( 826 io::ErrorKind::Other, 827 format!("failed to translate GVA addr 0x{addr:x}"), 828 )) 829 } 830 831 fn translate_gpa(&self, id: u32, addr: u64) -> std::result::Result<u64, std::io::Error> { 832 debug!("Translate GPA addr 0x{:x}", addr); 833 if let Some(domain_id) = self.endpoints.read().unwrap().get(&id) { 834 if let Some(domain) = self.domains.read().unwrap().get(domain_id) { 835 // Directly return identity mapping in case the domain is in 836 // bypass mode. 837 if domain.bypass { 838 return Ok(addr); 839 } 840 841 for (&key, &value) in domain.mappings.iter() { 842 if addr >= value.gpa && addr < value.gpa + value.size { 843 let new_addr = addr - value.gpa + key; 844 debug!("Into GVA addr 0x{:x}", new_addr); 845 return Ok(new_addr); 846 } 847 } 848 } 849 } else if self.bypass.load(Ordering::Acquire) { 850 return Ok(addr); 851 } 852 853 Err(io::Error::new( 854 io::ErrorKind::Other, 855 format!("failed to translate GPA addr 0x{addr:x}"), 856 )) 857 } 858 } 859 860 #[derive(Debug)] 861 pub struct AccessPlatformMapping { 862 id: u32, 863 mapping: Arc<IommuMapping>, 864 } 865 866 impl AccessPlatformMapping { 867 pub fn new(id: u32, mapping: Arc<IommuMapping>) -> Self { 868 AccessPlatformMapping { id, mapping } 869 } 870 } 871 872 impl AccessPlatform for AccessPlatformMapping { 873 fn translate_gva(&self, base: u64, _size: u64) -> std::result::Result<u64, std::io::Error> { 874 self.mapping.translate_gva(self.id, base) 875 } 876 fn translate_gpa(&self, base: u64, _size: u64) -> std::result::Result<u64, std::io::Error> { 877 self.mapping.translate_gpa(self.id, base) 878 } 879 } 880 881 pub struct Iommu { 882 common: VirtioCommon, 883 id: String, 884 config: VirtioIommuConfig, 885 mapping: Arc<IommuMapping>, 886 ext_mapping: Arc<Mutex<BTreeMap<u32, Arc<dyn ExternalDmaMapping>>>>, 887 seccomp_action: SeccompAction, 888 exit_evt: EventFd, 889 msi_iova_space: (u64, u64), 890 } 891 892 type EndpointsState = Vec<(u32, u32)>; 893 type DomainsState = Vec<(u32, (Vec<(u64, Mapping)>, bool))>; 894 895 #[derive(Serialize, Deserialize)] 896 pub struct IommuState { 897 avail_features: u64, 898 acked_features: u64, 899 endpoints: EndpointsState, 900 domains: DomainsState, 901 } 902 903 impl Iommu { 904 pub fn new( 905 id: String, 906 seccomp_action: SeccompAction, 907 exit_evt: EventFd, 908 msi_iova_space: (u64, u64), 909 address_width_bits: u8, 910 state: Option<IommuState>, 911 ) -> io::Result<(Self, Arc<IommuMapping>)> { 912 let (mut avail_features, acked_features, endpoints, domains, paused) = 913 if let Some(state) = state { 914 info!("Restoring virtio-iommu {}", id); 915 ( 916 state.avail_features, 917 state.acked_features, 918 state.endpoints.into_iter().collect(), 919 state 920 .domains 921 .into_iter() 922 .map(|(k, v)| { 923 ( 924 k, 925 Domain { 926 mappings: v.0.into_iter().collect(), 927 bypass: v.1, 928 }, 929 ) 930 }) 931 .collect(), 932 true, 933 ) 934 } else { 935 let avail_features = (1u64 << VIRTIO_F_VERSION_1) 936 | (1u64 << VIRTIO_IOMMU_F_MAP_UNMAP) 937 | (1u64 << VIRTIO_IOMMU_F_PROBE) 938 | (1u64 << VIRTIO_IOMMU_F_BYPASS_CONFIG); 939 940 (avail_features, 0, BTreeMap::new(), BTreeMap::new(), false) 941 }; 942 943 let mut config = VirtioIommuConfig { 944 page_size_mask: VIRTIO_IOMMU_PAGE_SIZE_MASK, 945 probe_size: PROBE_PROP_SIZE, 946 ..Default::default() 947 }; 948 949 if address_width_bits < 64 { 950 avail_features |= 1u64 << VIRTIO_IOMMU_F_INPUT_RANGE; 951 config.input_range = VirtioIommuRange64 { 952 start: 0, 953 end: (1u64 << address_width_bits) - 1, 954 } 955 } 956 957 let mapping = Arc::new(IommuMapping { 958 endpoints: Arc::new(RwLock::new(endpoints)), 959 domains: Arc::new(RwLock::new(domains)), 960 bypass: AtomicBool::new(true), 961 }); 962 963 Ok(( 964 Iommu { 965 id, 966 common: VirtioCommon { 967 device_type: VirtioDeviceType::Iommu as u32, 968 queue_sizes: QUEUE_SIZES.to_vec(), 969 avail_features, 970 acked_features, 971 paused_sync: Some(Arc::new(Barrier::new(2))), 972 min_queues: NUM_QUEUES as u16, 973 paused: Arc::new(AtomicBool::new(paused)), 974 ..Default::default() 975 }, 976 config, 977 mapping: mapping.clone(), 978 ext_mapping: Arc::new(Mutex::new(BTreeMap::new())), 979 seccomp_action, 980 exit_evt, 981 msi_iova_space, 982 }, 983 mapping, 984 )) 985 } 986 987 fn state(&self) -> IommuState { 988 IommuState { 989 avail_features: self.common.avail_features, 990 acked_features: self.common.acked_features, 991 endpoints: self 992 .mapping 993 .endpoints 994 .read() 995 .unwrap() 996 .clone() 997 .into_iter() 998 .collect(), 999 domains: self 1000 .mapping 1001 .domains 1002 .read() 1003 .unwrap() 1004 .clone() 1005 .into_iter() 1006 .map(|(k, v)| (k, (v.mappings.into_iter().collect(), v.bypass))) 1007 .collect(), 1008 } 1009 } 1010 1011 fn update_bypass(&mut self) { 1012 // Use bypass from config if VIRTIO_IOMMU_F_BYPASS_CONFIG has been negotiated 1013 if !self 1014 .common 1015 .feature_acked(VIRTIO_IOMMU_F_BYPASS_CONFIG.into()) 1016 { 1017 return; 1018 } 1019 1020 let bypass = self.config.bypass == 1; 1021 info!("Updating bypass mode to {}", bypass); 1022 self.mapping.bypass.store(bypass, Ordering::Release); 1023 } 1024 1025 pub fn add_external_mapping(&mut self, device_id: u32, mapping: Arc<dyn ExternalDmaMapping>) { 1026 self.ext_mapping.lock().unwrap().insert(device_id, mapping); 1027 } 1028 1029 #[cfg(fuzzing)] 1030 pub fn wait_for_epoll_threads(&mut self) { 1031 self.common.wait_for_epoll_threads(); 1032 } 1033 } 1034 1035 impl Drop for Iommu { 1036 fn drop(&mut self) { 1037 if let Some(kill_evt) = self.common.kill_evt.take() { 1038 // Ignore the result because there is nothing we can do about it. 1039 let _ = kill_evt.write(1); 1040 } 1041 self.common.wait_for_epoll_threads(); 1042 } 1043 } 1044 1045 impl VirtioDevice for Iommu { 1046 fn device_type(&self) -> u32 { 1047 self.common.device_type 1048 } 1049 1050 fn queue_max_sizes(&self) -> &[u16] { 1051 &self.common.queue_sizes 1052 } 1053 1054 fn features(&self) -> u64 { 1055 self.common.avail_features 1056 } 1057 1058 fn ack_features(&mut self, value: u64) { 1059 self.common.ack_features(value) 1060 } 1061 1062 fn read_config(&self, offset: u64, data: &mut [u8]) { 1063 self.read_config_from_slice(self.config.as_slice(), offset, data); 1064 } 1065 1066 fn write_config(&mut self, offset: u64, data: &[u8]) { 1067 // The "bypass" field is the only mutable field 1068 let bypass_offset = 1069 (&self.config.bypass as *const _ as u64) - (&self.config as *const _ as u64); 1070 if offset != bypass_offset || data.len() != std::mem::size_of_val(&self.config.bypass) { 1071 error!( 1072 "Attempt to write to read-only field: offset {:x} length {}", 1073 offset, 1074 data.len() 1075 ); 1076 return; 1077 } 1078 1079 self.config.bypass = data[0]; 1080 1081 self.update_bypass(); 1082 } 1083 1084 fn activate( 1085 &mut self, 1086 mem: GuestMemoryAtomic<GuestMemoryMmap>, 1087 interrupt_cb: Arc<dyn VirtioInterrupt>, 1088 mut queues: Vec<(usize, Queue, EventFd)>, 1089 ) -> ActivateResult { 1090 self.common.activate(&queues, &interrupt_cb)?; 1091 let (kill_evt, pause_evt) = self.common.dup_eventfds(); 1092 1093 let (_, request_queue, request_queue_evt) = queues.remove(0); 1094 let (_, _event_queue, _event_queue_evt) = queues.remove(0); 1095 1096 let mut handler = IommuEpollHandler { 1097 mem, 1098 request_queue, 1099 _event_queue, 1100 interrupt_cb, 1101 request_queue_evt, 1102 _event_queue_evt, 1103 kill_evt, 1104 pause_evt, 1105 mapping: self.mapping.clone(), 1106 ext_mapping: self.ext_mapping.clone(), 1107 msi_iova_space: self.msi_iova_space, 1108 }; 1109 1110 let paused = self.common.paused.clone(); 1111 let paused_sync = self.common.paused_sync.clone(); 1112 let mut epoll_threads = Vec::new(); 1113 spawn_virtio_thread( 1114 &self.id, 1115 &self.seccomp_action, 1116 Thread::VirtioIommu, 1117 &mut epoll_threads, 1118 &self.exit_evt, 1119 move || handler.run(paused, paused_sync.unwrap()), 1120 )?; 1121 1122 self.common.epoll_threads = Some(epoll_threads); 1123 1124 event!("virtio-device", "activated", "id", &self.id); 1125 Ok(()) 1126 } 1127 1128 fn reset(&mut self) -> Option<Arc<dyn VirtioInterrupt>> { 1129 let result = self.common.reset(); 1130 event!("virtio-device", "reset", "id", &self.id); 1131 result 1132 } 1133 } 1134 1135 impl Pausable for Iommu { 1136 fn pause(&mut self) -> result::Result<(), MigratableError> { 1137 self.common.pause() 1138 } 1139 1140 fn resume(&mut self) -> result::Result<(), MigratableError> { 1141 self.common.resume() 1142 } 1143 } 1144 1145 impl Snapshottable for Iommu { 1146 fn id(&self) -> String { 1147 self.id.clone() 1148 } 1149 1150 fn snapshot(&mut self) -> std::result::Result<Snapshot, MigratableError> { 1151 Snapshot::new_from_state(&self.state()) 1152 } 1153 } 1154 impl Transportable for Iommu {} 1155 impl Migratable for Iommu {} 1156