1 // Copyright © 2019 Intel Corporation 2 // 3 // SPDX-License-Identifier: Apache-2.0 AND BSD-3-Clause 4 5 use std::collections::BTreeMap; 6 use std::mem::size_of; 7 use std::ops::Bound::Included; 8 use std::os::unix::io::AsRawFd; 9 use std::sync::atomic::{AtomicBool, Ordering}; 10 use std::sync::{Arc, Barrier, Mutex, RwLock}; 11 use std::{io, result}; 12 13 use anyhow::anyhow; 14 use seccompiler::SeccompAction; 15 use serde::{Deserialize, Serialize}; 16 use thiserror::Error; 17 use virtio_queue::{DescriptorChain, Queue, QueueT}; 18 use vm_device::dma_mapping::ExternalDmaMapping; 19 use vm_memory::{ 20 Address, ByteValued, Bytes, GuestAddress, GuestAddressSpace, GuestMemoryAtomic, 21 GuestMemoryError, GuestMemoryLoadGuard, 22 }; 23 use vm_migration::{Migratable, MigratableError, Pausable, Snapshot, Snapshottable, Transportable}; 24 use vm_virtio::AccessPlatform; 25 use vmm_sys_util::eventfd::EventFd; 26 27 use super::{ 28 ActivateResult, EpollHelper, EpollHelperError, EpollHelperHandler, Error as DeviceError, 29 VirtioCommon, VirtioDevice, VirtioDeviceType, EPOLL_HELPER_EVENT_LAST, VIRTIO_F_VERSION_1, 30 }; 31 use crate::seccomp_filters::Thread; 32 use crate::thread_helper::spawn_virtio_thread; 33 use crate::{DmaRemapping, GuestMemoryMmap, VirtioInterrupt, VirtioInterruptType}; 34 35 /// Queues sizes 36 const QUEUE_SIZE: u16 = 256; 37 const NUM_QUEUES: usize = 2; 38 const QUEUE_SIZES: &[u16] = &[QUEUE_SIZE; NUM_QUEUES]; 39 40 /// New descriptors are pending on the request queue. 41 /// "requestq" is meant to be used anytime an action is required to be 42 /// performed on behalf of the guest driver. 43 const REQUEST_Q_EVENT: u16 = EPOLL_HELPER_EVENT_LAST + 1; 44 /// New descriptors are pending on the event queue. 45 /// "eventq" lets the device report any fault or other asynchronous event to 46 /// the guest driver. 47 const _EVENT_Q_EVENT: u16 = EPOLL_HELPER_EVENT_LAST + 2; 48 49 /// PROBE properties size. 50 /// This is the minimal size to provide at least one RESV_MEM property. 51 /// Because virtio-iommu expects one MSI reserved region, we must provide it, 52 /// otherwise the driver in the guest will define a predefined one between 53 /// 0x8000000 and 0x80FFFFF, which is only relevant for ARM architecture, but 54 /// will conflict with x86. 55 const PROBE_PROP_SIZE: u32 = 56 (size_of::<VirtioIommuProbeProperty>() + size_of::<VirtioIommuProbeResvMem>()) as u32; 57 58 /// Virtio IOMMU features 59 #[allow(unused)] 60 const VIRTIO_IOMMU_F_INPUT_RANGE: u32 = 0; 61 #[allow(unused)] 62 const VIRTIO_IOMMU_F_DOMAIN_RANGE: u32 = 1; 63 #[allow(unused)] 64 const VIRTIO_IOMMU_F_MAP_UNMAP: u32 = 2; 65 #[allow(unused)] 66 const VIRTIO_IOMMU_F_BYPASS: u32 = 3; 67 const VIRTIO_IOMMU_F_PROBE: u32 = 4; 68 #[allow(unused)] 69 const VIRTIO_IOMMU_F_MMIO: u32 = 5; 70 const VIRTIO_IOMMU_F_BYPASS_CONFIG: u32 = 6; 71 72 // Support 2MiB and 4KiB page sizes. 73 const VIRTIO_IOMMU_PAGE_SIZE_MASK: u64 = (2 << 20) | (4 << 10); 74 75 #[derive(Copy, Clone, Debug, Default)] 76 #[repr(C, packed)] 77 #[allow(dead_code)] 78 struct VirtioIommuRange32 { 79 start: u32, 80 end: u32, 81 } 82 83 #[derive(Copy, Clone, Debug, Default)] 84 #[repr(C, packed)] 85 #[allow(dead_code)] 86 struct VirtioIommuRange64 { 87 start: u64, 88 end: u64, 89 } 90 91 #[derive(Copy, Clone, Debug, Default)] 92 #[repr(C, packed)] 93 #[allow(dead_code)] 94 struct VirtioIommuConfig { 95 page_size_mask: u64, 96 input_range: VirtioIommuRange64, 97 domain_range: VirtioIommuRange32, 98 probe_size: u32, 99 bypass: u8, 100 _reserved: [u8; 7], 101 } 102 103 /// Virtio IOMMU request type 104 const VIRTIO_IOMMU_T_ATTACH: u8 = 1; 105 const VIRTIO_IOMMU_T_DETACH: u8 = 2; 106 const VIRTIO_IOMMU_T_MAP: u8 = 3; 107 const VIRTIO_IOMMU_T_UNMAP: u8 = 4; 108 const VIRTIO_IOMMU_T_PROBE: u8 = 5; 109 110 #[derive(Copy, Clone, Debug, Default)] 111 #[repr(C, packed)] 112 struct VirtioIommuReqHead { 113 type_: u8, 114 _reserved: [u8; 3], 115 } 116 117 /// Virtio IOMMU request status 118 const VIRTIO_IOMMU_S_OK: u8 = 0; 119 #[allow(unused)] 120 const VIRTIO_IOMMU_S_IOERR: u8 = 1; 121 #[allow(unused)] 122 const VIRTIO_IOMMU_S_UNSUPP: u8 = 2; 123 #[allow(unused)] 124 const VIRTIO_IOMMU_S_DEVERR: u8 = 3; 125 #[allow(unused)] 126 const VIRTIO_IOMMU_S_INVAL: u8 = 4; 127 #[allow(unused)] 128 const VIRTIO_IOMMU_S_RANGE: u8 = 5; 129 #[allow(unused)] 130 const VIRTIO_IOMMU_S_NOENT: u8 = 6; 131 #[allow(unused)] 132 const VIRTIO_IOMMU_S_FAULT: u8 = 7; 133 #[allow(unused)] 134 const VIRTIO_IOMMU_S_NOMEM: u8 = 8; 135 136 #[derive(Copy, Clone, Debug, Default)] 137 #[repr(C, packed)] 138 #[allow(dead_code)] 139 struct VirtioIommuReqTail { 140 status: u8, 141 _reserved: [u8; 3], 142 } 143 144 /// ATTACH request 145 #[derive(Copy, Clone, Debug, Default)] 146 #[repr(C, packed)] 147 struct VirtioIommuReqAttach { 148 domain: u32, 149 endpoint: u32, 150 flags: u32, 151 _reserved: [u8; 4], 152 } 153 154 const VIRTIO_IOMMU_ATTACH_F_BYPASS: u32 = 1; 155 156 /// DETACH request 157 #[derive(Copy, Clone, Debug, Default)] 158 #[repr(C, packed)] 159 struct VirtioIommuReqDetach { 160 domain: u32, 161 endpoint: u32, 162 _reserved: [u8; 8], 163 } 164 165 /// Virtio IOMMU request MAP flags 166 #[allow(unused)] 167 const VIRTIO_IOMMU_MAP_F_READ: u32 = 1; 168 #[allow(unused)] 169 const VIRTIO_IOMMU_MAP_F_WRITE: u32 = 1 << 1; 170 #[allow(unused)] 171 const VIRTIO_IOMMU_MAP_F_MMIO: u32 = 1 << 2; 172 #[allow(unused)] 173 const VIRTIO_IOMMU_MAP_F_MASK: u32 = 174 VIRTIO_IOMMU_MAP_F_READ | VIRTIO_IOMMU_MAP_F_WRITE | VIRTIO_IOMMU_MAP_F_MMIO; 175 176 /// MAP request 177 #[derive(Copy, Clone, Debug, Default)] 178 #[repr(C, packed)] 179 struct VirtioIommuReqMap { 180 domain: u32, 181 virt_start: u64, 182 virt_end: u64, 183 phys_start: u64, 184 _flags: u32, 185 } 186 187 /// UNMAP request 188 #[derive(Copy, Clone, Debug, Default)] 189 #[repr(C, packed)] 190 struct VirtioIommuReqUnmap { 191 domain: u32, 192 virt_start: u64, 193 virt_end: u64, 194 _reserved: [u8; 4], 195 } 196 197 /// Virtio IOMMU request PROBE types 198 #[allow(unused)] 199 const VIRTIO_IOMMU_PROBE_T_NONE: u16 = 0; 200 const VIRTIO_IOMMU_PROBE_T_RESV_MEM: u16 = 1; 201 #[allow(unused)] 202 const VIRTIO_IOMMU_PROBE_T_MASK: u16 = 0xfff; 203 204 /// PROBE request 205 #[derive(Copy, Clone, Debug, Default)] 206 #[repr(C, packed)] 207 #[allow(dead_code)] 208 struct VirtioIommuReqProbe { 209 endpoint: u32, 210 _reserved: [u64; 8], 211 } 212 213 #[derive(Copy, Clone, Debug, Default)] 214 #[repr(C, packed)] 215 #[allow(dead_code)] 216 struct VirtioIommuProbeProperty { 217 type_: u16, 218 length: u16, 219 } 220 221 /// Virtio IOMMU request PROBE property RESV_MEM subtypes 222 #[allow(unused)] 223 const VIRTIO_IOMMU_RESV_MEM_T_RESERVED: u8 = 0; 224 const VIRTIO_IOMMU_RESV_MEM_T_MSI: u8 = 1; 225 226 #[derive(Copy, Clone, Debug, Default)] 227 #[repr(C, packed)] 228 #[allow(dead_code)] 229 struct VirtioIommuProbeResvMem { 230 subtype: u8, 231 _reserved: [u8; 3], 232 start: u64, 233 end: u64, 234 } 235 236 /// Virtio IOMMU fault flags 237 #[allow(unused)] 238 const VIRTIO_IOMMU_FAULT_F_READ: u32 = 1; 239 #[allow(unused)] 240 const VIRTIO_IOMMU_FAULT_F_WRITE: u32 = 1 << 1; 241 #[allow(unused)] 242 const VIRTIO_IOMMU_FAULT_F_EXEC: u32 = 1 << 2; 243 #[allow(unused)] 244 const VIRTIO_IOMMU_FAULT_F_ADDRESS: u32 = 1 << 8; 245 246 /// Virtio IOMMU fault reasons 247 #[allow(unused)] 248 const VIRTIO_IOMMU_FAULT_R_UNKNOWN: u32 = 0; 249 #[allow(unused)] 250 const VIRTIO_IOMMU_FAULT_R_DOMAIN: u32 = 1; 251 #[allow(unused)] 252 const VIRTIO_IOMMU_FAULT_R_MAPPING: u32 = 2; 253 254 /// Fault reporting through eventq 255 #[allow(unused)] 256 #[derive(Copy, Clone, Debug, Default)] 257 #[repr(C, packed)] 258 struct VirtioIommuFault { 259 reason: u8, 260 reserved: [u8; 3], 261 flags: u32, 262 endpoint: u32, 263 reserved2: [u8; 4], 264 address: u64, 265 } 266 267 // SAFETY: data structure only contain integers and have no implicit padding 268 unsafe impl ByteValued for VirtioIommuRange32 {} 269 // SAFETY: data structure only contain integers and have no implicit padding 270 unsafe impl ByteValued for VirtioIommuRange64 {} 271 // SAFETY: data structure only contain integers and have no implicit padding 272 unsafe impl ByteValued for VirtioIommuConfig {} 273 // SAFETY: data structure only contain integers and have no implicit padding 274 unsafe impl ByteValued for VirtioIommuReqHead {} 275 // SAFETY: data structure only contain integers and have no implicit padding 276 unsafe impl ByteValued for VirtioIommuReqTail {} 277 // SAFETY: data structure only contain integers and have no implicit padding 278 unsafe impl ByteValued for VirtioIommuReqAttach {} 279 // SAFETY: data structure only contain integers and have no implicit padding 280 unsafe impl ByteValued for VirtioIommuReqDetach {} 281 // SAFETY: data structure only contain integers and have no implicit padding 282 unsafe impl ByteValued for VirtioIommuReqMap {} 283 // SAFETY: data structure only contain integers and have no implicit padding 284 unsafe impl ByteValued for VirtioIommuReqUnmap {} 285 // SAFETY: data structure only contain integers and have no implicit padding 286 unsafe impl ByteValued for VirtioIommuReqProbe {} 287 // SAFETY: data structure only contain integers and have no implicit padding 288 unsafe impl ByteValued for VirtioIommuProbeProperty {} 289 // SAFETY: data structure only contain integers and have no implicit padding 290 unsafe impl ByteValued for VirtioIommuProbeResvMem {} 291 // SAFETY: data structure only contain integers and have no implicit padding 292 unsafe impl ByteValued for VirtioIommuFault {} 293 294 #[derive(Error, Debug)] 295 enum Error { 296 #[error("Guest gave us bad memory addresses: {0}")] 297 GuestMemory(GuestMemoryError), 298 #[error("Guest gave us a write only descriptor that protocol says to read from")] 299 UnexpectedWriteOnlyDescriptor, 300 #[error("Guest gave us a read only descriptor that protocol says to write to")] 301 UnexpectedReadOnlyDescriptor, 302 #[error("Guest gave us too few descriptors in a descriptor chain")] 303 DescriptorChainTooShort, 304 #[error("Guest gave us a buffer that was too short to use")] 305 BufferLengthTooSmall, 306 #[error("Guest sent us invalid request")] 307 InvalidRequest, 308 #[error("Guest sent us invalid ATTACH request")] 309 InvalidAttachRequest, 310 #[error("Guest sent us invalid DETACH request")] 311 InvalidDetachRequest, 312 #[error("Guest sent us invalid MAP request")] 313 InvalidMapRequest, 314 #[error("Invalid to map because the domain is in bypass mode")] 315 InvalidMapRequestBypassDomain, 316 #[error("Invalid to map because the domain is missing")] 317 InvalidMapRequestMissingDomain, 318 #[error("Guest sent us invalid UNMAP request")] 319 InvalidUnmapRequest, 320 #[error("Invalid to unmap because the domain is in bypass mode")] 321 InvalidUnmapRequestBypassDomain, 322 #[error("Invalid to unmap because the domain is missing")] 323 InvalidUnmapRequestMissingDomain, 324 #[error("Guest sent us invalid PROBE request")] 325 InvalidProbeRequest, 326 #[error("Failed to performing external mapping: {0}")] 327 ExternalMapping(io::Error), 328 #[error("Failed to performing external unmapping: {0}")] 329 ExternalUnmapping(io::Error), 330 #[error("Failed adding used index: {0}")] 331 QueueAddUsed(virtio_queue::Error), 332 } 333 334 struct Request {} 335 336 impl Request { 337 // Parse the available vring buffer. Based on the hashmap table of external 338 // mappings required from various devices such as VFIO or vhost-user ones, 339 // this function might update the hashmap table of external mappings per 340 // domain. 341 // Basically, the VMM knows about the device_id <=> mapping relationship 342 // before running the VM, but at runtime, a new domain <=> mapping hashmap 343 // is created based on the information provided from the guest driver for 344 // virtio-iommu (giving the link device_id <=> domain). 345 fn parse( 346 desc_chain: &mut DescriptorChain<GuestMemoryLoadGuard<GuestMemoryMmap>>, 347 mapping: &Arc<IommuMapping>, 348 ext_mapping: &BTreeMap<u32, Arc<dyn ExternalDmaMapping>>, 349 msi_iova_space: (u64, u64), 350 ) -> result::Result<usize, Error> { 351 let desc = desc_chain 352 .next() 353 .ok_or(Error::DescriptorChainTooShort) 354 .inspect_err(|_| { 355 error!("Missing head descriptor"); 356 })?; 357 358 // The descriptor contains the request type which MUST be readable. 359 if desc.is_write_only() { 360 return Err(Error::UnexpectedWriteOnlyDescriptor); 361 } 362 363 if (desc.len() as usize) < size_of::<VirtioIommuReqHead>() { 364 return Err(Error::InvalidRequest); 365 } 366 367 let req_head: VirtioIommuReqHead = desc_chain 368 .memory() 369 .read_obj(desc.addr()) 370 .map_err(Error::GuestMemory)?; 371 let req_offset = size_of::<VirtioIommuReqHead>(); 372 let desc_size_left = (desc.len() as usize) - req_offset; 373 let req_addr = if let Some(addr) = desc.addr().checked_add(req_offset as u64) { 374 addr 375 } else { 376 return Err(Error::InvalidRequest); 377 }; 378 379 let (msi_iova_start, msi_iova_end) = msi_iova_space; 380 381 // Create the reply 382 let mut reply: Vec<u8> = Vec::new(); 383 let mut status = VIRTIO_IOMMU_S_OK; 384 let mut hdr_len = 0; 385 386 let result = (|| { 387 match req_head.type_ { 388 VIRTIO_IOMMU_T_ATTACH => { 389 if desc_size_left != size_of::<VirtioIommuReqAttach>() { 390 status = VIRTIO_IOMMU_S_INVAL; 391 return Err(Error::InvalidAttachRequest); 392 } 393 394 let req: VirtioIommuReqAttach = desc_chain 395 .memory() 396 .read_obj(req_addr as GuestAddress) 397 .map_err(Error::GuestMemory)?; 398 debug!("Attach request {:?}", req); 399 400 // Copy the value to use it as a proper reference. 401 let domain_id = req.domain; 402 let endpoint = req.endpoint; 403 let bypass = 404 (req.flags & VIRTIO_IOMMU_ATTACH_F_BYPASS) == VIRTIO_IOMMU_ATTACH_F_BYPASS; 405 406 let mut old_domain_id = domain_id; 407 if let Some(&id) = mapping.endpoints.read().unwrap().get(&endpoint) { 408 old_domain_id = id; 409 } 410 411 if old_domain_id != domain_id { 412 detach_endpoint_from_domain(endpoint, old_domain_id, mapping, ext_mapping)?; 413 } 414 415 // Add endpoint associated with specific domain 416 mapping 417 .endpoints 418 .write() 419 .unwrap() 420 .insert(endpoint, domain_id); 421 422 // If any other mappings exist in the domain for other containers, 423 // make sure to issue these mappings for the new endpoint/container 424 if let Some(domain_mappings) = &mapping.domains.read().unwrap().get(&domain_id) 425 { 426 if let Some(ext_map) = ext_mapping.get(&endpoint) { 427 for (virt_start, addr_map) in &domain_mappings.mappings { 428 ext_map 429 .map(*virt_start, addr_map.gpa, addr_map.size) 430 .map_err(Error::ExternalUnmapping)?; 431 } 432 } 433 } 434 435 // Add new domain with no mapping if the entry didn't exist yet 436 let mut domains = mapping.domains.write().unwrap(); 437 let domain = Domain { 438 mappings: BTreeMap::new(), 439 bypass, 440 }; 441 domains.entry(domain_id).or_insert_with(|| domain); 442 } 443 VIRTIO_IOMMU_T_DETACH => { 444 if desc_size_left != size_of::<VirtioIommuReqDetach>() { 445 status = VIRTIO_IOMMU_S_INVAL; 446 return Err(Error::InvalidDetachRequest); 447 } 448 449 let req: VirtioIommuReqDetach = desc_chain 450 .memory() 451 .read_obj(req_addr as GuestAddress) 452 .map_err(Error::GuestMemory)?; 453 debug!("Detach request {:?}", req); 454 455 // Copy the value to use it as a proper reference. 456 let domain_id = req.domain; 457 let endpoint = req.endpoint; 458 459 // Remove endpoint associated with specific domain 460 detach_endpoint_from_domain(endpoint, domain_id, mapping, ext_mapping)?; 461 } 462 VIRTIO_IOMMU_T_MAP => { 463 if desc_size_left != size_of::<VirtioIommuReqMap>() { 464 status = VIRTIO_IOMMU_S_INVAL; 465 return Err(Error::InvalidMapRequest); 466 } 467 468 let req: VirtioIommuReqMap = desc_chain 469 .memory() 470 .read_obj(req_addr as GuestAddress) 471 .map_err(Error::GuestMemory)?; 472 debug!("Map request {:?}", req); 473 474 // Copy the value to use it as a proper reference. 475 let domain_id = req.domain; 476 477 if let Some(domain) = mapping.domains.read().unwrap().get(&domain_id) { 478 if domain.bypass { 479 status = VIRTIO_IOMMU_S_INVAL; 480 return Err(Error::InvalidMapRequestBypassDomain); 481 } 482 } else { 483 status = VIRTIO_IOMMU_S_INVAL; 484 return Err(Error::InvalidMapRequestMissingDomain); 485 } 486 487 // Find the list of endpoints attached to the given domain. 488 let endpoints: Vec<u32> = mapping 489 .endpoints 490 .write() 491 .unwrap() 492 .iter() 493 .filter(|(_, &d)| d == domain_id) 494 .map(|(&e, _)| e) 495 .collect(); 496 497 // For viommu all endpoints receive their own VFIO container, as a result 498 // Each endpoint within the domain needs to be separately mapped, as the 499 // mapping is done on a per-container level, not a per-domain level 500 for endpoint in endpoints { 501 if let Some(ext_map) = ext_mapping.get(&endpoint) { 502 let size = req.virt_end - req.virt_start + 1; 503 ext_map 504 .map(req.virt_start, req.phys_start, size) 505 .map_err(Error::ExternalMapping)?; 506 } 507 } 508 509 // Add new mapping associated with the domain 510 mapping 511 .domains 512 .write() 513 .unwrap() 514 .get_mut(&domain_id) 515 .unwrap() 516 .mappings 517 .insert( 518 req.virt_start, 519 Mapping { 520 gpa: req.phys_start, 521 size: req.virt_end - req.virt_start + 1, 522 }, 523 ); 524 } 525 VIRTIO_IOMMU_T_UNMAP => { 526 if desc_size_left != size_of::<VirtioIommuReqUnmap>() { 527 status = VIRTIO_IOMMU_S_INVAL; 528 return Err(Error::InvalidUnmapRequest); 529 } 530 531 let req: VirtioIommuReqUnmap = desc_chain 532 .memory() 533 .read_obj(req_addr as GuestAddress) 534 .map_err(Error::GuestMemory)?; 535 debug!("Unmap request {:?}", req); 536 537 // Copy the value to use it as a proper reference. 538 let domain_id = req.domain; 539 let virt_start = req.virt_start; 540 541 if let Some(domain) = mapping.domains.read().unwrap().get(&domain_id) { 542 if domain.bypass { 543 status = VIRTIO_IOMMU_S_INVAL; 544 return Err(Error::InvalidUnmapRequestBypassDomain); 545 } 546 } else { 547 status = VIRTIO_IOMMU_S_INVAL; 548 return Err(Error::InvalidUnmapRequestMissingDomain); 549 } 550 551 // Find the list of endpoints attached to the given domain. 552 let endpoints: Vec<u32> = mapping 553 .endpoints 554 .write() 555 .unwrap() 556 .iter() 557 .filter(|(_, &d)| d == domain_id) 558 .map(|(&e, _)| e) 559 .collect(); 560 561 // Trigger external unmapping if necessary. 562 for endpoint in endpoints { 563 if let Some(ext_map) = ext_mapping.get(&endpoint) { 564 let size = req.virt_end - virt_start + 1; 565 ext_map 566 .unmap(virt_start, size) 567 .map_err(Error::ExternalUnmapping)?; 568 } 569 } 570 571 // Remove all mappings associated with the domain within the requested range 572 mapping 573 .domains 574 .write() 575 .unwrap() 576 .get_mut(&domain_id) 577 .unwrap() 578 .mappings 579 .retain(|&x, _| (x < req.virt_start || x > req.virt_end)); 580 } 581 VIRTIO_IOMMU_T_PROBE => { 582 if desc_size_left != size_of::<VirtioIommuReqProbe>() { 583 status = VIRTIO_IOMMU_S_INVAL; 584 return Err(Error::InvalidProbeRequest); 585 } 586 587 let req: VirtioIommuReqProbe = desc_chain 588 .memory() 589 .read_obj(req_addr as GuestAddress) 590 .map_err(Error::GuestMemory)?; 591 debug!("Probe request {:?}", req); 592 593 let probe_prop = VirtioIommuProbeProperty { 594 type_: VIRTIO_IOMMU_PROBE_T_RESV_MEM, 595 length: size_of::<VirtioIommuProbeResvMem>() as u16, 596 }; 597 reply.extend_from_slice(probe_prop.as_slice()); 598 599 let resv_mem = VirtioIommuProbeResvMem { 600 subtype: VIRTIO_IOMMU_RESV_MEM_T_MSI, 601 start: msi_iova_start, 602 end: msi_iova_end, 603 ..Default::default() 604 }; 605 reply.extend_from_slice(resv_mem.as_slice()); 606 607 hdr_len = PROBE_PROP_SIZE; 608 } 609 _ => { 610 status = VIRTIO_IOMMU_S_INVAL; 611 return Err(Error::InvalidRequest); 612 } 613 } 614 Ok(()) 615 })(); 616 617 let status_desc = desc_chain.next().ok_or(Error::DescriptorChainTooShort)?; 618 619 // The status MUST always be writable 620 if !status_desc.is_write_only() { 621 return Err(Error::UnexpectedReadOnlyDescriptor); 622 } 623 624 if status_desc.len() < hdr_len + size_of::<VirtioIommuReqTail>() as u32 { 625 return Err(Error::BufferLengthTooSmall); 626 } 627 628 let tail = VirtioIommuReqTail { 629 status, 630 ..Default::default() 631 }; 632 reply.extend_from_slice(tail.as_slice()); 633 634 // Make sure we return the result of the request to the guest before 635 // we return a potential error internally. 636 desc_chain 637 .memory() 638 .write_slice(reply.as_slice(), status_desc.addr()) 639 .map_err(Error::GuestMemory)?; 640 641 // Return the error if the result was not Ok(). 642 result?; 643 644 Ok((hdr_len as usize) + size_of::<VirtioIommuReqTail>()) 645 } 646 } 647 648 fn detach_endpoint_from_domain( 649 endpoint: u32, 650 domain_id: u32, 651 mapping: &Arc<IommuMapping>, 652 ext_mapping: &BTreeMap<u32, Arc<dyn ExternalDmaMapping>>, 653 ) -> result::Result<(), Error> { 654 // Remove endpoint associated with specific domain 655 mapping.endpoints.write().unwrap().remove(&endpoint); 656 657 // Trigger external unmapping for the endpoint if necessary. 658 if let Some(domain_mappings) = &mapping.domains.read().unwrap().get(&domain_id) { 659 if let Some(ext_map) = ext_mapping.get(&endpoint) { 660 for (virt_start, addr_map) in &domain_mappings.mappings { 661 ext_map 662 .unmap(*virt_start, addr_map.size) 663 .map_err(Error::ExternalUnmapping)?; 664 } 665 } 666 } 667 668 if mapping 669 .endpoints 670 .write() 671 .unwrap() 672 .iter() 673 .filter(|(_, &d)| d == domain_id) 674 .count() 675 == 0 676 { 677 mapping.domains.write().unwrap().remove(&domain_id); 678 } 679 680 Ok(()) 681 } 682 683 struct IommuEpollHandler { 684 mem: GuestMemoryAtomic<GuestMemoryMmap>, 685 request_queue: Queue, 686 _event_queue: Queue, 687 interrupt_cb: Arc<dyn VirtioInterrupt>, 688 request_queue_evt: EventFd, 689 _event_queue_evt: EventFd, 690 kill_evt: EventFd, 691 pause_evt: EventFd, 692 mapping: Arc<IommuMapping>, 693 ext_mapping: Arc<Mutex<BTreeMap<u32, Arc<dyn ExternalDmaMapping>>>>, 694 msi_iova_space: (u64, u64), 695 } 696 697 impl IommuEpollHandler { 698 fn request_queue(&mut self) -> Result<bool, Error> { 699 let mut used_descs = false; 700 while let Some(mut desc_chain) = self.request_queue.pop_descriptor_chain(self.mem.memory()) 701 { 702 let len = Request::parse( 703 &mut desc_chain, 704 &self.mapping, 705 &self.ext_mapping.lock().unwrap(), 706 self.msi_iova_space, 707 )?; 708 709 self.request_queue 710 .add_used(desc_chain.memory(), desc_chain.head_index(), len as u32) 711 .map_err(Error::QueueAddUsed)?; 712 713 used_descs = true; 714 } 715 716 Ok(used_descs) 717 } 718 719 fn signal_used_queue(&self, queue_index: u16) -> result::Result<(), DeviceError> { 720 self.interrupt_cb 721 .trigger(VirtioInterruptType::Queue(queue_index)) 722 .map_err(|e| { 723 error!("Failed to signal used queue: {:?}", e); 724 DeviceError::FailedSignalingUsedQueue(e) 725 }) 726 } 727 728 fn run( 729 &mut self, 730 paused: Arc<AtomicBool>, 731 paused_sync: Arc<Barrier>, 732 ) -> result::Result<(), EpollHelperError> { 733 let mut helper = EpollHelper::new(&self.kill_evt, &self.pause_evt)?; 734 helper.add_event(self.request_queue_evt.as_raw_fd(), REQUEST_Q_EVENT)?; 735 helper.run(paused, paused_sync, self)?; 736 737 Ok(()) 738 } 739 } 740 741 impl EpollHelperHandler for IommuEpollHandler { 742 fn handle_event( 743 &mut self, 744 _helper: &mut EpollHelper, 745 event: &epoll::Event, 746 ) -> result::Result<(), EpollHelperError> { 747 let ev_type = event.data as u16; 748 match ev_type { 749 REQUEST_Q_EVENT => { 750 self.request_queue_evt.read().map_err(|e| { 751 EpollHelperError::HandleEvent(anyhow!("Failed to get queue event: {:?}", e)) 752 })?; 753 754 let needs_notification = self.request_queue().map_err(|e| { 755 EpollHelperError::HandleEvent(anyhow!( 756 "Failed to process request queue : {:?}", 757 e 758 )) 759 })?; 760 if needs_notification { 761 self.signal_used_queue(0).map_err(|e| { 762 EpollHelperError::HandleEvent(anyhow!( 763 "Failed to signal used queue: {:?}", 764 e 765 )) 766 })?; 767 } 768 } 769 _ => { 770 return Err(EpollHelperError::HandleEvent(anyhow!( 771 "Unexpected event: {}", 772 ev_type 773 ))); 774 } 775 } 776 Ok(()) 777 } 778 } 779 780 #[derive(Clone, Copy, Debug, Serialize, Deserialize)] 781 struct Mapping { 782 gpa: u64, 783 size: u64, 784 } 785 786 #[derive(Clone, Debug)] 787 struct Domain { 788 mappings: BTreeMap<u64, Mapping>, 789 bypass: bool, 790 } 791 792 #[derive(Debug)] 793 pub struct IommuMapping { 794 // Domain related to an endpoint. 795 endpoints: Arc<RwLock<BTreeMap<u32, u32>>>, 796 // Information related to each domain. 797 domains: Arc<RwLock<BTreeMap<u32, Domain>>>, 798 // Global flag indicating if endpoints that are not attached to any domain 799 // are in bypass mode. 800 bypass: AtomicBool, 801 } 802 803 impl DmaRemapping for IommuMapping { 804 fn translate_gva(&self, id: u32, addr: u64) -> std::result::Result<u64, std::io::Error> { 805 debug!("Translate GVA addr 0x{:x}", addr); 806 if let Some(domain_id) = self.endpoints.read().unwrap().get(&id) { 807 if let Some(domain) = self.domains.read().unwrap().get(domain_id) { 808 // Directly return identity mapping in case the domain is in 809 // bypass mode. 810 if domain.bypass { 811 return Ok(addr); 812 } 813 814 let range_start = if VIRTIO_IOMMU_PAGE_SIZE_MASK > addr { 815 0 816 } else { 817 addr - VIRTIO_IOMMU_PAGE_SIZE_MASK 818 }; 819 for (&key, &value) in domain 820 .mappings 821 .range((Included(&range_start), Included(&addr))) 822 { 823 if addr >= key && addr < key + value.size { 824 let new_addr = addr - key + value.gpa; 825 debug!("Into GPA addr 0x{:x}", new_addr); 826 return Ok(new_addr); 827 } 828 } 829 } 830 } else if self.bypass.load(Ordering::Acquire) { 831 return Ok(addr); 832 } 833 834 Err(io::Error::new( 835 io::ErrorKind::Other, 836 format!("failed to translate GVA addr 0x{addr:x}"), 837 )) 838 } 839 840 fn translate_gpa(&self, id: u32, addr: u64) -> std::result::Result<u64, std::io::Error> { 841 debug!("Translate GPA addr 0x{:x}", addr); 842 if let Some(domain_id) = self.endpoints.read().unwrap().get(&id) { 843 if let Some(domain) = self.domains.read().unwrap().get(domain_id) { 844 // Directly return identity mapping in case the domain is in 845 // bypass mode. 846 if domain.bypass { 847 return Ok(addr); 848 } 849 850 for (&key, &value) in domain.mappings.iter() { 851 if addr >= value.gpa && addr < value.gpa + value.size { 852 let new_addr = addr - value.gpa + key; 853 debug!("Into GVA addr 0x{:x}", new_addr); 854 return Ok(new_addr); 855 } 856 } 857 } 858 } else if self.bypass.load(Ordering::Acquire) { 859 return Ok(addr); 860 } 861 862 Err(io::Error::new( 863 io::ErrorKind::Other, 864 format!("failed to translate GPA addr 0x{addr:x}"), 865 )) 866 } 867 } 868 869 #[derive(Debug)] 870 pub struct AccessPlatformMapping { 871 id: u32, 872 mapping: Arc<IommuMapping>, 873 } 874 875 impl AccessPlatformMapping { 876 pub fn new(id: u32, mapping: Arc<IommuMapping>) -> Self { 877 AccessPlatformMapping { id, mapping } 878 } 879 } 880 881 impl AccessPlatform for AccessPlatformMapping { 882 fn translate_gva(&self, base: u64, _size: u64) -> std::result::Result<u64, std::io::Error> { 883 self.mapping.translate_gva(self.id, base) 884 } 885 fn translate_gpa(&self, base: u64, _size: u64) -> std::result::Result<u64, std::io::Error> { 886 self.mapping.translate_gpa(self.id, base) 887 } 888 } 889 890 pub struct Iommu { 891 common: VirtioCommon, 892 id: String, 893 config: VirtioIommuConfig, 894 mapping: Arc<IommuMapping>, 895 ext_mapping: Arc<Mutex<BTreeMap<u32, Arc<dyn ExternalDmaMapping>>>>, 896 seccomp_action: SeccompAction, 897 exit_evt: EventFd, 898 msi_iova_space: (u64, u64), 899 } 900 901 type EndpointsState = Vec<(u32, u32)>; 902 type DomainsState = Vec<(u32, (Vec<(u64, Mapping)>, bool))>; 903 904 #[derive(Serialize, Deserialize)] 905 pub struct IommuState { 906 avail_features: u64, 907 acked_features: u64, 908 endpoints: EndpointsState, 909 domains: DomainsState, 910 } 911 912 impl Iommu { 913 pub fn new( 914 id: String, 915 seccomp_action: SeccompAction, 916 exit_evt: EventFd, 917 msi_iova_space: (u64, u64), 918 state: Option<IommuState>, 919 ) -> io::Result<(Self, Arc<IommuMapping>)> { 920 let (avail_features, acked_features, endpoints, domains, paused) = 921 if let Some(state) = state { 922 info!("Restoring virtio-iommu {}", id); 923 ( 924 state.avail_features, 925 state.acked_features, 926 state.endpoints.into_iter().collect(), 927 state 928 .domains 929 .into_iter() 930 .map(|(k, v)| { 931 ( 932 k, 933 Domain { 934 mappings: v.0.into_iter().collect(), 935 bypass: v.1, 936 }, 937 ) 938 }) 939 .collect(), 940 true, 941 ) 942 } else { 943 let avail_features = (1u64 << VIRTIO_F_VERSION_1) 944 | (1u64 << VIRTIO_IOMMU_F_MAP_UNMAP) 945 | (1u64 << VIRTIO_IOMMU_F_PROBE) 946 | (1u64 << VIRTIO_IOMMU_F_BYPASS_CONFIG); 947 948 (avail_features, 0, BTreeMap::new(), BTreeMap::new(), false) 949 }; 950 951 let config = VirtioIommuConfig { 952 page_size_mask: VIRTIO_IOMMU_PAGE_SIZE_MASK, 953 probe_size: PROBE_PROP_SIZE, 954 ..Default::default() 955 }; 956 957 let mapping = Arc::new(IommuMapping { 958 endpoints: Arc::new(RwLock::new(endpoints)), 959 domains: Arc::new(RwLock::new(domains)), 960 bypass: AtomicBool::new(true), 961 }); 962 963 Ok(( 964 Iommu { 965 id, 966 common: VirtioCommon { 967 device_type: VirtioDeviceType::Iommu as u32, 968 queue_sizes: QUEUE_SIZES.to_vec(), 969 avail_features, 970 acked_features, 971 paused_sync: Some(Arc::new(Barrier::new(2))), 972 min_queues: NUM_QUEUES as u16, 973 paused: Arc::new(AtomicBool::new(paused)), 974 ..Default::default() 975 }, 976 config, 977 mapping: mapping.clone(), 978 ext_mapping: Arc::new(Mutex::new(BTreeMap::new())), 979 seccomp_action, 980 exit_evt, 981 msi_iova_space, 982 }, 983 mapping, 984 )) 985 } 986 987 fn state(&self) -> IommuState { 988 IommuState { 989 avail_features: self.common.avail_features, 990 acked_features: self.common.acked_features, 991 endpoints: self 992 .mapping 993 .endpoints 994 .read() 995 .unwrap() 996 .clone() 997 .into_iter() 998 .collect(), 999 domains: self 1000 .mapping 1001 .domains 1002 .read() 1003 .unwrap() 1004 .clone() 1005 .into_iter() 1006 .map(|(k, v)| (k, (v.mappings.into_iter().collect(), v.bypass))) 1007 .collect(), 1008 } 1009 } 1010 1011 fn update_bypass(&mut self) { 1012 // Use bypass from config if VIRTIO_IOMMU_F_BYPASS_CONFIG has been negotiated 1013 if !self 1014 .common 1015 .feature_acked(VIRTIO_IOMMU_F_BYPASS_CONFIG.into()) 1016 { 1017 return; 1018 } 1019 1020 let bypass = self.config.bypass == 1; 1021 info!("Updating bypass mode to {}", bypass); 1022 self.mapping.bypass.store(bypass, Ordering::Release); 1023 } 1024 1025 pub fn add_external_mapping(&mut self, device_id: u32, mapping: Arc<dyn ExternalDmaMapping>) { 1026 self.ext_mapping.lock().unwrap().insert(device_id, mapping); 1027 } 1028 1029 #[cfg(fuzzing)] 1030 pub fn wait_for_epoll_threads(&mut self) { 1031 self.common.wait_for_epoll_threads(); 1032 } 1033 } 1034 1035 impl Drop for Iommu { 1036 fn drop(&mut self) { 1037 if let Some(kill_evt) = self.common.kill_evt.take() { 1038 // Ignore the result because there is nothing we can do about it. 1039 let _ = kill_evt.write(1); 1040 } 1041 self.common.wait_for_epoll_threads(); 1042 } 1043 } 1044 1045 impl VirtioDevice for Iommu { 1046 fn device_type(&self) -> u32 { 1047 self.common.device_type 1048 } 1049 1050 fn queue_max_sizes(&self) -> &[u16] { 1051 &self.common.queue_sizes 1052 } 1053 1054 fn features(&self) -> u64 { 1055 self.common.avail_features 1056 } 1057 1058 fn ack_features(&mut self, value: u64) { 1059 self.common.ack_features(value) 1060 } 1061 1062 fn read_config(&self, offset: u64, data: &mut [u8]) { 1063 self.read_config_from_slice(self.config.as_slice(), offset, data); 1064 } 1065 1066 fn write_config(&mut self, offset: u64, data: &[u8]) { 1067 // The "bypass" field is the only mutable field 1068 let bypass_offset = 1069 (&self.config.bypass as *const _ as u64) - (&self.config as *const _ as u64); 1070 if offset != bypass_offset || data.len() != std::mem::size_of_val(&self.config.bypass) { 1071 error!( 1072 "Attempt to write to read-only field: offset {:x} length {}", 1073 offset, 1074 data.len() 1075 ); 1076 return; 1077 } 1078 1079 self.config.bypass = data[0]; 1080 1081 self.update_bypass(); 1082 } 1083 1084 fn activate( 1085 &mut self, 1086 mem: GuestMemoryAtomic<GuestMemoryMmap>, 1087 interrupt_cb: Arc<dyn VirtioInterrupt>, 1088 mut queues: Vec<(usize, Queue, EventFd)>, 1089 ) -> ActivateResult { 1090 self.common.activate(&queues, &interrupt_cb)?; 1091 let (kill_evt, pause_evt) = self.common.dup_eventfds(); 1092 1093 let (_, request_queue, request_queue_evt) = queues.remove(0); 1094 let (_, _event_queue, _event_queue_evt) = queues.remove(0); 1095 1096 let mut handler = IommuEpollHandler { 1097 mem, 1098 request_queue, 1099 _event_queue, 1100 interrupt_cb, 1101 request_queue_evt, 1102 _event_queue_evt, 1103 kill_evt, 1104 pause_evt, 1105 mapping: self.mapping.clone(), 1106 ext_mapping: self.ext_mapping.clone(), 1107 msi_iova_space: self.msi_iova_space, 1108 }; 1109 1110 let paused = self.common.paused.clone(); 1111 let paused_sync = self.common.paused_sync.clone(); 1112 let mut epoll_threads = Vec::new(); 1113 spawn_virtio_thread( 1114 &self.id, 1115 &self.seccomp_action, 1116 Thread::VirtioIommu, 1117 &mut epoll_threads, 1118 &self.exit_evt, 1119 move || handler.run(paused, paused_sync.unwrap()), 1120 )?; 1121 1122 self.common.epoll_threads = Some(epoll_threads); 1123 1124 event!("virtio-device", "activated", "id", &self.id); 1125 Ok(()) 1126 } 1127 1128 fn reset(&mut self) -> Option<Arc<dyn VirtioInterrupt>> { 1129 let result = self.common.reset(); 1130 event!("virtio-device", "reset", "id", &self.id); 1131 result 1132 } 1133 } 1134 1135 impl Pausable for Iommu { 1136 fn pause(&mut self) -> result::Result<(), MigratableError> { 1137 self.common.pause() 1138 } 1139 1140 fn resume(&mut self) -> result::Result<(), MigratableError> { 1141 self.common.resume() 1142 } 1143 } 1144 1145 impl Snapshottable for Iommu { 1146 fn id(&self) -> String { 1147 self.id.clone() 1148 } 1149 1150 fn snapshot(&mut self) -> std::result::Result<Snapshot, MigratableError> { 1151 Snapshot::new_from_state(&self.state()) 1152 } 1153 } 1154 impl Transportable for Iommu {} 1155 impl Migratable for Iommu {} 1156