xref: /cloud-hypervisor/virtio-devices/src/iommu.rs (revision 20296e909a88ab1dd9d94b6bd2160ca84c430f5f)
1 // Copyright © 2019 Intel Corporation
2 //
3 // SPDX-License-Identifier: Apache-2.0 AND BSD-3-Clause
4 
5 use std::collections::BTreeMap;
6 use std::mem::size_of;
7 use std::os::unix::io::AsRawFd;
8 use std::sync::atomic::{AtomicBool, Ordering};
9 use std::sync::{Arc, Barrier, Mutex, RwLock};
10 use std::{io, result};
11 
12 use anyhow::anyhow;
13 use seccompiler::SeccompAction;
14 use serde::{Deserialize, Serialize};
15 use thiserror::Error;
16 use virtio_queue::{DescriptorChain, Queue, QueueT};
17 use vm_device::dma_mapping::ExternalDmaMapping;
18 use vm_memory::{
19     Address, ByteValued, Bytes, GuestAddress, GuestAddressSpace, GuestMemoryAtomic,
20     GuestMemoryError, GuestMemoryLoadGuard,
21 };
22 use vm_migration::{Migratable, MigratableError, Pausable, Snapshot, Snapshottable, Transportable};
23 use vm_virtio::AccessPlatform;
24 use vmm_sys_util::eventfd::EventFd;
25 
26 use super::{
27     ActivateResult, EpollHelper, EpollHelperError, EpollHelperHandler, Error as DeviceError,
28     VirtioCommon, VirtioDevice, VirtioDeviceType, EPOLL_HELPER_EVENT_LAST, VIRTIO_F_VERSION_1,
29 };
30 use crate::seccomp_filters::Thread;
31 use crate::thread_helper::spawn_virtio_thread;
32 use crate::{DmaRemapping, GuestMemoryMmap, VirtioInterrupt, VirtioInterruptType};
33 
34 /// Queues sizes
35 const QUEUE_SIZE: u16 = 256;
36 const NUM_QUEUES: usize = 2;
37 const QUEUE_SIZES: &[u16] = &[QUEUE_SIZE; NUM_QUEUES];
38 
39 /// New descriptors are pending on the request queue.
40 /// "requestq" is meant to be used anytime an action is required to be
41 /// performed on behalf of the guest driver.
42 const REQUEST_Q_EVENT: u16 = EPOLL_HELPER_EVENT_LAST + 1;
43 /// New descriptors are pending on the event queue.
44 /// "eventq" lets the device report any fault or other asynchronous event to
45 /// the guest driver.
46 const _EVENT_Q_EVENT: u16 = EPOLL_HELPER_EVENT_LAST + 2;
47 
48 /// PROBE properties size.
49 /// This is the minimal size to provide at least one RESV_MEM property.
50 /// Because virtio-iommu expects one MSI reserved region, we must provide it,
51 /// otherwise the driver in the guest will define a predefined one between
52 /// 0x8000000 and 0x80FFFFF, which is only relevant for ARM architecture, but
53 /// will conflict with x86.
54 const PROBE_PROP_SIZE: u32 =
55     (size_of::<VirtioIommuProbeProperty>() + size_of::<VirtioIommuProbeResvMem>()) as u32;
56 
57 /// Virtio IOMMU features
58 #[allow(unused)]
59 const VIRTIO_IOMMU_F_INPUT_RANGE: u32 = 0;
60 #[allow(unused)]
61 const VIRTIO_IOMMU_F_DOMAIN_RANGE: u32 = 1;
62 #[allow(unused)]
63 const VIRTIO_IOMMU_F_MAP_UNMAP: u32 = 2;
64 #[allow(unused)]
65 const VIRTIO_IOMMU_F_BYPASS: u32 = 3;
66 const VIRTIO_IOMMU_F_PROBE: u32 = 4;
67 #[allow(unused)]
68 const VIRTIO_IOMMU_F_MMIO: u32 = 5;
69 const VIRTIO_IOMMU_F_BYPASS_CONFIG: u32 = 6;
70 
71 // Support 2MiB and 4KiB page sizes.
72 const VIRTIO_IOMMU_PAGE_SIZE_MASK: u64 = (2 << 20) | (4 << 10);
73 
74 #[derive(Copy, Clone, Debug, Default)]
75 #[repr(C, packed)]
76 #[allow(dead_code)]
77 struct VirtioIommuRange32 {
78     start: u32,
79     end: u32,
80 }
81 
82 #[derive(Copy, Clone, Debug, Default)]
83 #[repr(C, packed)]
84 #[allow(dead_code)]
85 struct VirtioIommuRange64 {
86     start: u64,
87     end: u64,
88 }
89 
90 #[derive(Copy, Clone, Debug, Default)]
91 #[repr(C, packed)]
92 #[allow(dead_code)]
93 struct VirtioIommuConfig {
94     page_size_mask: u64,
95     input_range: VirtioIommuRange64,
96     domain_range: VirtioIommuRange32,
97     probe_size: u32,
98     bypass: u8,
99     _reserved: [u8; 7],
100 }
101 
102 /// Virtio IOMMU request type
103 const VIRTIO_IOMMU_T_ATTACH: u8 = 1;
104 const VIRTIO_IOMMU_T_DETACH: u8 = 2;
105 const VIRTIO_IOMMU_T_MAP: u8 = 3;
106 const VIRTIO_IOMMU_T_UNMAP: u8 = 4;
107 const VIRTIO_IOMMU_T_PROBE: u8 = 5;
108 
109 #[derive(Copy, Clone, Debug, Default)]
110 #[repr(C, packed)]
111 struct VirtioIommuReqHead {
112     type_: u8,
113     _reserved: [u8; 3],
114 }
115 
116 /// Virtio IOMMU request status
117 const VIRTIO_IOMMU_S_OK: u8 = 0;
118 #[allow(unused)]
119 const VIRTIO_IOMMU_S_IOERR: u8 = 1;
120 #[allow(unused)]
121 const VIRTIO_IOMMU_S_UNSUPP: u8 = 2;
122 #[allow(unused)]
123 const VIRTIO_IOMMU_S_DEVERR: u8 = 3;
124 #[allow(unused)]
125 const VIRTIO_IOMMU_S_INVAL: u8 = 4;
126 #[allow(unused)]
127 const VIRTIO_IOMMU_S_RANGE: u8 = 5;
128 #[allow(unused)]
129 const VIRTIO_IOMMU_S_NOENT: u8 = 6;
130 #[allow(unused)]
131 const VIRTIO_IOMMU_S_FAULT: u8 = 7;
132 #[allow(unused)]
133 const VIRTIO_IOMMU_S_NOMEM: u8 = 8;
134 
135 #[derive(Copy, Clone, Debug, Default)]
136 #[repr(C, packed)]
137 #[allow(dead_code)]
138 struct VirtioIommuReqTail {
139     status: u8,
140     _reserved: [u8; 3],
141 }
142 
143 /// ATTACH request
144 #[derive(Copy, Clone, Debug, Default)]
145 #[repr(C, packed)]
146 struct VirtioIommuReqAttach {
147     domain: u32,
148     endpoint: u32,
149     flags: u32,
150     _reserved: [u8; 4],
151 }
152 
153 const VIRTIO_IOMMU_ATTACH_F_BYPASS: u32 = 1;
154 
155 /// DETACH request
156 #[derive(Copy, Clone, Debug, Default)]
157 #[repr(C, packed)]
158 struct VirtioIommuReqDetach {
159     domain: u32,
160     endpoint: u32,
161     _reserved: [u8; 8],
162 }
163 
164 /// Virtio IOMMU request MAP flags
165 #[allow(unused)]
166 const VIRTIO_IOMMU_MAP_F_READ: u32 = 1;
167 #[allow(unused)]
168 const VIRTIO_IOMMU_MAP_F_WRITE: u32 = 1 << 1;
169 #[allow(unused)]
170 const VIRTIO_IOMMU_MAP_F_MMIO: u32 = 1 << 2;
171 #[allow(unused)]
172 const VIRTIO_IOMMU_MAP_F_MASK: u32 =
173     VIRTIO_IOMMU_MAP_F_READ | VIRTIO_IOMMU_MAP_F_WRITE | VIRTIO_IOMMU_MAP_F_MMIO;
174 
175 /// MAP request
176 #[derive(Copy, Clone, Debug, Default)]
177 #[repr(C, packed)]
178 struct VirtioIommuReqMap {
179     domain: u32,
180     virt_start: u64,
181     virt_end: u64,
182     phys_start: u64,
183     _flags: u32,
184 }
185 
186 /// UNMAP request
187 #[derive(Copy, Clone, Debug, Default)]
188 #[repr(C, packed)]
189 struct VirtioIommuReqUnmap {
190     domain: u32,
191     virt_start: u64,
192     virt_end: u64,
193     _reserved: [u8; 4],
194 }
195 
196 /// Virtio IOMMU request PROBE types
197 #[allow(unused)]
198 const VIRTIO_IOMMU_PROBE_T_NONE: u16 = 0;
199 const VIRTIO_IOMMU_PROBE_T_RESV_MEM: u16 = 1;
200 #[allow(unused)]
201 const VIRTIO_IOMMU_PROBE_T_MASK: u16 = 0xfff;
202 
203 /// PROBE request
204 #[derive(Copy, Clone, Debug, Default)]
205 #[repr(C, packed)]
206 #[allow(dead_code)]
207 struct VirtioIommuReqProbe {
208     endpoint: u32,
209     _reserved: [u64; 8],
210 }
211 
212 #[derive(Copy, Clone, Debug, Default)]
213 #[repr(C, packed)]
214 #[allow(dead_code)]
215 struct VirtioIommuProbeProperty {
216     type_: u16,
217     length: u16,
218 }
219 
220 /// Virtio IOMMU request PROBE property RESV_MEM subtypes
221 #[allow(unused)]
222 const VIRTIO_IOMMU_RESV_MEM_T_RESERVED: u8 = 0;
223 const VIRTIO_IOMMU_RESV_MEM_T_MSI: u8 = 1;
224 
225 #[derive(Copy, Clone, Debug, Default)]
226 #[repr(C, packed)]
227 #[allow(dead_code)]
228 struct VirtioIommuProbeResvMem {
229     subtype: u8,
230     _reserved: [u8; 3],
231     start: u64,
232     end: u64,
233 }
234 
235 /// Virtio IOMMU fault flags
236 #[allow(unused)]
237 const VIRTIO_IOMMU_FAULT_F_READ: u32 = 1;
238 #[allow(unused)]
239 const VIRTIO_IOMMU_FAULT_F_WRITE: u32 = 1 << 1;
240 #[allow(unused)]
241 const VIRTIO_IOMMU_FAULT_F_EXEC: u32 = 1 << 2;
242 #[allow(unused)]
243 const VIRTIO_IOMMU_FAULT_F_ADDRESS: u32 = 1 << 8;
244 
245 /// Virtio IOMMU fault reasons
246 #[allow(unused)]
247 const VIRTIO_IOMMU_FAULT_R_UNKNOWN: u32 = 0;
248 #[allow(unused)]
249 const VIRTIO_IOMMU_FAULT_R_DOMAIN: u32 = 1;
250 #[allow(unused)]
251 const VIRTIO_IOMMU_FAULT_R_MAPPING: u32 = 2;
252 
253 /// Fault reporting through eventq
254 #[allow(unused)]
255 #[derive(Copy, Clone, Debug, Default)]
256 #[repr(C, packed)]
257 struct VirtioIommuFault {
258     reason: u8,
259     reserved: [u8; 3],
260     flags: u32,
261     endpoint: u32,
262     reserved2: [u8; 4],
263     address: u64,
264 }
265 
266 // SAFETY: data structure only contain integers and have no implicit padding
267 unsafe impl ByteValued for VirtioIommuRange32 {}
268 // SAFETY: data structure only contain integers and have no implicit padding
269 unsafe impl ByteValued for VirtioIommuRange64 {}
270 // SAFETY: data structure only contain integers and have no implicit padding
271 unsafe impl ByteValued for VirtioIommuConfig {}
272 // SAFETY: data structure only contain integers and have no implicit padding
273 unsafe impl ByteValued for VirtioIommuReqHead {}
274 // SAFETY: data structure only contain integers and have no implicit padding
275 unsafe impl ByteValued for VirtioIommuReqTail {}
276 // SAFETY: data structure only contain integers and have no implicit padding
277 unsafe impl ByteValued for VirtioIommuReqAttach {}
278 // SAFETY: data structure only contain integers and have no implicit padding
279 unsafe impl ByteValued for VirtioIommuReqDetach {}
280 // SAFETY: data structure only contain integers and have no implicit padding
281 unsafe impl ByteValued for VirtioIommuReqMap {}
282 // SAFETY: data structure only contain integers and have no implicit padding
283 unsafe impl ByteValued for VirtioIommuReqUnmap {}
284 // SAFETY: data structure only contain integers and have no implicit padding
285 unsafe impl ByteValued for VirtioIommuReqProbe {}
286 // SAFETY: data structure only contain integers and have no implicit padding
287 unsafe impl ByteValued for VirtioIommuProbeProperty {}
288 // SAFETY: data structure only contain integers and have no implicit padding
289 unsafe impl ByteValued for VirtioIommuProbeResvMem {}
290 // SAFETY: data structure only contain integers and have no implicit padding
291 unsafe impl ByteValued for VirtioIommuFault {}
292 
293 #[derive(Error, Debug)]
294 enum Error {
295     #[error("Guest gave us bad memory addresses: {0}")]
296     GuestMemory(#[source] GuestMemoryError),
297     #[error("Guest gave us a write only descriptor that protocol says to read from")]
298     UnexpectedWriteOnlyDescriptor,
299     #[error("Guest gave us a read only descriptor that protocol says to write to")]
300     UnexpectedReadOnlyDescriptor,
301     #[error("Guest gave us too few descriptors in a descriptor chain")]
302     DescriptorChainTooShort,
303     #[error("Guest gave us a buffer that was too short to use")]
304     BufferLengthTooSmall,
305     #[error("Guest sent us invalid request")]
306     InvalidRequest,
307     #[error("Guest sent us invalid ATTACH request")]
308     InvalidAttachRequest,
309     #[error("Guest sent us invalid DETACH request")]
310     InvalidDetachRequest,
311     #[error("Guest sent us invalid MAP request")]
312     InvalidMapRequest,
313     #[error("Invalid to map because the domain is in bypass mode")]
314     InvalidMapRequestBypassDomain,
315     #[error("Invalid to map because the domain is missing")]
316     InvalidMapRequestMissingDomain,
317     #[error("Guest sent us invalid UNMAP request")]
318     InvalidUnmapRequest,
319     #[error("Invalid to unmap because the domain is in bypass mode")]
320     InvalidUnmapRequestBypassDomain,
321     #[error("Invalid to unmap because the domain is missing")]
322     InvalidUnmapRequestMissingDomain,
323     #[error("Guest sent us invalid PROBE request")]
324     InvalidProbeRequest,
325     #[error("Failed to performing external mapping: {0}")]
326     ExternalMapping(#[source] io::Error),
327     #[error("Failed to performing external unmapping: {0}")]
328     ExternalUnmapping(#[source] io::Error),
329     #[error("Failed adding used index: {0}")]
330     QueueAddUsed(#[source] virtio_queue::Error),
331 }
332 
333 struct Request {}
334 
335 impl Request {
336     // Parse the available vring buffer. Based on the hashmap table of external
337     // mappings required from various devices such as VFIO or vhost-user ones,
338     // this function might update the hashmap table of external mappings per
339     // domain.
340     // Basically, the VMM knows about the device_id <=> mapping relationship
341     // before running the VM, but at runtime, a new domain <=> mapping hashmap
342     // is created based on the information provided from the guest driver for
343     // virtio-iommu (giving the link device_id <=> domain).
344     fn parse(
345         desc_chain: &mut DescriptorChain<GuestMemoryLoadGuard<GuestMemoryMmap>>,
346         mapping: &Arc<IommuMapping>,
347         ext_mapping: &BTreeMap<u32, Arc<dyn ExternalDmaMapping>>,
348         msi_iova_space: (u64, u64),
349     ) -> result::Result<usize, Error> {
350         let desc = desc_chain
351             .next()
352             .ok_or(Error::DescriptorChainTooShort)
353             .inspect_err(|_| {
354                 error!("Missing head descriptor");
355             })?;
356 
357         // The descriptor contains the request type which MUST be readable.
358         if desc.is_write_only() {
359             return Err(Error::UnexpectedWriteOnlyDescriptor);
360         }
361 
362         if (desc.len() as usize) < size_of::<VirtioIommuReqHead>() {
363             return Err(Error::InvalidRequest);
364         }
365 
366         let req_head: VirtioIommuReqHead = desc_chain
367             .memory()
368             .read_obj(desc.addr())
369             .map_err(Error::GuestMemory)?;
370         let req_offset = size_of::<VirtioIommuReqHead>();
371         let desc_size_left = (desc.len() as usize) - req_offset;
372         let req_addr = if let Some(addr) = desc.addr().checked_add(req_offset as u64) {
373             addr
374         } else {
375             return Err(Error::InvalidRequest);
376         };
377 
378         let (msi_iova_start, msi_iova_end) = msi_iova_space;
379 
380         // Create the reply
381         let mut reply: Vec<u8> = Vec::new();
382         let mut status = VIRTIO_IOMMU_S_OK;
383         let mut hdr_len = 0;
384 
385         let result = (|| {
386             match req_head.type_ {
387                 VIRTIO_IOMMU_T_ATTACH => {
388                     if desc_size_left != size_of::<VirtioIommuReqAttach>() {
389                         status = VIRTIO_IOMMU_S_INVAL;
390                         return Err(Error::InvalidAttachRequest);
391                     }
392 
393                     let req: VirtioIommuReqAttach = desc_chain
394                         .memory()
395                         .read_obj(req_addr as GuestAddress)
396                         .map_err(Error::GuestMemory)?;
397                     debug!("Attach request 0x{:x?}", req);
398 
399                     // Copy the value to use it as a proper reference.
400                     let domain_id = req.domain;
401                     let endpoint = req.endpoint;
402                     let bypass =
403                         (req.flags & VIRTIO_IOMMU_ATTACH_F_BYPASS) == VIRTIO_IOMMU_ATTACH_F_BYPASS;
404 
405                     let mut old_domain_id = domain_id;
406                     if let Some(&id) = mapping.endpoints.read().unwrap().get(&endpoint) {
407                         old_domain_id = id;
408                     }
409 
410                     if old_domain_id != domain_id {
411                         detach_endpoint_from_domain(endpoint, old_domain_id, mapping, ext_mapping)?;
412                     }
413 
414                     // Add endpoint associated with specific domain
415                     mapping
416                         .endpoints
417                         .write()
418                         .unwrap()
419                         .insert(endpoint, domain_id);
420 
421                     // If any other mappings exist in the domain for other containers,
422                     // make sure to issue these mappings for the new endpoint/container
423                     if let Some(domain_mappings) = &mapping.domains.read().unwrap().get(&domain_id)
424                     {
425                         if let Some(ext_map) = ext_mapping.get(&endpoint) {
426                             for (virt_start, addr_map) in &domain_mappings.mappings {
427                                 ext_map
428                                     .map(*virt_start, addr_map.gpa, addr_map.size)
429                                     .map_err(Error::ExternalUnmapping)?;
430                             }
431                         }
432                     }
433 
434                     // Add new domain with no mapping if the entry didn't exist yet
435                     let mut domains = mapping.domains.write().unwrap();
436                     let domain = Domain {
437                         mappings: BTreeMap::new(),
438                         bypass,
439                     };
440                     domains.entry(domain_id).or_insert_with(|| domain);
441                 }
442                 VIRTIO_IOMMU_T_DETACH => {
443                     if desc_size_left != size_of::<VirtioIommuReqDetach>() {
444                         status = VIRTIO_IOMMU_S_INVAL;
445                         return Err(Error::InvalidDetachRequest);
446                     }
447 
448                     let req: VirtioIommuReqDetach = desc_chain
449                         .memory()
450                         .read_obj(req_addr as GuestAddress)
451                         .map_err(Error::GuestMemory)?;
452                     debug!("Detach request 0x{:x?}", req);
453 
454                     // Copy the value to use it as a proper reference.
455                     let domain_id = req.domain;
456                     let endpoint = req.endpoint;
457 
458                     // Remove endpoint associated with specific domain
459                     detach_endpoint_from_domain(endpoint, domain_id, mapping, ext_mapping)?;
460                 }
461                 VIRTIO_IOMMU_T_MAP => {
462                     if desc_size_left != size_of::<VirtioIommuReqMap>() {
463                         status = VIRTIO_IOMMU_S_INVAL;
464                         return Err(Error::InvalidMapRequest);
465                     }
466 
467                     let req: VirtioIommuReqMap = desc_chain
468                         .memory()
469                         .read_obj(req_addr as GuestAddress)
470                         .map_err(Error::GuestMemory)?;
471                     debug!("Map request 0x{:x?}", req);
472 
473                     // Copy the value to use it as a proper reference.
474                     let domain_id = req.domain;
475 
476                     if let Some(domain) = mapping.domains.read().unwrap().get(&domain_id) {
477                         if domain.bypass {
478                             status = VIRTIO_IOMMU_S_INVAL;
479                             return Err(Error::InvalidMapRequestBypassDomain);
480                         }
481                     } else {
482                         status = VIRTIO_IOMMU_S_INVAL;
483                         return Err(Error::InvalidMapRequestMissingDomain);
484                     }
485 
486                     // Find the list of endpoints attached to the given domain.
487                     let endpoints: Vec<u32> = mapping
488                         .endpoints
489                         .write()
490                         .unwrap()
491                         .iter()
492                         .filter(|(_, &d)| d == domain_id)
493                         .map(|(&e, _)| e)
494                         .collect();
495 
496                     // For viommu all endpoints receive their own VFIO container, as a result
497                     // Each endpoint within the domain needs to be separately mapped, as the
498                     // mapping is done on a per-container level, not a per-domain level
499                     for endpoint in endpoints {
500                         if let Some(ext_map) = ext_mapping.get(&endpoint) {
501                             let size = req.virt_end - req.virt_start + 1;
502                             ext_map
503                                 .map(req.virt_start, req.phys_start, size)
504                                 .map_err(Error::ExternalMapping)?;
505                         }
506                     }
507 
508                     // Add new mapping associated with the domain
509                     mapping
510                         .domains
511                         .write()
512                         .unwrap()
513                         .get_mut(&domain_id)
514                         .unwrap()
515                         .mappings
516                         .insert(
517                             req.virt_start,
518                             Mapping {
519                                 gpa: req.phys_start,
520                                 size: req.virt_end - req.virt_start + 1,
521                             },
522                         );
523                 }
524                 VIRTIO_IOMMU_T_UNMAP => {
525                     if desc_size_left != size_of::<VirtioIommuReqUnmap>() {
526                         status = VIRTIO_IOMMU_S_INVAL;
527                         return Err(Error::InvalidUnmapRequest);
528                     }
529 
530                     let req: VirtioIommuReqUnmap = desc_chain
531                         .memory()
532                         .read_obj(req_addr as GuestAddress)
533                         .map_err(Error::GuestMemory)?;
534                     debug!("Unmap request 0x{:x?}", req);
535 
536                     // Copy the value to use it as a proper reference.
537                     let domain_id = req.domain;
538                     let virt_start = req.virt_start;
539 
540                     if let Some(domain) = mapping.domains.read().unwrap().get(&domain_id) {
541                         if domain.bypass {
542                             status = VIRTIO_IOMMU_S_INVAL;
543                             return Err(Error::InvalidUnmapRequestBypassDomain);
544                         }
545                     } else {
546                         status = VIRTIO_IOMMU_S_INVAL;
547                         return Err(Error::InvalidUnmapRequestMissingDomain);
548                     }
549 
550                     // Find the list of endpoints attached to the given domain.
551                     let endpoints: Vec<u32> = mapping
552                         .endpoints
553                         .write()
554                         .unwrap()
555                         .iter()
556                         .filter(|(_, &d)| d == domain_id)
557                         .map(|(&e, _)| e)
558                         .collect();
559 
560                     // Trigger external unmapping if necessary.
561                     for endpoint in endpoints {
562                         if let Some(ext_map) = ext_mapping.get(&endpoint) {
563                             let size = req.virt_end - virt_start + 1;
564                             ext_map
565                                 .unmap(virt_start, size)
566                                 .map_err(Error::ExternalUnmapping)?;
567                         }
568                     }
569 
570                     // Remove all mappings associated with the domain within the requested range
571                     mapping
572                         .domains
573                         .write()
574                         .unwrap()
575                         .get_mut(&domain_id)
576                         .unwrap()
577                         .mappings
578                         .retain(|&x, _| (x < req.virt_start || x > req.virt_end));
579                 }
580                 VIRTIO_IOMMU_T_PROBE => {
581                     if desc_size_left != size_of::<VirtioIommuReqProbe>() {
582                         status = VIRTIO_IOMMU_S_INVAL;
583                         return Err(Error::InvalidProbeRequest);
584                     }
585 
586                     let req: VirtioIommuReqProbe = desc_chain
587                         .memory()
588                         .read_obj(req_addr as GuestAddress)
589                         .map_err(Error::GuestMemory)?;
590                     debug!("Probe request 0x{:x?}", req);
591 
592                     let probe_prop = VirtioIommuProbeProperty {
593                         type_: VIRTIO_IOMMU_PROBE_T_RESV_MEM,
594                         length: size_of::<VirtioIommuProbeResvMem>() as u16,
595                     };
596                     reply.extend_from_slice(probe_prop.as_slice());
597 
598                     let resv_mem = VirtioIommuProbeResvMem {
599                         subtype: VIRTIO_IOMMU_RESV_MEM_T_MSI,
600                         start: msi_iova_start,
601                         end: msi_iova_end,
602                         ..Default::default()
603                     };
604                     reply.extend_from_slice(resv_mem.as_slice());
605 
606                     hdr_len = PROBE_PROP_SIZE;
607                 }
608                 _ => {
609                     status = VIRTIO_IOMMU_S_INVAL;
610                     return Err(Error::InvalidRequest);
611                 }
612             }
613             Ok(())
614         })();
615 
616         let status_desc = desc_chain.next().ok_or(Error::DescriptorChainTooShort)?;
617 
618         // The status MUST always be writable
619         if !status_desc.is_write_only() {
620             return Err(Error::UnexpectedReadOnlyDescriptor);
621         }
622 
623         if status_desc.len() < hdr_len + size_of::<VirtioIommuReqTail>() as u32 {
624             return Err(Error::BufferLengthTooSmall);
625         }
626 
627         let tail = VirtioIommuReqTail {
628             status,
629             ..Default::default()
630         };
631         reply.extend_from_slice(tail.as_slice());
632 
633         // Make sure we return the result of the request to the guest before
634         // we return a potential error internally.
635         desc_chain
636             .memory()
637             .write_slice(reply.as_slice(), status_desc.addr())
638             .map_err(Error::GuestMemory)?;
639 
640         // Return the error if the result was not Ok().
641         result?;
642 
643         Ok((hdr_len as usize) + size_of::<VirtioIommuReqTail>())
644     }
645 }
646 
647 fn detach_endpoint_from_domain(
648     endpoint: u32,
649     domain_id: u32,
650     mapping: &Arc<IommuMapping>,
651     ext_mapping: &BTreeMap<u32, Arc<dyn ExternalDmaMapping>>,
652 ) -> result::Result<(), Error> {
653     // Remove endpoint associated with specific domain
654     mapping.endpoints.write().unwrap().remove(&endpoint);
655 
656     // Trigger external unmapping for the endpoint if necessary.
657     if let Some(domain_mappings) = &mapping.domains.read().unwrap().get(&domain_id) {
658         if let Some(ext_map) = ext_mapping.get(&endpoint) {
659             for (virt_start, addr_map) in &domain_mappings.mappings {
660                 ext_map
661                     .unmap(*virt_start, addr_map.size)
662                     .map_err(Error::ExternalUnmapping)?;
663             }
664         }
665     }
666 
667     if mapping
668         .endpoints
669         .write()
670         .unwrap()
671         .iter()
672         .filter(|(_, &d)| d == domain_id)
673         .count()
674         == 0
675     {
676         mapping.domains.write().unwrap().remove(&domain_id);
677     }
678 
679     Ok(())
680 }
681 
682 struct IommuEpollHandler {
683     mem: GuestMemoryAtomic<GuestMemoryMmap>,
684     request_queue: Queue,
685     _event_queue: Queue,
686     interrupt_cb: Arc<dyn VirtioInterrupt>,
687     request_queue_evt: EventFd,
688     _event_queue_evt: EventFd,
689     kill_evt: EventFd,
690     pause_evt: EventFd,
691     mapping: Arc<IommuMapping>,
692     ext_mapping: Arc<Mutex<BTreeMap<u32, Arc<dyn ExternalDmaMapping>>>>,
693     msi_iova_space: (u64, u64),
694 }
695 
696 impl IommuEpollHandler {
697     fn request_queue(&mut self) -> Result<bool, Error> {
698         let mut used_descs = false;
699         while let Some(mut desc_chain) = self.request_queue.pop_descriptor_chain(self.mem.memory())
700         {
701             let len = Request::parse(
702                 &mut desc_chain,
703                 &self.mapping,
704                 &self.ext_mapping.lock().unwrap(),
705                 self.msi_iova_space,
706             )?;
707 
708             self.request_queue
709                 .add_used(desc_chain.memory(), desc_chain.head_index(), len as u32)
710                 .map_err(Error::QueueAddUsed)?;
711 
712             used_descs = true;
713         }
714 
715         Ok(used_descs)
716     }
717 
718     fn signal_used_queue(&self, queue_index: u16) -> result::Result<(), DeviceError> {
719         self.interrupt_cb
720             .trigger(VirtioInterruptType::Queue(queue_index))
721             .map_err(|e| {
722                 error!("Failed to signal used queue: {:?}", e);
723                 DeviceError::FailedSignalingUsedQueue(e)
724             })
725     }
726 
727     fn run(
728         &mut self,
729         paused: Arc<AtomicBool>,
730         paused_sync: Arc<Barrier>,
731     ) -> result::Result<(), EpollHelperError> {
732         let mut helper = EpollHelper::new(&self.kill_evt, &self.pause_evt)?;
733         helper.add_event(self.request_queue_evt.as_raw_fd(), REQUEST_Q_EVENT)?;
734         helper.run(paused, paused_sync, self)?;
735 
736         Ok(())
737     }
738 }
739 
740 impl EpollHelperHandler for IommuEpollHandler {
741     fn handle_event(
742         &mut self,
743         _helper: &mut EpollHelper,
744         event: &epoll::Event,
745     ) -> result::Result<(), EpollHelperError> {
746         let ev_type = event.data as u16;
747         match ev_type {
748             REQUEST_Q_EVENT => {
749                 self.request_queue_evt.read().map_err(|e| {
750                     EpollHelperError::HandleEvent(anyhow!("Failed to get queue event: {:?}", e))
751                 })?;
752 
753                 let needs_notification = self.request_queue().map_err(|e| {
754                     EpollHelperError::HandleEvent(anyhow!(
755                         "Failed to process request queue : {:?}",
756                         e
757                     ))
758                 })?;
759                 if needs_notification {
760                     self.signal_used_queue(0).map_err(|e| {
761                         EpollHelperError::HandleEvent(anyhow!(
762                             "Failed to signal used queue: {:?}",
763                             e
764                         ))
765                     })?;
766                 }
767             }
768             _ => {
769                 return Err(EpollHelperError::HandleEvent(anyhow!(
770                     "Unexpected event: {}",
771                     ev_type
772                 )));
773             }
774         }
775         Ok(())
776     }
777 }
778 
779 #[derive(Clone, Copy, Debug, Serialize, Deserialize)]
780 struct Mapping {
781     gpa: u64,
782     size: u64,
783 }
784 
785 #[derive(Clone, Debug)]
786 struct Domain {
787     mappings: BTreeMap<u64, Mapping>,
788     bypass: bool,
789 }
790 
791 #[derive(Debug)]
792 pub struct IommuMapping {
793     // Domain related to an endpoint.
794     endpoints: Arc<RwLock<BTreeMap<u32, u32>>>,
795     // Information related to each domain.
796     domains: Arc<RwLock<BTreeMap<u32, Domain>>>,
797     // Global flag indicating if endpoints that are not attached to any domain
798     // are in bypass mode.
799     bypass: AtomicBool,
800 }
801 
802 impl DmaRemapping for IommuMapping {
803     fn translate_gva(&self, id: u32, addr: u64) -> std::result::Result<u64, std::io::Error> {
804         debug!("Translate GVA addr 0x{:x}", addr);
805         if let Some(domain_id) = self.endpoints.read().unwrap().get(&id) {
806             if let Some(domain) = self.domains.read().unwrap().get(domain_id) {
807                 // Directly return identity mapping in case the domain is in
808                 // bypass mode.
809                 if domain.bypass {
810                     return Ok(addr);
811                 }
812 
813                 for (&key, &value) in domain.mappings.iter() {
814                     if addr >= key && addr < key + value.size {
815                         let new_addr = addr - key + value.gpa;
816                         debug!("Into GPA addr 0x{:x}", new_addr);
817                         return Ok(new_addr);
818                     }
819                 }
820             }
821         } else if self.bypass.load(Ordering::Acquire) {
822             return Ok(addr);
823         }
824 
825         Err(io::Error::other(format!(
826             "failed to translate GVA addr 0x{addr:x}"
827         )))
828     }
829 
830     fn translate_gpa(&self, id: u32, addr: u64) -> std::result::Result<u64, std::io::Error> {
831         debug!("Translate GPA addr 0x{:x}", addr);
832         if let Some(domain_id) = self.endpoints.read().unwrap().get(&id) {
833             if let Some(domain) = self.domains.read().unwrap().get(domain_id) {
834                 // Directly return identity mapping in case the domain is in
835                 // bypass mode.
836                 if domain.bypass {
837                     return Ok(addr);
838                 }
839 
840                 for (&key, &value) in domain.mappings.iter() {
841                     if addr >= value.gpa && addr < value.gpa + value.size {
842                         let new_addr = addr - value.gpa + key;
843                         debug!("Into GVA addr 0x{:x}", new_addr);
844                         return Ok(new_addr);
845                     }
846                 }
847             }
848         } else if self.bypass.load(Ordering::Acquire) {
849             return Ok(addr);
850         }
851 
852         Err(io::Error::other(format!(
853             "failed to translate GPA addr 0x{addr:x}"
854         )))
855     }
856 }
857 
858 #[derive(Debug)]
859 pub struct AccessPlatformMapping {
860     id: u32,
861     mapping: Arc<IommuMapping>,
862 }
863 
864 impl AccessPlatformMapping {
865     pub fn new(id: u32, mapping: Arc<IommuMapping>) -> Self {
866         AccessPlatformMapping { id, mapping }
867     }
868 }
869 
870 impl AccessPlatform for AccessPlatformMapping {
871     fn translate_gva(&self, base: u64, _size: u64) -> std::result::Result<u64, std::io::Error> {
872         self.mapping.translate_gva(self.id, base)
873     }
874     fn translate_gpa(&self, base: u64, _size: u64) -> std::result::Result<u64, std::io::Error> {
875         self.mapping.translate_gpa(self.id, base)
876     }
877 }
878 
879 pub struct Iommu {
880     common: VirtioCommon,
881     id: String,
882     config: VirtioIommuConfig,
883     mapping: Arc<IommuMapping>,
884     ext_mapping: Arc<Mutex<BTreeMap<u32, Arc<dyn ExternalDmaMapping>>>>,
885     seccomp_action: SeccompAction,
886     exit_evt: EventFd,
887     msi_iova_space: (u64, u64),
888 }
889 
890 type EndpointsState = Vec<(u32, u32)>;
891 type DomainsState = Vec<(u32, (Vec<(u64, Mapping)>, bool))>;
892 
893 #[derive(Serialize, Deserialize)]
894 pub struct IommuState {
895     avail_features: u64,
896     acked_features: u64,
897     endpoints: EndpointsState,
898     domains: DomainsState,
899 }
900 
901 impl Iommu {
902     pub fn new(
903         id: String,
904         seccomp_action: SeccompAction,
905         exit_evt: EventFd,
906         msi_iova_space: (u64, u64),
907         address_width_bits: u8,
908         state: Option<IommuState>,
909     ) -> io::Result<(Self, Arc<IommuMapping>)> {
910         let (mut avail_features, acked_features, endpoints, domains, paused) =
911             if let Some(state) = state {
912                 info!("Restoring virtio-iommu {}", id);
913                 (
914                     state.avail_features,
915                     state.acked_features,
916                     state.endpoints.into_iter().collect(),
917                     state
918                         .domains
919                         .into_iter()
920                         .map(|(k, v)| {
921                             (
922                                 k,
923                                 Domain {
924                                     mappings: v.0.into_iter().collect(),
925                                     bypass: v.1,
926                                 },
927                             )
928                         })
929                         .collect(),
930                     true,
931                 )
932             } else {
933                 let avail_features = (1u64 << VIRTIO_F_VERSION_1)
934                     | (1u64 << VIRTIO_IOMMU_F_MAP_UNMAP)
935                     | (1u64 << VIRTIO_IOMMU_F_PROBE)
936                     | (1u64 << VIRTIO_IOMMU_F_BYPASS_CONFIG);
937 
938                 (avail_features, 0, BTreeMap::new(), BTreeMap::new(), false)
939             };
940 
941         let mut config = VirtioIommuConfig {
942             page_size_mask: VIRTIO_IOMMU_PAGE_SIZE_MASK,
943             probe_size: PROBE_PROP_SIZE,
944             ..Default::default()
945         };
946 
947         if address_width_bits < 64 {
948             avail_features |= 1u64 << VIRTIO_IOMMU_F_INPUT_RANGE;
949             config.input_range = VirtioIommuRange64 {
950                 start: 0,
951                 end: (1u64 << address_width_bits) - 1,
952             }
953         }
954 
955         let mapping = Arc::new(IommuMapping {
956             endpoints: Arc::new(RwLock::new(endpoints)),
957             domains: Arc::new(RwLock::new(domains)),
958             bypass: AtomicBool::new(true),
959         });
960 
961         Ok((
962             Iommu {
963                 id,
964                 common: VirtioCommon {
965                     device_type: VirtioDeviceType::Iommu as u32,
966                     queue_sizes: QUEUE_SIZES.to_vec(),
967                     avail_features,
968                     acked_features,
969                     paused_sync: Some(Arc::new(Barrier::new(2))),
970                     min_queues: NUM_QUEUES as u16,
971                     paused: Arc::new(AtomicBool::new(paused)),
972                     ..Default::default()
973                 },
974                 config,
975                 mapping: mapping.clone(),
976                 ext_mapping: Arc::new(Mutex::new(BTreeMap::new())),
977                 seccomp_action,
978                 exit_evt,
979                 msi_iova_space,
980             },
981             mapping,
982         ))
983     }
984 
985     fn state(&self) -> IommuState {
986         IommuState {
987             avail_features: self.common.avail_features,
988             acked_features: self.common.acked_features,
989             endpoints: self
990                 .mapping
991                 .endpoints
992                 .read()
993                 .unwrap()
994                 .clone()
995                 .into_iter()
996                 .collect(),
997             domains: self
998                 .mapping
999                 .domains
1000                 .read()
1001                 .unwrap()
1002                 .clone()
1003                 .into_iter()
1004                 .map(|(k, v)| (k, (v.mappings.into_iter().collect(), v.bypass)))
1005                 .collect(),
1006         }
1007     }
1008 
1009     fn update_bypass(&mut self) {
1010         // Use bypass from config if VIRTIO_IOMMU_F_BYPASS_CONFIG has been negotiated
1011         if !self
1012             .common
1013             .feature_acked(VIRTIO_IOMMU_F_BYPASS_CONFIG.into())
1014         {
1015             return;
1016         }
1017 
1018         let bypass = self.config.bypass == 1;
1019         info!("Updating bypass mode to {}", bypass);
1020         self.mapping.bypass.store(bypass, Ordering::Release);
1021     }
1022 
1023     pub fn add_external_mapping(&mut self, device_id: u32, mapping: Arc<dyn ExternalDmaMapping>) {
1024         self.ext_mapping.lock().unwrap().insert(device_id, mapping);
1025     }
1026 
1027     #[cfg(fuzzing)]
1028     pub fn wait_for_epoll_threads(&mut self) {
1029         self.common.wait_for_epoll_threads();
1030     }
1031 }
1032 
1033 impl Drop for Iommu {
1034     fn drop(&mut self) {
1035         if let Some(kill_evt) = self.common.kill_evt.take() {
1036             // Ignore the result because there is nothing we can do about it.
1037             let _ = kill_evt.write(1);
1038         }
1039         self.common.wait_for_epoll_threads();
1040     }
1041 }
1042 
1043 impl VirtioDevice for Iommu {
1044     fn device_type(&self) -> u32 {
1045         self.common.device_type
1046     }
1047 
1048     fn queue_max_sizes(&self) -> &[u16] {
1049         &self.common.queue_sizes
1050     }
1051 
1052     fn features(&self) -> u64 {
1053         self.common.avail_features
1054     }
1055 
1056     fn ack_features(&mut self, value: u64) {
1057         self.common.ack_features(value)
1058     }
1059 
1060     fn read_config(&self, offset: u64, data: &mut [u8]) {
1061         self.read_config_from_slice(self.config.as_slice(), offset, data);
1062     }
1063 
1064     fn write_config(&mut self, offset: u64, data: &[u8]) {
1065         // The "bypass" field is the only mutable field
1066         let bypass_offset =
1067             (&self.config.bypass as *const _ as u64) - (&self.config as *const _ as u64);
1068         if offset != bypass_offset || data.len() != std::mem::size_of_val(&self.config.bypass) {
1069             error!(
1070                 "Attempt to write to read-only field: offset {:x} length {}",
1071                 offset,
1072                 data.len()
1073             );
1074             return;
1075         }
1076 
1077         self.config.bypass = data[0];
1078 
1079         self.update_bypass();
1080     }
1081 
1082     fn activate(
1083         &mut self,
1084         mem: GuestMemoryAtomic<GuestMemoryMmap>,
1085         interrupt_cb: Arc<dyn VirtioInterrupt>,
1086         mut queues: Vec<(usize, Queue, EventFd)>,
1087     ) -> ActivateResult {
1088         self.common.activate(&queues, &interrupt_cb)?;
1089         let (kill_evt, pause_evt) = self.common.dup_eventfds();
1090 
1091         let (_, request_queue, request_queue_evt) = queues.remove(0);
1092         let (_, _event_queue, _event_queue_evt) = queues.remove(0);
1093 
1094         let mut handler = IommuEpollHandler {
1095             mem,
1096             request_queue,
1097             _event_queue,
1098             interrupt_cb,
1099             request_queue_evt,
1100             _event_queue_evt,
1101             kill_evt,
1102             pause_evt,
1103             mapping: self.mapping.clone(),
1104             ext_mapping: self.ext_mapping.clone(),
1105             msi_iova_space: self.msi_iova_space,
1106         };
1107 
1108         let paused = self.common.paused.clone();
1109         let paused_sync = self.common.paused_sync.clone();
1110         let mut epoll_threads = Vec::new();
1111         spawn_virtio_thread(
1112             &self.id,
1113             &self.seccomp_action,
1114             Thread::VirtioIommu,
1115             &mut epoll_threads,
1116             &self.exit_evt,
1117             move || handler.run(paused, paused_sync.unwrap()),
1118         )?;
1119 
1120         self.common.epoll_threads = Some(epoll_threads);
1121 
1122         event!("virtio-device", "activated", "id", &self.id);
1123         Ok(())
1124     }
1125 
1126     fn reset(&mut self) -> Option<Arc<dyn VirtioInterrupt>> {
1127         let result = self.common.reset();
1128         event!("virtio-device", "reset", "id", &self.id);
1129         result
1130     }
1131 }
1132 
1133 impl Pausable for Iommu {
1134     fn pause(&mut self) -> result::Result<(), MigratableError> {
1135         self.common.pause()
1136     }
1137 
1138     fn resume(&mut self) -> result::Result<(), MigratableError> {
1139         self.common.resume()
1140     }
1141 }
1142 
1143 impl Snapshottable for Iommu {
1144     fn id(&self) -> String {
1145         self.id.clone()
1146     }
1147 
1148     fn snapshot(&mut self) -> std::result::Result<Snapshot, MigratableError> {
1149         Snapshot::new_from_state(&self.state())
1150     }
1151 }
1152 impl Transportable for Iommu {}
1153 impl Migratable for Iommu {}
1154