xref: /cloud-hypervisor/virtio-devices/src/iommu.rs (revision 2fe7f54ece2a8f0461ed29aaeab41614f1f2da75)
1 // Copyright © 2019 Intel Corporation
2 //
3 // SPDX-License-Identifier: Apache-2.0 AND BSD-3-Clause
4 
5 use std::collections::BTreeMap;
6 use std::mem::size_of;
7 use std::ops::Bound::Included;
8 use std::os::unix::io::AsRawFd;
9 use std::sync::atomic::{AtomicBool, Ordering};
10 use std::sync::{Arc, Barrier, Mutex, RwLock};
11 use std::{io, result};
12 
13 use anyhow::anyhow;
14 use seccompiler::SeccompAction;
15 use serde::{Deserialize, Serialize};
16 use thiserror::Error;
17 use virtio_queue::{DescriptorChain, Queue, QueueT};
18 use vm_device::dma_mapping::ExternalDmaMapping;
19 use vm_memory::{
20     Address, ByteValued, Bytes, GuestAddress, GuestAddressSpace, GuestMemoryAtomic,
21     GuestMemoryError, GuestMemoryLoadGuard,
22 };
23 use vm_migration::{Migratable, MigratableError, Pausable, Snapshot, Snapshottable, Transportable};
24 use vm_virtio::AccessPlatform;
25 use vmm_sys_util::eventfd::EventFd;
26 
27 use super::{
28     ActivateResult, EpollHelper, EpollHelperError, EpollHelperHandler, Error as DeviceError,
29     VirtioCommon, VirtioDevice, VirtioDeviceType, EPOLL_HELPER_EVENT_LAST, VIRTIO_F_VERSION_1,
30 };
31 use crate::seccomp_filters::Thread;
32 use crate::thread_helper::spawn_virtio_thread;
33 use crate::{DmaRemapping, GuestMemoryMmap, VirtioInterrupt, VirtioInterruptType};
34 
35 /// Queues sizes
36 const QUEUE_SIZE: u16 = 256;
37 const NUM_QUEUES: usize = 2;
38 const QUEUE_SIZES: &[u16] = &[QUEUE_SIZE; NUM_QUEUES];
39 
40 /// New descriptors are pending on the request queue.
41 /// "requestq" is meant to be used anytime an action is required to be
42 /// performed on behalf of the guest driver.
43 const REQUEST_Q_EVENT: u16 = EPOLL_HELPER_EVENT_LAST + 1;
44 /// New descriptors are pending on the event queue.
45 /// "eventq" lets the device report any fault or other asynchronous event to
46 /// the guest driver.
47 const _EVENT_Q_EVENT: u16 = EPOLL_HELPER_EVENT_LAST + 2;
48 
49 /// PROBE properties size.
50 /// This is the minimal size to provide at least one RESV_MEM property.
51 /// Because virtio-iommu expects one MSI reserved region, we must provide it,
52 /// otherwise the driver in the guest will define a predefined one between
53 /// 0x8000000 and 0x80FFFFF, which is only relevant for ARM architecture, but
54 /// will conflict with x86.
55 const PROBE_PROP_SIZE: u32 =
56     (size_of::<VirtioIommuProbeProperty>() + size_of::<VirtioIommuProbeResvMem>()) as u32;
57 
58 /// Virtio IOMMU features
59 #[allow(unused)]
60 const VIRTIO_IOMMU_F_INPUT_RANGE: u32 = 0;
61 #[allow(unused)]
62 const VIRTIO_IOMMU_F_DOMAIN_RANGE: u32 = 1;
63 #[allow(unused)]
64 const VIRTIO_IOMMU_F_MAP_UNMAP: u32 = 2;
65 #[allow(unused)]
66 const VIRTIO_IOMMU_F_BYPASS: u32 = 3;
67 const VIRTIO_IOMMU_F_PROBE: u32 = 4;
68 #[allow(unused)]
69 const VIRTIO_IOMMU_F_MMIO: u32 = 5;
70 const VIRTIO_IOMMU_F_BYPASS_CONFIG: u32 = 6;
71 
72 // Support 2MiB and 4KiB page sizes.
73 const VIRTIO_IOMMU_PAGE_SIZE_MASK: u64 = (2 << 20) | (4 << 10);
74 
75 #[derive(Copy, Clone, Debug, Default)]
76 #[repr(C, packed)]
77 #[allow(dead_code)]
78 struct VirtioIommuRange32 {
79     start: u32,
80     end: u32,
81 }
82 
83 #[derive(Copy, Clone, Debug, Default)]
84 #[repr(C, packed)]
85 #[allow(dead_code)]
86 struct VirtioIommuRange64 {
87     start: u64,
88     end: u64,
89 }
90 
91 #[derive(Copy, Clone, Debug, Default)]
92 #[repr(C, packed)]
93 #[allow(dead_code)]
94 struct VirtioIommuConfig {
95     page_size_mask: u64,
96     input_range: VirtioIommuRange64,
97     domain_range: VirtioIommuRange32,
98     probe_size: u32,
99     bypass: u8,
100     _reserved: [u8; 7],
101 }
102 
103 /// Virtio IOMMU request type
104 const VIRTIO_IOMMU_T_ATTACH: u8 = 1;
105 const VIRTIO_IOMMU_T_DETACH: u8 = 2;
106 const VIRTIO_IOMMU_T_MAP: u8 = 3;
107 const VIRTIO_IOMMU_T_UNMAP: u8 = 4;
108 const VIRTIO_IOMMU_T_PROBE: u8 = 5;
109 
110 #[derive(Copy, Clone, Debug, Default)]
111 #[repr(C, packed)]
112 struct VirtioIommuReqHead {
113     type_: u8,
114     _reserved: [u8; 3],
115 }
116 
117 /// Virtio IOMMU request status
118 const VIRTIO_IOMMU_S_OK: u8 = 0;
119 #[allow(unused)]
120 const VIRTIO_IOMMU_S_IOERR: u8 = 1;
121 #[allow(unused)]
122 const VIRTIO_IOMMU_S_UNSUPP: u8 = 2;
123 #[allow(unused)]
124 const VIRTIO_IOMMU_S_DEVERR: u8 = 3;
125 #[allow(unused)]
126 const VIRTIO_IOMMU_S_INVAL: u8 = 4;
127 #[allow(unused)]
128 const VIRTIO_IOMMU_S_RANGE: u8 = 5;
129 #[allow(unused)]
130 const VIRTIO_IOMMU_S_NOENT: u8 = 6;
131 #[allow(unused)]
132 const VIRTIO_IOMMU_S_FAULT: u8 = 7;
133 #[allow(unused)]
134 const VIRTIO_IOMMU_S_NOMEM: u8 = 8;
135 
136 #[derive(Copy, Clone, Debug, Default)]
137 #[repr(C, packed)]
138 #[allow(dead_code)]
139 struct VirtioIommuReqTail {
140     status: u8,
141     _reserved: [u8; 3],
142 }
143 
144 /// ATTACH request
145 #[derive(Copy, Clone, Debug, Default)]
146 #[repr(C, packed)]
147 struct VirtioIommuReqAttach {
148     domain: u32,
149     endpoint: u32,
150     flags: u32,
151     _reserved: [u8; 4],
152 }
153 
154 const VIRTIO_IOMMU_ATTACH_F_BYPASS: u32 = 1;
155 
156 /// DETACH request
157 #[derive(Copy, Clone, Debug, Default)]
158 #[repr(C, packed)]
159 struct VirtioIommuReqDetach {
160     domain: u32,
161     endpoint: u32,
162     _reserved: [u8; 8],
163 }
164 
165 /// Virtio IOMMU request MAP flags
166 #[allow(unused)]
167 const VIRTIO_IOMMU_MAP_F_READ: u32 = 1;
168 #[allow(unused)]
169 const VIRTIO_IOMMU_MAP_F_WRITE: u32 = 1 << 1;
170 #[allow(unused)]
171 const VIRTIO_IOMMU_MAP_F_MMIO: u32 = 1 << 2;
172 #[allow(unused)]
173 const VIRTIO_IOMMU_MAP_F_MASK: u32 =
174     VIRTIO_IOMMU_MAP_F_READ | VIRTIO_IOMMU_MAP_F_WRITE | VIRTIO_IOMMU_MAP_F_MMIO;
175 
176 /// MAP request
177 #[derive(Copy, Clone, Debug, Default)]
178 #[repr(C, packed)]
179 struct VirtioIommuReqMap {
180     domain: u32,
181     virt_start: u64,
182     virt_end: u64,
183     phys_start: u64,
184     _flags: u32,
185 }
186 
187 /// UNMAP request
188 #[derive(Copy, Clone, Debug, Default)]
189 #[repr(C, packed)]
190 struct VirtioIommuReqUnmap {
191     domain: u32,
192     virt_start: u64,
193     virt_end: u64,
194     _reserved: [u8; 4],
195 }
196 
197 /// Virtio IOMMU request PROBE types
198 #[allow(unused)]
199 const VIRTIO_IOMMU_PROBE_T_NONE: u16 = 0;
200 const VIRTIO_IOMMU_PROBE_T_RESV_MEM: u16 = 1;
201 #[allow(unused)]
202 const VIRTIO_IOMMU_PROBE_T_MASK: u16 = 0xfff;
203 
204 /// PROBE request
205 #[derive(Copy, Clone, Debug, Default)]
206 #[repr(C, packed)]
207 #[allow(dead_code)]
208 struct VirtioIommuReqProbe {
209     endpoint: u32,
210     _reserved: [u64; 8],
211 }
212 
213 #[derive(Copy, Clone, Debug, Default)]
214 #[repr(C, packed)]
215 #[allow(dead_code)]
216 struct VirtioIommuProbeProperty {
217     type_: u16,
218     length: u16,
219 }
220 
221 /// Virtio IOMMU request PROBE property RESV_MEM subtypes
222 #[allow(unused)]
223 const VIRTIO_IOMMU_RESV_MEM_T_RESERVED: u8 = 0;
224 const VIRTIO_IOMMU_RESV_MEM_T_MSI: u8 = 1;
225 
226 #[derive(Copy, Clone, Debug, Default)]
227 #[repr(C, packed)]
228 #[allow(dead_code)]
229 struct VirtioIommuProbeResvMem {
230     subtype: u8,
231     _reserved: [u8; 3],
232     start: u64,
233     end: u64,
234 }
235 
236 /// Virtio IOMMU fault flags
237 #[allow(unused)]
238 const VIRTIO_IOMMU_FAULT_F_READ: u32 = 1;
239 #[allow(unused)]
240 const VIRTIO_IOMMU_FAULT_F_WRITE: u32 = 1 << 1;
241 #[allow(unused)]
242 const VIRTIO_IOMMU_FAULT_F_EXEC: u32 = 1 << 2;
243 #[allow(unused)]
244 const VIRTIO_IOMMU_FAULT_F_ADDRESS: u32 = 1 << 8;
245 
246 /// Virtio IOMMU fault reasons
247 #[allow(unused)]
248 const VIRTIO_IOMMU_FAULT_R_UNKNOWN: u32 = 0;
249 #[allow(unused)]
250 const VIRTIO_IOMMU_FAULT_R_DOMAIN: u32 = 1;
251 #[allow(unused)]
252 const VIRTIO_IOMMU_FAULT_R_MAPPING: u32 = 2;
253 
254 /// Fault reporting through eventq
255 #[allow(unused)]
256 #[derive(Copy, Clone, Debug, Default)]
257 #[repr(C, packed)]
258 struct VirtioIommuFault {
259     reason: u8,
260     reserved: [u8; 3],
261     flags: u32,
262     endpoint: u32,
263     reserved2: [u8; 4],
264     address: u64,
265 }
266 
267 // SAFETY: data structure only contain integers and have no implicit padding
268 unsafe impl ByteValued for VirtioIommuRange32 {}
269 // SAFETY: data structure only contain integers and have no implicit padding
270 unsafe impl ByteValued for VirtioIommuRange64 {}
271 // SAFETY: data structure only contain integers and have no implicit padding
272 unsafe impl ByteValued for VirtioIommuConfig {}
273 // SAFETY: data structure only contain integers and have no implicit padding
274 unsafe impl ByteValued for VirtioIommuReqHead {}
275 // SAFETY: data structure only contain integers and have no implicit padding
276 unsafe impl ByteValued for VirtioIommuReqTail {}
277 // SAFETY: data structure only contain integers and have no implicit padding
278 unsafe impl ByteValued for VirtioIommuReqAttach {}
279 // SAFETY: data structure only contain integers and have no implicit padding
280 unsafe impl ByteValued for VirtioIommuReqDetach {}
281 // SAFETY: data structure only contain integers and have no implicit padding
282 unsafe impl ByteValued for VirtioIommuReqMap {}
283 // SAFETY: data structure only contain integers and have no implicit padding
284 unsafe impl ByteValued for VirtioIommuReqUnmap {}
285 // SAFETY: data structure only contain integers and have no implicit padding
286 unsafe impl ByteValued for VirtioIommuReqProbe {}
287 // SAFETY: data structure only contain integers and have no implicit padding
288 unsafe impl ByteValued for VirtioIommuProbeProperty {}
289 // SAFETY: data structure only contain integers and have no implicit padding
290 unsafe impl ByteValued for VirtioIommuProbeResvMem {}
291 // SAFETY: data structure only contain integers and have no implicit padding
292 unsafe impl ByteValued for VirtioIommuFault {}
293 
294 #[derive(Error, Debug)]
295 enum Error {
296     #[error("Guest gave us bad memory addresses: {0}")]
297     GuestMemory(GuestMemoryError),
298     #[error("Guest gave us a write only descriptor that protocol says to read from")]
299     UnexpectedWriteOnlyDescriptor,
300     #[error("Guest gave us a read only descriptor that protocol says to write to")]
301     UnexpectedReadOnlyDescriptor,
302     #[error("Guest gave us too few descriptors in a descriptor chain")]
303     DescriptorChainTooShort,
304     #[error("Guest gave us a buffer that was too short to use")]
305     BufferLengthTooSmall,
306     #[error("Guest sent us invalid request")]
307     InvalidRequest,
308     #[error("Guest sent us invalid ATTACH request")]
309     InvalidAttachRequest,
310     #[error("Guest sent us invalid DETACH request")]
311     InvalidDetachRequest,
312     #[error("Guest sent us invalid MAP request")]
313     InvalidMapRequest,
314     #[error("Invalid to map because the domain is in bypass mode")]
315     InvalidMapRequestBypassDomain,
316     #[error("Invalid to map because the domain is missing")]
317     InvalidMapRequestMissingDomain,
318     #[error("Guest sent us invalid UNMAP request")]
319     InvalidUnmapRequest,
320     #[error("Invalid to unmap because the domain is in bypass mode")]
321     InvalidUnmapRequestBypassDomain,
322     #[error("Invalid to unmap because the domain is missing")]
323     InvalidUnmapRequestMissingDomain,
324     #[error("Guest sent us invalid PROBE request")]
325     InvalidProbeRequest,
326     #[error("Failed to performing external mapping: {0}")]
327     ExternalMapping(io::Error),
328     #[error("Failed to performing external unmapping: {0}")]
329     ExternalUnmapping(io::Error),
330     #[error("Failed adding used index: {0}")]
331     QueueAddUsed(virtio_queue::Error),
332 }
333 
334 struct Request {}
335 
336 impl Request {
337     // Parse the available vring buffer. Based on the hashmap table of external
338     // mappings required from various devices such as VFIO or vhost-user ones,
339     // this function might update the hashmap table of external mappings per
340     // domain.
341     // Basically, the VMM knows about the device_id <=> mapping relationship
342     // before running the VM, but at runtime, a new domain <=> mapping hashmap
343     // is created based on the information provided from the guest driver for
344     // virtio-iommu (giving the link device_id <=> domain).
345     fn parse(
346         desc_chain: &mut DescriptorChain<GuestMemoryLoadGuard<GuestMemoryMmap>>,
347         mapping: &Arc<IommuMapping>,
348         ext_mapping: &BTreeMap<u32, Arc<dyn ExternalDmaMapping>>,
349         msi_iova_space: (u64, u64),
350     ) -> result::Result<usize, Error> {
351         let desc = desc_chain
352             .next()
353             .ok_or(Error::DescriptorChainTooShort)
354             .inspect_err(|_| {
355                 error!("Missing head descriptor");
356             })?;
357 
358         // The descriptor contains the request type which MUST be readable.
359         if desc.is_write_only() {
360             return Err(Error::UnexpectedWriteOnlyDescriptor);
361         }
362 
363         if (desc.len() as usize) < size_of::<VirtioIommuReqHead>() {
364             return Err(Error::InvalidRequest);
365         }
366 
367         let req_head: VirtioIommuReqHead = desc_chain
368             .memory()
369             .read_obj(desc.addr())
370             .map_err(Error::GuestMemory)?;
371         let req_offset = size_of::<VirtioIommuReqHead>();
372         let desc_size_left = (desc.len() as usize) - req_offset;
373         let req_addr = if let Some(addr) = desc.addr().checked_add(req_offset as u64) {
374             addr
375         } else {
376             return Err(Error::InvalidRequest);
377         };
378 
379         let (msi_iova_start, msi_iova_end) = msi_iova_space;
380 
381         // Create the reply
382         let mut reply: Vec<u8> = Vec::new();
383         let mut status = VIRTIO_IOMMU_S_OK;
384         let mut hdr_len = 0;
385 
386         let result = (|| {
387             match req_head.type_ {
388                 VIRTIO_IOMMU_T_ATTACH => {
389                     if desc_size_left != size_of::<VirtioIommuReqAttach>() {
390                         status = VIRTIO_IOMMU_S_INVAL;
391                         return Err(Error::InvalidAttachRequest);
392                     }
393 
394                     let req: VirtioIommuReqAttach = desc_chain
395                         .memory()
396                         .read_obj(req_addr as GuestAddress)
397                         .map_err(Error::GuestMemory)?;
398                     debug!("Attach request {:?}", req);
399 
400                     // Copy the value to use it as a proper reference.
401                     let domain_id = req.domain;
402                     let endpoint = req.endpoint;
403                     let bypass =
404                         (req.flags & VIRTIO_IOMMU_ATTACH_F_BYPASS) == VIRTIO_IOMMU_ATTACH_F_BYPASS;
405 
406                     let mut old_domain_id = domain_id;
407                     if let Some(&id) = mapping.endpoints.read().unwrap().get(&endpoint) {
408                         old_domain_id = id;
409                     }
410 
411                     if old_domain_id != domain_id {
412                         detach_endpoint_from_domain(endpoint, old_domain_id, mapping, ext_mapping)?;
413                     }
414 
415                     // Add endpoint associated with specific domain
416                     mapping
417                         .endpoints
418                         .write()
419                         .unwrap()
420                         .insert(endpoint, domain_id);
421 
422                     // If any other mappings exist in the domain for other containers,
423                     // make sure to issue these mappings for the new endpoint/container
424                     if let Some(domain_mappings) = &mapping.domains.read().unwrap().get(&domain_id)
425                     {
426                         if let Some(ext_map) = ext_mapping.get(&endpoint) {
427                             for (virt_start, addr_map) in &domain_mappings.mappings {
428                                 ext_map
429                                     .map(*virt_start, addr_map.gpa, addr_map.size)
430                                     .map_err(Error::ExternalUnmapping)?;
431                             }
432                         }
433                     }
434 
435                     // Add new domain with no mapping if the entry didn't exist yet
436                     let mut domains = mapping.domains.write().unwrap();
437                     let domain = Domain {
438                         mappings: BTreeMap::new(),
439                         bypass,
440                     };
441                     domains.entry(domain_id).or_insert_with(|| domain);
442                 }
443                 VIRTIO_IOMMU_T_DETACH => {
444                     if desc_size_left != size_of::<VirtioIommuReqDetach>() {
445                         status = VIRTIO_IOMMU_S_INVAL;
446                         return Err(Error::InvalidDetachRequest);
447                     }
448 
449                     let req: VirtioIommuReqDetach = desc_chain
450                         .memory()
451                         .read_obj(req_addr as GuestAddress)
452                         .map_err(Error::GuestMemory)?;
453                     debug!("Detach request {:?}", req);
454 
455                     // Copy the value to use it as a proper reference.
456                     let domain_id = req.domain;
457                     let endpoint = req.endpoint;
458 
459                     // Remove endpoint associated with specific domain
460                     detach_endpoint_from_domain(endpoint, domain_id, mapping, ext_mapping)?;
461                 }
462                 VIRTIO_IOMMU_T_MAP => {
463                     if desc_size_left != size_of::<VirtioIommuReqMap>() {
464                         status = VIRTIO_IOMMU_S_INVAL;
465                         return Err(Error::InvalidMapRequest);
466                     }
467 
468                     let req: VirtioIommuReqMap = desc_chain
469                         .memory()
470                         .read_obj(req_addr as GuestAddress)
471                         .map_err(Error::GuestMemory)?;
472                     debug!("Map request {:?}", req);
473 
474                     // Copy the value to use it as a proper reference.
475                     let domain_id = req.domain;
476 
477                     if let Some(domain) = mapping.domains.read().unwrap().get(&domain_id) {
478                         if domain.bypass {
479                             status = VIRTIO_IOMMU_S_INVAL;
480                             return Err(Error::InvalidMapRequestBypassDomain);
481                         }
482                     } else {
483                         status = VIRTIO_IOMMU_S_INVAL;
484                         return Err(Error::InvalidMapRequestMissingDomain);
485                     }
486 
487                     // Find the list of endpoints attached to the given domain.
488                     let endpoints: Vec<u32> = mapping
489                         .endpoints
490                         .write()
491                         .unwrap()
492                         .iter()
493                         .filter(|(_, &d)| d == domain_id)
494                         .map(|(&e, _)| e)
495                         .collect();
496 
497                     // For viommu all endpoints receive their own VFIO container, as a result
498                     // Each endpoint within the domain needs to be separately mapped, as the
499                     // mapping is done on a per-container level, not a per-domain level
500                     for endpoint in endpoints {
501                         if let Some(ext_map) = ext_mapping.get(&endpoint) {
502                             let size = req.virt_end - req.virt_start + 1;
503                             ext_map
504                                 .map(req.virt_start, req.phys_start, size)
505                                 .map_err(Error::ExternalMapping)?;
506                         }
507                     }
508 
509                     // Add new mapping associated with the domain
510                     mapping
511                         .domains
512                         .write()
513                         .unwrap()
514                         .get_mut(&domain_id)
515                         .unwrap()
516                         .mappings
517                         .insert(
518                             req.virt_start,
519                             Mapping {
520                                 gpa: req.phys_start,
521                                 size: req.virt_end - req.virt_start + 1,
522                             },
523                         );
524                 }
525                 VIRTIO_IOMMU_T_UNMAP => {
526                     if desc_size_left != size_of::<VirtioIommuReqUnmap>() {
527                         status = VIRTIO_IOMMU_S_INVAL;
528                         return Err(Error::InvalidUnmapRequest);
529                     }
530 
531                     let req: VirtioIommuReqUnmap = desc_chain
532                         .memory()
533                         .read_obj(req_addr as GuestAddress)
534                         .map_err(Error::GuestMemory)?;
535                     debug!("Unmap request {:?}", req);
536 
537                     // Copy the value to use it as a proper reference.
538                     let domain_id = req.domain;
539                     let virt_start = req.virt_start;
540 
541                     if let Some(domain) = mapping.domains.read().unwrap().get(&domain_id) {
542                         if domain.bypass {
543                             status = VIRTIO_IOMMU_S_INVAL;
544                             return Err(Error::InvalidUnmapRequestBypassDomain);
545                         }
546                     } else {
547                         status = VIRTIO_IOMMU_S_INVAL;
548                         return Err(Error::InvalidUnmapRequestMissingDomain);
549                     }
550 
551                     // Find the list of endpoints attached to the given domain.
552                     let endpoints: Vec<u32> = mapping
553                         .endpoints
554                         .write()
555                         .unwrap()
556                         .iter()
557                         .filter(|(_, &d)| d == domain_id)
558                         .map(|(&e, _)| e)
559                         .collect();
560 
561                     // Trigger external unmapping if necessary.
562                     for endpoint in endpoints {
563                         if let Some(ext_map) = ext_mapping.get(&endpoint) {
564                             let size = req.virt_end - virt_start + 1;
565                             ext_map
566                                 .unmap(virt_start, size)
567                                 .map_err(Error::ExternalUnmapping)?;
568                         }
569                     }
570 
571                     // Remove all mappings associated with the domain within the requested range
572                     mapping
573                         .domains
574                         .write()
575                         .unwrap()
576                         .get_mut(&domain_id)
577                         .unwrap()
578                         .mappings
579                         .retain(|&x, _| (x < req.virt_start || x > req.virt_end));
580                 }
581                 VIRTIO_IOMMU_T_PROBE => {
582                     if desc_size_left != size_of::<VirtioIommuReqProbe>() {
583                         status = VIRTIO_IOMMU_S_INVAL;
584                         return Err(Error::InvalidProbeRequest);
585                     }
586 
587                     let req: VirtioIommuReqProbe = desc_chain
588                         .memory()
589                         .read_obj(req_addr as GuestAddress)
590                         .map_err(Error::GuestMemory)?;
591                     debug!("Probe request {:?}", req);
592 
593                     let probe_prop = VirtioIommuProbeProperty {
594                         type_: VIRTIO_IOMMU_PROBE_T_RESV_MEM,
595                         length: size_of::<VirtioIommuProbeResvMem>() as u16,
596                     };
597                     reply.extend_from_slice(probe_prop.as_slice());
598 
599                     let resv_mem = VirtioIommuProbeResvMem {
600                         subtype: VIRTIO_IOMMU_RESV_MEM_T_MSI,
601                         start: msi_iova_start,
602                         end: msi_iova_end,
603                         ..Default::default()
604                     };
605                     reply.extend_from_slice(resv_mem.as_slice());
606 
607                     hdr_len = PROBE_PROP_SIZE;
608                 }
609                 _ => {
610                     status = VIRTIO_IOMMU_S_INVAL;
611                     return Err(Error::InvalidRequest);
612                 }
613             }
614             Ok(())
615         })();
616 
617         let status_desc = desc_chain.next().ok_or(Error::DescriptorChainTooShort)?;
618 
619         // The status MUST always be writable
620         if !status_desc.is_write_only() {
621             return Err(Error::UnexpectedReadOnlyDescriptor);
622         }
623 
624         if status_desc.len() < hdr_len + size_of::<VirtioIommuReqTail>() as u32 {
625             return Err(Error::BufferLengthTooSmall);
626         }
627 
628         let tail = VirtioIommuReqTail {
629             status,
630             ..Default::default()
631         };
632         reply.extend_from_slice(tail.as_slice());
633 
634         // Make sure we return the result of the request to the guest before
635         // we return a potential error internally.
636         desc_chain
637             .memory()
638             .write_slice(reply.as_slice(), status_desc.addr())
639             .map_err(Error::GuestMemory)?;
640 
641         // Return the error if the result was not Ok().
642         result?;
643 
644         Ok((hdr_len as usize) + size_of::<VirtioIommuReqTail>())
645     }
646 }
647 
648 fn detach_endpoint_from_domain(
649     endpoint: u32,
650     domain_id: u32,
651     mapping: &Arc<IommuMapping>,
652     ext_mapping: &BTreeMap<u32, Arc<dyn ExternalDmaMapping>>,
653 ) -> result::Result<(), Error> {
654     // Remove endpoint associated with specific domain
655     mapping.endpoints.write().unwrap().remove(&endpoint);
656 
657     // Trigger external unmapping for the endpoint if necessary.
658     if let Some(domain_mappings) = &mapping.domains.read().unwrap().get(&domain_id) {
659         if let Some(ext_map) = ext_mapping.get(&endpoint) {
660             for (virt_start, addr_map) in &domain_mappings.mappings {
661                 ext_map
662                     .unmap(*virt_start, addr_map.size)
663                     .map_err(Error::ExternalUnmapping)?;
664             }
665         }
666     }
667 
668     if mapping
669         .endpoints
670         .write()
671         .unwrap()
672         .iter()
673         .filter(|(_, &d)| d == domain_id)
674         .count()
675         == 0
676     {
677         mapping.domains.write().unwrap().remove(&domain_id);
678     }
679 
680     Ok(())
681 }
682 
683 struct IommuEpollHandler {
684     mem: GuestMemoryAtomic<GuestMemoryMmap>,
685     request_queue: Queue,
686     _event_queue: Queue,
687     interrupt_cb: Arc<dyn VirtioInterrupt>,
688     request_queue_evt: EventFd,
689     _event_queue_evt: EventFd,
690     kill_evt: EventFd,
691     pause_evt: EventFd,
692     mapping: Arc<IommuMapping>,
693     ext_mapping: Arc<Mutex<BTreeMap<u32, Arc<dyn ExternalDmaMapping>>>>,
694     msi_iova_space: (u64, u64),
695 }
696 
697 impl IommuEpollHandler {
698     fn request_queue(&mut self) -> Result<bool, Error> {
699         let mut used_descs = false;
700         while let Some(mut desc_chain) = self.request_queue.pop_descriptor_chain(self.mem.memory())
701         {
702             let len = Request::parse(
703                 &mut desc_chain,
704                 &self.mapping,
705                 &self.ext_mapping.lock().unwrap(),
706                 self.msi_iova_space,
707             )?;
708 
709             self.request_queue
710                 .add_used(desc_chain.memory(), desc_chain.head_index(), len as u32)
711                 .map_err(Error::QueueAddUsed)?;
712 
713             used_descs = true;
714         }
715 
716         Ok(used_descs)
717     }
718 
719     fn signal_used_queue(&self, queue_index: u16) -> result::Result<(), DeviceError> {
720         self.interrupt_cb
721             .trigger(VirtioInterruptType::Queue(queue_index))
722             .map_err(|e| {
723                 error!("Failed to signal used queue: {:?}", e);
724                 DeviceError::FailedSignalingUsedQueue(e)
725             })
726     }
727 
728     fn run(
729         &mut self,
730         paused: Arc<AtomicBool>,
731         paused_sync: Arc<Barrier>,
732     ) -> result::Result<(), EpollHelperError> {
733         let mut helper = EpollHelper::new(&self.kill_evt, &self.pause_evt)?;
734         helper.add_event(self.request_queue_evt.as_raw_fd(), REQUEST_Q_EVENT)?;
735         helper.run(paused, paused_sync, self)?;
736 
737         Ok(())
738     }
739 }
740 
741 impl EpollHelperHandler for IommuEpollHandler {
742     fn handle_event(
743         &mut self,
744         _helper: &mut EpollHelper,
745         event: &epoll::Event,
746     ) -> result::Result<(), EpollHelperError> {
747         let ev_type = event.data as u16;
748         match ev_type {
749             REQUEST_Q_EVENT => {
750                 self.request_queue_evt.read().map_err(|e| {
751                     EpollHelperError::HandleEvent(anyhow!("Failed to get queue event: {:?}", e))
752                 })?;
753 
754                 let needs_notification = self.request_queue().map_err(|e| {
755                     EpollHelperError::HandleEvent(anyhow!(
756                         "Failed to process request queue : {:?}",
757                         e
758                     ))
759                 })?;
760                 if needs_notification {
761                     self.signal_used_queue(0).map_err(|e| {
762                         EpollHelperError::HandleEvent(anyhow!(
763                             "Failed to signal used queue: {:?}",
764                             e
765                         ))
766                     })?;
767                 }
768             }
769             _ => {
770                 return Err(EpollHelperError::HandleEvent(anyhow!(
771                     "Unexpected event: {}",
772                     ev_type
773                 )));
774             }
775         }
776         Ok(())
777     }
778 }
779 
780 #[derive(Clone, Copy, Debug, Serialize, Deserialize)]
781 struct Mapping {
782     gpa: u64,
783     size: u64,
784 }
785 
786 #[derive(Clone, Debug)]
787 struct Domain {
788     mappings: BTreeMap<u64, Mapping>,
789     bypass: bool,
790 }
791 
792 #[derive(Debug)]
793 pub struct IommuMapping {
794     // Domain related to an endpoint.
795     endpoints: Arc<RwLock<BTreeMap<u32, u32>>>,
796     // Information related to each domain.
797     domains: Arc<RwLock<BTreeMap<u32, Domain>>>,
798     // Global flag indicating if endpoints that are not attached to any domain
799     // are in bypass mode.
800     bypass: AtomicBool,
801 }
802 
803 impl DmaRemapping for IommuMapping {
804     fn translate_gva(&self, id: u32, addr: u64) -> std::result::Result<u64, std::io::Error> {
805         debug!("Translate GVA addr 0x{:x}", addr);
806         if let Some(domain_id) = self.endpoints.read().unwrap().get(&id) {
807             if let Some(domain) = self.domains.read().unwrap().get(domain_id) {
808                 // Directly return identity mapping in case the domain is in
809                 // bypass mode.
810                 if domain.bypass {
811                     return Ok(addr);
812                 }
813 
814                 let range_start = if VIRTIO_IOMMU_PAGE_SIZE_MASK > addr {
815                     0
816                 } else {
817                     addr - VIRTIO_IOMMU_PAGE_SIZE_MASK
818                 };
819                 for (&key, &value) in domain
820                     .mappings
821                     .range((Included(&range_start), Included(&addr)))
822                 {
823                     if addr >= key && addr < key + value.size {
824                         let new_addr = addr - key + value.gpa;
825                         debug!("Into GPA addr 0x{:x}", new_addr);
826                         return Ok(new_addr);
827                     }
828                 }
829             }
830         } else if self.bypass.load(Ordering::Acquire) {
831             return Ok(addr);
832         }
833 
834         Err(io::Error::new(
835             io::ErrorKind::Other,
836             format!("failed to translate GVA addr 0x{addr:x}"),
837         ))
838     }
839 
840     fn translate_gpa(&self, id: u32, addr: u64) -> std::result::Result<u64, std::io::Error> {
841         debug!("Translate GPA addr 0x{:x}", addr);
842         if let Some(domain_id) = self.endpoints.read().unwrap().get(&id) {
843             if let Some(domain) = self.domains.read().unwrap().get(domain_id) {
844                 // Directly return identity mapping in case the domain is in
845                 // bypass mode.
846                 if domain.bypass {
847                     return Ok(addr);
848                 }
849 
850                 for (&key, &value) in domain.mappings.iter() {
851                     if addr >= value.gpa && addr < value.gpa + value.size {
852                         let new_addr = addr - value.gpa + key;
853                         debug!("Into GVA addr 0x{:x}", new_addr);
854                         return Ok(new_addr);
855                     }
856                 }
857             }
858         } else if self.bypass.load(Ordering::Acquire) {
859             return Ok(addr);
860         }
861 
862         Err(io::Error::new(
863             io::ErrorKind::Other,
864             format!("failed to translate GPA addr 0x{addr:x}"),
865         ))
866     }
867 }
868 
869 #[derive(Debug)]
870 pub struct AccessPlatformMapping {
871     id: u32,
872     mapping: Arc<IommuMapping>,
873 }
874 
875 impl AccessPlatformMapping {
876     pub fn new(id: u32, mapping: Arc<IommuMapping>) -> Self {
877         AccessPlatformMapping { id, mapping }
878     }
879 }
880 
881 impl AccessPlatform for AccessPlatformMapping {
882     fn translate_gva(&self, base: u64, _size: u64) -> std::result::Result<u64, std::io::Error> {
883         self.mapping.translate_gva(self.id, base)
884     }
885     fn translate_gpa(&self, base: u64, _size: u64) -> std::result::Result<u64, std::io::Error> {
886         self.mapping.translate_gpa(self.id, base)
887     }
888 }
889 
890 pub struct Iommu {
891     common: VirtioCommon,
892     id: String,
893     config: VirtioIommuConfig,
894     mapping: Arc<IommuMapping>,
895     ext_mapping: Arc<Mutex<BTreeMap<u32, Arc<dyn ExternalDmaMapping>>>>,
896     seccomp_action: SeccompAction,
897     exit_evt: EventFd,
898     msi_iova_space: (u64, u64),
899 }
900 
901 type EndpointsState = Vec<(u32, u32)>;
902 type DomainsState = Vec<(u32, (Vec<(u64, Mapping)>, bool))>;
903 
904 #[derive(Serialize, Deserialize)]
905 pub struct IommuState {
906     avail_features: u64,
907     acked_features: u64,
908     endpoints: EndpointsState,
909     domains: DomainsState,
910 }
911 
912 impl Iommu {
913     pub fn new(
914         id: String,
915         seccomp_action: SeccompAction,
916         exit_evt: EventFd,
917         msi_iova_space: (u64, u64),
918         state: Option<IommuState>,
919     ) -> io::Result<(Self, Arc<IommuMapping>)> {
920         let (avail_features, acked_features, endpoints, domains, paused) =
921             if let Some(state) = state {
922                 info!("Restoring virtio-iommu {}", id);
923                 (
924                     state.avail_features,
925                     state.acked_features,
926                     state.endpoints.into_iter().collect(),
927                     state
928                         .domains
929                         .into_iter()
930                         .map(|(k, v)| {
931                             (
932                                 k,
933                                 Domain {
934                                     mappings: v.0.into_iter().collect(),
935                                     bypass: v.1,
936                                 },
937                             )
938                         })
939                         .collect(),
940                     true,
941                 )
942             } else {
943                 let avail_features = (1u64 << VIRTIO_F_VERSION_1)
944                     | (1u64 << VIRTIO_IOMMU_F_MAP_UNMAP)
945                     | (1u64 << VIRTIO_IOMMU_F_PROBE)
946                     | (1u64 << VIRTIO_IOMMU_F_BYPASS_CONFIG);
947 
948                 (avail_features, 0, BTreeMap::new(), BTreeMap::new(), false)
949             };
950 
951         let config = VirtioIommuConfig {
952             page_size_mask: VIRTIO_IOMMU_PAGE_SIZE_MASK,
953             probe_size: PROBE_PROP_SIZE,
954             ..Default::default()
955         };
956 
957         let mapping = Arc::new(IommuMapping {
958             endpoints: Arc::new(RwLock::new(endpoints)),
959             domains: Arc::new(RwLock::new(domains)),
960             bypass: AtomicBool::new(true),
961         });
962 
963         Ok((
964             Iommu {
965                 id,
966                 common: VirtioCommon {
967                     device_type: VirtioDeviceType::Iommu as u32,
968                     queue_sizes: QUEUE_SIZES.to_vec(),
969                     avail_features,
970                     acked_features,
971                     paused_sync: Some(Arc::new(Barrier::new(2))),
972                     min_queues: NUM_QUEUES as u16,
973                     paused: Arc::new(AtomicBool::new(paused)),
974                     ..Default::default()
975                 },
976                 config,
977                 mapping: mapping.clone(),
978                 ext_mapping: Arc::new(Mutex::new(BTreeMap::new())),
979                 seccomp_action,
980                 exit_evt,
981                 msi_iova_space,
982             },
983             mapping,
984         ))
985     }
986 
987     fn state(&self) -> IommuState {
988         IommuState {
989             avail_features: self.common.avail_features,
990             acked_features: self.common.acked_features,
991             endpoints: self
992                 .mapping
993                 .endpoints
994                 .read()
995                 .unwrap()
996                 .clone()
997                 .into_iter()
998                 .collect(),
999             domains: self
1000                 .mapping
1001                 .domains
1002                 .read()
1003                 .unwrap()
1004                 .clone()
1005                 .into_iter()
1006                 .map(|(k, v)| (k, (v.mappings.into_iter().collect(), v.bypass)))
1007                 .collect(),
1008         }
1009     }
1010 
1011     fn update_bypass(&mut self) {
1012         // Use bypass from config if VIRTIO_IOMMU_F_BYPASS_CONFIG has been negotiated
1013         if !self
1014             .common
1015             .feature_acked(VIRTIO_IOMMU_F_BYPASS_CONFIG.into())
1016         {
1017             return;
1018         }
1019 
1020         let bypass = self.config.bypass == 1;
1021         info!("Updating bypass mode to {}", bypass);
1022         self.mapping.bypass.store(bypass, Ordering::Release);
1023     }
1024 
1025     pub fn add_external_mapping(&mut self, device_id: u32, mapping: Arc<dyn ExternalDmaMapping>) {
1026         self.ext_mapping.lock().unwrap().insert(device_id, mapping);
1027     }
1028 
1029     #[cfg(fuzzing)]
1030     pub fn wait_for_epoll_threads(&mut self) {
1031         self.common.wait_for_epoll_threads();
1032     }
1033 }
1034 
1035 impl Drop for Iommu {
1036     fn drop(&mut self) {
1037         if let Some(kill_evt) = self.common.kill_evt.take() {
1038             // Ignore the result because there is nothing we can do about it.
1039             let _ = kill_evt.write(1);
1040         }
1041         self.common.wait_for_epoll_threads();
1042     }
1043 }
1044 
1045 impl VirtioDevice for Iommu {
1046     fn device_type(&self) -> u32 {
1047         self.common.device_type
1048     }
1049 
1050     fn queue_max_sizes(&self) -> &[u16] {
1051         &self.common.queue_sizes
1052     }
1053 
1054     fn features(&self) -> u64 {
1055         self.common.avail_features
1056     }
1057 
1058     fn ack_features(&mut self, value: u64) {
1059         self.common.ack_features(value)
1060     }
1061 
1062     fn read_config(&self, offset: u64, data: &mut [u8]) {
1063         self.read_config_from_slice(self.config.as_slice(), offset, data);
1064     }
1065 
1066     fn write_config(&mut self, offset: u64, data: &[u8]) {
1067         // The "bypass" field is the only mutable field
1068         let bypass_offset =
1069             (&self.config.bypass as *const _ as u64) - (&self.config as *const _ as u64);
1070         if offset != bypass_offset || data.len() != std::mem::size_of_val(&self.config.bypass) {
1071             error!(
1072                 "Attempt to write to read-only field: offset {:x} length {}",
1073                 offset,
1074                 data.len()
1075             );
1076             return;
1077         }
1078 
1079         self.config.bypass = data[0];
1080 
1081         self.update_bypass();
1082     }
1083 
1084     fn activate(
1085         &mut self,
1086         mem: GuestMemoryAtomic<GuestMemoryMmap>,
1087         interrupt_cb: Arc<dyn VirtioInterrupt>,
1088         mut queues: Vec<(usize, Queue, EventFd)>,
1089     ) -> ActivateResult {
1090         self.common.activate(&queues, &interrupt_cb)?;
1091         let (kill_evt, pause_evt) = self.common.dup_eventfds();
1092 
1093         let (_, request_queue, request_queue_evt) = queues.remove(0);
1094         let (_, _event_queue, _event_queue_evt) = queues.remove(0);
1095 
1096         let mut handler = IommuEpollHandler {
1097             mem,
1098             request_queue,
1099             _event_queue,
1100             interrupt_cb,
1101             request_queue_evt,
1102             _event_queue_evt,
1103             kill_evt,
1104             pause_evt,
1105             mapping: self.mapping.clone(),
1106             ext_mapping: self.ext_mapping.clone(),
1107             msi_iova_space: self.msi_iova_space,
1108         };
1109 
1110         let paused = self.common.paused.clone();
1111         let paused_sync = self.common.paused_sync.clone();
1112         let mut epoll_threads = Vec::new();
1113         spawn_virtio_thread(
1114             &self.id,
1115             &self.seccomp_action,
1116             Thread::VirtioIommu,
1117             &mut epoll_threads,
1118             &self.exit_evt,
1119             move || handler.run(paused, paused_sync.unwrap()),
1120         )?;
1121 
1122         self.common.epoll_threads = Some(epoll_threads);
1123 
1124         event!("virtio-device", "activated", "id", &self.id);
1125         Ok(())
1126     }
1127 
1128     fn reset(&mut self) -> Option<Arc<dyn VirtioInterrupt>> {
1129         let result = self.common.reset();
1130         event!("virtio-device", "reset", "id", &self.id);
1131         result
1132     }
1133 }
1134 
1135 impl Pausable for Iommu {
1136     fn pause(&mut self) -> result::Result<(), MigratableError> {
1137         self.common.pause()
1138     }
1139 
1140     fn resume(&mut self) -> result::Result<(), MigratableError> {
1141         self.common.resume()
1142     }
1143 }
1144 
1145 impl Snapshottable for Iommu {
1146     fn id(&self) -> String {
1147         self.id.clone()
1148     }
1149 
1150     fn snapshot(&mut self) -> std::result::Result<Snapshot, MigratableError> {
1151         Snapshot::new_from_state(&self.state())
1152     }
1153 }
1154 impl Transportable for Iommu {}
1155 impl Migratable for Iommu {}
1156