1 // Copyright © 2019 Intel Corporation 2 // 3 // SPDX-License-Identifier: Apache-2.0 AND BSD-3-Clause 4 5 use super::Error as DeviceError; 6 use super::{ 7 ActivateResult, EpollHelper, EpollHelperError, EpollHelperHandler, VirtioCommon, VirtioDevice, 8 VirtioDeviceType, EPOLL_HELPER_EVENT_LAST, VIRTIO_F_VERSION_1, 9 }; 10 use crate::seccomp_filters::Thread; 11 use crate::thread_helper::spawn_virtio_thread; 12 use crate::GuestMemoryMmap; 13 use crate::{DmaRemapping, VirtioInterrupt, VirtioInterruptType}; 14 use anyhow::anyhow; 15 use seccompiler::SeccompAction; 16 use serde::{Deserialize, Serialize}; 17 use std::collections::BTreeMap; 18 use std::io; 19 use std::mem::size_of; 20 use std::ops::Bound::Included; 21 use std::os::unix::io::AsRawFd; 22 use std::result; 23 use std::sync::atomic::{AtomicBool, Ordering}; 24 use std::sync::{Arc, Barrier, Mutex, RwLock}; 25 use thiserror::Error; 26 use virtio_queue::{DescriptorChain, Queue, QueueT}; 27 use vm_device::dma_mapping::ExternalDmaMapping; 28 use vm_memory::{ 29 Address, ByteValued, Bytes, GuestAddress, GuestAddressSpace, GuestMemoryAtomic, 30 GuestMemoryError, GuestMemoryLoadGuard, 31 }; 32 use vm_migration::{Migratable, MigratableError, Pausable, Snapshot, Snapshottable, Transportable}; 33 use vm_virtio::AccessPlatform; 34 use vmm_sys_util::eventfd::EventFd; 35 36 /// Queues sizes 37 const QUEUE_SIZE: u16 = 256; 38 const NUM_QUEUES: usize = 2; 39 const QUEUE_SIZES: &[u16] = &[QUEUE_SIZE; NUM_QUEUES]; 40 41 /// New descriptors are pending on the request queue. 42 /// "requestq" is meant to be used anytime an action is required to be 43 /// performed on behalf of the guest driver. 44 const REQUEST_Q_EVENT: u16 = EPOLL_HELPER_EVENT_LAST + 1; 45 /// New descriptors are pending on the event queue. 46 /// "eventq" lets the device report any fault or other asynchronous event to 47 /// the guest driver. 48 const _EVENT_Q_EVENT: u16 = EPOLL_HELPER_EVENT_LAST + 2; 49 50 /// PROBE properties size. 51 /// This is the minimal size to provide at least one RESV_MEM property. 52 /// Because virtio-iommu expects one MSI reserved region, we must provide it, 53 /// otherwise the driver in the guest will define a predefined one between 54 /// 0x8000000 and 0x80FFFFF, which is only relevant for ARM architecture, but 55 /// will conflict with x86. 56 const PROBE_PROP_SIZE: u32 = 57 (size_of::<VirtioIommuProbeProperty>() + size_of::<VirtioIommuProbeResvMem>()) as u32; 58 59 /// Virtio IOMMU features 60 #[allow(unused)] 61 const VIRTIO_IOMMU_F_INPUT_RANGE: u32 = 0; 62 #[allow(unused)] 63 const VIRTIO_IOMMU_F_DOMAIN_RANGE: u32 = 1; 64 #[allow(unused)] 65 const VIRTIO_IOMMU_F_MAP_UNMAP: u32 = 2; 66 #[allow(unused)] 67 const VIRTIO_IOMMU_F_BYPASS: u32 = 3; 68 const VIRTIO_IOMMU_F_PROBE: u32 = 4; 69 #[allow(unused)] 70 const VIRTIO_IOMMU_F_MMIO: u32 = 5; 71 const VIRTIO_IOMMU_F_BYPASS_CONFIG: u32 = 6; 72 73 // Support 2MiB and 4KiB page sizes. 74 const VIRTIO_IOMMU_PAGE_SIZE_MASK: u64 = (2 << 20) | (4 << 10); 75 76 #[derive(Copy, Clone, Debug, Default)] 77 #[repr(packed)] 78 #[allow(dead_code)] 79 struct VirtioIommuRange32 { 80 start: u32, 81 end: u32, 82 } 83 84 #[derive(Copy, Clone, Debug, Default)] 85 #[repr(packed)] 86 #[allow(dead_code)] 87 struct VirtioIommuRange64 { 88 start: u64, 89 end: u64, 90 } 91 92 #[derive(Copy, Clone, Debug, Default)] 93 #[repr(packed)] 94 #[allow(dead_code)] 95 struct VirtioIommuConfig { 96 page_size_mask: u64, 97 input_range: VirtioIommuRange64, 98 domain_range: VirtioIommuRange32, 99 probe_size: u32, 100 bypass: u8, 101 _reserved: [u8; 7], 102 } 103 104 /// Virtio IOMMU request type 105 const VIRTIO_IOMMU_T_ATTACH: u8 = 1; 106 const VIRTIO_IOMMU_T_DETACH: u8 = 2; 107 const VIRTIO_IOMMU_T_MAP: u8 = 3; 108 const VIRTIO_IOMMU_T_UNMAP: u8 = 4; 109 const VIRTIO_IOMMU_T_PROBE: u8 = 5; 110 111 #[derive(Copy, Clone, Debug, Default)] 112 #[repr(packed)] 113 struct VirtioIommuReqHead { 114 type_: u8, 115 _reserved: [u8; 3], 116 } 117 118 /// Virtio IOMMU request status 119 const VIRTIO_IOMMU_S_OK: u8 = 0; 120 #[allow(unused)] 121 const VIRTIO_IOMMU_S_IOERR: u8 = 1; 122 #[allow(unused)] 123 const VIRTIO_IOMMU_S_UNSUPP: u8 = 2; 124 #[allow(unused)] 125 const VIRTIO_IOMMU_S_DEVERR: u8 = 3; 126 #[allow(unused)] 127 const VIRTIO_IOMMU_S_INVAL: u8 = 4; 128 #[allow(unused)] 129 const VIRTIO_IOMMU_S_RANGE: u8 = 5; 130 #[allow(unused)] 131 const VIRTIO_IOMMU_S_NOENT: u8 = 6; 132 #[allow(unused)] 133 const VIRTIO_IOMMU_S_FAULT: u8 = 7; 134 #[allow(unused)] 135 const VIRTIO_IOMMU_S_NOMEM: u8 = 8; 136 137 #[derive(Copy, Clone, Debug, Default)] 138 #[repr(packed)] 139 #[allow(dead_code)] 140 struct VirtioIommuReqTail { 141 status: u8, 142 _reserved: [u8; 3], 143 } 144 145 /// ATTACH request 146 #[derive(Copy, Clone, Debug, Default)] 147 #[repr(packed)] 148 struct VirtioIommuReqAttach { 149 domain: u32, 150 endpoint: u32, 151 flags: u32, 152 _reserved: [u8; 4], 153 } 154 155 const VIRTIO_IOMMU_ATTACH_F_BYPASS: u32 = 1; 156 157 /// DETACH request 158 #[derive(Copy, Clone, Debug, Default)] 159 #[repr(packed)] 160 struct VirtioIommuReqDetach { 161 domain: u32, 162 endpoint: u32, 163 _reserved: [u8; 8], 164 } 165 166 /// Virtio IOMMU request MAP flags 167 #[allow(unused)] 168 const VIRTIO_IOMMU_MAP_F_READ: u32 = 1; 169 #[allow(unused)] 170 const VIRTIO_IOMMU_MAP_F_WRITE: u32 = 1 << 1; 171 #[allow(unused)] 172 const VIRTIO_IOMMU_MAP_F_MMIO: u32 = 1 << 2; 173 #[allow(unused)] 174 const VIRTIO_IOMMU_MAP_F_MASK: u32 = 175 VIRTIO_IOMMU_MAP_F_READ | VIRTIO_IOMMU_MAP_F_WRITE | VIRTIO_IOMMU_MAP_F_MMIO; 176 177 /// MAP request 178 #[derive(Copy, Clone, Debug, Default)] 179 #[repr(packed)] 180 struct VirtioIommuReqMap { 181 domain: u32, 182 virt_start: u64, 183 virt_end: u64, 184 phys_start: u64, 185 _flags: u32, 186 } 187 188 /// UNMAP request 189 #[derive(Copy, Clone, Debug, Default)] 190 #[repr(packed)] 191 struct VirtioIommuReqUnmap { 192 domain: u32, 193 virt_start: u64, 194 virt_end: u64, 195 _reserved: [u8; 4], 196 } 197 198 /// Virtio IOMMU request PROBE types 199 #[allow(unused)] 200 const VIRTIO_IOMMU_PROBE_T_NONE: u16 = 0; 201 const VIRTIO_IOMMU_PROBE_T_RESV_MEM: u16 = 1; 202 #[allow(unused)] 203 const VIRTIO_IOMMU_PROBE_T_MASK: u16 = 0xfff; 204 205 /// PROBE request 206 #[derive(Copy, Clone, Debug, Default)] 207 #[repr(packed)] 208 #[allow(dead_code)] 209 struct VirtioIommuReqProbe { 210 endpoint: u32, 211 _reserved: [u64; 8], 212 } 213 214 #[derive(Copy, Clone, Debug, Default)] 215 #[repr(packed)] 216 #[allow(dead_code)] 217 struct VirtioIommuProbeProperty { 218 type_: u16, 219 length: u16, 220 } 221 222 /// Virtio IOMMU request PROBE property RESV_MEM subtypes 223 #[allow(unused)] 224 const VIRTIO_IOMMU_RESV_MEM_T_RESERVED: u8 = 0; 225 const VIRTIO_IOMMU_RESV_MEM_T_MSI: u8 = 1; 226 227 #[derive(Copy, Clone, Debug, Default)] 228 #[repr(packed)] 229 #[allow(dead_code)] 230 struct VirtioIommuProbeResvMem { 231 subtype: u8, 232 _reserved: [u8; 3], 233 start: u64, 234 end: u64, 235 } 236 237 /// Virtio IOMMU fault flags 238 #[allow(unused)] 239 const VIRTIO_IOMMU_FAULT_F_READ: u32 = 1; 240 #[allow(unused)] 241 const VIRTIO_IOMMU_FAULT_F_WRITE: u32 = 1 << 1; 242 #[allow(unused)] 243 const VIRTIO_IOMMU_FAULT_F_EXEC: u32 = 1 << 2; 244 #[allow(unused)] 245 const VIRTIO_IOMMU_FAULT_F_ADDRESS: u32 = 1 << 8; 246 247 /// Virtio IOMMU fault reasons 248 #[allow(unused)] 249 const VIRTIO_IOMMU_FAULT_R_UNKNOWN: u32 = 0; 250 #[allow(unused)] 251 const VIRTIO_IOMMU_FAULT_R_DOMAIN: u32 = 1; 252 #[allow(unused)] 253 const VIRTIO_IOMMU_FAULT_R_MAPPING: u32 = 2; 254 255 /// Fault reporting through eventq 256 #[allow(unused)] 257 #[derive(Copy, Clone, Debug, Default)] 258 #[repr(packed)] 259 struct VirtioIommuFault { 260 reason: u8, 261 reserved: [u8; 3], 262 flags: u32, 263 endpoint: u32, 264 reserved2: [u8; 4], 265 address: u64, 266 } 267 268 // SAFETY: data structure only contain integers and have no implicit padding 269 unsafe impl ByteValued for VirtioIommuRange32 {} 270 // SAFETY: data structure only contain integers and have no implicit padding 271 unsafe impl ByteValued for VirtioIommuRange64 {} 272 // SAFETY: data structure only contain integers and have no implicit padding 273 unsafe impl ByteValued for VirtioIommuConfig {} 274 // SAFETY: data structure only contain integers and have no implicit padding 275 unsafe impl ByteValued for VirtioIommuReqHead {} 276 // SAFETY: data structure only contain integers and have no implicit padding 277 unsafe impl ByteValued for VirtioIommuReqTail {} 278 // SAFETY: data structure only contain integers and have no implicit padding 279 unsafe impl ByteValued for VirtioIommuReqAttach {} 280 // SAFETY: data structure only contain integers and have no implicit padding 281 unsafe impl ByteValued for VirtioIommuReqDetach {} 282 // SAFETY: data structure only contain integers and have no implicit padding 283 unsafe impl ByteValued for VirtioIommuReqMap {} 284 // SAFETY: data structure only contain integers and have no implicit padding 285 unsafe impl ByteValued for VirtioIommuReqUnmap {} 286 // SAFETY: data structure only contain integers and have no implicit padding 287 unsafe impl ByteValued for VirtioIommuReqProbe {} 288 // SAFETY: data structure only contain integers and have no implicit padding 289 unsafe impl ByteValued for VirtioIommuProbeProperty {} 290 // SAFETY: data structure only contain integers and have no implicit padding 291 unsafe impl ByteValued for VirtioIommuProbeResvMem {} 292 // SAFETY: data structure only contain integers and have no implicit padding 293 unsafe impl ByteValued for VirtioIommuFault {} 294 295 #[derive(Error, Debug)] 296 enum Error { 297 #[error("Guest gave us bad memory addresses: {0}")] 298 GuestMemory(GuestMemoryError), 299 #[error("Guest gave us a write only descriptor that protocol says to read from")] 300 UnexpectedWriteOnlyDescriptor, 301 #[error("Guest gave us a read only descriptor that protocol says to write to")] 302 UnexpectedReadOnlyDescriptor, 303 #[error("Guest gave us too few descriptors in a descriptor chain")] 304 DescriptorChainTooShort, 305 #[error("Guest gave us a buffer that was too short to use")] 306 BufferLengthTooSmall, 307 #[error("Guest sent us invalid request")] 308 InvalidRequest, 309 #[error("Guest sent us invalid ATTACH request")] 310 InvalidAttachRequest, 311 #[error("Guest sent us invalid DETACH request")] 312 InvalidDetachRequest, 313 #[error("Guest sent us invalid MAP request")] 314 InvalidMapRequest, 315 #[error("Invalid to map because the domain is in bypass mode")] 316 InvalidMapRequestBypassDomain, 317 #[error("Invalid to map because the domain is missing")] 318 InvalidMapRequestMissingDomain, 319 #[error("Guest sent us invalid UNMAP request")] 320 InvalidUnmapRequest, 321 #[error("Invalid to unmap because the domain is in bypass mode")] 322 InvalidUnmapRequestBypassDomain, 323 #[error("Invalid to unmap because the domain is missing")] 324 InvalidUnmapRequestMissingDomain, 325 #[error("Guest sent us invalid PROBE request")] 326 InvalidProbeRequest, 327 #[error("Failed to performing external mapping: {0}")] 328 ExternalMapping(io::Error), 329 #[error("Failed to performing external unmapping: {0}")] 330 ExternalUnmapping(io::Error), 331 #[error("Failed adding used index: {0}")] 332 QueueAddUsed(virtio_queue::Error), 333 } 334 335 struct Request {} 336 337 impl Request { 338 // Parse the available vring buffer. Based on the hashmap table of external 339 // mappings required from various devices such as VFIO or vhost-user ones, 340 // this function might update the hashmap table of external mappings per 341 // domain. 342 // Basically, the VMM knows about the device_id <=> mapping relationship 343 // before running the VM, but at runtime, a new domain <=> mapping hashmap 344 // is created based on the information provided from the guest driver for 345 // virtio-iommu (giving the link device_id <=> domain). 346 fn parse( 347 desc_chain: &mut DescriptorChain<GuestMemoryLoadGuard<GuestMemoryMmap>>, 348 mapping: &Arc<IommuMapping>, 349 ext_mapping: &BTreeMap<u32, Arc<dyn ExternalDmaMapping>>, 350 msi_iova_space: (u64, u64), 351 ) -> result::Result<usize, Error> { 352 let desc = desc_chain 353 .next() 354 .ok_or(Error::DescriptorChainTooShort) 355 .map_err(|e| { 356 error!("Missing head descriptor"); 357 e 358 })?; 359 360 // The descriptor contains the request type which MUST be readable. 361 if desc.is_write_only() { 362 return Err(Error::UnexpectedWriteOnlyDescriptor); 363 } 364 365 if (desc.len() as usize) < size_of::<VirtioIommuReqHead>() { 366 return Err(Error::InvalidRequest); 367 } 368 369 let req_head: VirtioIommuReqHead = desc_chain 370 .memory() 371 .read_obj(desc.addr()) 372 .map_err(Error::GuestMemory)?; 373 let req_offset = size_of::<VirtioIommuReqHead>(); 374 let desc_size_left = (desc.len() as usize) - req_offset; 375 let req_addr = if let Some(addr) = desc.addr().checked_add(req_offset as u64) { 376 addr 377 } else { 378 return Err(Error::InvalidRequest); 379 }; 380 381 let (msi_iova_start, msi_iova_end) = msi_iova_space; 382 383 // Create the reply 384 let mut reply: Vec<u8> = Vec::new(); 385 let mut status = VIRTIO_IOMMU_S_OK; 386 let mut hdr_len = 0; 387 388 let result = (|| { 389 match req_head.type_ { 390 VIRTIO_IOMMU_T_ATTACH => { 391 if desc_size_left != size_of::<VirtioIommuReqAttach>() { 392 status = VIRTIO_IOMMU_S_INVAL; 393 return Err(Error::InvalidAttachRequest); 394 } 395 396 let req: VirtioIommuReqAttach = desc_chain 397 .memory() 398 .read_obj(req_addr as GuestAddress) 399 .map_err(Error::GuestMemory)?; 400 debug!("Attach request {:?}", req); 401 402 // Copy the value to use it as a proper reference. 403 let domain_id = req.domain; 404 let endpoint = req.endpoint; 405 let bypass = 406 (req.flags & VIRTIO_IOMMU_ATTACH_F_BYPASS) == VIRTIO_IOMMU_ATTACH_F_BYPASS; 407 408 let mut old_domain_id = domain_id; 409 if let Some(&id) = mapping.endpoints.read().unwrap().get(&endpoint) { 410 old_domain_id = id; 411 } 412 413 if old_domain_id != domain_id { 414 detach_endpoint_from_domain(endpoint, old_domain_id, mapping, ext_mapping)?; 415 } 416 417 // Add endpoint associated with specific domain 418 mapping 419 .endpoints 420 .write() 421 .unwrap() 422 .insert(endpoint, domain_id); 423 424 // If any other mappings exist in the domain for other containers, 425 // make sure to issue these mappings for the new endpoint/container 426 if let Some(domain_mappings) = &mapping.domains.read().unwrap().get(&domain_id) 427 { 428 if let Some(ext_map) = ext_mapping.get(&endpoint) { 429 for (virt_start, addr_map) in &domain_mappings.mappings { 430 ext_map 431 .map(*virt_start, addr_map.gpa, addr_map.size) 432 .map_err(Error::ExternalUnmapping)?; 433 } 434 } 435 } 436 437 // Add new domain with no mapping if the entry didn't exist yet 438 let mut domains = mapping.domains.write().unwrap(); 439 let domain = Domain { 440 mappings: BTreeMap::new(), 441 bypass, 442 }; 443 domains.entry(domain_id).or_insert_with(|| domain); 444 } 445 VIRTIO_IOMMU_T_DETACH => { 446 if desc_size_left != size_of::<VirtioIommuReqDetach>() { 447 status = VIRTIO_IOMMU_S_INVAL; 448 return Err(Error::InvalidDetachRequest); 449 } 450 451 let req: VirtioIommuReqDetach = desc_chain 452 .memory() 453 .read_obj(req_addr as GuestAddress) 454 .map_err(Error::GuestMemory)?; 455 debug!("Detach request {:?}", req); 456 457 // Copy the value to use it as a proper reference. 458 let domain_id = req.domain; 459 let endpoint = req.endpoint; 460 461 // Remove endpoint associated with specific domain 462 detach_endpoint_from_domain(endpoint, domain_id, mapping, ext_mapping)?; 463 } 464 VIRTIO_IOMMU_T_MAP => { 465 if desc_size_left != size_of::<VirtioIommuReqMap>() { 466 status = VIRTIO_IOMMU_S_INVAL; 467 return Err(Error::InvalidMapRequest); 468 } 469 470 let req: VirtioIommuReqMap = desc_chain 471 .memory() 472 .read_obj(req_addr as GuestAddress) 473 .map_err(Error::GuestMemory)?; 474 debug!("Map request {:?}", req); 475 476 // Copy the value to use it as a proper reference. 477 let domain_id = req.domain; 478 479 if let Some(domain) = mapping.domains.read().unwrap().get(&domain_id) { 480 if domain.bypass { 481 status = VIRTIO_IOMMU_S_INVAL; 482 return Err(Error::InvalidMapRequestBypassDomain); 483 } 484 } else { 485 status = VIRTIO_IOMMU_S_INVAL; 486 return Err(Error::InvalidMapRequestMissingDomain); 487 } 488 489 // Find the list of endpoints attached to the given domain. 490 let endpoints: Vec<u32> = mapping 491 .endpoints 492 .write() 493 .unwrap() 494 .iter() 495 .filter(|(_, &d)| d == domain_id) 496 .map(|(&e, _)| e) 497 .collect(); 498 499 // For viommu all endpoints receive their own VFIO container, as a result 500 // Each endpoint within the domain needs to be separately mapped, as the 501 // mapping is done on a per-container level, not a per-domain level 502 for endpoint in endpoints { 503 if let Some(ext_map) = ext_mapping.get(&endpoint) { 504 let size = req.virt_end - req.virt_start + 1; 505 ext_map 506 .map(req.virt_start, req.phys_start, size) 507 .map_err(Error::ExternalMapping)?; 508 } 509 } 510 511 // Add new mapping associated with the domain 512 mapping 513 .domains 514 .write() 515 .unwrap() 516 .get_mut(&domain_id) 517 .unwrap() 518 .mappings 519 .insert( 520 req.virt_start, 521 Mapping { 522 gpa: req.phys_start, 523 size: req.virt_end - req.virt_start + 1, 524 }, 525 ); 526 } 527 VIRTIO_IOMMU_T_UNMAP => { 528 if desc_size_left != size_of::<VirtioIommuReqUnmap>() { 529 status = VIRTIO_IOMMU_S_INVAL; 530 return Err(Error::InvalidUnmapRequest); 531 } 532 533 let req: VirtioIommuReqUnmap = desc_chain 534 .memory() 535 .read_obj(req_addr as GuestAddress) 536 .map_err(Error::GuestMemory)?; 537 debug!("Unmap request {:?}", req); 538 539 // Copy the value to use it as a proper reference. 540 let domain_id = req.domain; 541 let virt_start = req.virt_start; 542 543 if let Some(domain) = mapping.domains.read().unwrap().get(&domain_id) { 544 if domain.bypass { 545 status = VIRTIO_IOMMU_S_INVAL; 546 return Err(Error::InvalidUnmapRequestBypassDomain); 547 } 548 } else { 549 status = VIRTIO_IOMMU_S_INVAL; 550 return Err(Error::InvalidUnmapRequestMissingDomain); 551 } 552 553 // Find the list of endpoints attached to the given domain. 554 let endpoints: Vec<u32> = mapping 555 .endpoints 556 .write() 557 .unwrap() 558 .iter() 559 .filter(|(_, &d)| d == domain_id) 560 .map(|(&e, _)| e) 561 .collect(); 562 563 // Trigger external unmapping if necessary. 564 for endpoint in endpoints { 565 if let Some(ext_map) = ext_mapping.get(&endpoint) { 566 let size = req.virt_end - virt_start + 1; 567 ext_map 568 .unmap(virt_start, size) 569 .map_err(Error::ExternalUnmapping)?; 570 } 571 } 572 573 // Remove all mappings associated with the domain within the requested range 574 mapping 575 .domains 576 .write() 577 .unwrap() 578 .get_mut(&domain_id) 579 .unwrap() 580 .mappings 581 .retain(|&x, _| (x < req.virt_start || x > req.virt_end)); 582 } 583 VIRTIO_IOMMU_T_PROBE => { 584 if desc_size_left != size_of::<VirtioIommuReqProbe>() { 585 status = VIRTIO_IOMMU_S_INVAL; 586 return Err(Error::InvalidProbeRequest); 587 } 588 589 let req: VirtioIommuReqProbe = desc_chain 590 .memory() 591 .read_obj(req_addr as GuestAddress) 592 .map_err(Error::GuestMemory)?; 593 debug!("Probe request {:?}", req); 594 595 let probe_prop = VirtioIommuProbeProperty { 596 type_: VIRTIO_IOMMU_PROBE_T_RESV_MEM, 597 length: size_of::<VirtioIommuProbeResvMem>() as u16, 598 }; 599 reply.extend_from_slice(probe_prop.as_slice()); 600 601 let resv_mem = VirtioIommuProbeResvMem { 602 subtype: VIRTIO_IOMMU_RESV_MEM_T_MSI, 603 start: msi_iova_start, 604 end: msi_iova_end, 605 ..Default::default() 606 }; 607 reply.extend_from_slice(resv_mem.as_slice()); 608 609 hdr_len = PROBE_PROP_SIZE; 610 } 611 _ => { 612 status = VIRTIO_IOMMU_S_INVAL; 613 return Err(Error::InvalidRequest); 614 } 615 } 616 Ok(()) 617 })(); 618 619 let status_desc = desc_chain.next().ok_or(Error::DescriptorChainTooShort)?; 620 621 // The status MUST always be writable 622 if !status_desc.is_write_only() { 623 return Err(Error::UnexpectedReadOnlyDescriptor); 624 } 625 626 if status_desc.len() < hdr_len + size_of::<VirtioIommuReqTail>() as u32 { 627 return Err(Error::BufferLengthTooSmall); 628 } 629 630 let tail = VirtioIommuReqTail { 631 status, 632 ..Default::default() 633 }; 634 reply.extend_from_slice(tail.as_slice()); 635 636 // Make sure we return the result of the request to the guest before 637 // we return a potential error internally. 638 desc_chain 639 .memory() 640 .write_slice(reply.as_slice(), status_desc.addr()) 641 .map_err(Error::GuestMemory)?; 642 643 // Return the error if the result was not Ok(). 644 result?; 645 646 Ok((hdr_len as usize) + size_of::<VirtioIommuReqTail>()) 647 } 648 } 649 650 fn detach_endpoint_from_domain( 651 endpoint: u32, 652 domain_id: u32, 653 mapping: &Arc<IommuMapping>, 654 ext_mapping: &BTreeMap<u32, Arc<dyn ExternalDmaMapping>>, 655 ) -> result::Result<(), Error> { 656 // Remove endpoint associated with specific domain 657 mapping.endpoints.write().unwrap().remove(&endpoint); 658 659 // Trigger external unmapping for the endpoint if necessary. 660 if let Some(domain_mappings) = &mapping.domains.read().unwrap().get(&domain_id) { 661 if let Some(ext_map) = ext_mapping.get(&endpoint) { 662 for (virt_start, addr_map) in &domain_mappings.mappings { 663 ext_map 664 .unmap(*virt_start, addr_map.size) 665 .map_err(Error::ExternalUnmapping)?; 666 } 667 } 668 } 669 670 if mapping 671 .endpoints 672 .write() 673 .unwrap() 674 .iter() 675 .filter(|(_, &d)| d == domain_id) 676 .count() 677 == 0 678 { 679 mapping.domains.write().unwrap().remove(&domain_id); 680 } 681 682 Ok(()) 683 } 684 685 struct IommuEpollHandler { 686 mem: GuestMemoryAtomic<GuestMemoryMmap>, 687 request_queue: Queue, 688 _event_queue: Queue, 689 interrupt_cb: Arc<dyn VirtioInterrupt>, 690 request_queue_evt: EventFd, 691 _event_queue_evt: EventFd, 692 kill_evt: EventFd, 693 pause_evt: EventFd, 694 mapping: Arc<IommuMapping>, 695 ext_mapping: Arc<Mutex<BTreeMap<u32, Arc<dyn ExternalDmaMapping>>>>, 696 msi_iova_space: (u64, u64), 697 } 698 699 impl IommuEpollHandler { 700 fn request_queue(&mut self) -> Result<bool, Error> { 701 let mut used_descs = false; 702 while let Some(mut desc_chain) = self.request_queue.pop_descriptor_chain(self.mem.memory()) 703 { 704 let len = Request::parse( 705 &mut desc_chain, 706 &self.mapping, 707 &self.ext_mapping.lock().unwrap(), 708 self.msi_iova_space, 709 )?; 710 711 self.request_queue 712 .add_used(desc_chain.memory(), desc_chain.head_index(), len as u32) 713 .map_err(Error::QueueAddUsed)?; 714 715 used_descs = true; 716 } 717 718 Ok(used_descs) 719 } 720 721 fn signal_used_queue(&self, queue_index: u16) -> result::Result<(), DeviceError> { 722 self.interrupt_cb 723 .trigger(VirtioInterruptType::Queue(queue_index)) 724 .map_err(|e| { 725 error!("Failed to signal used queue: {:?}", e); 726 DeviceError::FailedSignalingUsedQueue(e) 727 }) 728 } 729 730 fn run( 731 &mut self, 732 paused: Arc<AtomicBool>, 733 paused_sync: Arc<Barrier>, 734 ) -> result::Result<(), EpollHelperError> { 735 let mut helper = EpollHelper::new(&self.kill_evt, &self.pause_evt)?; 736 helper.add_event(self.request_queue_evt.as_raw_fd(), REQUEST_Q_EVENT)?; 737 helper.run(paused, paused_sync, self)?; 738 739 Ok(()) 740 } 741 } 742 743 impl EpollHelperHandler for IommuEpollHandler { 744 fn handle_event( 745 &mut self, 746 _helper: &mut EpollHelper, 747 event: &epoll::Event, 748 ) -> result::Result<(), EpollHelperError> { 749 let ev_type = event.data as u16; 750 match ev_type { 751 REQUEST_Q_EVENT => { 752 self.request_queue_evt.read().map_err(|e| { 753 EpollHelperError::HandleEvent(anyhow!("Failed to get queue event: {:?}", e)) 754 })?; 755 756 let needs_notification = self.request_queue().map_err(|e| { 757 EpollHelperError::HandleEvent(anyhow!( 758 "Failed to process request queue : {:?}", 759 e 760 )) 761 })?; 762 if needs_notification { 763 self.signal_used_queue(0).map_err(|e| { 764 EpollHelperError::HandleEvent(anyhow!( 765 "Failed to signal used queue: {:?}", 766 e 767 )) 768 })?; 769 } 770 } 771 _ => { 772 return Err(EpollHelperError::HandleEvent(anyhow!( 773 "Unexpected event: {}", 774 ev_type 775 ))); 776 } 777 } 778 Ok(()) 779 } 780 } 781 782 #[derive(Clone, Copy, Debug, Serialize, Deserialize)] 783 struct Mapping { 784 gpa: u64, 785 size: u64, 786 } 787 788 #[derive(Clone, Debug)] 789 struct Domain { 790 mappings: BTreeMap<u64, Mapping>, 791 bypass: bool, 792 } 793 794 #[derive(Debug)] 795 pub struct IommuMapping { 796 // Domain related to an endpoint. 797 endpoints: Arc<RwLock<BTreeMap<u32, u32>>>, 798 // Information related to each domain. 799 domains: Arc<RwLock<BTreeMap<u32, Domain>>>, 800 // Global flag indicating if endpoints that are not attached to any domain 801 // are in bypass mode. 802 bypass: AtomicBool, 803 } 804 805 impl DmaRemapping for IommuMapping { 806 fn translate_gva(&self, id: u32, addr: u64) -> std::result::Result<u64, std::io::Error> { 807 debug!("Translate GVA addr 0x{:x}", addr); 808 if let Some(domain_id) = self.endpoints.read().unwrap().get(&id) { 809 if let Some(domain) = self.domains.read().unwrap().get(domain_id) { 810 // Directly return identity mapping in case the domain is in 811 // bypass mode. 812 if domain.bypass { 813 return Ok(addr); 814 } 815 816 let range_start = if VIRTIO_IOMMU_PAGE_SIZE_MASK > addr { 817 0 818 } else { 819 addr - VIRTIO_IOMMU_PAGE_SIZE_MASK 820 }; 821 for (&key, &value) in domain 822 .mappings 823 .range((Included(&range_start), Included(&addr))) 824 { 825 if addr >= key && addr < key + value.size { 826 let new_addr = addr - key + value.gpa; 827 debug!("Into GPA addr 0x{:x}", new_addr); 828 return Ok(new_addr); 829 } 830 } 831 } 832 } else if self.bypass.load(Ordering::Acquire) { 833 return Ok(addr); 834 } 835 836 Err(io::Error::new( 837 io::ErrorKind::Other, 838 format!("failed to translate GVA addr 0x{addr:x}"), 839 )) 840 } 841 842 fn translate_gpa(&self, id: u32, addr: u64) -> std::result::Result<u64, std::io::Error> { 843 debug!("Translate GPA addr 0x{:x}", addr); 844 if let Some(domain_id) = self.endpoints.read().unwrap().get(&id) { 845 if let Some(domain) = self.domains.read().unwrap().get(domain_id) { 846 // Directly return identity mapping in case the domain is in 847 // bypass mode. 848 if domain.bypass { 849 return Ok(addr); 850 } 851 852 for (&key, &value) in domain.mappings.iter() { 853 if addr >= value.gpa && addr < value.gpa + value.size { 854 let new_addr = addr - value.gpa + key; 855 debug!("Into GVA addr 0x{:x}", new_addr); 856 return Ok(new_addr); 857 } 858 } 859 } 860 } else if self.bypass.load(Ordering::Acquire) { 861 return Ok(addr); 862 } 863 864 Err(io::Error::new( 865 io::ErrorKind::Other, 866 format!("failed to translate GPA addr 0x{addr:x}"), 867 )) 868 } 869 } 870 871 #[derive(Debug)] 872 pub struct AccessPlatformMapping { 873 id: u32, 874 mapping: Arc<IommuMapping>, 875 } 876 877 impl AccessPlatformMapping { 878 pub fn new(id: u32, mapping: Arc<IommuMapping>) -> Self { 879 AccessPlatformMapping { id, mapping } 880 } 881 } 882 883 impl AccessPlatform for AccessPlatformMapping { 884 fn translate_gva(&self, base: u64, _size: u64) -> std::result::Result<u64, std::io::Error> { 885 self.mapping.translate_gva(self.id, base) 886 } 887 fn translate_gpa(&self, base: u64, _size: u64) -> std::result::Result<u64, std::io::Error> { 888 self.mapping.translate_gpa(self.id, base) 889 } 890 } 891 892 pub struct Iommu { 893 common: VirtioCommon, 894 id: String, 895 config: VirtioIommuConfig, 896 mapping: Arc<IommuMapping>, 897 ext_mapping: Arc<Mutex<BTreeMap<u32, Arc<dyn ExternalDmaMapping>>>>, 898 seccomp_action: SeccompAction, 899 exit_evt: EventFd, 900 msi_iova_space: (u64, u64), 901 } 902 903 type EndpointsState = Vec<(u32, u32)>; 904 type DomainsState = Vec<(u32, (Vec<(u64, Mapping)>, bool))>; 905 906 #[derive(Serialize, Deserialize)] 907 pub struct IommuState { 908 avail_features: u64, 909 acked_features: u64, 910 endpoints: EndpointsState, 911 domains: DomainsState, 912 } 913 914 impl Iommu { 915 pub fn new( 916 id: String, 917 seccomp_action: SeccompAction, 918 exit_evt: EventFd, 919 msi_iova_space: (u64, u64), 920 state: Option<IommuState>, 921 ) -> io::Result<(Self, Arc<IommuMapping>)> { 922 let (avail_features, acked_features, endpoints, domains, paused) = 923 if let Some(state) = state { 924 info!("Restoring virtio-iommu {}", id); 925 ( 926 state.avail_features, 927 state.acked_features, 928 state.endpoints.into_iter().collect(), 929 state 930 .domains 931 .into_iter() 932 .map(|(k, v)| { 933 ( 934 k, 935 Domain { 936 mappings: v.0.into_iter().collect(), 937 bypass: v.1, 938 }, 939 ) 940 }) 941 .collect(), 942 true, 943 ) 944 } else { 945 let avail_features = 1u64 << VIRTIO_F_VERSION_1 946 | 1u64 << VIRTIO_IOMMU_F_MAP_UNMAP 947 | 1u64 << VIRTIO_IOMMU_F_PROBE 948 | 1u64 << VIRTIO_IOMMU_F_BYPASS_CONFIG; 949 950 (avail_features, 0, BTreeMap::new(), BTreeMap::new(), false) 951 }; 952 953 let config = VirtioIommuConfig { 954 page_size_mask: VIRTIO_IOMMU_PAGE_SIZE_MASK, 955 probe_size: PROBE_PROP_SIZE, 956 ..Default::default() 957 }; 958 959 let mapping = Arc::new(IommuMapping { 960 endpoints: Arc::new(RwLock::new(endpoints)), 961 domains: Arc::new(RwLock::new(domains)), 962 bypass: AtomicBool::new(true), 963 }); 964 965 Ok(( 966 Iommu { 967 id, 968 common: VirtioCommon { 969 device_type: VirtioDeviceType::Iommu as u32, 970 queue_sizes: QUEUE_SIZES.to_vec(), 971 avail_features, 972 acked_features, 973 paused_sync: Some(Arc::new(Barrier::new(2))), 974 min_queues: NUM_QUEUES as u16, 975 paused: Arc::new(AtomicBool::new(paused)), 976 ..Default::default() 977 }, 978 config, 979 mapping: mapping.clone(), 980 ext_mapping: Arc::new(Mutex::new(BTreeMap::new())), 981 seccomp_action, 982 exit_evt, 983 msi_iova_space, 984 }, 985 mapping, 986 )) 987 } 988 989 fn state(&self) -> IommuState { 990 IommuState { 991 avail_features: self.common.avail_features, 992 acked_features: self.common.acked_features, 993 endpoints: self 994 .mapping 995 .endpoints 996 .read() 997 .unwrap() 998 .clone() 999 .into_iter() 1000 .collect(), 1001 domains: self 1002 .mapping 1003 .domains 1004 .read() 1005 .unwrap() 1006 .clone() 1007 .into_iter() 1008 .map(|(k, v)| (k, (v.mappings.into_iter().collect(), v.bypass))) 1009 .collect(), 1010 } 1011 } 1012 1013 fn update_bypass(&mut self) { 1014 // Use bypass from config if VIRTIO_IOMMU_F_BYPASS_CONFIG has been negotiated 1015 if !self 1016 .common 1017 .feature_acked(VIRTIO_IOMMU_F_BYPASS_CONFIG.into()) 1018 { 1019 return; 1020 } 1021 1022 let bypass = self.config.bypass == 1; 1023 info!("Updating bypass mode to {}", bypass); 1024 self.mapping.bypass.store(bypass, Ordering::Release); 1025 } 1026 1027 pub fn add_external_mapping(&mut self, device_id: u32, mapping: Arc<dyn ExternalDmaMapping>) { 1028 self.ext_mapping.lock().unwrap().insert(device_id, mapping); 1029 } 1030 1031 #[cfg(fuzzing)] 1032 pub fn wait_for_epoll_threads(&mut self) { 1033 self.common.wait_for_epoll_threads(); 1034 } 1035 } 1036 1037 impl Drop for Iommu { 1038 fn drop(&mut self) { 1039 if let Some(kill_evt) = self.common.kill_evt.take() { 1040 // Ignore the result because there is nothing we can do about it. 1041 let _ = kill_evt.write(1); 1042 } 1043 self.common.wait_for_epoll_threads(); 1044 } 1045 } 1046 1047 impl VirtioDevice for Iommu { 1048 fn device_type(&self) -> u32 { 1049 self.common.device_type 1050 } 1051 1052 fn queue_max_sizes(&self) -> &[u16] { 1053 &self.common.queue_sizes 1054 } 1055 1056 fn features(&self) -> u64 { 1057 self.common.avail_features 1058 } 1059 1060 fn ack_features(&mut self, value: u64) { 1061 self.common.ack_features(value) 1062 } 1063 1064 fn read_config(&self, offset: u64, data: &mut [u8]) { 1065 self.read_config_from_slice(self.config.as_slice(), offset, data); 1066 } 1067 1068 fn write_config(&mut self, offset: u64, data: &[u8]) { 1069 // The "bypass" field is the only mutable field 1070 let bypass_offset = 1071 (&self.config.bypass as *const _ as u64) - (&self.config as *const _ as u64); 1072 if offset != bypass_offset || data.len() != std::mem::size_of_val(&self.config.bypass) { 1073 error!( 1074 "Attempt to write to read-only field: offset {:x} length {}", 1075 offset, 1076 data.len() 1077 ); 1078 return; 1079 } 1080 1081 self.config.bypass = data[0]; 1082 1083 self.update_bypass(); 1084 } 1085 1086 fn activate( 1087 &mut self, 1088 mem: GuestMemoryAtomic<GuestMemoryMmap>, 1089 interrupt_cb: Arc<dyn VirtioInterrupt>, 1090 mut queues: Vec<(usize, Queue, EventFd)>, 1091 ) -> ActivateResult { 1092 self.common.activate(&queues, &interrupt_cb)?; 1093 let (kill_evt, pause_evt) = self.common.dup_eventfds(); 1094 1095 let (_, request_queue, request_queue_evt) = queues.remove(0); 1096 let (_, _event_queue, _event_queue_evt) = queues.remove(0); 1097 1098 let mut handler = IommuEpollHandler { 1099 mem, 1100 request_queue, 1101 _event_queue, 1102 interrupt_cb, 1103 request_queue_evt, 1104 _event_queue_evt, 1105 kill_evt, 1106 pause_evt, 1107 mapping: self.mapping.clone(), 1108 ext_mapping: self.ext_mapping.clone(), 1109 msi_iova_space: self.msi_iova_space, 1110 }; 1111 1112 let paused = self.common.paused.clone(); 1113 let paused_sync = self.common.paused_sync.clone(); 1114 let mut epoll_threads = Vec::new(); 1115 spawn_virtio_thread( 1116 &self.id, 1117 &self.seccomp_action, 1118 Thread::VirtioIommu, 1119 &mut epoll_threads, 1120 &self.exit_evt, 1121 move || handler.run(paused, paused_sync.unwrap()), 1122 )?; 1123 1124 self.common.epoll_threads = Some(epoll_threads); 1125 1126 event!("virtio-device", "activated", "id", &self.id); 1127 Ok(()) 1128 } 1129 1130 fn reset(&mut self) -> Option<Arc<dyn VirtioInterrupt>> { 1131 let result = self.common.reset(); 1132 event!("virtio-device", "reset", "id", &self.id); 1133 result 1134 } 1135 } 1136 1137 impl Pausable for Iommu { 1138 fn pause(&mut self) -> result::Result<(), MigratableError> { 1139 self.common.pause() 1140 } 1141 1142 fn resume(&mut self) -> result::Result<(), MigratableError> { 1143 self.common.resume() 1144 } 1145 } 1146 1147 impl Snapshottable for Iommu { 1148 fn id(&self) -> String { 1149 self.id.clone() 1150 } 1151 1152 fn snapshot(&mut self) -> std::result::Result<Snapshot, MigratableError> { 1153 Snapshot::new_from_state(&self.state()) 1154 } 1155 } 1156 impl Transportable for Iommu {} 1157 impl Migratable for Iommu {} 1158