xref: /cloud-hypervisor/virtio-devices/src/iommu.rs (revision eea9bcea38e0c5649f444c829f3a4f9c22aa486c)
1 // Copyright © 2019 Intel Corporation
2 //
3 // SPDX-License-Identifier: Apache-2.0 AND BSD-3-Clause
4 
5 use super::Error as DeviceError;
6 use super::{
7     ActivateResult, EpollHelper, EpollHelperError, EpollHelperHandler, VirtioCommon, VirtioDevice,
8     VirtioDeviceType, EPOLL_HELPER_EVENT_LAST, VIRTIO_F_VERSION_1,
9 };
10 use crate::seccomp_filters::Thread;
11 use crate::thread_helper::spawn_virtio_thread;
12 use crate::GuestMemoryMmap;
13 use crate::{DmaRemapping, VirtioInterrupt, VirtioInterruptType};
14 use anyhow::anyhow;
15 use seccompiler::SeccompAction;
16 use std::collections::BTreeMap;
17 use std::fmt::{self, Display};
18 use std::io;
19 use std::mem::size_of;
20 use std::ops::Bound::Included;
21 use std::os::unix::io::AsRawFd;
22 use std::result;
23 use std::sync::atomic::{AtomicBool, Ordering};
24 use std::sync::{Arc, Barrier, Mutex, RwLock};
25 use versionize::{VersionMap, Versionize, VersionizeResult};
26 use versionize_derive::Versionize;
27 use virtio_queue::{DescriptorChain, Queue, QueueT};
28 use vm_device::dma_mapping::ExternalDmaMapping;
29 use vm_memory::{
30     Address, ByteValued, Bytes, GuestAddress, GuestAddressSpace, GuestMemoryAtomic,
31     GuestMemoryError, GuestMemoryLoadGuard,
32 };
33 use vm_migration::VersionMapped;
34 use vm_migration::{Migratable, MigratableError, Pausable, Snapshot, Snapshottable, Transportable};
35 use vm_virtio::AccessPlatform;
36 use vmm_sys_util::eventfd::EventFd;
37 
38 /// Queues sizes
39 const QUEUE_SIZE: u16 = 256;
40 const NUM_QUEUES: usize = 2;
41 const QUEUE_SIZES: &[u16] = &[QUEUE_SIZE; NUM_QUEUES];
42 
43 /// New descriptors are pending on the request queue.
44 /// "requestq" is meant to be used anytime an action is required to be
45 /// performed on behalf of the guest driver.
46 const REQUEST_Q_EVENT: u16 = EPOLL_HELPER_EVENT_LAST + 1;
47 /// New descriptors are pending on the event queue.
48 /// "eventq" lets the device report any fault or other asynchronous event to
49 /// the guest driver.
50 const EVENT_Q_EVENT: u16 = EPOLL_HELPER_EVENT_LAST + 2;
51 
52 /// PROBE properties size.
53 /// This is the minimal size to provide at least one RESV_MEM property.
54 /// Because virtio-iommu expects one MSI reserved region, we must provide it,
55 /// otherwise the driver in the guest will define a predefined one between
56 /// 0x8000000 and 0x80FFFFF, which is only relevant for ARM architecture, but
57 /// will conflict with x86.
58 const PROBE_PROP_SIZE: u32 =
59     (size_of::<VirtioIommuProbeProperty>() + size_of::<VirtioIommuProbeResvMem>()) as u32;
60 
61 /// Virtio IOMMU features
62 #[allow(unused)]
63 const VIRTIO_IOMMU_F_INPUT_RANGE: u32 = 0;
64 #[allow(unused)]
65 const VIRTIO_IOMMU_F_DOMAIN_RANGE: u32 = 1;
66 #[allow(unused)]
67 const VIRTIO_IOMMU_F_MAP_UNMAP: u32 = 2;
68 #[allow(unused)]
69 const VIRTIO_IOMMU_F_BYPASS: u32 = 3;
70 const VIRTIO_IOMMU_F_PROBE: u32 = 4;
71 #[allow(unused)]
72 const VIRTIO_IOMMU_F_MMIO: u32 = 5;
73 const VIRTIO_IOMMU_F_BYPASS_CONFIG: u32 = 6;
74 
75 // Support 2MiB and 4KiB page sizes.
76 const VIRTIO_IOMMU_PAGE_SIZE_MASK: u64 = (2 << 20) | (4 << 10);
77 
78 #[derive(Copy, Clone, Debug, Default)]
79 #[repr(packed)]
80 #[allow(dead_code)]
81 struct VirtioIommuRange32 {
82     start: u32,
83     end: u32,
84 }
85 
86 #[derive(Copy, Clone, Debug, Default)]
87 #[repr(packed)]
88 #[allow(dead_code)]
89 struct VirtioIommuRange64 {
90     start: u64,
91     end: u64,
92 }
93 
94 #[derive(Copy, Clone, Debug, Default)]
95 #[repr(packed)]
96 #[allow(dead_code)]
97 struct VirtioIommuConfig {
98     page_size_mask: u64,
99     input_range: VirtioIommuRange64,
100     domain_range: VirtioIommuRange32,
101     probe_size: u32,
102     bypass: u8,
103     _reserved: [u8; 7],
104 }
105 
106 /// Virtio IOMMU request type
107 const VIRTIO_IOMMU_T_ATTACH: u8 = 1;
108 const VIRTIO_IOMMU_T_DETACH: u8 = 2;
109 const VIRTIO_IOMMU_T_MAP: u8 = 3;
110 const VIRTIO_IOMMU_T_UNMAP: u8 = 4;
111 const VIRTIO_IOMMU_T_PROBE: u8 = 5;
112 
113 #[derive(Copy, Clone, Debug, Default)]
114 #[repr(packed)]
115 struct VirtioIommuReqHead {
116     type_: u8,
117     _reserved: [u8; 3],
118 }
119 
120 /// Virtio IOMMU request status
121 const VIRTIO_IOMMU_S_OK: u8 = 0;
122 #[allow(unused)]
123 const VIRTIO_IOMMU_S_IOERR: u8 = 1;
124 #[allow(unused)]
125 const VIRTIO_IOMMU_S_UNSUPP: u8 = 2;
126 #[allow(unused)]
127 const VIRTIO_IOMMU_S_DEVERR: u8 = 3;
128 #[allow(unused)]
129 const VIRTIO_IOMMU_S_INVAL: u8 = 4;
130 #[allow(unused)]
131 const VIRTIO_IOMMU_S_RANGE: u8 = 5;
132 #[allow(unused)]
133 const VIRTIO_IOMMU_S_NOENT: u8 = 6;
134 #[allow(unused)]
135 const VIRTIO_IOMMU_S_FAULT: u8 = 7;
136 #[allow(unused)]
137 const VIRTIO_IOMMU_S_NOMEM: u8 = 8;
138 
139 #[derive(Copy, Clone, Debug, Default)]
140 #[repr(packed)]
141 #[allow(dead_code)]
142 struct VirtioIommuReqTail {
143     status: u8,
144     _reserved: [u8; 3],
145 }
146 
147 /// ATTACH request
148 #[derive(Copy, Clone, Debug, Default)]
149 #[repr(packed)]
150 struct VirtioIommuReqAttach {
151     domain: u32,
152     endpoint: u32,
153     flags: u32,
154     _reserved: [u8; 4],
155 }
156 
157 const VIRTIO_IOMMU_ATTACH_F_BYPASS: u32 = 1;
158 
159 /// DETACH request
160 #[derive(Copy, Clone, Debug, Default)]
161 #[repr(packed)]
162 struct VirtioIommuReqDetach {
163     domain: u32,
164     endpoint: u32,
165     _reserved: [u8; 8],
166 }
167 
168 /// Virtio IOMMU request MAP flags
169 #[allow(unused)]
170 const VIRTIO_IOMMU_MAP_F_READ: u32 = 1;
171 #[allow(unused)]
172 const VIRTIO_IOMMU_MAP_F_WRITE: u32 = 1 << 1;
173 #[allow(unused)]
174 const VIRTIO_IOMMU_MAP_F_MMIO: u32 = 1 << 2;
175 #[allow(unused)]
176 const VIRTIO_IOMMU_MAP_F_MASK: u32 =
177     VIRTIO_IOMMU_MAP_F_READ | VIRTIO_IOMMU_MAP_F_WRITE | VIRTIO_IOMMU_MAP_F_MMIO;
178 
179 /// MAP request
180 #[derive(Copy, Clone, Debug, Default)]
181 #[repr(packed)]
182 struct VirtioIommuReqMap {
183     domain: u32,
184     virt_start: u64,
185     virt_end: u64,
186     phys_start: u64,
187     _flags: u32,
188 }
189 
190 /// UNMAP request
191 #[derive(Copy, Clone, Debug, Default)]
192 #[repr(packed)]
193 struct VirtioIommuReqUnmap {
194     domain: u32,
195     virt_start: u64,
196     virt_end: u64,
197     _reserved: [u8; 4],
198 }
199 
200 /// Virtio IOMMU request PROBE types
201 #[allow(unused)]
202 const VIRTIO_IOMMU_PROBE_T_NONE: u16 = 0;
203 const VIRTIO_IOMMU_PROBE_T_RESV_MEM: u16 = 1;
204 #[allow(unused)]
205 const VIRTIO_IOMMU_PROBE_T_MASK: u16 = 0xfff;
206 
207 /// PROBE request
208 #[derive(Copy, Clone, Debug, Default)]
209 #[repr(packed)]
210 #[allow(dead_code)]
211 struct VirtioIommuReqProbe {
212     endpoint: u32,
213     _reserved: [u64; 8],
214 }
215 
216 #[derive(Copy, Clone, Debug, Default)]
217 #[repr(packed)]
218 #[allow(dead_code)]
219 struct VirtioIommuProbeProperty {
220     type_: u16,
221     length: u16,
222 }
223 
224 /// Virtio IOMMU request PROBE property RESV_MEM subtypes
225 #[allow(unused)]
226 const VIRTIO_IOMMU_RESV_MEM_T_RESERVED: u8 = 0;
227 const VIRTIO_IOMMU_RESV_MEM_T_MSI: u8 = 1;
228 
229 #[derive(Copy, Clone, Debug, Default)]
230 #[repr(packed)]
231 #[allow(dead_code)]
232 struct VirtioIommuProbeResvMem {
233     subtype: u8,
234     _reserved: [u8; 3],
235     start: u64,
236     end: u64,
237 }
238 
239 /// Virtio IOMMU fault flags
240 #[allow(unused)]
241 const VIRTIO_IOMMU_FAULT_F_READ: u32 = 1;
242 #[allow(unused)]
243 const VIRTIO_IOMMU_FAULT_F_WRITE: u32 = 1 << 1;
244 #[allow(unused)]
245 const VIRTIO_IOMMU_FAULT_F_EXEC: u32 = 1 << 2;
246 #[allow(unused)]
247 const VIRTIO_IOMMU_FAULT_F_ADDRESS: u32 = 1 << 8;
248 
249 /// Virtio IOMMU fault reasons
250 #[allow(unused)]
251 const VIRTIO_IOMMU_FAULT_R_UNKNOWN: u32 = 0;
252 #[allow(unused)]
253 const VIRTIO_IOMMU_FAULT_R_DOMAIN: u32 = 1;
254 #[allow(unused)]
255 const VIRTIO_IOMMU_FAULT_R_MAPPING: u32 = 2;
256 
257 /// Fault reporting through eventq
258 #[allow(unused)]
259 #[derive(Copy, Clone, Debug, Default)]
260 #[repr(packed)]
261 struct VirtioIommuFault {
262     reason: u8,
263     reserved: [u8; 3],
264     flags: u32,
265     endpoint: u32,
266     reserved2: [u8; 4],
267     address: u64,
268 }
269 
270 // SAFETY: these data structures only contain integers and have no implicit padding
271 unsafe impl ByteValued for VirtioIommuRange32 {}
272 unsafe impl ByteValued for VirtioIommuRange64 {}
273 unsafe impl ByteValued for VirtioIommuConfig {}
274 unsafe impl ByteValued for VirtioIommuReqHead {}
275 unsafe impl ByteValued for VirtioIommuReqTail {}
276 unsafe impl ByteValued for VirtioIommuReqAttach {}
277 unsafe impl ByteValued for VirtioIommuReqDetach {}
278 unsafe impl ByteValued for VirtioIommuReqMap {}
279 unsafe impl ByteValued for VirtioIommuReqUnmap {}
280 unsafe impl ByteValued for VirtioIommuReqProbe {}
281 unsafe impl ByteValued for VirtioIommuProbeProperty {}
282 unsafe impl ByteValued for VirtioIommuProbeResvMem {}
283 unsafe impl ByteValued for VirtioIommuFault {}
284 
285 #[derive(Debug)]
286 enum Error {
287     /// Guest gave us bad memory addresses.
288     GuestMemory(GuestMemoryError),
289     /// Guest gave us a write only descriptor that protocol says to read from.
290     UnexpectedWriteOnlyDescriptor,
291     /// Guest gave us a read only descriptor that protocol says to write to.
292     UnexpectedReadOnlyDescriptor,
293     /// Guest gave us too few descriptors in a descriptor chain.
294     DescriptorChainTooShort,
295     /// Guest gave us a buffer that was too short to use.
296     BufferLengthTooSmall,
297     /// Guest sent us invalid request.
298     InvalidRequest,
299     /// Guest sent us invalid ATTACH request.
300     InvalidAttachRequest,
301     /// Guest sent us invalid DETACH request.
302     InvalidDetachRequest,
303     /// Guest sent us invalid MAP request.
304     InvalidMapRequest,
305     /// Invalid to map because the domain is in bypass mode.
306     InvalidMapRequestBypassDomain,
307     /// Invalid to map because the domain is missing.
308     InvalidMapRequestMissingDomain,
309     /// Guest sent us invalid UNMAP request.
310     InvalidUnmapRequest,
311     /// Invalid to unmap because the domain is in bypass mode.
312     InvalidUnmapRequestBypassDomain,
313     /// Invalid to unmap because the domain is missing.
314     InvalidUnmapRequestMissingDomain,
315     /// Guest sent us invalid PROBE request.
316     InvalidProbeRequest,
317     /// Failed to performing external mapping.
318     ExternalMapping(io::Error),
319     /// Failed to performing external unmapping.
320     ExternalUnmapping(io::Error),
321 }
322 
323 impl Display for Error {
324     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
325         use self::Error::*;
326 
327         match self {
328             BufferLengthTooSmall => write!(f, "buffer length too small"),
329             DescriptorChainTooShort => write!(f, "descriptor chain too short"),
330             GuestMemory(e) => write!(f, "bad guest memory address: {}", e),
331             InvalidRequest => write!(f, "invalid request"),
332             InvalidAttachRequest => write!(f, "invalid attach request"),
333             InvalidDetachRequest => write!(f, "invalid detach request"),
334             InvalidMapRequest => write!(f, "invalid map request"),
335             InvalidMapRequestBypassDomain => {
336                 write!(f, "invalid map request because domain in bypass mode")
337             }
338             InvalidMapRequestMissingDomain => {
339                 write!(f, "invalid map request because missing domain")
340             }
341             InvalidUnmapRequest => write!(f, "invalid unmap request"),
342             InvalidUnmapRequestBypassDomain => {
343                 write!(f, "invalid unmap request because domain in bypass mode")
344             }
345             InvalidUnmapRequestMissingDomain => {
346                 write!(f, "invalid unmap request because missing domain")
347             }
348             InvalidProbeRequest => write!(f, "invalid probe request"),
349             UnexpectedReadOnlyDescriptor => write!(f, "unexpected read-only descriptor"),
350             UnexpectedWriteOnlyDescriptor => write!(f, "unexpected write-only descriptor"),
351             ExternalMapping(e) => write!(f, "failed performing external mapping: {}", e),
352             ExternalUnmapping(e) => write!(f, "failed performing external unmapping: {}", e),
353         }
354     }
355 }
356 
357 struct Request {}
358 
359 impl Request {
360     // Parse the available vring buffer. Based on the hashmap table of external
361     // mappings required from various devices such as VFIO or vhost-user ones,
362     // this function might update the hashmap table of external mappings per
363     // domain.
364     // Basically, the VMM knows about the device_id <=> mapping relationship
365     // before running the VM, but at runtime, a new domain <=> mapping hashmap
366     // is created based on the information provided from the guest driver for
367     // virtio-iommu (giving the link device_id <=> domain).
368     fn parse(
369         desc_chain: &mut DescriptorChain<GuestMemoryLoadGuard<GuestMemoryMmap>>,
370         mapping: &Arc<IommuMapping>,
371         ext_mapping: &BTreeMap<u32, Arc<dyn ExternalDmaMapping>>,
372         msi_iova_space: (u64, u64),
373     ) -> result::Result<usize, Error> {
374         let desc = desc_chain
375             .next()
376             .ok_or(Error::DescriptorChainTooShort)
377             .map_err(|e| {
378                 error!("Missing head descriptor");
379                 e
380             })?;
381 
382         // The descriptor contains the request type which MUST be readable.
383         if desc.is_write_only() {
384             return Err(Error::UnexpectedWriteOnlyDescriptor);
385         }
386 
387         if (desc.len() as usize) < size_of::<VirtioIommuReqHead>() {
388             return Err(Error::InvalidRequest);
389         }
390 
391         let req_head: VirtioIommuReqHead = desc_chain
392             .memory()
393             .read_obj(desc.addr())
394             .map_err(Error::GuestMemory)?;
395         let req_offset = size_of::<VirtioIommuReqHead>();
396         let desc_size_left = (desc.len() as usize) - req_offset;
397         let req_addr = if let Some(addr) = desc.addr().checked_add(req_offset as u64) {
398             addr
399         } else {
400             return Err(Error::InvalidRequest);
401         };
402 
403         let (msi_iova_start, msi_iova_end) = msi_iova_space;
404 
405         // Create the reply
406         let mut reply: Vec<u8> = Vec::new();
407         let mut status = VIRTIO_IOMMU_S_OK;
408         let mut hdr_len = 0;
409 
410         let result = (|| {
411             match req_head.type_ {
412                 VIRTIO_IOMMU_T_ATTACH => {
413                     if desc_size_left != size_of::<VirtioIommuReqAttach>() {
414                         status = VIRTIO_IOMMU_S_INVAL;
415                         return Err(Error::InvalidAttachRequest);
416                     }
417 
418                     let req: VirtioIommuReqAttach = desc_chain
419                         .memory()
420                         .read_obj(req_addr as GuestAddress)
421                         .map_err(Error::GuestMemory)?;
422                     debug!("Attach request {:?}", req);
423 
424                     // Copy the value to use it as a proper reference.
425                     let domain_id = req.domain;
426                     let endpoint = req.endpoint;
427                     let bypass =
428                         (req.flags & VIRTIO_IOMMU_ATTACH_F_BYPASS) == VIRTIO_IOMMU_ATTACH_F_BYPASS;
429 
430                     // Add endpoint associated with specific domain
431                     mapping
432                         .endpoints
433                         .write()
434                         .unwrap()
435                         .insert(endpoint, domain_id);
436 
437                     // Add new domain with no mapping if the entry didn't exist yet
438                     let mut domains = mapping.domains.write().unwrap();
439                     let domain = Domain {
440                         mappings: BTreeMap::new(),
441                         bypass,
442                     };
443                     domains.entry(domain_id).or_insert_with(|| domain);
444                 }
445                 VIRTIO_IOMMU_T_DETACH => {
446                     if desc_size_left != size_of::<VirtioIommuReqDetach>() {
447                         status = VIRTIO_IOMMU_S_INVAL;
448                         return Err(Error::InvalidDetachRequest);
449                     }
450 
451                     let req: VirtioIommuReqDetach = desc_chain
452                         .memory()
453                         .read_obj(req_addr as GuestAddress)
454                         .map_err(Error::GuestMemory)?;
455                     debug!("Detach request {:?}", req);
456 
457                     // Copy the value to use it as a proper reference.
458                     let domain_id = req.domain;
459                     let endpoint = req.endpoint;
460 
461                     // Remove endpoint associated with specific domain
462                     mapping.endpoints.write().unwrap().remove(&endpoint);
463 
464                     // After all endpoints have been successfully detached from a
465                     // domain, the domain can be removed. This means we must remove
466                     // the mappings associated with this domain.
467                     if mapping
468                         .endpoints
469                         .write()
470                         .unwrap()
471                         .iter()
472                         .filter(|(_, &d)| d == domain_id)
473                         .count()
474                         == 0
475                     {
476                         mapping.domains.write().unwrap().remove(&domain_id);
477                     }
478                 }
479                 VIRTIO_IOMMU_T_MAP => {
480                     if desc_size_left != size_of::<VirtioIommuReqMap>() {
481                         status = VIRTIO_IOMMU_S_INVAL;
482                         return Err(Error::InvalidMapRequest);
483                     }
484 
485                     let req: VirtioIommuReqMap = desc_chain
486                         .memory()
487                         .read_obj(req_addr as GuestAddress)
488                         .map_err(Error::GuestMemory)?;
489                     debug!("Map request {:?}", req);
490 
491                     // Copy the value to use it as a proper reference.
492                     let domain_id = req.domain;
493 
494                     if let Some(domain) = mapping.domains.read().unwrap().get(&domain_id) {
495                         if domain.bypass {
496                             status = VIRTIO_IOMMU_S_INVAL;
497                             return Err(Error::InvalidMapRequestBypassDomain);
498                         }
499                     } else {
500                         status = VIRTIO_IOMMU_S_INVAL;
501                         return Err(Error::InvalidMapRequestMissingDomain);
502                     }
503 
504                     // Find the list of endpoints attached to the given domain.
505                     let endpoints: Vec<u32> = mapping
506                         .endpoints
507                         .write()
508                         .unwrap()
509                         .iter()
510                         .filter(|(_, &d)| d == domain_id)
511                         .map(|(&e, _)| e)
512                         .collect();
513 
514                     // Trigger external mapping if necessary.
515                     for endpoint in endpoints {
516                         if let Some(ext_map) = ext_mapping.get(&endpoint) {
517                             let size = req.virt_end - req.virt_start + 1;
518                             ext_map
519                                 .map(req.virt_start, req.phys_start, size)
520                                 .map_err(Error::ExternalMapping)?;
521                         }
522                     }
523 
524                     // Add new mapping associated with the domain
525                     mapping
526                         .domains
527                         .write()
528                         .unwrap()
529                         .get_mut(&domain_id)
530                         .unwrap()
531                         .mappings
532                         .insert(
533                             req.virt_start,
534                             Mapping {
535                                 gpa: req.phys_start,
536                                 size: req.virt_end - req.virt_start + 1,
537                             },
538                         );
539                 }
540                 VIRTIO_IOMMU_T_UNMAP => {
541                     if desc_size_left != size_of::<VirtioIommuReqUnmap>() {
542                         status = VIRTIO_IOMMU_S_INVAL;
543                         return Err(Error::InvalidUnmapRequest);
544                     }
545 
546                     let req: VirtioIommuReqUnmap = desc_chain
547                         .memory()
548                         .read_obj(req_addr as GuestAddress)
549                         .map_err(Error::GuestMemory)?;
550                     debug!("Unmap request {:?}", req);
551 
552                     // Copy the value to use it as a proper reference.
553                     let domain_id = req.domain;
554                     let virt_start = req.virt_start;
555 
556                     if let Some(domain) = mapping.domains.read().unwrap().get(&domain_id) {
557                         if domain.bypass {
558                             status = VIRTIO_IOMMU_S_INVAL;
559                             return Err(Error::InvalidUnmapRequestBypassDomain);
560                         }
561                     } else {
562                         status = VIRTIO_IOMMU_S_INVAL;
563                         return Err(Error::InvalidUnmapRequestMissingDomain);
564                     }
565 
566                     // Find the list of endpoints attached to the given domain.
567                     let endpoints: Vec<u32> = mapping
568                         .endpoints
569                         .write()
570                         .unwrap()
571                         .iter()
572                         .filter(|(_, &d)| d == domain_id)
573                         .map(|(&e, _)| e)
574                         .collect();
575 
576                     // Trigger external unmapping if necessary.
577                     for endpoint in endpoints {
578                         if let Some(ext_map) = ext_mapping.get(&endpoint) {
579                             let size = req.virt_end - virt_start + 1;
580                             ext_map
581                                 .unmap(virt_start, size)
582                                 .map_err(Error::ExternalUnmapping)?;
583                         }
584                     }
585 
586                     // Remove mapping associated with the domain
587                     mapping
588                         .domains
589                         .write()
590                         .unwrap()
591                         .get_mut(&domain_id)
592                         .unwrap()
593                         .mappings
594                         .remove(&virt_start);
595                 }
596                 VIRTIO_IOMMU_T_PROBE => {
597                     if desc_size_left != size_of::<VirtioIommuReqProbe>() {
598                         status = VIRTIO_IOMMU_S_INVAL;
599                         return Err(Error::InvalidProbeRequest);
600                     }
601 
602                     let req: VirtioIommuReqProbe = desc_chain
603                         .memory()
604                         .read_obj(req_addr as GuestAddress)
605                         .map_err(Error::GuestMemory)?;
606                     debug!("Probe request {:?}", req);
607 
608                     let probe_prop = VirtioIommuProbeProperty {
609                         type_: VIRTIO_IOMMU_PROBE_T_RESV_MEM,
610                         length: size_of::<VirtioIommuProbeResvMem>() as u16,
611                     };
612                     reply.extend_from_slice(probe_prop.as_slice());
613 
614                     let resv_mem = VirtioIommuProbeResvMem {
615                         subtype: VIRTIO_IOMMU_RESV_MEM_T_MSI,
616                         start: msi_iova_start,
617                         end: msi_iova_end,
618                         ..Default::default()
619                     };
620                     reply.extend_from_slice(resv_mem.as_slice());
621 
622                     hdr_len = PROBE_PROP_SIZE;
623                 }
624                 _ => {
625                     status = VIRTIO_IOMMU_S_INVAL;
626                     return Err(Error::InvalidRequest);
627                 }
628             }
629             Ok(())
630         })();
631 
632         let status_desc = desc_chain.next().ok_or(Error::DescriptorChainTooShort)?;
633 
634         // The status MUST always be writable
635         if !status_desc.is_write_only() {
636             return Err(Error::UnexpectedReadOnlyDescriptor);
637         }
638 
639         if status_desc.len() < hdr_len + size_of::<VirtioIommuReqTail>() as u32 {
640             return Err(Error::BufferLengthTooSmall);
641         }
642 
643         let tail = VirtioIommuReqTail {
644             status,
645             ..Default::default()
646         };
647         reply.extend_from_slice(tail.as_slice());
648 
649         // Make sure we return the result of the request to the guest before
650         // we return a potential error internally.
651         desc_chain
652             .memory()
653             .write_slice(reply.as_slice(), status_desc.addr())
654             .map_err(Error::GuestMemory)?;
655 
656         // Return the error if the result was not Ok().
657         result?;
658 
659         Ok((hdr_len as usize) + size_of::<VirtioIommuReqTail>())
660     }
661 }
662 
663 struct IommuEpollHandler {
664     mem: GuestMemoryAtomic<GuestMemoryMmap>,
665     queues: Vec<Queue>,
666     interrupt_cb: Arc<dyn VirtioInterrupt>,
667     queue_evts: Vec<EventFd>,
668     kill_evt: EventFd,
669     pause_evt: EventFd,
670     mapping: Arc<IommuMapping>,
671     ext_mapping: Arc<Mutex<BTreeMap<u32, Arc<dyn ExternalDmaMapping>>>>,
672     msi_iova_space: (u64, u64),
673 }
674 
675 impl IommuEpollHandler {
676     fn request_queue(&mut self) -> bool {
677         let mut used_descs = false;
678         while let Some(mut desc_chain) = self.queues[0].pop_descriptor_chain(self.mem.memory()) {
679             let len = match Request::parse(
680                 &mut desc_chain,
681                 &self.mapping,
682                 &self.ext_mapping.lock().unwrap(),
683                 self.msi_iova_space,
684             ) {
685                 Ok(len) => len as u32,
686                 Err(e) => {
687                     error!("failed parsing descriptor: {}", e);
688                     0
689                 }
690             };
691 
692             self.queues[0]
693                 .add_used(desc_chain.memory(), desc_chain.head_index(), len)
694                 .unwrap();
695             used_descs = true;
696         }
697 
698         used_descs
699     }
700 
701     fn event_queue(&mut self) -> bool {
702         false
703     }
704 
705     fn signal_used_queue(&self, queue_index: u16) -> result::Result<(), DeviceError> {
706         self.interrupt_cb
707             .trigger(VirtioInterruptType::Queue(queue_index))
708             .map_err(|e| {
709                 error!("Failed to signal used queue: {:?}", e);
710                 DeviceError::FailedSignalingUsedQueue(e)
711             })
712     }
713 
714     fn run(
715         &mut self,
716         paused: Arc<AtomicBool>,
717         paused_sync: Arc<Barrier>,
718     ) -> result::Result<(), EpollHelperError> {
719         let mut helper = EpollHelper::new(&self.kill_evt, &self.pause_evt)?;
720         helper.add_event(self.queue_evts[0].as_raw_fd(), REQUEST_Q_EVENT)?;
721         helper.add_event(self.queue_evts[1].as_raw_fd(), EVENT_Q_EVENT)?;
722         helper.run(paused, paused_sync, self)?;
723 
724         Ok(())
725     }
726 }
727 
728 impl EpollHelperHandler for IommuEpollHandler {
729     fn handle_event(
730         &mut self,
731         _helper: &mut EpollHelper,
732         event: &epoll::Event,
733     ) -> result::Result<(), EpollHelperError> {
734         let ev_type = event.data as u16;
735         match ev_type {
736             REQUEST_Q_EVENT => {
737                 self.queue_evts[0].read().map_err(|e| {
738                     EpollHelperError::HandleEvent(anyhow!("Failed to get queue event: {:?}", e))
739                 })?;
740 
741                 if self.request_queue() {
742                     self.signal_used_queue(0).map_err(|e| {
743                         EpollHelperError::HandleEvent(anyhow!(
744                             "Failed to signal used queue: {:?}",
745                             e
746                         ))
747                     })?;
748                 }
749             }
750             EVENT_Q_EVENT => {
751                 self.queue_evts[1].read().map_err(|e| {
752                     EpollHelperError::HandleEvent(anyhow!("Failed to get queue event: {:?}", e))
753                 })?;
754 
755                 if self.event_queue() {
756                     self.signal_used_queue(1).map_err(|e| {
757                         EpollHelperError::HandleEvent(anyhow!(
758                             "Failed to signal used queue: {:?}",
759                             e
760                         ))
761                     })?;
762                 }
763             }
764             _ => {
765                 return Err(EpollHelperError::HandleEvent(anyhow!(
766                     "Unexpected event: {}",
767                     ev_type
768                 )));
769             }
770         }
771         Ok(())
772     }
773 }
774 
775 #[derive(Clone, Copy, Debug, Versionize)]
776 struct Mapping {
777     gpa: u64,
778     size: u64,
779 }
780 
781 #[derive(Clone, Debug)]
782 struct Domain {
783     mappings: BTreeMap<u64, Mapping>,
784     bypass: bool,
785 }
786 
787 #[derive(Debug)]
788 pub struct IommuMapping {
789     // Domain related to an endpoint.
790     endpoints: Arc<RwLock<BTreeMap<u32, u32>>>,
791     // Information related to each domain.
792     domains: Arc<RwLock<BTreeMap<u32, Domain>>>,
793     // Global flag indicating if endpoints that are not attached to any domain
794     // are in bypass mode.
795     bypass: AtomicBool,
796 }
797 
798 impl DmaRemapping for IommuMapping {
799     fn translate_gva(&self, id: u32, addr: u64) -> std::result::Result<u64, std::io::Error> {
800         debug!("Translate GVA addr 0x{:x}", addr);
801         if let Some(domain_id) = self.endpoints.read().unwrap().get(&id) {
802             if let Some(domain) = self.domains.read().unwrap().get(domain_id) {
803                 // Directly return identity mapping in case the domain is in
804                 // bypass mode.
805                 if domain.bypass {
806                     return Ok(addr);
807                 }
808 
809                 let range_start = if VIRTIO_IOMMU_PAGE_SIZE_MASK > addr {
810                     0
811                 } else {
812                     addr - VIRTIO_IOMMU_PAGE_SIZE_MASK
813                 };
814                 for (&key, &value) in domain
815                     .mappings
816                     .range((Included(&range_start), Included(&addr)))
817                 {
818                     if addr >= key && addr < key + value.size {
819                         let new_addr = addr - key + value.gpa;
820                         debug!("Into GPA addr 0x{:x}", new_addr);
821                         return Ok(new_addr);
822                     }
823                 }
824             }
825         } else if self.bypass.load(Ordering::Acquire) {
826             return Ok(addr);
827         }
828 
829         Err(io::Error::new(
830             io::ErrorKind::Other,
831             format!("failed to translate GVA addr 0x{:x}", addr),
832         ))
833     }
834 
835     fn translate_gpa(&self, id: u32, addr: u64) -> std::result::Result<u64, std::io::Error> {
836         debug!("Translate GPA addr 0x{:x}", addr);
837         if let Some(domain_id) = self.endpoints.read().unwrap().get(&id) {
838             if let Some(domain) = self.domains.read().unwrap().get(domain_id) {
839                 // Directly return identity mapping in case the domain is in
840                 // bypass mode.
841                 if domain.bypass {
842                     return Ok(addr);
843                 }
844 
845                 for (&key, &value) in domain.mappings.iter() {
846                     if addr >= value.gpa && addr < value.gpa + value.size {
847                         let new_addr = addr - value.gpa + key;
848                         debug!("Into GVA addr 0x{:x}", new_addr);
849                         return Ok(new_addr);
850                     }
851                 }
852             }
853         } else if self.bypass.load(Ordering::Acquire) {
854             return Ok(addr);
855         }
856 
857         Err(io::Error::new(
858             io::ErrorKind::Other,
859             format!("failed to translate GPA addr 0x{:x}", addr),
860         ))
861     }
862 }
863 
864 #[derive(Debug)]
865 pub struct AccessPlatformMapping {
866     id: u32,
867     mapping: Arc<IommuMapping>,
868 }
869 
870 impl AccessPlatformMapping {
871     pub fn new(id: u32, mapping: Arc<IommuMapping>) -> Self {
872         AccessPlatformMapping { id, mapping }
873     }
874 }
875 
876 impl AccessPlatform for AccessPlatformMapping {
877     fn translate_gva(&self, base: u64, _size: u64) -> std::result::Result<u64, std::io::Error> {
878         self.mapping.translate_gva(self.id, base)
879     }
880     fn translate_gpa(&self, base: u64, _size: u64) -> std::result::Result<u64, std::io::Error> {
881         self.mapping.translate_gpa(self.id, base)
882     }
883 }
884 
885 pub struct Iommu {
886     common: VirtioCommon,
887     id: String,
888     config: VirtioIommuConfig,
889     mapping: Arc<IommuMapping>,
890     ext_mapping: Arc<Mutex<BTreeMap<u32, Arc<dyn ExternalDmaMapping>>>>,
891     seccomp_action: SeccompAction,
892     exit_evt: EventFd,
893     msi_iova_space: (u64, u64),
894 }
895 
896 type EndpointsState = Vec<(u32, u32)>;
897 type DomainsState = Vec<(u32, (Vec<(u64, Mapping)>, bool))>;
898 
899 #[derive(Versionize)]
900 struct IommuState {
901     avail_features: u64,
902     acked_features: u64,
903     endpoints: EndpointsState,
904     domains: DomainsState,
905 }
906 
907 impl VersionMapped for IommuState {}
908 
909 impl Iommu {
910     pub fn new(
911         id: String,
912         seccomp_action: SeccompAction,
913         exit_evt: EventFd,
914         msi_iova_space: (u64, u64),
915     ) -> io::Result<(Self, Arc<IommuMapping>)> {
916         let config = VirtioIommuConfig {
917             page_size_mask: VIRTIO_IOMMU_PAGE_SIZE_MASK,
918             probe_size: PROBE_PROP_SIZE,
919             ..Default::default()
920         };
921 
922         let mapping = Arc::new(IommuMapping {
923             endpoints: Arc::new(RwLock::new(BTreeMap::new())),
924             domains: Arc::new(RwLock::new(BTreeMap::new())),
925             bypass: AtomicBool::new(true),
926         });
927 
928         Ok((
929             Iommu {
930                 id,
931                 common: VirtioCommon {
932                     device_type: VirtioDeviceType::Iommu as u32,
933                     queue_sizes: QUEUE_SIZES.to_vec(),
934                     avail_features: 1u64 << VIRTIO_F_VERSION_1
935                         | 1u64 << VIRTIO_IOMMU_F_MAP_UNMAP
936                         | 1u64 << VIRTIO_IOMMU_F_PROBE
937                         | 1u64 << VIRTIO_IOMMU_F_BYPASS_CONFIG,
938                     paused_sync: Some(Arc::new(Barrier::new(2))),
939                     ..Default::default()
940                 },
941                 config,
942                 mapping: mapping.clone(),
943                 ext_mapping: Arc::new(Mutex::new(BTreeMap::new())),
944                 seccomp_action,
945                 exit_evt,
946                 msi_iova_space,
947             },
948             mapping,
949         ))
950     }
951 
952     fn state(&self) -> IommuState {
953         IommuState {
954             avail_features: self.common.avail_features,
955             acked_features: self.common.acked_features,
956             endpoints: self
957                 .mapping
958                 .endpoints
959                 .read()
960                 .unwrap()
961                 .clone()
962                 .into_iter()
963                 .collect(),
964             domains: self
965                 .mapping
966                 .domains
967                 .read()
968                 .unwrap()
969                 .clone()
970                 .into_iter()
971                 .map(|(k, v)| (k, (v.mappings.into_iter().collect(), v.bypass)))
972                 .collect(),
973         }
974     }
975 
976     fn set_state(&mut self, state: &IommuState) {
977         self.common.avail_features = state.avail_features;
978         self.common.acked_features = state.acked_features;
979         *(self.mapping.endpoints.write().unwrap()) = state.endpoints.clone().into_iter().collect();
980         *(self.mapping.domains.write().unwrap()) = state
981             .domains
982             .clone()
983             .into_iter()
984             .map(|(k, v)| {
985                 (
986                     k,
987                     Domain {
988                         mappings: v.0.into_iter().collect(),
989                         bypass: v.1,
990                     },
991                 )
992             })
993             .collect();
994     }
995 
996     fn update_bypass(&mut self) {
997         // Use bypass from config if VIRTIO_IOMMU_F_BYPASS_CONFIG has been negotiated
998         if !self
999             .common
1000             .feature_acked(VIRTIO_IOMMU_F_BYPASS_CONFIG.into())
1001         {
1002             return;
1003         }
1004 
1005         let bypass = self.config.bypass == 1;
1006         info!("Updating bypass mode to {}", bypass);
1007         self.mapping.bypass.store(bypass, Ordering::Release);
1008     }
1009 
1010     pub fn add_external_mapping(&mut self, device_id: u32, mapping: Arc<dyn ExternalDmaMapping>) {
1011         self.ext_mapping.lock().unwrap().insert(device_id, mapping);
1012     }
1013 }
1014 
1015 impl Drop for Iommu {
1016     fn drop(&mut self) {
1017         if let Some(kill_evt) = self.common.kill_evt.take() {
1018             // Ignore the result because there is nothing we can do about it.
1019             let _ = kill_evt.write(1);
1020         }
1021     }
1022 }
1023 
1024 impl VirtioDevice for Iommu {
1025     fn device_type(&self) -> u32 {
1026         self.common.device_type
1027     }
1028 
1029     fn queue_max_sizes(&self) -> &[u16] {
1030         &self.common.queue_sizes
1031     }
1032 
1033     fn features(&self) -> u64 {
1034         self.common.avail_features
1035     }
1036 
1037     fn ack_features(&mut self, value: u64) {
1038         self.common.ack_features(value)
1039     }
1040 
1041     fn read_config(&self, offset: u64, data: &mut [u8]) {
1042         self.read_config_from_slice(self.config.as_slice(), offset, data);
1043     }
1044 
1045     fn write_config(&mut self, offset: u64, data: &[u8]) {
1046         // The "bypass" field is the only mutable field
1047         let bypass_offset =
1048             (&self.config.bypass as *const _ as u64) - (&self.config as *const _ as u64);
1049         if offset != bypass_offset || data.len() != std::mem::size_of_val(&self.config.bypass) {
1050             error!(
1051                 "Attempt to write to read-only field: offset {:x} length {}",
1052                 offset,
1053                 data.len()
1054             );
1055             return;
1056         }
1057 
1058         self.config.bypass = data[0];
1059 
1060         self.update_bypass();
1061     }
1062 
1063     fn activate(
1064         &mut self,
1065         mem: GuestMemoryAtomic<GuestMemoryMmap>,
1066         interrupt_cb: Arc<dyn VirtioInterrupt>,
1067         queues: Vec<(usize, Queue, EventFd)>,
1068     ) -> ActivateResult {
1069         self.common.activate(&queues, &interrupt_cb)?;
1070         let (kill_evt, pause_evt) = self.common.dup_eventfds();
1071 
1072         let mut virtqueues = Vec::new();
1073         let mut queue_evts = Vec::new();
1074         for (_, queue, queue_evt) in queues {
1075             virtqueues.push(queue);
1076             queue_evts.push(queue_evt);
1077         }
1078 
1079         let mut handler = IommuEpollHandler {
1080             mem,
1081             queues: virtqueues,
1082             interrupt_cb,
1083             queue_evts,
1084             kill_evt,
1085             pause_evt,
1086             mapping: self.mapping.clone(),
1087             ext_mapping: self.ext_mapping.clone(),
1088             msi_iova_space: self.msi_iova_space,
1089         };
1090 
1091         let paused = self.common.paused.clone();
1092         let paused_sync = self.common.paused_sync.clone();
1093         let mut epoll_threads = Vec::new();
1094         spawn_virtio_thread(
1095             &self.id,
1096             &self.seccomp_action,
1097             Thread::VirtioIommu,
1098             &mut epoll_threads,
1099             &self.exit_evt,
1100             move || handler.run(paused, paused_sync.unwrap()),
1101         )?;
1102 
1103         self.common.epoll_threads = Some(epoll_threads);
1104 
1105         event!("virtio-device", "activated", "id", &self.id);
1106         Ok(())
1107     }
1108 
1109     fn reset(&mut self) -> Option<Arc<dyn VirtioInterrupt>> {
1110         let result = self.common.reset();
1111         event!("virtio-device", "reset", "id", &self.id);
1112         result
1113     }
1114 }
1115 
1116 impl Pausable for Iommu {
1117     fn pause(&mut self) -> result::Result<(), MigratableError> {
1118         self.common.pause()
1119     }
1120 
1121     fn resume(&mut self) -> result::Result<(), MigratableError> {
1122         self.common.resume()
1123     }
1124 }
1125 
1126 impl Snapshottable for Iommu {
1127     fn id(&self) -> String {
1128         self.id.clone()
1129     }
1130 
1131     fn snapshot(&mut self) -> std::result::Result<Snapshot, MigratableError> {
1132         Snapshot::new_from_versioned_state(&self.id, &self.state())
1133     }
1134 
1135     fn restore(&mut self, snapshot: Snapshot) -> std::result::Result<(), MigratableError> {
1136         self.set_state(&snapshot.to_versioned_state(&self.id)?);
1137         Ok(())
1138     }
1139 }
1140 impl Transportable for Iommu {}
1141 impl Migratable for Iommu {}
1142