xref: /cloud-hypervisor/virtio-devices/src/iommu.rs (revision f6cd3bd86ded632da437b6dd6077f4237d2f71fe)
1 // Copyright © 2019 Intel Corporation
2 //
3 // SPDX-License-Identifier: Apache-2.0 AND BSD-3-Clause
4 
5 use super::Error as DeviceError;
6 use super::{
7     ActivateResult, EpollHelper, EpollHelperError, EpollHelperHandler, VirtioCommon, VirtioDevice,
8     VirtioDeviceType, EPOLL_HELPER_EVENT_LAST, VIRTIO_F_VERSION_1,
9 };
10 use crate::seccomp_filters::Thread;
11 use crate::thread_helper::spawn_virtio_thread;
12 use crate::GuestMemoryMmap;
13 use crate::{DmaRemapping, VirtioInterrupt, VirtioInterruptType};
14 use anyhow::anyhow;
15 use seccompiler::SeccompAction;
16 use serde::{Deserialize, Serialize};
17 use std::collections::BTreeMap;
18 use std::io;
19 use std::mem::size_of;
20 use std::ops::Bound::Included;
21 use std::os::unix::io::AsRawFd;
22 use std::result;
23 use std::sync::atomic::{AtomicBool, Ordering};
24 use std::sync::{Arc, Barrier, Mutex, RwLock};
25 use thiserror::Error;
26 use virtio_queue::{DescriptorChain, Queue, QueueT};
27 use vm_device::dma_mapping::ExternalDmaMapping;
28 use vm_memory::{
29     Address, ByteValued, Bytes, GuestAddress, GuestAddressSpace, GuestMemoryAtomic,
30     GuestMemoryError, GuestMemoryLoadGuard,
31 };
32 use vm_migration::{Migratable, MigratableError, Pausable, Snapshot, Snapshottable, Transportable};
33 use vm_virtio::AccessPlatform;
34 use vmm_sys_util::eventfd::EventFd;
35 
36 /// Queues sizes
37 const QUEUE_SIZE: u16 = 256;
38 const NUM_QUEUES: usize = 2;
39 const QUEUE_SIZES: &[u16] = &[QUEUE_SIZE; NUM_QUEUES];
40 
41 /// New descriptors are pending on the request queue.
42 /// "requestq" is meant to be used anytime an action is required to be
43 /// performed on behalf of the guest driver.
44 const REQUEST_Q_EVENT: u16 = EPOLL_HELPER_EVENT_LAST + 1;
45 /// New descriptors are pending on the event queue.
46 /// "eventq" lets the device report any fault or other asynchronous event to
47 /// the guest driver.
48 const _EVENT_Q_EVENT: u16 = EPOLL_HELPER_EVENT_LAST + 2;
49 
50 /// PROBE properties size.
51 /// This is the minimal size to provide at least one RESV_MEM property.
52 /// Because virtio-iommu expects one MSI reserved region, we must provide it,
53 /// otherwise the driver in the guest will define a predefined one between
54 /// 0x8000000 and 0x80FFFFF, which is only relevant for ARM architecture, but
55 /// will conflict with x86.
56 const PROBE_PROP_SIZE: u32 =
57     (size_of::<VirtioIommuProbeProperty>() + size_of::<VirtioIommuProbeResvMem>()) as u32;
58 
59 /// Virtio IOMMU features
60 #[allow(unused)]
61 const VIRTIO_IOMMU_F_INPUT_RANGE: u32 = 0;
62 #[allow(unused)]
63 const VIRTIO_IOMMU_F_DOMAIN_RANGE: u32 = 1;
64 #[allow(unused)]
65 const VIRTIO_IOMMU_F_MAP_UNMAP: u32 = 2;
66 #[allow(unused)]
67 const VIRTIO_IOMMU_F_BYPASS: u32 = 3;
68 const VIRTIO_IOMMU_F_PROBE: u32 = 4;
69 #[allow(unused)]
70 const VIRTIO_IOMMU_F_MMIO: u32 = 5;
71 const VIRTIO_IOMMU_F_BYPASS_CONFIG: u32 = 6;
72 
73 // Support 2MiB and 4KiB page sizes.
74 const VIRTIO_IOMMU_PAGE_SIZE_MASK: u64 = (2 << 20) | (4 << 10);
75 
76 #[derive(Copy, Clone, Debug, Default)]
77 #[repr(packed)]
78 #[allow(dead_code)]
79 struct VirtioIommuRange32 {
80     start: u32,
81     end: u32,
82 }
83 
84 #[derive(Copy, Clone, Debug, Default)]
85 #[repr(packed)]
86 #[allow(dead_code)]
87 struct VirtioIommuRange64 {
88     start: u64,
89     end: u64,
90 }
91 
92 #[derive(Copy, Clone, Debug, Default)]
93 #[repr(packed)]
94 #[allow(dead_code)]
95 struct VirtioIommuConfig {
96     page_size_mask: u64,
97     input_range: VirtioIommuRange64,
98     domain_range: VirtioIommuRange32,
99     probe_size: u32,
100     bypass: u8,
101     _reserved: [u8; 7],
102 }
103 
104 /// Virtio IOMMU request type
105 const VIRTIO_IOMMU_T_ATTACH: u8 = 1;
106 const VIRTIO_IOMMU_T_DETACH: u8 = 2;
107 const VIRTIO_IOMMU_T_MAP: u8 = 3;
108 const VIRTIO_IOMMU_T_UNMAP: u8 = 4;
109 const VIRTIO_IOMMU_T_PROBE: u8 = 5;
110 
111 #[derive(Copy, Clone, Debug, Default)]
112 #[repr(packed)]
113 struct VirtioIommuReqHead {
114     type_: u8,
115     _reserved: [u8; 3],
116 }
117 
118 /// Virtio IOMMU request status
119 const VIRTIO_IOMMU_S_OK: u8 = 0;
120 #[allow(unused)]
121 const VIRTIO_IOMMU_S_IOERR: u8 = 1;
122 #[allow(unused)]
123 const VIRTIO_IOMMU_S_UNSUPP: u8 = 2;
124 #[allow(unused)]
125 const VIRTIO_IOMMU_S_DEVERR: u8 = 3;
126 #[allow(unused)]
127 const VIRTIO_IOMMU_S_INVAL: u8 = 4;
128 #[allow(unused)]
129 const VIRTIO_IOMMU_S_RANGE: u8 = 5;
130 #[allow(unused)]
131 const VIRTIO_IOMMU_S_NOENT: u8 = 6;
132 #[allow(unused)]
133 const VIRTIO_IOMMU_S_FAULT: u8 = 7;
134 #[allow(unused)]
135 const VIRTIO_IOMMU_S_NOMEM: u8 = 8;
136 
137 #[derive(Copy, Clone, Debug, Default)]
138 #[repr(packed)]
139 #[allow(dead_code)]
140 struct VirtioIommuReqTail {
141     status: u8,
142     _reserved: [u8; 3],
143 }
144 
145 /// ATTACH request
146 #[derive(Copy, Clone, Debug, Default)]
147 #[repr(packed)]
148 struct VirtioIommuReqAttach {
149     domain: u32,
150     endpoint: u32,
151     flags: u32,
152     _reserved: [u8; 4],
153 }
154 
155 const VIRTIO_IOMMU_ATTACH_F_BYPASS: u32 = 1;
156 
157 /// DETACH request
158 #[derive(Copy, Clone, Debug, Default)]
159 #[repr(packed)]
160 struct VirtioIommuReqDetach {
161     domain: u32,
162     endpoint: u32,
163     _reserved: [u8; 8],
164 }
165 
166 /// Virtio IOMMU request MAP flags
167 #[allow(unused)]
168 const VIRTIO_IOMMU_MAP_F_READ: u32 = 1;
169 #[allow(unused)]
170 const VIRTIO_IOMMU_MAP_F_WRITE: u32 = 1 << 1;
171 #[allow(unused)]
172 const VIRTIO_IOMMU_MAP_F_MMIO: u32 = 1 << 2;
173 #[allow(unused)]
174 const VIRTIO_IOMMU_MAP_F_MASK: u32 =
175     VIRTIO_IOMMU_MAP_F_READ | VIRTIO_IOMMU_MAP_F_WRITE | VIRTIO_IOMMU_MAP_F_MMIO;
176 
177 /// MAP request
178 #[derive(Copy, Clone, Debug, Default)]
179 #[repr(packed)]
180 struct VirtioIommuReqMap {
181     domain: u32,
182     virt_start: u64,
183     virt_end: u64,
184     phys_start: u64,
185     _flags: u32,
186 }
187 
188 /// UNMAP request
189 #[derive(Copy, Clone, Debug, Default)]
190 #[repr(packed)]
191 struct VirtioIommuReqUnmap {
192     domain: u32,
193     virt_start: u64,
194     virt_end: u64,
195     _reserved: [u8; 4],
196 }
197 
198 /// Virtio IOMMU request PROBE types
199 #[allow(unused)]
200 const VIRTIO_IOMMU_PROBE_T_NONE: u16 = 0;
201 const VIRTIO_IOMMU_PROBE_T_RESV_MEM: u16 = 1;
202 #[allow(unused)]
203 const VIRTIO_IOMMU_PROBE_T_MASK: u16 = 0xfff;
204 
205 /// PROBE request
206 #[derive(Copy, Clone, Debug, Default)]
207 #[repr(packed)]
208 #[allow(dead_code)]
209 struct VirtioIommuReqProbe {
210     endpoint: u32,
211     _reserved: [u64; 8],
212 }
213 
214 #[derive(Copy, Clone, Debug, Default)]
215 #[repr(packed)]
216 #[allow(dead_code)]
217 struct VirtioIommuProbeProperty {
218     type_: u16,
219     length: u16,
220 }
221 
222 /// Virtio IOMMU request PROBE property RESV_MEM subtypes
223 #[allow(unused)]
224 const VIRTIO_IOMMU_RESV_MEM_T_RESERVED: u8 = 0;
225 const VIRTIO_IOMMU_RESV_MEM_T_MSI: u8 = 1;
226 
227 #[derive(Copy, Clone, Debug, Default)]
228 #[repr(packed)]
229 #[allow(dead_code)]
230 struct VirtioIommuProbeResvMem {
231     subtype: u8,
232     _reserved: [u8; 3],
233     start: u64,
234     end: u64,
235 }
236 
237 /// Virtio IOMMU fault flags
238 #[allow(unused)]
239 const VIRTIO_IOMMU_FAULT_F_READ: u32 = 1;
240 #[allow(unused)]
241 const VIRTIO_IOMMU_FAULT_F_WRITE: u32 = 1 << 1;
242 #[allow(unused)]
243 const VIRTIO_IOMMU_FAULT_F_EXEC: u32 = 1 << 2;
244 #[allow(unused)]
245 const VIRTIO_IOMMU_FAULT_F_ADDRESS: u32 = 1 << 8;
246 
247 /// Virtio IOMMU fault reasons
248 #[allow(unused)]
249 const VIRTIO_IOMMU_FAULT_R_UNKNOWN: u32 = 0;
250 #[allow(unused)]
251 const VIRTIO_IOMMU_FAULT_R_DOMAIN: u32 = 1;
252 #[allow(unused)]
253 const VIRTIO_IOMMU_FAULT_R_MAPPING: u32 = 2;
254 
255 /// Fault reporting through eventq
256 #[allow(unused)]
257 #[derive(Copy, Clone, Debug, Default)]
258 #[repr(packed)]
259 struct VirtioIommuFault {
260     reason: u8,
261     reserved: [u8; 3],
262     flags: u32,
263     endpoint: u32,
264     reserved2: [u8; 4],
265     address: u64,
266 }
267 
268 // SAFETY: data structure only contain integers and have no implicit padding
269 unsafe impl ByteValued for VirtioIommuRange32 {}
270 // SAFETY: data structure only contain integers and have no implicit padding
271 unsafe impl ByteValued for VirtioIommuRange64 {}
272 // SAFETY: data structure only contain integers and have no implicit padding
273 unsafe impl ByteValued for VirtioIommuConfig {}
274 // SAFETY: data structure only contain integers and have no implicit padding
275 unsafe impl ByteValued for VirtioIommuReqHead {}
276 // SAFETY: data structure only contain integers and have no implicit padding
277 unsafe impl ByteValued for VirtioIommuReqTail {}
278 // SAFETY: data structure only contain integers and have no implicit padding
279 unsafe impl ByteValued for VirtioIommuReqAttach {}
280 // SAFETY: data structure only contain integers and have no implicit padding
281 unsafe impl ByteValued for VirtioIommuReqDetach {}
282 // SAFETY: data structure only contain integers and have no implicit padding
283 unsafe impl ByteValued for VirtioIommuReqMap {}
284 // SAFETY: data structure only contain integers and have no implicit padding
285 unsafe impl ByteValued for VirtioIommuReqUnmap {}
286 // SAFETY: data structure only contain integers and have no implicit padding
287 unsafe impl ByteValued for VirtioIommuReqProbe {}
288 // SAFETY: data structure only contain integers and have no implicit padding
289 unsafe impl ByteValued for VirtioIommuProbeProperty {}
290 // SAFETY: data structure only contain integers and have no implicit padding
291 unsafe impl ByteValued for VirtioIommuProbeResvMem {}
292 // SAFETY: data structure only contain integers and have no implicit padding
293 unsafe impl ByteValued for VirtioIommuFault {}
294 
295 #[derive(Error, Debug)]
296 enum Error {
297     #[error("Guest gave us bad memory addresses: {0}")]
298     GuestMemory(GuestMemoryError),
299     #[error("Guest gave us a write only descriptor that protocol says to read from")]
300     UnexpectedWriteOnlyDescriptor,
301     #[error("Guest gave us a read only descriptor that protocol says to write to")]
302     UnexpectedReadOnlyDescriptor,
303     #[error("Guest gave us too few descriptors in a descriptor chain")]
304     DescriptorChainTooShort,
305     #[error("Guest gave us a buffer that was too short to use")]
306     BufferLengthTooSmall,
307     #[error("Guest sent us invalid request")]
308     InvalidRequest,
309     #[error("Guest sent us invalid ATTACH request")]
310     InvalidAttachRequest,
311     #[error("Guest sent us invalid DETACH request")]
312     InvalidDetachRequest,
313     #[error("Guest sent us invalid MAP request")]
314     InvalidMapRequest,
315     #[error("Invalid to map because the domain is in bypass mode")]
316     InvalidMapRequestBypassDomain,
317     #[error("Invalid to map because the domain is missing")]
318     InvalidMapRequestMissingDomain,
319     #[error("Guest sent us invalid UNMAP request")]
320     InvalidUnmapRequest,
321     #[error("Invalid to unmap because the domain is in bypass mode")]
322     InvalidUnmapRequestBypassDomain,
323     #[error("Invalid to unmap because the domain is missing")]
324     InvalidUnmapRequestMissingDomain,
325     #[error("Guest sent us invalid PROBE request")]
326     InvalidProbeRequest,
327     #[error("Failed to performing external mapping: {0}")]
328     ExternalMapping(io::Error),
329     #[error("Failed to performing external unmapping: {0}")]
330     ExternalUnmapping(io::Error),
331     #[error("Failed adding used index: {0}")]
332     QueueAddUsed(virtio_queue::Error),
333 }
334 
335 struct Request {}
336 
337 impl Request {
338     // Parse the available vring buffer. Based on the hashmap table of external
339     // mappings required from various devices such as VFIO or vhost-user ones,
340     // this function might update the hashmap table of external mappings per
341     // domain.
342     // Basically, the VMM knows about the device_id <=> mapping relationship
343     // before running the VM, but at runtime, a new domain <=> mapping hashmap
344     // is created based on the information provided from the guest driver for
345     // virtio-iommu (giving the link device_id <=> domain).
346     fn parse(
347         desc_chain: &mut DescriptorChain<GuestMemoryLoadGuard<GuestMemoryMmap>>,
348         mapping: &Arc<IommuMapping>,
349         ext_mapping: &BTreeMap<u32, Arc<dyn ExternalDmaMapping>>,
350         msi_iova_space: (u64, u64),
351     ) -> result::Result<usize, Error> {
352         let desc = desc_chain
353             .next()
354             .ok_or(Error::DescriptorChainTooShort)
355             .map_err(|e| {
356                 error!("Missing head descriptor");
357                 e
358             })?;
359 
360         // The descriptor contains the request type which MUST be readable.
361         if desc.is_write_only() {
362             return Err(Error::UnexpectedWriteOnlyDescriptor);
363         }
364 
365         if (desc.len() as usize) < size_of::<VirtioIommuReqHead>() {
366             return Err(Error::InvalidRequest);
367         }
368 
369         let req_head: VirtioIommuReqHead = desc_chain
370             .memory()
371             .read_obj(desc.addr())
372             .map_err(Error::GuestMemory)?;
373         let req_offset = size_of::<VirtioIommuReqHead>();
374         let desc_size_left = (desc.len() as usize) - req_offset;
375         let req_addr = if let Some(addr) = desc.addr().checked_add(req_offset as u64) {
376             addr
377         } else {
378             return Err(Error::InvalidRequest);
379         };
380 
381         let (msi_iova_start, msi_iova_end) = msi_iova_space;
382 
383         // Create the reply
384         let mut reply: Vec<u8> = Vec::new();
385         let mut status = VIRTIO_IOMMU_S_OK;
386         let mut hdr_len = 0;
387 
388         let result = (|| {
389             match req_head.type_ {
390                 VIRTIO_IOMMU_T_ATTACH => {
391                     if desc_size_left != size_of::<VirtioIommuReqAttach>() {
392                         status = VIRTIO_IOMMU_S_INVAL;
393                         return Err(Error::InvalidAttachRequest);
394                     }
395 
396                     let req: VirtioIommuReqAttach = desc_chain
397                         .memory()
398                         .read_obj(req_addr as GuestAddress)
399                         .map_err(Error::GuestMemory)?;
400                     debug!("Attach request {:?}", req);
401 
402                     // Copy the value to use it as a proper reference.
403                     let domain_id = req.domain;
404                     let endpoint = req.endpoint;
405                     let bypass =
406                         (req.flags & VIRTIO_IOMMU_ATTACH_F_BYPASS) == VIRTIO_IOMMU_ATTACH_F_BYPASS;
407 
408                     let mut old_domain_id = domain_id;
409                     if let Some(&id) = mapping.endpoints.read().unwrap().get(&endpoint) {
410                         old_domain_id = id;
411                     }
412 
413                     if old_domain_id != domain_id {
414                         detach_endpoint_from_domain(endpoint, old_domain_id, mapping, ext_mapping)?;
415                     }
416 
417                     // Add endpoint associated with specific domain
418                     mapping
419                         .endpoints
420                         .write()
421                         .unwrap()
422                         .insert(endpoint, domain_id);
423 
424                     // If any other mappings exist in the domain for other containers,
425                     // make sure to issue these mappings for the new endpoint/container
426                     if let Some(domain_mappings) = &mapping.domains.read().unwrap().get(&domain_id)
427                     {
428                         if let Some(ext_map) = ext_mapping.get(&endpoint) {
429                             for (virt_start, addr_map) in &domain_mappings.mappings {
430                                 ext_map
431                                     .map(*virt_start, addr_map.gpa, addr_map.size)
432                                     .map_err(Error::ExternalUnmapping)?;
433                             }
434                         }
435                     }
436 
437                     // Add new domain with no mapping if the entry didn't exist yet
438                     let mut domains = mapping.domains.write().unwrap();
439                     let domain = Domain {
440                         mappings: BTreeMap::new(),
441                         bypass,
442                     };
443                     domains.entry(domain_id).or_insert_with(|| domain);
444                 }
445                 VIRTIO_IOMMU_T_DETACH => {
446                     if desc_size_left != size_of::<VirtioIommuReqDetach>() {
447                         status = VIRTIO_IOMMU_S_INVAL;
448                         return Err(Error::InvalidDetachRequest);
449                     }
450 
451                     let req: VirtioIommuReqDetach = desc_chain
452                         .memory()
453                         .read_obj(req_addr as GuestAddress)
454                         .map_err(Error::GuestMemory)?;
455                     debug!("Detach request {:?}", req);
456 
457                     // Copy the value to use it as a proper reference.
458                     let domain_id = req.domain;
459                     let endpoint = req.endpoint;
460 
461                     // Remove endpoint associated with specific domain
462                     detach_endpoint_from_domain(endpoint, domain_id, mapping, ext_mapping)?;
463                 }
464                 VIRTIO_IOMMU_T_MAP => {
465                     if desc_size_left != size_of::<VirtioIommuReqMap>() {
466                         status = VIRTIO_IOMMU_S_INVAL;
467                         return Err(Error::InvalidMapRequest);
468                     }
469 
470                     let req: VirtioIommuReqMap = desc_chain
471                         .memory()
472                         .read_obj(req_addr as GuestAddress)
473                         .map_err(Error::GuestMemory)?;
474                     debug!("Map request {:?}", req);
475 
476                     // Copy the value to use it as a proper reference.
477                     let domain_id = req.domain;
478 
479                     if let Some(domain) = mapping.domains.read().unwrap().get(&domain_id) {
480                         if domain.bypass {
481                             status = VIRTIO_IOMMU_S_INVAL;
482                             return Err(Error::InvalidMapRequestBypassDomain);
483                         }
484                     } else {
485                         status = VIRTIO_IOMMU_S_INVAL;
486                         return Err(Error::InvalidMapRequestMissingDomain);
487                     }
488 
489                     // Find the list of endpoints attached to the given domain.
490                     let endpoints: Vec<u32> = mapping
491                         .endpoints
492                         .write()
493                         .unwrap()
494                         .iter()
495                         .filter(|(_, &d)| d == domain_id)
496                         .map(|(&e, _)| e)
497                         .collect();
498 
499                     // For viommu all endpoints receive their own VFIO container, as a result
500                     // Each endpoint within the domain needs to be separately mapped, as the
501                     // mapping is done on a per-container level, not a per-domain level
502                     for endpoint in endpoints {
503                         if let Some(ext_map) = ext_mapping.get(&endpoint) {
504                             let size = req.virt_end - req.virt_start + 1;
505                             ext_map
506                                 .map(req.virt_start, req.phys_start, size)
507                                 .map_err(Error::ExternalMapping)?;
508                         }
509                     }
510 
511                     // Add new mapping associated with the domain
512                     mapping
513                         .domains
514                         .write()
515                         .unwrap()
516                         .get_mut(&domain_id)
517                         .unwrap()
518                         .mappings
519                         .insert(
520                             req.virt_start,
521                             Mapping {
522                                 gpa: req.phys_start,
523                                 size: req.virt_end - req.virt_start + 1,
524                             },
525                         );
526                 }
527                 VIRTIO_IOMMU_T_UNMAP => {
528                     if desc_size_left != size_of::<VirtioIommuReqUnmap>() {
529                         status = VIRTIO_IOMMU_S_INVAL;
530                         return Err(Error::InvalidUnmapRequest);
531                     }
532 
533                     let req: VirtioIommuReqUnmap = desc_chain
534                         .memory()
535                         .read_obj(req_addr as GuestAddress)
536                         .map_err(Error::GuestMemory)?;
537                     debug!("Unmap request {:?}", req);
538 
539                     // Copy the value to use it as a proper reference.
540                     let domain_id = req.domain;
541                     let virt_start = req.virt_start;
542 
543                     if let Some(domain) = mapping.domains.read().unwrap().get(&domain_id) {
544                         if domain.bypass {
545                             status = VIRTIO_IOMMU_S_INVAL;
546                             return Err(Error::InvalidUnmapRequestBypassDomain);
547                         }
548                     } else {
549                         status = VIRTIO_IOMMU_S_INVAL;
550                         return Err(Error::InvalidUnmapRequestMissingDomain);
551                     }
552 
553                     // Find the list of endpoints attached to the given domain.
554                     let endpoints: Vec<u32> = mapping
555                         .endpoints
556                         .write()
557                         .unwrap()
558                         .iter()
559                         .filter(|(_, &d)| d == domain_id)
560                         .map(|(&e, _)| e)
561                         .collect();
562 
563                     // Trigger external unmapping if necessary.
564                     for endpoint in endpoints {
565                         if let Some(ext_map) = ext_mapping.get(&endpoint) {
566                             let size = req.virt_end - virt_start + 1;
567                             ext_map
568                                 .unmap(virt_start, size)
569                                 .map_err(Error::ExternalUnmapping)?;
570                         }
571                     }
572 
573                     // Remove all mappings associated with the domain within the requested range
574                     mapping
575                         .domains
576                         .write()
577                         .unwrap()
578                         .get_mut(&domain_id)
579                         .unwrap()
580                         .mappings
581                         .retain(|&x, _| (x < req.virt_start || x > req.virt_end));
582                 }
583                 VIRTIO_IOMMU_T_PROBE => {
584                     if desc_size_left != size_of::<VirtioIommuReqProbe>() {
585                         status = VIRTIO_IOMMU_S_INVAL;
586                         return Err(Error::InvalidProbeRequest);
587                     }
588 
589                     let req: VirtioIommuReqProbe = desc_chain
590                         .memory()
591                         .read_obj(req_addr as GuestAddress)
592                         .map_err(Error::GuestMemory)?;
593                     debug!("Probe request {:?}", req);
594 
595                     let probe_prop = VirtioIommuProbeProperty {
596                         type_: VIRTIO_IOMMU_PROBE_T_RESV_MEM,
597                         length: size_of::<VirtioIommuProbeResvMem>() as u16,
598                     };
599                     reply.extend_from_slice(probe_prop.as_slice());
600 
601                     let resv_mem = VirtioIommuProbeResvMem {
602                         subtype: VIRTIO_IOMMU_RESV_MEM_T_MSI,
603                         start: msi_iova_start,
604                         end: msi_iova_end,
605                         ..Default::default()
606                     };
607                     reply.extend_from_slice(resv_mem.as_slice());
608 
609                     hdr_len = PROBE_PROP_SIZE;
610                 }
611                 _ => {
612                     status = VIRTIO_IOMMU_S_INVAL;
613                     return Err(Error::InvalidRequest);
614                 }
615             }
616             Ok(())
617         })();
618 
619         let status_desc = desc_chain.next().ok_or(Error::DescriptorChainTooShort)?;
620 
621         // The status MUST always be writable
622         if !status_desc.is_write_only() {
623             return Err(Error::UnexpectedReadOnlyDescriptor);
624         }
625 
626         if status_desc.len() < hdr_len + size_of::<VirtioIommuReqTail>() as u32 {
627             return Err(Error::BufferLengthTooSmall);
628         }
629 
630         let tail = VirtioIommuReqTail {
631             status,
632             ..Default::default()
633         };
634         reply.extend_from_slice(tail.as_slice());
635 
636         // Make sure we return the result of the request to the guest before
637         // we return a potential error internally.
638         desc_chain
639             .memory()
640             .write_slice(reply.as_slice(), status_desc.addr())
641             .map_err(Error::GuestMemory)?;
642 
643         // Return the error if the result was not Ok().
644         result?;
645 
646         Ok((hdr_len as usize) + size_of::<VirtioIommuReqTail>())
647     }
648 }
649 
650 fn detach_endpoint_from_domain(
651     endpoint: u32,
652     domain_id: u32,
653     mapping: &Arc<IommuMapping>,
654     ext_mapping: &BTreeMap<u32, Arc<dyn ExternalDmaMapping>>,
655 ) -> result::Result<(), Error> {
656     // Remove endpoint associated with specific domain
657     mapping.endpoints.write().unwrap().remove(&endpoint);
658 
659     // Trigger external unmapping for the endpoint if necessary.
660     if let Some(domain_mappings) = &mapping.domains.read().unwrap().get(&domain_id) {
661         if let Some(ext_map) = ext_mapping.get(&endpoint) {
662             for (virt_start, addr_map) in &domain_mappings.mappings {
663                 ext_map
664                     .unmap(*virt_start, addr_map.size)
665                     .map_err(Error::ExternalUnmapping)?;
666             }
667         }
668     }
669 
670     if mapping
671         .endpoints
672         .write()
673         .unwrap()
674         .iter()
675         .filter(|(_, &d)| d == domain_id)
676         .count()
677         == 0
678     {
679         mapping.domains.write().unwrap().remove(&domain_id);
680     }
681 
682     Ok(())
683 }
684 
685 struct IommuEpollHandler {
686     mem: GuestMemoryAtomic<GuestMemoryMmap>,
687     request_queue: Queue,
688     _event_queue: Queue,
689     interrupt_cb: Arc<dyn VirtioInterrupt>,
690     request_queue_evt: EventFd,
691     _event_queue_evt: EventFd,
692     kill_evt: EventFd,
693     pause_evt: EventFd,
694     mapping: Arc<IommuMapping>,
695     ext_mapping: Arc<Mutex<BTreeMap<u32, Arc<dyn ExternalDmaMapping>>>>,
696     msi_iova_space: (u64, u64),
697 }
698 
699 impl IommuEpollHandler {
700     fn request_queue(&mut self) -> Result<bool, Error> {
701         let mut used_descs = false;
702         while let Some(mut desc_chain) = self.request_queue.pop_descriptor_chain(self.mem.memory())
703         {
704             let len = Request::parse(
705                 &mut desc_chain,
706                 &self.mapping,
707                 &self.ext_mapping.lock().unwrap(),
708                 self.msi_iova_space,
709             )?;
710 
711             self.request_queue
712                 .add_used(desc_chain.memory(), desc_chain.head_index(), len as u32)
713                 .map_err(Error::QueueAddUsed)?;
714 
715             used_descs = true;
716         }
717 
718         Ok(used_descs)
719     }
720 
721     fn signal_used_queue(&self, queue_index: u16) -> result::Result<(), DeviceError> {
722         self.interrupt_cb
723             .trigger(VirtioInterruptType::Queue(queue_index))
724             .map_err(|e| {
725                 error!("Failed to signal used queue: {:?}", e);
726                 DeviceError::FailedSignalingUsedQueue(e)
727             })
728     }
729 
730     fn run(
731         &mut self,
732         paused: Arc<AtomicBool>,
733         paused_sync: Arc<Barrier>,
734     ) -> result::Result<(), EpollHelperError> {
735         let mut helper = EpollHelper::new(&self.kill_evt, &self.pause_evt)?;
736         helper.add_event(self.request_queue_evt.as_raw_fd(), REQUEST_Q_EVENT)?;
737         helper.run(paused, paused_sync, self)?;
738 
739         Ok(())
740     }
741 }
742 
743 impl EpollHelperHandler for IommuEpollHandler {
744     fn handle_event(
745         &mut self,
746         _helper: &mut EpollHelper,
747         event: &epoll::Event,
748     ) -> result::Result<(), EpollHelperError> {
749         let ev_type = event.data as u16;
750         match ev_type {
751             REQUEST_Q_EVENT => {
752                 self.request_queue_evt.read().map_err(|e| {
753                     EpollHelperError::HandleEvent(anyhow!("Failed to get queue event: {:?}", e))
754                 })?;
755 
756                 let needs_notification = self.request_queue().map_err(|e| {
757                     EpollHelperError::HandleEvent(anyhow!(
758                         "Failed to process request queue : {:?}",
759                         e
760                     ))
761                 })?;
762                 if needs_notification {
763                     self.signal_used_queue(0).map_err(|e| {
764                         EpollHelperError::HandleEvent(anyhow!(
765                             "Failed to signal used queue: {:?}",
766                             e
767                         ))
768                     })?;
769                 }
770             }
771             _ => {
772                 return Err(EpollHelperError::HandleEvent(anyhow!(
773                     "Unexpected event: {}",
774                     ev_type
775                 )));
776             }
777         }
778         Ok(())
779     }
780 }
781 
782 #[derive(Clone, Copy, Debug, Serialize, Deserialize)]
783 struct Mapping {
784     gpa: u64,
785     size: u64,
786 }
787 
788 #[derive(Clone, Debug)]
789 struct Domain {
790     mappings: BTreeMap<u64, Mapping>,
791     bypass: bool,
792 }
793 
794 #[derive(Debug)]
795 pub struct IommuMapping {
796     // Domain related to an endpoint.
797     endpoints: Arc<RwLock<BTreeMap<u32, u32>>>,
798     // Information related to each domain.
799     domains: Arc<RwLock<BTreeMap<u32, Domain>>>,
800     // Global flag indicating if endpoints that are not attached to any domain
801     // are in bypass mode.
802     bypass: AtomicBool,
803 }
804 
805 impl DmaRemapping for IommuMapping {
806     fn translate_gva(&self, id: u32, addr: u64) -> std::result::Result<u64, std::io::Error> {
807         debug!("Translate GVA addr 0x{:x}", addr);
808         if let Some(domain_id) = self.endpoints.read().unwrap().get(&id) {
809             if let Some(domain) = self.domains.read().unwrap().get(domain_id) {
810                 // Directly return identity mapping in case the domain is in
811                 // bypass mode.
812                 if domain.bypass {
813                     return Ok(addr);
814                 }
815 
816                 let range_start = if VIRTIO_IOMMU_PAGE_SIZE_MASK > addr {
817                     0
818                 } else {
819                     addr - VIRTIO_IOMMU_PAGE_SIZE_MASK
820                 };
821                 for (&key, &value) in domain
822                     .mappings
823                     .range((Included(&range_start), Included(&addr)))
824                 {
825                     if addr >= key && addr < key + value.size {
826                         let new_addr = addr - key + value.gpa;
827                         debug!("Into GPA addr 0x{:x}", new_addr);
828                         return Ok(new_addr);
829                     }
830                 }
831             }
832         } else if self.bypass.load(Ordering::Acquire) {
833             return Ok(addr);
834         }
835 
836         Err(io::Error::new(
837             io::ErrorKind::Other,
838             format!("failed to translate GVA addr 0x{addr:x}"),
839         ))
840     }
841 
842     fn translate_gpa(&self, id: u32, addr: u64) -> std::result::Result<u64, std::io::Error> {
843         debug!("Translate GPA addr 0x{:x}", addr);
844         if let Some(domain_id) = self.endpoints.read().unwrap().get(&id) {
845             if let Some(domain) = self.domains.read().unwrap().get(domain_id) {
846                 // Directly return identity mapping in case the domain is in
847                 // bypass mode.
848                 if domain.bypass {
849                     return Ok(addr);
850                 }
851 
852                 for (&key, &value) in domain.mappings.iter() {
853                     if addr >= value.gpa && addr < value.gpa + value.size {
854                         let new_addr = addr - value.gpa + key;
855                         debug!("Into GVA addr 0x{:x}", new_addr);
856                         return Ok(new_addr);
857                     }
858                 }
859             }
860         } else if self.bypass.load(Ordering::Acquire) {
861             return Ok(addr);
862         }
863 
864         Err(io::Error::new(
865             io::ErrorKind::Other,
866             format!("failed to translate GPA addr 0x{addr:x}"),
867         ))
868     }
869 }
870 
871 #[derive(Debug)]
872 pub struct AccessPlatformMapping {
873     id: u32,
874     mapping: Arc<IommuMapping>,
875 }
876 
877 impl AccessPlatformMapping {
878     pub fn new(id: u32, mapping: Arc<IommuMapping>) -> Self {
879         AccessPlatformMapping { id, mapping }
880     }
881 }
882 
883 impl AccessPlatform for AccessPlatformMapping {
884     fn translate_gva(&self, base: u64, _size: u64) -> std::result::Result<u64, std::io::Error> {
885         self.mapping.translate_gva(self.id, base)
886     }
887     fn translate_gpa(&self, base: u64, _size: u64) -> std::result::Result<u64, std::io::Error> {
888         self.mapping.translate_gpa(self.id, base)
889     }
890 }
891 
892 pub struct Iommu {
893     common: VirtioCommon,
894     id: String,
895     config: VirtioIommuConfig,
896     mapping: Arc<IommuMapping>,
897     ext_mapping: Arc<Mutex<BTreeMap<u32, Arc<dyn ExternalDmaMapping>>>>,
898     seccomp_action: SeccompAction,
899     exit_evt: EventFd,
900     msi_iova_space: (u64, u64),
901 }
902 
903 type EndpointsState = Vec<(u32, u32)>;
904 type DomainsState = Vec<(u32, (Vec<(u64, Mapping)>, bool))>;
905 
906 #[derive(Serialize, Deserialize)]
907 pub struct IommuState {
908     avail_features: u64,
909     acked_features: u64,
910     endpoints: EndpointsState,
911     domains: DomainsState,
912 }
913 
914 impl Iommu {
915     pub fn new(
916         id: String,
917         seccomp_action: SeccompAction,
918         exit_evt: EventFd,
919         msi_iova_space: (u64, u64),
920         state: Option<IommuState>,
921     ) -> io::Result<(Self, Arc<IommuMapping>)> {
922         let (avail_features, acked_features, endpoints, domains, paused) =
923             if let Some(state) = state {
924                 info!("Restoring virtio-iommu {}", id);
925                 (
926                     state.avail_features,
927                     state.acked_features,
928                     state.endpoints.into_iter().collect(),
929                     state
930                         .domains
931                         .into_iter()
932                         .map(|(k, v)| {
933                             (
934                                 k,
935                                 Domain {
936                                     mappings: v.0.into_iter().collect(),
937                                     bypass: v.1,
938                                 },
939                             )
940                         })
941                         .collect(),
942                     true,
943                 )
944             } else {
945                 let avail_features = 1u64 << VIRTIO_F_VERSION_1
946                     | 1u64 << VIRTIO_IOMMU_F_MAP_UNMAP
947                     | 1u64 << VIRTIO_IOMMU_F_PROBE
948                     | 1u64 << VIRTIO_IOMMU_F_BYPASS_CONFIG;
949 
950                 (avail_features, 0, BTreeMap::new(), BTreeMap::new(), false)
951             };
952 
953         let config = VirtioIommuConfig {
954             page_size_mask: VIRTIO_IOMMU_PAGE_SIZE_MASK,
955             probe_size: PROBE_PROP_SIZE,
956             ..Default::default()
957         };
958 
959         let mapping = Arc::new(IommuMapping {
960             endpoints: Arc::new(RwLock::new(endpoints)),
961             domains: Arc::new(RwLock::new(domains)),
962             bypass: AtomicBool::new(true),
963         });
964 
965         Ok((
966             Iommu {
967                 id,
968                 common: VirtioCommon {
969                     device_type: VirtioDeviceType::Iommu as u32,
970                     queue_sizes: QUEUE_SIZES.to_vec(),
971                     avail_features,
972                     acked_features,
973                     paused_sync: Some(Arc::new(Barrier::new(2))),
974                     min_queues: NUM_QUEUES as u16,
975                     paused: Arc::new(AtomicBool::new(paused)),
976                     ..Default::default()
977                 },
978                 config,
979                 mapping: mapping.clone(),
980                 ext_mapping: Arc::new(Mutex::new(BTreeMap::new())),
981                 seccomp_action,
982                 exit_evt,
983                 msi_iova_space,
984             },
985             mapping,
986         ))
987     }
988 
989     fn state(&self) -> IommuState {
990         IommuState {
991             avail_features: self.common.avail_features,
992             acked_features: self.common.acked_features,
993             endpoints: self
994                 .mapping
995                 .endpoints
996                 .read()
997                 .unwrap()
998                 .clone()
999                 .into_iter()
1000                 .collect(),
1001             domains: self
1002                 .mapping
1003                 .domains
1004                 .read()
1005                 .unwrap()
1006                 .clone()
1007                 .into_iter()
1008                 .map(|(k, v)| (k, (v.mappings.into_iter().collect(), v.bypass)))
1009                 .collect(),
1010         }
1011     }
1012 
1013     fn update_bypass(&mut self) {
1014         // Use bypass from config if VIRTIO_IOMMU_F_BYPASS_CONFIG has been negotiated
1015         if !self
1016             .common
1017             .feature_acked(VIRTIO_IOMMU_F_BYPASS_CONFIG.into())
1018         {
1019             return;
1020         }
1021 
1022         let bypass = self.config.bypass == 1;
1023         info!("Updating bypass mode to {}", bypass);
1024         self.mapping.bypass.store(bypass, Ordering::Release);
1025     }
1026 
1027     pub fn add_external_mapping(&mut self, device_id: u32, mapping: Arc<dyn ExternalDmaMapping>) {
1028         self.ext_mapping.lock().unwrap().insert(device_id, mapping);
1029     }
1030 
1031     #[cfg(fuzzing)]
1032     pub fn wait_for_epoll_threads(&mut self) {
1033         self.common.wait_for_epoll_threads();
1034     }
1035 }
1036 
1037 impl Drop for Iommu {
1038     fn drop(&mut self) {
1039         if let Some(kill_evt) = self.common.kill_evt.take() {
1040             // Ignore the result because there is nothing we can do about it.
1041             let _ = kill_evt.write(1);
1042         }
1043         self.common.wait_for_epoll_threads();
1044     }
1045 }
1046 
1047 impl VirtioDevice for Iommu {
1048     fn device_type(&self) -> u32 {
1049         self.common.device_type
1050     }
1051 
1052     fn queue_max_sizes(&self) -> &[u16] {
1053         &self.common.queue_sizes
1054     }
1055 
1056     fn features(&self) -> u64 {
1057         self.common.avail_features
1058     }
1059 
1060     fn ack_features(&mut self, value: u64) {
1061         self.common.ack_features(value)
1062     }
1063 
1064     fn read_config(&self, offset: u64, data: &mut [u8]) {
1065         self.read_config_from_slice(self.config.as_slice(), offset, data);
1066     }
1067 
1068     fn write_config(&mut self, offset: u64, data: &[u8]) {
1069         // The "bypass" field is the only mutable field
1070         let bypass_offset =
1071             (&self.config.bypass as *const _ as u64) - (&self.config as *const _ as u64);
1072         if offset != bypass_offset || data.len() != std::mem::size_of_val(&self.config.bypass) {
1073             error!(
1074                 "Attempt to write to read-only field: offset {:x} length {}",
1075                 offset,
1076                 data.len()
1077             );
1078             return;
1079         }
1080 
1081         self.config.bypass = data[0];
1082 
1083         self.update_bypass();
1084     }
1085 
1086     fn activate(
1087         &mut self,
1088         mem: GuestMemoryAtomic<GuestMemoryMmap>,
1089         interrupt_cb: Arc<dyn VirtioInterrupt>,
1090         mut queues: Vec<(usize, Queue, EventFd)>,
1091     ) -> ActivateResult {
1092         self.common.activate(&queues, &interrupt_cb)?;
1093         let (kill_evt, pause_evt) = self.common.dup_eventfds();
1094 
1095         let (_, request_queue, request_queue_evt) = queues.remove(0);
1096         let (_, _event_queue, _event_queue_evt) = queues.remove(0);
1097 
1098         let mut handler = IommuEpollHandler {
1099             mem,
1100             request_queue,
1101             _event_queue,
1102             interrupt_cb,
1103             request_queue_evt,
1104             _event_queue_evt,
1105             kill_evt,
1106             pause_evt,
1107             mapping: self.mapping.clone(),
1108             ext_mapping: self.ext_mapping.clone(),
1109             msi_iova_space: self.msi_iova_space,
1110         };
1111 
1112         let paused = self.common.paused.clone();
1113         let paused_sync = self.common.paused_sync.clone();
1114         let mut epoll_threads = Vec::new();
1115         spawn_virtio_thread(
1116             &self.id,
1117             &self.seccomp_action,
1118             Thread::VirtioIommu,
1119             &mut epoll_threads,
1120             &self.exit_evt,
1121             move || handler.run(paused, paused_sync.unwrap()),
1122         )?;
1123 
1124         self.common.epoll_threads = Some(epoll_threads);
1125 
1126         event!("virtio-device", "activated", "id", &self.id);
1127         Ok(())
1128     }
1129 
1130     fn reset(&mut self) -> Option<Arc<dyn VirtioInterrupt>> {
1131         let result = self.common.reset();
1132         event!("virtio-device", "reset", "id", &self.id);
1133         result
1134     }
1135 }
1136 
1137 impl Pausable for Iommu {
1138     fn pause(&mut self) -> result::Result<(), MigratableError> {
1139         self.common.pause()
1140     }
1141 
1142     fn resume(&mut self) -> result::Result<(), MigratableError> {
1143         self.common.resume()
1144     }
1145 }
1146 
1147 impl Snapshottable for Iommu {
1148     fn id(&self) -> String {
1149         self.id.clone()
1150     }
1151 
1152     fn snapshot(&mut self) -> std::result::Result<Snapshot, MigratableError> {
1153         Snapshot::new_from_state(&self.state())
1154     }
1155 }
1156 impl Transportable for Iommu {}
1157 impl Migratable for Iommu {}
1158