xref: /cloud-hypervisor/virtio-devices/src/iommu.rs (revision 7d7bfb2034001d4cb15df2ddc56d2d350c8da30f)
1 // Copyright © 2019 Intel Corporation
2 //
3 // SPDX-License-Identifier: Apache-2.0 AND BSD-3-Clause
4 
5 use super::Error as DeviceError;
6 use super::{
7     ActivateResult, EpollHelper, EpollHelperError, EpollHelperHandler, VirtioCommon, VirtioDevice,
8     VirtioDeviceType, EPOLL_HELPER_EVENT_LAST, VIRTIO_F_VERSION_1,
9 };
10 use crate::seccomp_filters::Thread;
11 use crate::thread_helper::spawn_virtio_thread;
12 use crate::GuestMemoryMmap;
13 use crate::{DmaRemapping, VirtioInterrupt, VirtioInterruptType};
14 use seccompiler::SeccompAction;
15 use std::collections::BTreeMap;
16 use std::fmt::{self, Display};
17 use std::io;
18 use std::mem::size_of;
19 use std::ops::Bound::Included;
20 use std::os::unix::io::AsRawFd;
21 use std::result;
22 use std::sync::atomic::AtomicBool;
23 use std::sync::{Arc, Barrier, Mutex, RwLock};
24 use versionize::{VersionMap, Versionize, VersionizeResult};
25 use versionize_derive::Versionize;
26 use virtio_queue::{DescriptorChain, Queue};
27 use vm_device::dma_mapping::ExternalDmaMapping;
28 use vm_memory::{
29     Address, ByteValued, Bytes, GuestAddress, GuestMemoryAtomic, GuestMemoryError,
30     GuestMemoryLoadGuard,
31 };
32 use vm_migration::VersionMapped;
33 use vm_migration::{Migratable, MigratableError, Pausable, Snapshot, Snapshottable, Transportable};
34 use vm_virtio::AccessPlatform;
35 use vmm_sys_util::eventfd::EventFd;
36 
37 /// Queues sizes
38 const QUEUE_SIZE: u16 = 256;
39 const NUM_QUEUES: usize = 2;
40 const QUEUE_SIZES: &[u16] = &[QUEUE_SIZE; NUM_QUEUES];
41 
42 /// New descriptors are pending on the request queue.
43 /// "requestq" is meant to be used anytime an action is required to be
44 /// performed on behalf of the guest driver.
45 const REQUEST_Q_EVENT: u16 = EPOLL_HELPER_EVENT_LAST + 1;
46 /// New descriptors are pending on the event queue.
47 /// "eventq" lets the device report any fault or other asynchronous event to
48 /// the guest driver.
49 const EVENT_Q_EVENT: u16 = EPOLL_HELPER_EVENT_LAST + 2;
50 
51 /// PROBE properties size.
52 /// This is the minimal size to provide at least one RESV_MEM property.
53 /// Because virtio-iommu expects one MSI reserved region, we must provide it,
54 /// otherwise the driver in the guest will define a predefined one between
55 /// 0x8000000 and 0x80FFFFF, which is only relevant for ARM architecture, but
56 /// will conflict with x86.
57 const PROBE_PROP_SIZE: u32 =
58     (size_of::<VirtioIommuProbeProperty>() + size_of::<VirtioIommuProbeResvMem>()) as u32;
59 
60 /// Virtio IOMMU features
61 #[allow(unused)]
62 const VIRTIO_IOMMU_F_INPUT_RANGE: u32 = 0;
63 #[allow(unused)]
64 const VIRTIO_IOMMU_F_DOMAIN_RANGE: u32 = 1;
65 #[allow(unused)]
66 const VIRTIO_IOMMU_F_MAP_UNMAP: u32 = 2;
67 #[allow(unused)]
68 const VIRTIO_IOMMU_F_BYPASS: u32 = 3;
69 const VIRTIO_IOMMU_F_PROBE: u32 = 4;
70 #[allow(unused)]
71 const VIRTIO_IOMMU_F_MMIO: u32 = 5;
72 #[allow(unused)]
73 const VIRTIO_IOMMU_F_BYPASS_CONFIG: u32 = 6;
74 
75 // Support 2MiB and 4KiB page sizes.
76 const VIRTIO_IOMMU_PAGE_SIZE_MASK: u64 = (2 << 20) | (4 << 10);
77 
78 #[derive(Copy, Clone, Debug, Default)]
79 #[repr(packed)]
80 #[allow(dead_code)]
81 struct VirtioIommuRange32 {
82     start: u32,
83     end: u32,
84 }
85 
86 #[derive(Copy, Clone, Debug, Default)]
87 #[repr(packed)]
88 #[allow(dead_code)]
89 struct VirtioIommuRange64 {
90     start: u64,
91     end: u64,
92 }
93 
94 #[derive(Copy, Clone, Debug, Default)]
95 #[repr(packed)]
96 #[allow(dead_code)]
97 struct VirtioIommuConfig {
98     page_size_mask: u64,
99     input_range: VirtioIommuRange64,
100     domain_range: VirtioIommuRange32,
101     probe_size: u32,
102     bypass: u8,
103     _reserved: [u8; 7],
104 }
105 
106 /// Virtio IOMMU request type
107 const VIRTIO_IOMMU_T_ATTACH: u8 = 1;
108 const VIRTIO_IOMMU_T_DETACH: u8 = 2;
109 const VIRTIO_IOMMU_T_MAP: u8 = 3;
110 const VIRTIO_IOMMU_T_UNMAP: u8 = 4;
111 const VIRTIO_IOMMU_T_PROBE: u8 = 5;
112 
113 #[derive(Copy, Clone, Debug, Default)]
114 #[repr(packed)]
115 struct VirtioIommuReqHead {
116     type_: u8,
117     _reserved: [u8; 3],
118 }
119 
120 /// Virtio IOMMU request status
121 const VIRTIO_IOMMU_S_OK: u8 = 0;
122 #[allow(unused)]
123 const VIRTIO_IOMMU_S_IOERR: u8 = 1;
124 #[allow(unused)]
125 const VIRTIO_IOMMU_S_UNSUPP: u8 = 2;
126 #[allow(unused)]
127 const VIRTIO_IOMMU_S_DEVERR: u8 = 3;
128 #[allow(unused)]
129 const VIRTIO_IOMMU_S_INVAL: u8 = 4;
130 #[allow(unused)]
131 const VIRTIO_IOMMU_S_RANGE: u8 = 5;
132 #[allow(unused)]
133 const VIRTIO_IOMMU_S_NOENT: u8 = 6;
134 #[allow(unused)]
135 const VIRTIO_IOMMU_S_FAULT: u8 = 7;
136 #[allow(unused)]
137 const VIRTIO_IOMMU_S_NOMEM: u8 = 8;
138 
139 #[derive(Copy, Clone, Debug, Default)]
140 #[repr(packed)]
141 #[allow(dead_code)]
142 struct VirtioIommuReqTail {
143     status: u8,
144     _reserved: [u8; 3],
145 }
146 
147 /// ATTACH request
148 #[derive(Copy, Clone, Debug, Default)]
149 #[repr(packed)]
150 struct VirtioIommuReqAttach {
151     domain: u32,
152     endpoint: u32,
153     _reserved: [u8; 8],
154 }
155 
156 /// DETACH request
157 #[derive(Copy, Clone, Debug, Default)]
158 #[repr(packed)]
159 struct VirtioIommuReqDetach {
160     domain: u32,
161     endpoint: u32,
162     _reserved: [u8; 8],
163 }
164 
165 /// Virtio IOMMU request MAP flags
166 #[allow(unused)]
167 const VIRTIO_IOMMU_MAP_F_READ: u32 = 1;
168 #[allow(unused)]
169 const VIRTIO_IOMMU_MAP_F_WRITE: u32 = 1 << 1;
170 #[allow(unused)]
171 const VIRTIO_IOMMU_MAP_F_MMIO: u32 = 1 << 2;
172 #[allow(unused)]
173 const VIRTIO_IOMMU_MAP_F_MASK: u32 =
174     VIRTIO_IOMMU_MAP_F_READ | VIRTIO_IOMMU_MAP_F_WRITE | VIRTIO_IOMMU_MAP_F_MMIO;
175 
176 /// MAP request
177 #[derive(Copy, Clone, Debug, Default)]
178 #[repr(packed)]
179 struct VirtioIommuReqMap {
180     domain: u32,
181     virt_start: u64,
182     virt_end: u64,
183     phys_start: u64,
184     _flags: u32,
185 }
186 
187 /// UNMAP request
188 #[derive(Copy, Clone, Debug, Default)]
189 #[repr(packed)]
190 struct VirtioIommuReqUnmap {
191     domain: u32,
192     virt_start: u64,
193     virt_end: u64,
194     _reserved: [u8; 4],
195 }
196 
197 /// Virtio IOMMU request PROBE types
198 #[allow(unused)]
199 const VIRTIO_IOMMU_PROBE_T_NONE: u16 = 0;
200 const VIRTIO_IOMMU_PROBE_T_RESV_MEM: u16 = 1;
201 #[allow(unused)]
202 const VIRTIO_IOMMU_PROBE_T_MASK: u16 = 0xfff;
203 
204 /// PROBE request
205 #[derive(Copy, Clone, Debug, Default)]
206 #[repr(packed)]
207 #[allow(dead_code)]
208 struct VirtioIommuReqProbe {
209     endpoint: u32,
210     _reserved: [u64; 8],
211 }
212 
213 #[derive(Copy, Clone, Debug, Default)]
214 #[repr(packed)]
215 #[allow(dead_code)]
216 struct VirtioIommuProbeProperty {
217     type_: u16,
218     length: u16,
219 }
220 
221 /// Virtio IOMMU request PROBE property RESV_MEM subtypes
222 #[allow(unused)]
223 const VIRTIO_IOMMU_RESV_MEM_T_RESERVED: u8 = 0;
224 const VIRTIO_IOMMU_RESV_MEM_T_MSI: u8 = 1;
225 
226 #[derive(Copy, Clone, Debug, Default)]
227 #[repr(packed)]
228 #[allow(dead_code)]
229 struct VirtioIommuProbeResvMem {
230     subtype: u8,
231     _reserved: [u8; 3],
232     start: u64,
233     end: u64,
234 }
235 
236 /// Virtio IOMMU fault flags
237 #[allow(unused)]
238 const VIRTIO_IOMMU_FAULT_F_READ: u32 = 1;
239 #[allow(unused)]
240 const VIRTIO_IOMMU_FAULT_F_WRITE: u32 = 1 << 1;
241 #[allow(unused)]
242 const VIRTIO_IOMMU_FAULT_F_EXEC: u32 = 1 << 2;
243 #[allow(unused)]
244 const VIRTIO_IOMMU_FAULT_F_ADDRESS: u32 = 1 << 8;
245 
246 /// Virtio IOMMU fault reasons
247 #[allow(unused)]
248 const VIRTIO_IOMMU_FAULT_R_UNKNOWN: u32 = 0;
249 #[allow(unused)]
250 const VIRTIO_IOMMU_FAULT_R_DOMAIN: u32 = 1;
251 #[allow(unused)]
252 const VIRTIO_IOMMU_FAULT_R_MAPPING: u32 = 2;
253 
254 /// Fault reporting through eventq
255 #[allow(unused)]
256 #[derive(Copy, Clone, Debug, Default)]
257 #[repr(packed)]
258 struct VirtioIommuFault {
259     reason: u8,
260     reserved: [u8; 3],
261     flags: u32,
262     endpoint: u32,
263     reserved2: [u8; 4],
264     address: u64,
265 }
266 
267 // SAFETY: these data structures only contain integers and have no implicit padding
268 unsafe impl ByteValued for VirtioIommuRange32 {}
269 unsafe impl ByteValued for VirtioIommuRange64 {}
270 unsafe impl ByteValued for VirtioIommuConfig {}
271 unsafe impl ByteValued for VirtioIommuReqHead {}
272 unsafe impl ByteValued for VirtioIommuReqTail {}
273 unsafe impl ByteValued for VirtioIommuReqAttach {}
274 unsafe impl ByteValued for VirtioIommuReqDetach {}
275 unsafe impl ByteValued for VirtioIommuReqMap {}
276 unsafe impl ByteValued for VirtioIommuReqUnmap {}
277 unsafe impl ByteValued for VirtioIommuReqProbe {}
278 unsafe impl ByteValued for VirtioIommuProbeProperty {}
279 unsafe impl ByteValued for VirtioIommuProbeResvMem {}
280 unsafe impl ByteValued for VirtioIommuFault {}
281 
282 #[derive(Debug)]
283 enum Error {
284     /// Guest gave us bad memory addresses.
285     GuestMemory(GuestMemoryError),
286     /// Guest gave us a write only descriptor that protocol says to read from.
287     UnexpectedWriteOnlyDescriptor,
288     /// Guest gave us a read only descriptor that protocol says to write to.
289     UnexpectedReadOnlyDescriptor,
290     /// Guest gave us too few descriptors in a descriptor chain.
291     DescriptorChainTooShort,
292     /// Guest gave us a buffer that was too short to use.
293     BufferLengthTooSmall,
294     /// Guest sent us invalid request.
295     InvalidRequest,
296     /// Guest sent us invalid ATTACH request.
297     InvalidAttachRequest,
298     /// Guest sent us invalid DETACH request.
299     InvalidDetachRequest,
300     /// Guest sent us invalid MAP request.
301     InvalidMapRequest,
302     /// Guest sent us invalid UNMAP request.
303     InvalidUnmapRequest,
304     /// Guest sent us invalid PROBE request.
305     InvalidProbeRequest,
306     /// Failed to performing external mapping.
307     ExternalMapping(io::Error),
308     /// Failed to performing external unmapping.
309     ExternalUnmapping(io::Error),
310 }
311 
312 impl Display for Error {
313     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
314         use self::Error::*;
315 
316         match self {
317             BufferLengthTooSmall => write!(f, "buffer length too small"),
318             DescriptorChainTooShort => write!(f, "descriptor chain too short"),
319             GuestMemory(e) => write!(f, "bad guest memory address: {}", e),
320             InvalidRequest => write!(f, "invalid request"),
321             InvalidAttachRequest => write!(f, "invalid attach request"),
322             InvalidDetachRequest => write!(f, "invalid detach request"),
323             InvalidMapRequest => write!(f, "invalid map request"),
324             InvalidUnmapRequest => write!(f, "invalid unmap request"),
325             InvalidProbeRequest => write!(f, "invalid probe request"),
326             UnexpectedReadOnlyDescriptor => write!(f, "unexpected read-only descriptor"),
327             UnexpectedWriteOnlyDescriptor => write!(f, "unexpected write-only descriptor"),
328             ExternalMapping(e) => write!(f, "failed performing external mapping: {}", e),
329             ExternalUnmapping(e) => write!(f, "failed performing external unmapping: {}", e),
330         }
331     }
332 }
333 
334 struct Request {}
335 
336 impl Request {
337     // Parse the available vring buffer. Based on the hashmap table of external
338     // mappings required from various devices such as VFIO or vhost-user ones,
339     // this function might update the hashmap table of external mappings per
340     // domain.
341     // Basically, the VMM knows about the device_id <=> mapping relationship
342     // before running the VM, but at runtime, a new domain <=> mapping hashmap
343     // is created based on the information provided from the guest driver for
344     // virtio-iommu (giving the link device_id <=> domain).
345     fn parse(
346         desc_chain: &mut DescriptorChain<GuestMemoryLoadGuard<GuestMemoryMmap>>,
347         mapping: &Arc<IommuMapping>,
348         ext_mapping: &BTreeMap<u32, Arc<dyn ExternalDmaMapping>>,
349         ext_domain_mapping: &mut BTreeMap<u32, Arc<dyn ExternalDmaMapping>>,
350         msi_iova_space: (u64, u64),
351     ) -> result::Result<usize, Error> {
352         let desc = desc_chain
353             .next()
354             .ok_or(Error::DescriptorChainTooShort)
355             .map_err(|e| {
356                 error!("Missing head descriptor");
357                 e
358             })?;
359 
360         // The descriptor contains the request type which MUST be readable.
361         if desc.is_write_only() {
362             return Err(Error::UnexpectedWriteOnlyDescriptor);
363         }
364 
365         if (desc.len() as usize) < size_of::<VirtioIommuReqHead>() {
366             return Err(Error::InvalidRequest);
367         }
368 
369         let req_head: VirtioIommuReqHead = desc_chain
370             .memory()
371             .read_obj(desc.addr())
372             .map_err(Error::GuestMemory)?;
373         let req_offset = size_of::<VirtioIommuReqHead>();
374         let desc_size_left = (desc.len() as usize) - req_offset;
375         let req_addr = if let Some(addr) = desc.addr().checked_add(req_offset as u64) {
376             addr
377         } else {
378             return Err(Error::InvalidRequest);
379         };
380 
381         let (msi_iova_start, msi_iova_end) = msi_iova_space;
382 
383         // Create the reply
384         let mut reply: Vec<u8> = Vec::new();
385 
386         let hdr_len = match req_head.type_ {
387             VIRTIO_IOMMU_T_ATTACH => {
388                 if desc_size_left != size_of::<VirtioIommuReqAttach>() {
389                     return Err(Error::InvalidAttachRequest);
390                 }
391 
392                 let req: VirtioIommuReqAttach = desc_chain
393                     .memory()
394                     .read_obj(req_addr as GuestAddress)
395                     .map_err(Error::GuestMemory)?;
396                 debug!("Attach request {:?}", req);
397 
398                 // Copy the value to use it as a proper reference.
399                 let domain = req.domain;
400                 let endpoint = req.endpoint;
401 
402                 // Add endpoint associated with specific domain
403                 mapping.endpoints.write().unwrap().insert(endpoint, domain);
404 
405                 // If the endpoint is part of the list of devices with an
406                 // external mapping, insert a new entry for the corresponding
407                 // domain, with the same reference to the trait.
408                 if let Some(map) = ext_mapping.get(&endpoint) {
409                     ext_domain_mapping.insert(domain, map.clone());
410                 }
411 
412                 // Add new domain with no mapping if the entry didn't exist yet
413                 let mut mappings = mapping.mappings.write().unwrap();
414                 mappings.entry(domain).or_insert_with(BTreeMap::new);
415 
416                 0
417             }
418             VIRTIO_IOMMU_T_DETACH => {
419                 if desc_size_left != size_of::<VirtioIommuReqDetach>() {
420                     return Err(Error::InvalidDetachRequest);
421                 }
422 
423                 let req: VirtioIommuReqDetach = desc_chain
424                     .memory()
425                     .read_obj(req_addr as GuestAddress)
426                     .map_err(Error::GuestMemory)?;
427                 debug!("Detach request {:?}", req);
428 
429                 // Copy the value to use it as a proper reference.
430                 let domain = req.domain;
431                 let endpoint = req.endpoint;
432 
433                 // If the endpoint is part of the list of devices with an
434                 // external mapping, remove the entry for the corresponding
435                 // domain.
436                 if ext_mapping.contains_key(&endpoint) {
437                     ext_domain_mapping.remove(&domain);
438                 }
439 
440                 // Remove endpoint associated with specific domain
441                 mapping.endpoints.write().unwrap().remove(&endpoint);
442 
443                 0
444             }
445             VIRTIO_IOMMU_T_MAP => {
446                 if desc_size_left != size_of::<VirtioIommuReqMap>() {
447                     return Err(Error::InvalidMapRequest);
448                 }
449 
450                 let req: VirtioIommuReqMap = desc_chain
451                     .memory()
452                     .read_obj(req_addr as GuestAddress)
453                     .map_err(Error::GuestMemory)?;
454                 debug!("Map request {:?}", req);
455 
456                 // Copy the value to use it as a proper reference.
457                 let domain = req.domain;
458 
459                 // Trigger external mapping if necessary.
460                 if let Some(ext_map) = ext_domain_mapping.get(&domain) {
461                     let size = req.virt_end - req.virt_start + 1;
462                     ext_map
463                         .map(req.virt_start, req.phys_start, size)
464                         .map_err(Error::ExternalMapping)?;
465                 }
466 
467                 // Add new mapping associated with the domain
468                 if let Some(entry) = mapping.mappings.write().unwrap().get_mut(&domain) {
469                     entry.insert(
470                         req.virt_start,
471                         Mapping {
472                             gpa: req.phys_start,
473                             size: req.virt_end - req.virt_start + 1,
474                         },
475                     );
476                 } else {
477                     return Err(Error::InvalidMapRequest);
478                 }
479 
480                 0
481             }
482             VIRTIO_IOMMU_T_UNMAP => {
483                 if desc_size_left != size_of::<VirtioIommuReqUnmap>() {
484                     return Err(Error::InvalidUnmapRequest);
485                 }
486 
487                 let req: VirtioIommuReqUnmap = desc_chain
488                     .memory()
489                     .read_obj(req_addr as GuestAddress)
490                     .map_err(Error::GuestMemory)?;
491                 debug!("Unmap request {:?}", req);
492 
493                 // Copy the value to use it as a proper reference.
494                 let domain = req.domain;
495                 let virt_start = req.virt_start;
496 
497                 // Trigger external unmapping if necessary.
498                 if let Some(ext_map) = ext_domain_mapping.get(&domain) {
499                     let size = req.virt_end - virt_start + 1;
500                     ext_map
501                         .unmap(virt_start, size)
502                         .map_err(Error::ExternalUnmapping)?;
503                 }
504 
505                 // Add new mapping associated with the domain
506                 if let Some(entry) = mapping.mappings.write().unwrap().get_mut(&domain) {
507                     entry.remove(&virt_start);
508                 }
509 
510                 0
511             }
512             VIRTIO_IOMMU_T_PROBE => {
513                 if desc_size_left != size_of::<VirtioIommuReqProbe>() {
514                     return Err(Error::InvalidProbeRequest);
515                 }
516 
517                 let req: VirtioIommuReqProbe = desc_chain
518                     .memory()
519                     .read_obj(req_addr as GuestAddress)
520                     .map_err(Error::GuestMemory)?;
521                 debug!("Probe request {:?}", req);
522 
523                 let probe_prop = VirtioIommuProbeProperty {
524                     type_: VIRTIO_IOMMU_PROBE_T_RESV_MEM,
525                     length: size_of::<VirtioIommuProbeResvMem>() as u16,
526                 };
527                 reply.extend_from_slice(probe_prop.as_slice());
528 
529                 let resv_mem = VirtioIommuProbeResvMem {
530                     subtype: VIRTIO_IOMMU_RESV_MEM_T_MSI,
531                     start: msi_iova_start,
532                     end: msi_iova_end,
533                     ..Default::default()
534                 };
535                 reply.extend_from_slice(resv_mem.as_slice());
536 
537                 PROBE_PROP_SIZE
538             }
539             _ => return Err(Error::InvalidRequest),
540         };
541 
542         let status_desc = desc_chain.next().ok_or(Error::DescriptorChainTooShort)?;
543 
544         // The status MUST always be writable
545         if !status_desc.is_write_only() {
546             return Err(Error::UnexpectedReadOnlyDescriptor);
547         }
548 
549         if status_desc.len() < hdr_len + size_of::<VirtioIommuReqTail>() as u32 {
550             return Err(Error::BufferLengthTooSmall);
551         }
552 
553         let tail = VirtioIommuReqTail {
554             status: VIRTIO_IOMMU_S_OK,
555             ..Default::default()
556         };
557         reply.extend_from_slice(tail.as_slice());
558 
559         desc_chain
560             .memory()
561             .write_slice(reply.as_slice(), status_desc.addr())
562             .map_err(Error::GuestMemory)?;
563 
564         Ok((hdr_len as usize) + size_of::<VirtioIommuReqTail>())
565     }
566 }
567 
568 struct IommuEpollHandler {
569     queues: Vec<Queue<GuestMemoryAtomic<GuestMemoryMmap>>>,
570     interrupt_cb: Arc<dyn VirtioInterrupt>,
571     queue_evts: Vec<EventFd>,
572     kill_evt: EventFd,
573     pause_evt: EventFd,
574     mapping: Arc<IommuMapping>,
575     ext_mapping: Arc<Mutex<BTreeMap<u32, Arc<dyn ExternalDmaMapping>>>>,
576     ext_domain_mapping: BTreeMap<u32, Arc<dyn ExternalDmaMapping>>,
577     msi_iova_space: (u64, u64),
578 }
579 
580 impl IommuEpollHandler {
581     fn request_queue(&mut self) -> bool {
582         let mut used_desc_heads = [(0, 0); QUEUE_SIZE as usize];
583         let mut used_count = 0;
584         for mut desc_chain in self.queues[0].iter().unwrap() {
585             let len = match Request::parse(
586                 &mut desc_chain,
587                 &self.mapping,
588                 &self.ext_mapping.lock().unwrap(),
589                 &mut self.ext_domain_mapping,
590                 self.msi_iova_space,
591             ) {
592                 Ok(len) => len as u32,
593                 Err(e) => {
594                     error!("failed parsing descriptor: {}", e);
595                     0
596                 }
597             };
598 
599             used_desc_heads[used_count] = (desc_chain.head_index(), len);
600             used_count += 1;
601         }
602 
603         for &(desc_index, len) in &used_desc_heads[..used_count] {
604             self.queues[0].add_used(desc_index, len).unwrap();
605         }
606         used_count > 0
607     }
608 
609     fn event_queue(&mut self) -> bool {
610         false
611     }
612 
613     fn signal_used_queue(&self, queue_index: u16) -> result::Result<(), DeviceError> {
614         self.interrupt_cb
615             .trigger(VirtioInterruptType::Queue(queue_index))
616             .map_err(|e| {
617                 error!("Failed to signal used queue: {:?}", e);
618                 DeviceError::FailedSignalingUsedQueue(e)
619             })
620     }
621 
622     fn run(
623         &mut self,
624         paused: Arc<AtomicBool>,
625         paused_sync: Arc<Barrier>,
626     ) -> result::Result<(), EpollHelperError> {
627         let mut helper = EpollHelper::new(&self.kill_evt, &self.pause_evt)?;
628         helper.add_event(self.queue_evts[0].as_raw_fd(), REQUEST_Q_EVENT)?;
629         helper.add_event(self.queue_evts[1].as_raw_fd(), EVENT_Q_EVENT)?;
630         helper.run(paused, paused_sync, self)?;
631 
632         Ok(())
633     }
634 }
635 
636 impl EpollHelperHandler for IommuEpollHandler {
637     fn handle_event(&mut self, _helper: &mut EpollHelper, event: &epoll::Event) -> bool {
638         let ev_type = event.data as u16;
639         match ev_type {
640             REQUEST_Q_EVENT => {
641                 if let Err(e) = self.queue_evts[0].read() {
642                     error!("Failed to get queue event: {:?}", e);
643                     return true;
644                 } else if self.request_queue() {
645                     if let Err(e) = self.signal_used_queue(0) {
646                         error!("Failed to signal used queue: {:?}", e);
647                         return true;
648                     }
649                 }
650             }
651             EVENT_Q_EVENT => {
652                 if let Err(e) = self.queue_evts[1].read() {
653                     error!("Failed to get queue event: {:?}", e);
654                     return true;
655                 } else if self.event_queue() {
656                     if let Err(e) = self.signal_used_queue(1) {
657                         error!("Failed to signal used queue: {:?}", e);
658                         return true;
659                     }
660                 }
661             }
662             _ => {
663                 error!("Unexpected event: {}", ev_type);
664                 return true;
665             }
666         }
667         false
668     }
669 }
670 
671 #[derive(Clone, Copy, Debug, Versionize)]
672 struct Mapping {
673     gpa: u64,
674     size: u64,
675 }
676 
677 #[derive(Debug)]
678 pub struct IommuMapping {
679     // Domain related to an endpoint.
680     endpoints: Arc<RwLock<BTreeMap<u32, u32>>>,
681     // List of mappings per domain.
682     mappings: Arc<RwLock<BTreeMap<u32, BTreeMap<u64, Mapping>>>>,
683 }
684 
685 impl DmaRemapping for IommuMapping {
686     fn translate_gva(&self, id: u32, addr: u64) -> std::result::Result<u64, std::io::Error> {
687         debug!("Translate GVA addr 0x{:x}", addr);
688         if let Some(domain) = self.endpoints.read().unwrap().get(&id) {
689             if let Some(mapping) = self.mappings.read().unwrap().get(domain) {
690                 let range_start = if VIRTIO_IOMMU_PAGE_SIZE_MASK > addr {
691                     0
692                 } else {
693                     addr - VIRTIO_IOMMU_PAGE_SIZE_MASK
694                 };
695                 for (&key, &value) in mapping.range((Included(&range_start), Included(&addr))) {
696                     if addr >= key && addr < key + value.size {
697                         let new_addr = addr - key + value.gpa;
698                         debug!("Into GPA addr 0x{:x}", new_addr);
699                         return Ok(new_addr);
700                     }
701                 }
702             }
703         }
704 
705         Err(io::Error::new(
706             io::ErrorKind::Other,
707             format!("failed to translate GVA addr 0x{:x}", addr),
708         ))
709     }
710 
711     fn translate_gpa(&self, id: u32, addr: u64) -> std::result::Result<u64, std::io::Error> {
712         debug!("Translate GPA addr 0x{:x}", addr);
713         if let Some(domain) = self.endpoints.read().unwrap().get(&id) {
714             if let Some(mapping) = self.mappings.read().unwrap().get(domain) {
715                 for (&key, &value) in mapping.iter() {
716                     if addr >= value.gpa && addr < value.gpa + value.size {
717                         let new_addr = addr - value.gpa + key;
718                         debug!("Into GVA addr 0x{:x}", new_addr);
719                         return Ok(new_addr);
720                     }
721                 }
722             }
723         }
724 
725         Err(io::Error::new(
726             io::ErrorKind::Other,
727             format!("failed to translate GPA addr 0x{:x}", addr),
728         ))
729     }
730 }
731 
732 #[derive(Debug)]
733 pub struct AccessPlatformMapping {
734     id: u32,
735     mapping: Arc<IommuMapping>,
736 }
737 
738 impl AccessPlatformMapping {
739     pub fn new(id: u32, mapping: Arc<IommuMapping>) -> Self {
740         AccessPlatformMapping { id, mapping }
741     }
742 }
743 
744 impl AccessPlatform for AccessPlatformMapping {
745     fn translate_gva(&self, base: u64, _size: u64) -> std::result::Result<u64, std::io::Error> {
746         self.mapping.translate_gva(self.id, base)
747     }
748     fn translate_gpa(&self, base: u64, _size: u64) -> std::result::Result<u64, std::io::Error> {
749         self.mapping.translate_gpa(self.id, base)
750     }
751 }
752 
753 pub struct Iommu {
754     common: VirtioCommon,
755     id: String,
756     config: VirtioIommuConfig,
757     mapping: Arc<IommuMapping>,
758     ext_mapping: Arc<Mutex<BTreeMap<u32, Arc<dyn ExternalDmaMapping>>>>,
759     seccomp_action: SeccompAction,
760     exit_evt: EventFd,
761     msi_iova_space: (u64, u64),
762 }
763 
764 #[derive(Versionize)]
765 struct IommuState {
766     avail_features: u64,
767     acked_features: u64,
768     endpoints: Vec<(u32, u32)>,
769     mappings: Vec<(u32, Vec<(u64, Mapping)>)>,
770 }
771 
772 impl VersionMapped for IommuState {}
773 
774 impl Iommu {
775     pub fn new(
776         id: String,
777         seccomp_action: SeccompAction,
778         exit_evt: EventFd,
779         msi_iova_space: (u64, u64),
780     ) -> io::Result<(Self, Arc<IommuMapping>)> {
781         let config = VirtioIommuConfig {
782             page_size_mask: VIRTIO_IOMMU_PAGE_SIZE_MASK,
783             probe_size: PROBE_PROP_SIZE,
784             ..Default::default()
785         };
786 
787         let mapping = Arc::new(IommuMapping {
788             endpoints: Arc::new(RwLock::new(BTreeMap::new())),
789             mappings: Arc::new(RwLock::new(BTreeMap::new())),
790         });
791 
792         Ok((
793             Iommu {
794                 id,
795                 common: VirtioCommon {
796                     device_type: VirtioDeviceType::Iommu as u32,
797                     queue_sizes: QUEUE_SIZES.to_vec(),
798                     avail_features: 1u64 << VIRTIO_F_VERSION_1
799                         | 1u64 << VIRTIO_IOMMU_F_MAP_UNMAP
800                         | 1u64 << VIRTIO_IOMMU_F_PROBE,
801                     paused_sync: Some(Arc::new(Barrier::new(2))),
802                     ..Default::default()
803                 },
804                 config,
805                 mapping: mapping.clone(),
806                 ext_mapping: Arc::new(Mutex::new(BTreeMap::new())),
807                 seccomp_action,
808                 exit_evt,
809                 msi_iova_space,
810             },
811             mapping,
812         ))
813     }
814 
815     fn state(&self) -> IommuState {
816         IommuState {
817             avail_features: self.common.avail_features,
818             acked_features: self.common.acked_features,
819             endpoints: self
820                 .mapping
821                 .endpoints
822                 .read()
823                 .unwrap()
824                 .clone()
825                 .into_iter()
826                 .collect(),
827             mappings: self
828                 .mapping
829                 .mappings
830                 .read()
831                 .unwrap()
832                 .clone()
833                 .into_iter()
834                 .map(|(k, v)| (k, v.into_iter().collect()))
835                 .collect(),
836         }
837     }
838 
839     fn set_state(&mut self, state: &IommuState) {
840         self.common.avail_features = state.avail_features;
841         self.common.acked_features = state.acked_features;
842         *(self.mapping.endpoints.write().unwrap()) = state.endpoints.clone().into_iter().collect();
843         *(self.mapping.mappings.write().unwrap()) = state
844             .mappings
845             .clone()
846             .into_iter()
847             .map(|(k, v)| (k, v.into_iter().collect()))
848             .collect();
849     }
850 
851     pub fn add_external_mapping(&mut self, device_id: u32, mapping: Arc<dyn ExternalDmaMapping>) {
852         self.ext_mapping.lock().unwrap().insert(device_id, mapping);
853     }
854 }
855 
856 impl Drop for Iommu {
857     fn drop(&mut self) {
858         if let Some(kill_evt) = self.common.kill_evt.take() {
859             // Ignore the result because there is nothing we can do about it.
860             let _ = kill_evt.write(1);
861         }
862     }
863 }
864 
865 impl VirtioDevice for Iommu {
866     fn device_type(&self) -> u32 {
867         self.common.device_type
868     }
869 
870     fn queue_max_sizes(&self) -> &[u16] {
871         &self.common.queue_sizes
872     }
873 
874     fn features(&self) -> u64 {
875         self.common.avail_features
876     }
877 
878     fn ack_features(&mut self, value: u64) {
879         self.common.ack_features(value)
880     }
881 
882     fn read_config(&self, offset: u64, data: &mut [u8]) {
883         self.read_config_from_slice(self.config.as_slice(), offset, data);
884     }
885 
886     fn activate(
887         &mut self,
888         _mem: GuestMemoryAtomic<GuestMemoryMmap>,
889         interrupt_cb: Arc<dyn VirtioInterrupt>,
890         queues: Vec<Queue<GuestMemoryAtomic<GuestMemoryMmap>>>,
891         queue_evts: Vec<EventFd>,
892     ) -> ActivateResult {
893         self.common.activate(&queues, &queue_evts, &interrupt_cb)?;
894         let (kill_evt, pause_evt) = self.common.dup_eventfds();
895         let mut handler = IommuEpollHandler {
896             queues,
897             interrupt_cb,
898             queue_evts,
899             kill_evt,
900             pause_evt,
901             mapping: self.mapping.clone(),
902             ext_mapping: self.ext_mapping.clone(),
903             ext_domain_mapping: BTreeMap::new(),
904             msi_iova_space: self.msi_iova_space,
905         };
906 
907         let paused = self.common.paused.clone();
908         let paused_sync = self.common.paused_sync.clone();
909         let mut epoll_threads = Vec::new();
910         spawn_virtio_thread(
911             &self.id,
912             &self.seccomp_action,
913             Thread::VirtioIommu,
914             &mut epoll_threads,
915             &self.exit_evt,
916             move || {
917                 if let Err(e) = handler.run(paused, paused_sync.unwrap()) {
918                     error!("Error running worker: {:?}", e);
919                 }
920             },
921         )?;
922 
923         self.common.epoll_threads = Some(epoll_threads);
924 
925         event!("virtio-device", "activated", "id", &self.id);
926         Ok(())
927     }
928 
929     fn reset(&mut self) -> Option<Arc<dyn VirtioInterrupt>> {
930         let result = self.common.reset();
931         event!("virtio-device", "reset", "id", &self.id);
932         result
933     }
934 }
935 
936 impl Pausable for Iommu {
937     fn pause(&mut self) -> result::Result<(), MigratableError> {
938         self.common.pause()
939     }
940 
941     fn resume(&mut self) -> result::Result<(), MigratableError> {
942         self.common.resume()
943     }
944 }
945 
946 impl Snapshottable for Iommu {
947     fn id(&self) -> String {
948         self.id.clone()
949     }
950 
951     fn snapshot(&mut self) -> std::result::Result<Snapshot, MigratableError> {
952         Snapshot::new_from_versioned_state(&self.id, &self.state())
953     }
954 
955     fn restore(&mut self, snapshot: Snapshot) -> std::result::Result<(), MigratableError> {
956         self.set_state(&snapshot.to_versioned_state(&self.id)?);
957         Ok(())
958     }
959 }
960 impl Transportable for Iommu {}
961 impl Migratable for Iommu {}
962