xref: /cloud-hypervisor/virtio-devices/src/iommu.rs (revision fa7a000dbe9637eb256af18ae8c3c4a8d5bf9c8f)
1 // Copyright © 2019 Intel Corporation
2 //
3 // SPDX-License-Identifier: Apache-2.0 AND BSD-3-Clause
4 
5 use super::Error as DeviceError;
6 use super::{
7     ActivateResult, EpollHelper, EpollHelperError, EpollHelperHandler, VirtioCommon, VirtioDevice,
8     VirtioDeviceType, EPOLL_HELPER_EVENT_LAST, VIRTIO_F_VERSION_1,
9 };
10 use crate::seccomp_filters::Thread;
11 use crate::thread_helper::spawn_virtio_thread;
12 use crate::GuestMemoryMmap;
13 use crate::{DmaRemapping, VirtioInterrupt, VirtioInterruptType};
14 use anyhow::anyhow;
15 use seccompiler::SeccompAction;
16 use std::collections::BTreeMap;
17 use std::io;
18 use std::mem::size_of;
19 use std::ops::Bound::Included;
20 use std::os::unix::io::AsRawFd;
21 use std::result;
22 use std::sync::atomic::{AtomicBool, Ordering};
23 use std::sync::{Arc, Barrier, Mutex, RwLock};
24 use thiserror::Error;
25 use versionize::{VersionMap, Versionize, VersionizeResult};
26 use versionize_derive::Versionize;
27 use virtio_queue::{DescriptorChain, Queue, QueueT};
28 use vm_device::dma_mapping::ExternalDmaMapping;
29 use vm_memory::{
30     Address, ByteValued, Bytes, GuestAddress, GuestAddressSpace, GuestMemoryAtomic,
31     GuestMemoryError, GuestMemoryLoadGuard,
32 };
33 use vm_migration::VersionMapped;
34 use vm_migration::{Migratable, MigratableError, Pausable, Snapshot, Snapshottable, Transportable};
35 use vm_virtio::AccessPlatform;
36 use vmm_sys_util::eventfd::EventFd;
37 
38 /// Queues sizes
39 const QUEUE_SIZE: u16 = 256;
40 const NUM_QUEUES: usize = 2;
41 const QUEUE_SIZES: &[u16] = &[QUEUE_SIZE; NUM_QUEUES];
42 
43 /// New descriptors are pending on the request queue.
44 /// "requestq" is meant to be used anytime an action is required to be
45 /// performed on behalf of the guest driver.
46 const REQUEST_Q_EVENT: u16 = EPOLL_HELPER_EVENT_LAST + 1;
47 /// New descriptors are pending on the event queue.
48 /// "eventq" lets the device report any fault or other asynchronous event to
49 /// the guest driver.
50 const _EVENT_Q_EVENT: u16 = EPOLL_HELPER_EVENT_LAST + 2;
51 
52 /// PROBE properties size.
53 /// This is the minimal size to provide at least one RESV_MEM property.
54 /// Because virtio-iommu expects one MSI reserved region, we must provide it,
55 /// otherwise the driver in the guest will define a predefined one between
56 /// 0x8000000 and 0x80FFFFF, which is only relevant for ARM architecture, but
57 /// will conflict with x86.
58 const PROBE_PROP_SIZE: u32 =
59     (size_of::<VirtioIommuProbeProperty>() + size_of::<VirtioIommuProbeResvMem>()) as u32;
60 
61 /// Virtio IOMMU features
62 #[allow(unused)]
63 const VIRTIO_IOMMU_F_INPUT_RANGE: u32 = 0;
64 #[allow(unused)]
65 const VIRTIO_IOMMU_F_DOMAIN_RANGE: u32 = 1;
66 #[allow(unused)]
67 const VIRTIO_IOMMU_F_MAP_UNMAP: u32 = 2;
68 #[allow(unused)]
69 const VIRTIO_IOMMU_F_BYPASS: u32 = 3;
70 const VIRTIO_IOMMU_F_PROBE: u32 = 4;
71 #[allow(unused)]
72 const VIRTIO_IOMMU_F_MMIO: u32 = 5;
73 const VIRTIO_IOMMU_F_BYPASS_CONFIG: u32 = 6;
74 
75 // Support 2MiB and 4KiB page sizes.
76 const VIRTIO_IOMMU_PAGE_SIZE_MASK: u64 = (2 << 20) | (4 << 10);
77 
78 #[derive(Copy, Clone, Debug, Default)]
79 #[repr(packed)]
80 #[allow(dead_code)]
81 struct VirtioIommuRange32 {
82     start: u32,
83     end: u32,
84 }
85 
86 #[derive(Copy, Clone, Debug, Default)]
87 #[repr(packed)]
88 #[allow(dead_code)]
89 struct VirtioIommuRange64 {
90     start: u64,
91     end: u64,
92 }
93 
94 #[derive(Copy, Clone, Debug, Default)]
95 #[repr(packed)]
96 #[allow(dead_code)]
97 struct VirtioIommuConfig {
98     page_size_mask: u64,
99     input_range: VirtioIommuRange64,
100     domain_range: VirtioIommuRange32,
101     probe_size: u32,
102     bypass: u8,
103     _reserved: [u8; 7],
104 }
105 
106 /// Virtio IOMMU request type
107 const VIRTIO_IOMMU_T_ATTACH: u8 = 1;
108 const VIRTIO_IOMMU_T_DETACH: u8 = 2;
109 const VIRTIO_IOMMU_T_MAP: u8 = 3;
110 const VIRTIO_IOMMU_T_UNMAP: u8 = 4;
111 const VIRTIO_IOMMU_T_PROBE: u8 = 5;
112 
113 #[derive(Copy, Clone, Debug, Default)]
114 #[repr(packed)]
115 struct VirtioIommuReqHead {
116     type_: u8,
117     _reserved: [u8; 3],
118 }
119 
120 /// Virtio IOMMU request status
121 const VIRTIO_IOMMU_S_OK: u8 = 0;
122 #[allow(unused)]
123 const VIRTIO_IOMMU_S_IOERR: u8 = 1;
124 #[allow(unused)]
125 const VIRTIO_IOMMU_S_UNSUPP: u8 = 2;
126 #[allow(unused)]
127 const VIRTIO_IOMMU_S_DEVERR: u8 = 3;
128 #[allow(unused)]
129 const VIRTIO_IOMMU_S_INVAL: u8 = 4;
130 #[allow(unused)]
131 const VIRTIO_IOMMU_S_RANGE: u8 = 5;
132 #[allow(unused)]
133 const VIRTIO_IOMMU_S_NOENT: u8 = 6;
134 #[allow(unused)]
135 const VIRTIO_IOMMU_S_FAULT: u8 = 7;
136 #[allow(unused)]
137 const VIRTIO_IOMMU_S_NOMEM: u8 = 8;
138 
139 #[derive(Copy, Clone, Debug, Default)]
140 #[repr(packed)]
141 #[allow(dead_code)]
142 struct VirtioIommuReqTail {
143     status: u8,
144     _reserved: [u8; 3],
145 }
146 
147 /// ATTACH request
148 #[derive(Copy, Clone, Debug, Default)]
149 #[repr(packed)]
150 struct VirtioIommuReqAttach {
151     domain: u32,
152     endpoint: u32,
153     flags: u32,
154     _reserved: [u8; 4],
155 }
156 
157 const VIRTIO_IOMMU_ATTACH_F_BYPASS: u32 = 1;
158 
159 /// DETACH request
160 #[derive(Copy, Clone, Debug, Default)]
161 #[repr(packed)]
162 struct VirtioIommuReqDetach {
163     domain: u32,
164     endpoint: u32,
165     _reserved: [u8; 8],
166 }
167 
168 /// Virtio IOMMU request MAP flags
169 #[allow(unused)]
170 const VIRTIO_IOMMU_MAP_F_READ: u32 = 1;
171 #[allow(unused)]
172 const VIRTIO_IOMMU_MAP_F_WRITE: u32 = 1 << 1;
173 #[allow(unused)]
174 const VIRTIO_IOMMU_MAP_F_MMIO: u32 = 1 << 2;
175 #[allow(unused)]
176 const VIRTIO_IOMMU_MAP_F_MASK: u32 =
177     VIRTIO_IOMMU_MAP_F_READ | VIRTIO_IOMMU_MAP_F_WRITE | VIRTIO_IOMMU_MAP_F_MMIO;
178 
179 /// MAP request
180 #[derive(Copy, Clone, Debug, Default)]
181 #[repr(packed)]
182 struct VirtioIommuReqMap {
183     domain: u32,
184     virt_start: u64,
185     virt_end: u64,
186     phys_start: u64,
187     _flags: u32,
188 }
189 
190 /// UNMAP request
191 #[derive(Copy, Clone, Debug, Default)]
192 #[repr(packed)]
193 struct VirtioIommuReqUnmap {
194     domain: u32,
195     virt_start: u64,
196     virt_end: u64,
197     _reserved: [u8; 4],
198 }
199 
200 /// Virtio IOMMU request PROBE types
201 #[allow(unused)]
202 const VIRTIO_IOMMU_PROBE_T_NONE: u16 = 0;
203 const VIRTIO_IOMMU_PROBE_T_RESV_MEM: u16 = 1;
204 #[allow(unused)]
205 const VIRTIO_IOMMU_PROBE_T_MASK: u16 = 0xfff;
206 
207 /// PROBE request
208 #[derive(Copy, Clone, Debug, Default)]
209 #[repr(packed)]
210 #[allow(dead_code)]
211 struct VirtioIommuReqProbe {
212     endpoint: u32,
213     _reserved: [u64; 8],
214 }
215 
216 #[derive(Copy, Clone, Debug, Default)]
217 #[repr(packed)]
218 #[allow(dead_code)]
219 struct VirtioIommuProbeProperty {
220     type_: u16,
221     length: u16,
222 }
223 
224 /// Virtio IOMMU request PROBE property RESV_MEM subtypes
225 #[allow(unused)]
226 const VIRTIO_IOMMU_RESV_MEM_T_RESERVED: u8 = 0;
227 const VIRTIO_IOMMU_RESV_MEM_T_MSI: u8 = 1;
228 
229 #[derive(Copy, Clone, Debug, Default)]
230 #[repr(packed)]
231 #[allow(dead_code)]
232 struct VirtioIommuProbeResvMem {
233     subtype: u8,
234     _reserved: [u8; 3],
235     start: u64,
236     end: u64,
237 }
238 
239 /// Virtio IOMMU fault flags
240 #[allow(unused)]
241 const VIRTIO_IOMMU_FAULT_F_READ: u32 = 1;
242 #[allow(unused)]
243 const VIRTIO_IOMMU_FAULT_F_WRITE: u32 = 1 << 1;
244 #[allow(unused)]
245 const VIRTIO_IOMMU_FAULT_F_EXEC: u32 = 1 << 2;
246 #[allow(unused)]
247 const VIRTIO_IOMMU_FAULT_F_ADDRESS: u32 = 1 << 8;
248 
249 /// Virtio IOMMU fault reasons
250 #[allow(unused)]
251 const VIRTIO_IOMMU_FAULT_R_UNKNOWN: u32 = 0;
252 #[allow(unused)]
253 const VIRTIO_IOMMU_FAULT_R_DOMAIN: u32 = 1;
254 #[allow(unused)]
255 const VIRTIO_IOMMU_FAULT_R_MAPPING: u32 = 2;
256 
257 /// Fault reporting through eventq
258 #[allow(unused)]
259 #[derive(Copy, Clone, Debug, Default)]
260 #[repr(packed)]
261 struct VirtioIommuFault {
262     reason: u8,
263     reserved: [u8; 3],
264     flags: u32,
265     endpoint: u32,
266     reserved2: [u8; 4],
267     address: u64,
268 }
269 
270 // SAFETY: data structure only contain integers and have no implicit padding
271 unsafe impl ByteValued for VirtioIommuRange32 {}
272 // SAFETY: data structure only contain integers and have no implicit padding
273 unsafe impl ByteValued for VirtioIommuRange64 {}
274 // SAFETY: data structure only contain integers and have no implicit padding
275 unsafe impl ByteValued for VirtioIommuConfig {}
276 // SAFETY: data structure only contain integers and have no implicit padding
277 unsafe impl ByteValued for VirtioIommuReqHead {}
278 // SAFETY: data structure only contain integers and have no implicit padding
279 unsafe impl ByteValued for VirtioIommuReqTail {}
280 // SAFETY: data structure only contain integers and have no implicit padding
281 unsafe impl ByteValued for VirtioIommuReqAttach {}
282 // SAFETY: data structure only contain integers and have no implicit padding
283 unsafe impl ByteValued for VirtioIommuReqDetach {}
284 // SAFETY: data structure only contain integers and have no implicit padding
285 unsafe impl ByteValued for VirtioIommuReqMap {}
286 // SAFETY: data structure only contain integers and have no implicit padding
287 unsafe impl ByteValued for VirtioIommuReqUnmap {}
288 // SAFETY: data structure only contain integers and have no implicit padding
289 unsafe impl ByteValued for VirtioIommuReqProbe {}
290 // SAFETY: data structure only contain integers and have no implicit padding
291 unsafe impl ByteValued for VirtioIommuProbeProperty {}
292 // SAFETY: data structure only contain integers and have no implicit padding
293 unsafe impl ByteValued for VirtioIommuProbeResvMem {}
294 // SAFETY: data structure only contain integers and have no implicit padding
295 unsafe impl ByteValued for VirtioIommuFault {}
296 
297 #[derive(Error, Debug)]
298 enum Error {
299     #[error("Guest gave us bad memory addresses: {0}")]
300     GuestMemory(GuestMemoryError),
301     #[error("Guest gave us a write only descriptor that protocol says to read from")]
302     UnexpectedWriteOnlyDescriptor,
303     #[error("Guest gave us a read only descriptor that protocol says to write to")]
304     UnexpectedReadOnlyDescriptor,
305     #[error("Guest gave us too few descriptors in a descriptor chain")]
306     DescriptorChainTooShort,
307     #[error("Guest gave us a buffer that was too short to use")]
308     BufferLengthTooSmall,
309     #[error("Guest sent us invalid request")]
310     InvalidRequest,
311     #[error("Guest sent us invalid ATTACH request")]
312     InvalidAttachRequest,
313     #[error("Guest sent us invalid DETACH request")]
314     InvalidDetachRequest,
315     #[error("Guest sent us invalid MAP request")]
316     InvalidMapRequest,
317     #[error("Invalid to map because the domain is in bypass mode")]
318     InvalidMapRequestBypassDomain,
319     #[error("Invalid to map because the domain is missing")]
320     InvalidMapRequestMissingDomain,
321     #[error("Guest sent us invalid UNMAP request")]
322     InvalidUnmapRequest,
323     #[error("Invalid to unmap because the domain is in bypass mode")]
324     InvalidUnmapRequestBypassDomain,
325     #[error("Invalid to unmap because the domain is missing")]
326     InvalidUnmapRequestMissingDomain,
327     #[error("Guest sent us invalid PROBE request")]
328     InvalidProbeRequest,
329     #[error("Failed to performing external mapping: {0}")]
330     ExternalMapping(io::Error),
331     #[error("Failed to performing external unmapping: {0}")]
332     ExternalUnmapping(io::Error),
333     #[error("Failed adding used index: {0}")]
334     QueueAddUsed(virtio_queue::Error),
335 }
336 
337 struct Request {}
338 
339 impl Request {
340     // Parse the available vring buffer. Based on the hashmap table of external
341     // mappings required from various devices such as VFIO or vhost-user ones,
342     // this function might update the hashmap table of external mappings per
343     // domain.
344     // Basically, the VMM knows about the device_id <=> mapping relationship
345     // before running the VM, but at runtime, a new domain <=> mapping hashmap
346     // is created based on the information provided from the guest driver for
347     // virtio-iommu (giving the link device_id <=> domain).
348     fn parse(
349         desc_chain: &mut DescriptorChain<GuestMemoryLoadGuard<GuestMemoryMmap>>,
350         mapping: &Arc<IommuMapping>,
351         ext_mapping: &BTreeMap<u32, Arc<dyn ExternalDmaMapping>>,
352         msi_iova_space: (u64, u64),
353     ) -> result::Result<usize, Error> {
354         let desc = desc_chain
355             .next()
356             .ok_or(Error::DescriptorChainTooShort)
357             .map_err(|e| {
358                 error!("Missing head descriptor");
359                 e
360             })?;
361 
362         // The descriptor contains the request type which MUST be readable.
363         if desc.is_write_only() {
364             return Err(Error::UnexpectedWriteOnlyDescriptor);
365         }
366 
367         if (desc.len() as usize) < size_of::<VirtioIommuReqHead>() {
368             return Err(Error::InvalidRequest);
369         }
370 
371         let req_head: VirtioIommuReqHead = desc_chain
372             .memory()
373             .read_obj(desc.addr())
374             .map_err(Error::GuestMemory)?;
375         let req_offset = size_of::<VirtioIommuReqHead>();
376         let desc_size_left = (desc.len() as usize) - req_offset;
377         let req_addr = if let Some(addr) = desc.addr().checked_add(req_offset as u64) {
378             addr
379         } else {
380             return Err(Error::InvalidRequest);
381         };
382 
383         let (msi_iova_start, msi_iova_end) = msi_iova_space;
384 
385         // Create the reply
386         let mut reply: Vec<u8> = Vec::new();
387         let mut status = VIRTIO_IOMMU_S_OK;
388         let mut hdr_len = 0;
389 
390         let result = (|| {
391             match req_head.type_ {
392                 VIRTIO_IOMMU_T_ATTACH => {
393                     if desc_size_left != size_of::<VirtioIommuReqAttach>() {
394                         status = VIRTIO_IOMMU_S_INVAL;
395                         return Err(Error::InvalidAttachRequest);
396                     }
397 
398                     let req: VirtioIommuReqAttach = desc_chain
399                         .memory()
400                         .read_obj(req_addr as GuestAddress)
401                         .map_err(Error::GuestMemory)?;
402                     debug!("Attach request {:?}", req);
403 
404                     // Copy the value to use it as a proper reference.
405                     let domain_id = req.domain;
406                     let endpoint = req.endpoint;
407                     let bypass =
408                         (req.flags & VIRTIO_IOMMU_ATTACH_F_BYPASS) == VIRTIO_IOMMU_ATTACH_F_BYPASS;
409 
410                     let mut old_domain_id = domain_id;
411                     if let Some(&id) = mapping.endpoints.read().unwrap().get(&endpoint) {
412                         old_domain_id = id;
413                     }
414 
415                     if old_domain_id != domain_id {
416                         detach_endpoint_from_domain(endpoint, old_domain_id, mapping, ext_mapping)?;
417                     }
418 
419                     // Add endpoint associated with specific domain
420                     mapping
421                         .endpoints
422                         .write()
423                         .unwrap()
424                         .insert(endpoint, domain_id);
425 
426                     // If any other mappings exist in the domain for other containers,
427                     // make sure to issue these mappings for the new endpoint/container
428                     if let Some(domain_mappings) = &mapping.domains.read().unwrap().get(&domain_id)
429                     {
430                         if let Some(ext_map) = ext_mapping.get(&endpoint) {
431                             for (virt_start, addr_map) in &domain_mappings.mappings {
432                                 ext_map
433                                     .map(*virt_start, addr_map.gpa, addr_map.size)
434                                     .map_err(Error::ExternalUnmapping)?;
435                             }
436                         }
437                     }
438 
439                     // Add new domain with no mapping if the entry didn't exist yet
440                     let mut domains = mapping.domains.write().unwrap();
441                     let domain = Domain {
442                         mappings: BTreeMap::new(),
443                         bypass,
444                     };
445                     domains.entry(domain_id).or_insert_with(|| domain);
446                 }
447                 VIRTIO_IOMMU_T_DETACH => {
448                     if desc_size_left != size_of::<VirtioIommuReqDetach>() {
449                         status = VIRTIO_IOMMU_S_INVAL;
450                         return Err(Error::InvalidDetachRequest);
451                     }
452 
453                     let req: VirtioIommuReqDetach = desc_chain
454                         .memory()
455                         .read_obj(req_addr as GuestAddress)
456                         .map_err(Error::GuestMemory)?;
457                     debug!("Detach request {:?}", req);
458 
459                     // Copy the value to use it as a proper reference.
460                     let domain_id = req.domain;
461                     let endpoint = req.endpoint;
462 
463                     // Remove endpoint associated with specific domain
464                     detach_endpoint_from_domain(endpoint, domain_id, mapping, ext_mapping)?;
465                 }
466                 VIRTIO_IOMMU_T_MAP => {
467                     if desc_size_left != size_of::<VirtioIommuReqMap>() {
468                         status = VIRTIO_IOMMU_S_INVAL;
469                         return Err(Error::InvalidMapRequest);
470                     }
471 
472                     let req: VirtioIommuReqMap = desc_chain
473                         .memory()
474                         .read_obj(req_addr as GuestAddress)
475                         .map_err(Error::GuestMemory)?;
476                     debug!("Map request {:?}", req);
477 
478                     // Copy the value to use it as a proper reference.
479                     let domain_id = req.domain;
480 
481                     if let Some(domain) = mapping.domains.read().unwrap().get(&domain_id) {
482                         if domain.bypass {
483                             status = VIRTIO_IOMMU_S_INVAL;
484                             return Err(Error::InvalidMapRequestBypassDomain);
485                         }
486                     } else {
487                         status = VIRTIO_IOMMU_S_INVAL;
488                         return Err(Error::InvalidMapRequestMissingDomain);
489                     }
490 
491                     // Find the list of endpoints attached to the given domain.
492                     let endpoints: Vec<u32> = mapping
493                         .endpoints
494                         .write()
495                         .unwrap()
496                         .iter()
497                         .filter(|(_, &d)| d == domain_id)
498                         .map(|(&e, _)| e)
499                         .collect();
500 
501                     // For viommu all endpoints receive their own VFIO container, as a result
502                     // Each endpoint within the domain needs to be separately mapped, as the
503                     // mapping is done on a per-container level, not a per-domain level
504                     for endpoint in endpoints {
505                         if let Some(ext_map) = ext_mapping.get(&endpoint) {
506                             let size = req.virt_end - req.virt_start + 1;
507                             ext_map
508                                 .map(req.virt_start, req.phys_start, size)
509                                 .map_err(Error::ExternalMapping)?;
510                         }
511                     }
512 
513                     // Add new mapping associated with the domain
514                     mapping
515                         .domains
516                         .write()
517                         .unwrap()
518                         .get_mut(&domain_id)
519                         .unwrap()
520                         .mappings
521                         .insert(
522                             req.virt_start,
523                             Mapping {
524                                 gpa: req.phys_start,
525                                 size: req.virt_end - req.virt_start + 1,
526                             },
527                         );
528                 }
529                 VIRTIO_IOMMU_T_UNMAP => {
530                     if desc_size_left != size_of::<VirtioIommuReqUnmap>() {
531                         status = VIRTIO_IOMMU_S_INVAL;
532                         return Err(Error::InvalidUnmapRequest);
533                     }
534 
535                     let req: VirtioIommuReqUnmap = desc_chain
536                         .memory()
537                         .read_obj(req_addr as GuestAddress)
538                         .map_err(Error::GuestMemory)?;
539                     debug!("Unmap request {:?}", req);
540 
541                     // Copy the value to use it as a proper reference.
542                     let domain_id = req.domain;
543                     let virt_start = req.virt_start;
544 
545                     if let Some(domain) = mapping.domains.read().unwrap().get(&domain_id) {
546                         if domain.bypass {
547                             status = VIRTIO_IOMMU_S_INVAL;
548                             return Err(Error::InvalidUnmapRequestBypassDomain);
549                         }
550                     } else {
551                         status = VIRTIO_IOMMU_S_INVAL;
552                         return Err(Error::InvalidUnmapRequestMissingDomain);
553                     }
554 
555                     // Find the list of endpoints attached to the given domain.
556                     let endpoints: Vec<u32> = mapping
557                         .endpoints
558                         .write()
559                         .unwrap()
560                         .iter()
561                         .filter(|(_, &d)| d == domain_id)
562                         .map(|(&e, _)| e)
563                         .collect();
564 
565                     // Trigger external unmapping if necessary.
566                     for endpoint in endpoints {
567                         if let Some(ext_map) = ext_mapping.get(&endpoint) {
568                             let size = req.virt_end - virt_start + 1;
569                             ext_map
570                                 .unmap(virt_start, size)
571                                 .map_err(Error::ExternalUnmapping)?;
572                         }
573                     }
574 
575                     // Remove all mappings associated with the domain within the requested range
576                     mapping
577                         .domains
578                         .write()
579                         .unwrap()
580                         .get_mut(&domain_id)
581                         .unwrap()
582                         .mappings
583                         .retain(|&x, _| (x < req.virt_start || x > req.virt_end));
584                 }
585                 VIRTIO_IOMMU_T_PROBE => {
586                     if desc_size_left != size_of::<VirtioIommuReqProbe>() {
587                         status = VIRTIO_IOMMU_S_INVAL;
588                         return Err(Error::InvalidProbeRequest);
589                     }
590 
591                     let req: VirtioIommuReqProbe = desc_chain
592                         .memory()
593                         .read_obj(req_addr as GuestAddress)
594                         .map_err(Error::GuestMemory)?;
595                     debug!("Probe request {:?}", req);
596 
597                     let probe_prop = VirtioIommuProbeProperty {
598                         type_: VIRTIO_IOMMU_PROBE_T_RESV_MEM,
599                         length: size_of::<VirtioIommuProbeResvMem>() as u16,
600                     };
601                     reply.extend_from_slice(probe_prop.as_slice());
602 
603                     let resv_mem = VirtioIommuProbeResvMem {
604                         subtype: VIRTIO_IOMMU_RESV_MEM_T_MSI,
605                         start: msi_iova_start,
606                         end: msi_iova_end,
607                         ..Default::default()
608                     };
609                     reply.extend_from_slice(resv_mem.as_slice());
610 
611                     hdr_len = PROBE_PROP_SIZE;
612                 }
613                 _ => {
614                     status = VIRTIO_IOMMU_S_INVAL;
615                     return Err(Error::InvalidRequest);
616                 }
617             }
618             Ok(())
619         })();
620 
621         let status_desc = desc_chain.next().ok_or(Error::DescriptorChainTooShort)?;
622 
623         // The status MUST always be writable
624         if !status_desc.is_write_only() {
625             return Err(Error::UnexpectedReadOnlyDescriptor);
626         }
627 
628         if status_desc.len() < hdr_len + size_of::<VirtioIommuReqTail>() as u32 {
629             return Err(Error::BufferLengthTooSmall);
630         }
631 
632         let tail = VirtioIommuReqTail {
633             status,
634             ..Default::default()
635         };
636         reply.extend_from_slice(tail.as_slice());
637 
638         // Make sure we return the result of the request to the guest before
639         // we return a potential error internally.
640         desc_chain
641             .memory()
642             .write_slice(reply.as_slice(), status_desc.addr())
643             .map_err(Error::GuestMemory)?;
644 
645         // Return the error if the result was not Ok().
646         result?;
647 
648         Ok((hdr_len as usize) + size_of::<VirtioIommuReqTail>())
649     }
650 }
651 
652 fn detach_endpoint_from_domain(
653     endpoint: u32,
654     domain_id: u32,
655     mapping: &Arc<IommuMapping>,
656     ext_mapping: &BTreeMap<u32, Arc<dyn ExternalDmaMapping>>,
657 ) -> result::Result<(), Error> {
658     // Remove endpoint associated with specific domain
659     mapping.endpoints.write().unwrap().remove(&endpoint);
660 
661     // Trigger external unmapping for the endpoint if necessary.
662     if let Some(domain_mappings) = &mapping.domains.read().unwrap().get(&domain_id) {
663         if let Some(ext_map) = ext_mapping.get(&endpoint) {
664             for (virt_start, addr_map) in &domain_mappings.mappings {
665                 ext_map
666                     .unmap(*virt_start, addr_map.size)
667                     .map_err(Error::ExternalUnmapping)?;
668             }
669         }
670     }
671 
672     if mapping
673         .endpoints
674         .write()
675         .unwrap()
676         .iter()
677         .filter(|(_, &d)| d == domain_id)
678         .count()
679         == 0
680     {
681         mapping.domains.write().unwrap().remove(&domain_id);
682     }
683 
684     Ok(())
685 }
686 
687 struct IommuEpollHandler {
688     mem: GuestMemoryAtomic<GuestMemoryMmap>,
689     request_queue: Queue,
690     _event_queue: Queue,
691     interrupt_cb: Arc<dyn VirtioInterrupt>,
692     request_queue_evt: EventFd,
693     _event_queue_evt: EventFd,
694     kill_evt: EventFd,
695     pause_evt: EventFd,
696     mapping: Arc<IommuMapping>,
697     ext_mapping: Arc<Mutex<BTreeMap<u32, Arc<dyn ExternalDmaMapping>>>>,
698     msi_iova_space: (u64, u64),
699 }
700 
701 impl IommuEpollHandler {
702     fn request_queue(&mut self) -> Result<bool, Error> {
703         let mut used_descs = false;
704         while let Some(mut desc_chain) = self.request_queue.pop_descriptor_chain(self.mem.memory())
705         {
706             let len = Request::parse(
707                 &mut desc_chain,
708                 &self.mapping,
709                 &self.ext_mapping.lock().unwrap(),
710                 self.msi_iova_space,
711             )?;
712 
713             self.request_queue
714                 .add_used(desc_chain.memory(), desc_chain.head_index(), len as u32)
715                 .map_err(Error::QueueAddUsed)?;
716 
717             used_descs = true;
718         }
719 
720         Ok(used_descs)
721     }
722 
723     fn signal_used_queue(&self, queue_index: u16) -> result::Result<(), DeviceError> {
724         self.interrupt_cb
725             .trigger(VirtioInterruptType::Queue(queue_index))
726             .map_err(|e| {
727                 error!("Failed to signal used queue: {:?}", e);
728                 DeviceError::FailedSignalingUsedQueue(e)
729             })
730     }
731 
732     fn run(
733         &mut self,
734         paused: Arc<AtomicBool>,
735         paused_sync: Arc<Barrier>,
736     ) -> result::Result<(), EpollHelperError> {
737         let mut helper = EpollHelper::new(&self.kill_evt, &self.pause_evt)?;
738         helper.add_event(self.request_queue_evt.as_raw_fd(), REQUEST_Q_EVENT)?;
739         helper.run(paused, paused_sync, self)?;
740 
741         Ok(())
742     }
743 }
744 
745 impl EpollHelperHandler for IommuEpollHandler {
746     fn handle_event(
747         &mut self,
748         _helper: &mut EpollHelper,
749         event: &epoll::Event,
750     ) -> result::Result<(), EpollHelperError> {
751         let ev_type = event.data as u16;
752         match ev_type {
753             REQUEST_Q_EVENT => {
754                 self.request_queue_evt.read().map_err(|e| {
755                     EpollHelperError::HandleEvent(anyhow!("Failed to get queue event: {:?}", e))
756                 })?;
757 
758                 let needs_notification = self.request_queue().map_err(|e| {
759                     EpollHelperError::HandleEvent(anyhow!(
760                         "Failed to process request queue : {:?}",
761                         e
762                     ))
763                 })?;
764                 if needs_notification {
765                     self.signal_used_queue(0).map_err(|e| {
766                         EpollHelperError::HandleEvent(anyhow!(
767                             "Failed to signal used queue: {:?}",
768                             e
769                         ))
770                     })?;
771                 }
772             }
773             _ => {
774                 return Err(EpollHelperError::HandleEvent(anyhow!(
775                     "Unexpected event: {}",
776                     ev_type
777                 )));
778             }
779         }
780         Ok(())
781     }
782 }
783 
784 #[derive(Clone, Copy, Debug, Versionize)]
785 struct Mapping {
786     gpa: u64,
787     size: u64,
788 }
789 
790 #[derive(Clone, Debug)]
791 struct Domain {
792     mappings: BTreeMap<u64, Mapping>,
793     bypass: bool,
794 }
795 
796 #[derive(Debug)]
797 pub struct IommuMapping {
798     // Domain related to an endpoint.
799     endpoints: Arc<RwLock<BTreeMap<u32, u32>>>,
800     // Information related to each domain.
801     domains: Arc<RwLock<BTreeMap<u32, Domain>>>,
802     // Global flag indicating if endpoints that are not attached to any domain
803     // are in bypass mode.
804     bypass: AtomicBool,
805 }
806 
807 impl DmaRemapping for IommuMapping {
808     fn translate_gva(&self, id: u32, addr: u64) -> std::result::Result<u64, std::io::Error> {
809         debug!("Translate GVA addr 0x{:x}", addr);
810         if let Some(domain_id) = self.endpoints.read().unwrap().get(&id) {
811             if let Some(domain) = self.domains.read().unwrap().get(domain_id) {
812                 // Directly return identity mapping in case the domain is in
813                 // bypass mode.
814                 if domain.bypass {
815                     return Ok(addr);
816                 }
817 
818                 let range_start = if VIRTIO_IOMMU_PAGE_SIZE_MASK > addr {
819                     0
820                 } else {
821                     addr - VIRTIO_IOMMU_PAGE_SIZE_MASK
822                 };
823                 for (&key, &value) in domain
824                     .mappings
825                     .range((Included(&range_start), Included(&addr)))
826                 {
827                     if addr >= key && addr < key + value.size {
828                         let new_addr = addr - key + value.gpa;
829                         debug!("Into GPA addr 0x{:x}", new_addr);
830                         return Ok(new_addr);
831                     }
832                 }
833             }
834         } else if self.bypass.load(Ordering::Acquire) {
835             return Ok(addr);
836         }
837 
838         Err(io::Error::new(
839             io::ErrorKind::Other,
840             format!("failed to translate GVA addr 0x{addr:x}"),
841         ))
842     }
843 
844     fn translate_gpa(&self, id: u32, addr: u64) -> std::result::Result<u64, std::io::Error> {
845         debug!("Translate GPA addr 0x{:x}", addr);
846         if let Some(domain_id) = self.endpoints.read().unwrap().get(&id) {
847             if let Some(domain) = self.domains.read().unwrap().get(domain_id) {
848                 // Directly return identity mapping in case the domain is in
849                 // bypass mode.
850                 if domain.bypass {
851                     return Ok(addr);
852                 }
853 
854                 for (&key, &value) in domain.mappings.iter() {
855                     if addr >= value.gpa && addr < value.gpa + value.size {
856                         let new_addr = addr - value.gpa + key;
857                         debug!("Into GVA addr 0x{:x}", new_addr);
858                         return Ok(new_addr);
859                     }
860                 }
861             }
862         } else if self.bypass.load(Ordering::Acquire) {
863             return Ok(addr);
864         }
865 
866         Err(io::Error::new(
867             io::ErrorKind::Other,
868             format!("failed to translate GPA addr 0x{addr:x}"),
869         ))
870     }
871 }
872 
873 #[derive(Debug)]
874 pub struct AccessPlatformMapping {
875     id: u32,
876     mapping: Arc<IommuMapping>,
877 }
878 
879 impl AccessPlatformMapping {
880     pub fn new(id: u32, mapping: Arc<IommuMapping>) -> Self {
881         AccessPlatformMapping { id, mapping }
882     }
883 }
884 
885 impl AccessPlatform for AccessPlatformMapping {
886     fn translate_gva(&self, base: u64, _size: u64) -> std::result::Result<u64, std::io::Error> {
887         self.mapping.translate_gva(self.id, base)
888     }
889     fn translate_gpa(&self, base: u64, _size: u64) -> std::result::Result<u64, std::io::Error> {
890         self.mapping.translate_gpa(self.id, base)
891     }
892 }
893 
894 pub struct Iommu {
895     common: VirtioCommon,
896     id: String,
897     config: VirtioIommuConfig,
898     mapping: Arc<IommuMapping>,
899     ext_mapping: Arc<Mutex<BTreeMap<u32, Arc<dyn ExternalDmaMapping>>>>,
900     seccomp_action: SeccompAction,
901     exit_evt: EventFd,
902     msi_iova_space: (u64, u64),
903 }
904 
905 type EndpointsState = Vec<(u32, u32)>;
906 type DomainsState = Vec<(u32, (Vec<(u64, Mapping)>, bool))>;
907 
908 #[derive(Versionize)]
909 pub struct IommuState {
910     avail_features: u64,
911     acked_features: u64,
912     endpoints: EndpointsState,
913     domains: DomainsState,
914 }
915 
916 impl VersionMapped for IommuState {}
917 
918 impl Iommu {
919     pub fn new(
920         id: String,
921         seccomp_action: SeccompAction,
922         exit_evt: EventFd,
923         msi_iova_space: (u64, u64),
924         state: Option<IommuState>,
925     ) -> io::Result<(Self, Arc<IommuMapping>)> {
926         let (avail_features, acked_features, endpoints, domains, paused) =
927             if let Some(state) = state {
928                 info!("Restoring virtio-iommu {}", id);
929                 (
930                     state.avail_features,
931                     state.acked_features,
932                     state.endpoints.into_iter().collect(),
933                     state
934                         .domains
935                         .into_iter()
936                         .map(|(k, v)| {
937                             (
938                                 k,
939                                 Domain {
940                                     mappings: v.0.into_iter().collect(),
941                                     bypass: v.1,
942                                 },
943                             )
944                         })
945                         .collect(),
946                     true,
947                 )
948             } else {
949                 let avail_features = 1u64 << VIRTIO_F_VERSION_1
950                     | 1u64 << VIRTIO_IOMMU_F_MAP_UNMAP
951                     | 1u64 << VIRTIO_IOMMU_F_PROBE
952                     | 1u64 << VIRTIO_IOMMU_F_BYPASS_CONFIG;
953 
954                 (avail_features, 0, BTreeMap::new(), BTreeMap::new(), false)
955             };
956 
957         let config = VirtioIommuConfig {
958             page_size_mask: VIRTIO_IOMMU_PAGE_SIZE_MASK,
959             probe_size: PROBE_PROP_SIZE,
960             ..Default::default()
961         };
962 
963         let mapping = Arc::new(IommuMapping {
964             endpoints: Arc::new(RwLock::new(endpoints)),
965             domains: Arc::new(RwLock::new(domains)),
966             bypass: AtomicBool::new(true),
967         });
968 
969         Ok((
970             Iommu {
971                 id,
972                 common: VirtioCommon {
973                     device_type: VirtioDeviceType::Iommu as u32,
974                     queue_sizes: QUEUE_SIZES.to_vec(),
975                     avail_features,
976                     acked_features,
977                     paused_sync: Some(Arc::new(Barrier::new(2))),
978                     min_queues: NUM_QUEUES as u16,
979                     paused: Arc::new(AtomicBool::new(paused)),
980                     ..Default::default()
981                 },
982                 config,
983                 mapping: mapping.clone(),
984                 ext_mapping: Arc::new(Mutex::new(BTreeMap::new())),
985                 seccomp_action,
986                 exit_evt,
987                 msi_iova_space,
988             },
989             mapping,
990         ))
991     }
992 
993     fn state(&self) -> IommuState {
994         IommuState {
995             avail_features: self.common.avail_features,
996             acked_features: self.common.acked_features,
997             endpoints: self
998                 .mapping
999                 .endpoints
1000                 .read()
1001                 .unwrap()
1002                 .clone()
1003                 .into_iter()
1004                 .collect(),
1005             domains: self
1006                 .mapping
1007                 .domains
1008                 .read()
1009                 .unwrap()
1010                 .clone()
1011                 .into_iter()
1012                 .map(|(k, v)| (k, (v.mappings.into_iter().collect(), v.bypass)))
1013                 .collect(),
1014         }
1015     }
1016 
1017     fn update_bypass(&mut self) {
1018         // Use bypass from config if VIRTIO_IOMMU_F_BYPASS_CONFIG has been negotiated
1019         if !self
1020             .common
1021             .feature_acked(VIRTIO_IOMMU_F_BYPASS_CONFIG.into())
1022         {
1023             return;
1024         }
1025 
1026         let bypass = self.config.bypass == 1;
1027         info!("Updating bypass mode to {}", bypass);
1028         self.mapping.bypass.store(bypass, Ordering::Release);
1029     }
1030 
1031     pub fn add_external_mapping(&mut self, device_id: u32, mapping: Arc<dyn ExternalDmaMapping>) {
1032         self.ext_mapping.lock().unwrap().insert(device_id, mapping);
1033     }
1034 
1035     #[cfg(fuzzing)]
1036     pub fn wait_for_epoll_threads(&mut self) {
1037         self.common.wait_for_epoll_threads();
1038     }
1039 }
1040 
1041 impl Drop for Iommu {
1042     fn drop(&mut self) {
1043         if let Some(kill_evt) = self.common.kill_evt.take() {
1044             // Ignore the result because there is nothing we can do about it.
1045             let _ = kill_evt.write(1);
1046         }
1047         self.common.wait_for_epoll_threads();
1048     }
1049 }
1050 
1051 impl VirtioDevice for Iommu {
1052     fn device_type(&self) -> u32 {
1053         self.common.device_type
1054     }
1055 
1056     fn queue_max_sizes(&self) -> &[u16] {
1057         &self.common.queue_sizes
1058     }
1059 
1060     fn features(&self) -> u64 {
1061         self.common.avail_features
1062     }
1063 
1064     fn ack_features(&mut self, value: u64) {
1065         self.common.ack_features(value)
1066     }
1067 
1068     fn read_config(&self, offset: u64, data: &mut [u8]) {
1069         self.read_config_from_slice(self.config.as_slice(), offset, data);
1070     }
1071 
1072     fn write_config(&mut self, offset: u64, data: &[u8]) {
1073         // The "bypass" field is the only mutable field
1074         let bypass_offset =
1075             (&self.config.bypass as *const _ as u64) - (&self.config as *const _ as u64);
1076         if offset != bypass_offset || data.len() != std::mem::size_of_val(&self.config.bypass) {
1077             error!(
1078                 "Attempt to write to read-only field: offset {:x} length {}",
1079                 offset,
1080                 data.len()
1081             );
1082             return;
1083         }
1084 
1085         self.config.bypass = data[0];
1086 
1087         self.update_bypass();
1088     }
1089 
1090     fn activate(
1091         &mut self,
1092         mem: GuestMemoryAtomic<GuestMemoryMmap>,
1093         interrupt_cb: Arc<dyn VirtioInterrupt>,
1094         mut queues: Vec<(usize, Queue, EventFd)>,
1095     ) -> ActivateResult {
1096         self.common.activate(&queues, &interrupt_cb)?;
1097         let (kill_evt, pause_evt) = self.common.dup_eventfds();
1098 
1099         let (_, request_queue, request_queue_evt) = queues.remove(0);
1100         let (_, _event_queue, _event_queue_evt) = queues.remove(0);
1101 
1102         let mut handler = IommuEpollHandler {
1103             mem,
1104             request_queue,
1105             _event_queue,
1106             interrupt_cb,
1107             request_queue_evt,
1108             _event_queue_evt,
1109             kill_evt,
1110             pause_evt,
1111             mapping: self.mapping.clone(),
1112             ext_mapping: self.ext_mapping.clone(),
1113             msi_iova_space: self.msi_iova_space,
1114         };
1115 
1116         let paused = self.common.paused.clone();
1117         let paused_sync = self.common.paused_sync.clone();
1118         let mut epoll_threads = Vec::new();
1119         spawn_virtio_thread(
1120             &self.id,
1121             &self.seccomp_action,
1122             Thread::VirtioIommu,
1123             &mut epoll_threads,
1124             &self.exit_evt,
1125             move || handler.run(paused, paused_sync.unwrap()),
1126         )?;
1127 
1128         self.common.epoll_threads = Some(epoll_threads);
1129 
1130         event!("virtio-device", "activated", "id", &self.id);
1131         Ok(())
1132     }
1133 
1134     fn reset(&mut self) -> Option<Arc<dyn VirtioInterrupt>> {
1135         let result = self.common.reset();
1136         event!("virtio-device", "reset", "id", &self.id);
1137         result
1138     }
1139 }
1140 
1141 impl Pausable for Iommu {
1142     fn pause(&mut self) -> result::Result<(), MigratableError> {
1143         self.common.pause()
1144     }
1145 
1146     fn resume(&mut self) -> result::Result<(), MigratableError> {
1147         self.common.resume()
1148     }
1149 }
1150 
1151 impl Snapshottable for Iommu {
1152     fn id(&self) -> String {
1153         self.id.clone()
1154     }
1155 
1156     fn snapshot(&mut self) -> std::result::Result<Snapshot, MigratableError> {
1157         Snapshot::new_from_versioned_state(&self.state())
1158     }
1159 }
1160 impl Transportable for Iommu {}
1161 impl Migratable for Iommu {}
1162