1 // Copyright © 2019 Intel Corporation
2 //
3 // SPDX-License-Identifier: Apache-2.0 AND BSD-3-Clause
4
5 use std::collections::BTreeMap;
6 use std::mem::size_of;
7 use std::os::unix::io::AsRawFd;
8 use std::sync::atomic::{AtomicBool, Ordering};
9 use std::sync::{Arc, Barrier, Mutex, RwLock};
10 use std::{io, result};
11
12 use anyhow::anyhow;
13 use seccompiler::SeccompAction;
14 use serde::{Deserialize, Serialize};
15 use thiserror::Error;
16 use virtio_queue::{DescriptorChain, Queue, QueueT};
17 use vm_device::dma_mapping::ExternalDmaMapping;
18 use vm_memory::{
19 Address, ByteValued, Bytes, GuestAddress, GuestAddressSpace, GuestMemoryAtomic,
20 GuestMemoryError, GuestMemoryLoadGuard,
21 };
22 use vm_migration::{Migratable, MigratableError, Pausable, Snapshot, Snapshottable, Transportable};
23 use vm_virtio::AccessPlatform;
24 use vmm_sys_util::eventfd::EventFd;
25
26 use super::{
27 ActivateResult, EpollHelper, EpollHelperError, EpollHelperHandler, Error as DeviceError,
28 VirtioCommon, VirtioDevice, VirtioDeviceType, EPOLL_HELPER_EVENT_LAST, VIRTIO_F_VERSION_1,
29 };
30 use crate::seccomp_filters::Thread;
31 use crate::thread_helper::spawn_virtio_thread;
32 use crate::{DmaRemapping, GuestMemoryMmap, VirtioInterrupt, VirtioInterruptType};
33
34 /// Queues sizes
35 const QUEUE_SIZE: u16 = 256;
36 const NUM_QUEUES: usize = 2;
37 const QUEUE_SIZES: &[u16] = &[QUEUE_SIZE; NUM_QUEUES];
38
39 /// New descriptors are pending on the request queue.
40 /// "requestq" is meant to be used anytime an action is required to be
41 /// performed on behalf of the guest driver.
42 const REQUEST_Q_EVENT: u16 = EPOLL_HELPER_EVENT_LAST + 1;
43 /// New descriptors are pending on the event queue.
44 /// "eventq" lets the device report any fault or other asynchronous event to
45 /// the guest driver.
46 const _EVENT_Q_EVENT: u16 = EPOLL_HELPER_EVENT_LAST + 2;
47
48 /// PROBE properties size.
49 /// This is the minimal size to provide at least one RESV_MEM property.
50 /// Because virtio-iommu expects one MSI reserved region, we must provide it,
51 /// otherwise the driver in the guest will define a predefined one between
52 /// 0x8000000 and 0x80FFFFF, which is only relevant for ARM architecture, but
53 /// will conflict with x86.
54 const PROBE_PROP_SIZE: u32 =
55 (size_of::<VirtioIommuProbeProperty>() + size_of::<VirtioIommuProbeResvMem>()) as u32;
56
57 /// Virtio IOMMU features
58 #[allow(unused)]
59 const VIRTIO_IOMMU_F_INPUT_RANGE: u32 = 0;
60 #[allow(unused)]
61 const VIRTIO_IOMMU_F_DOMAIN_RANGE: u32 = 1;
62 #[allow(unused)]
63 const VIRTIO_IOMMU_F_MAP_UNMAP: u32 = 2;
64 #[allow(unused)]
65 const VIRTIO_IOMMU_F_BYPASS: u32 = 3;
66 const VIRTIO_IOMMU_F_PROBE: u32 = 4;
67 #[allow(unused)]
68 const VIRTIO_IOMMU_F_MMIO: u32 = 5;
69 const VIRTIO_IOMMU_F_BYPASS_CONFIG: u32 = 6;
70
71 // Support 2MiB and 4KiB page sizes.
72 const VIRTIO_IOMMU_PAGE_SIZE_MASK: u64 = (2 << 20) | (4 << 10);
73
74 #[derive(Copy, Clone, Debug, Default)]
75 #[repr(C, packed)]
76 #[allow(dead_code)]
77 struct VirtioIommuRange32 {
78 start: u32,
79 end: u32,
80 }
81
82 #[derive(Copy, Clone, Debug, Default)]
83 #[repr(C, packed)]
84 #[allow(dead_code)]
85 struct VirtioIommuRange64 {
86 start: u64,
87 end: u64,
88 }
89
90 #[derive(Copy, Clone, Debug, Default)]
91 #[repr(C, packed)]
92 #[allow(dead_code)]
93 struct VirtioIommuConfig {
94 page_size_mask: u64,
95 input_range: VirtioIommuRange64,
96 domain_range: VirtioIommuRange32,
97 probe_size: u32,
98 bypass: u8,
99 _reserved: [u8; 7],
100 }
101
102 /// Virtio IOMMU request type
103 const VIRTIO_IOMMU_T_ATTACH: u8 = 1;
104 const VIRTIO_IOMMU_T_DETACH: u8 = 2;
105 const VIRTIO_IOMMU_T_MAP: u8 = 3;
106 const VIRTIO_IOMMU_T_UNMAP: u8 = 4;
107 const VIRTIO_IOMMU_T_PROBE: u8 = 5;
108
109 #[derive(Copy, Clone, Debug, Default)]
110 #[repr(C, packed)]
111 struct VirtioIommuReqHead {
112 type_: u8,
113 _reserved: [u8; 3],
114 }
115
116 /// Virtio IOMMU request status
117 const VIRTIO_IOMMU_S_OK: u8 = 0;
118 #[allow(unused)]
119 const VIRTIO_IOMMU_S_IOERR: u8 = 1;
120 #[allow(unused)]
121 const VIRTIO_IOMMU_S_UNSUPP: u8 = 2;
122 #[allow(unused)]
123 const VIRTIO_IOMMU_S_DEVERR: u8 = 3;
124 #[allow(unused)]
125 const VIRTIO_IOMMU_S_INVAL: u8 = 4;
126 #[allow(unused)]
127 const VIRTIO_IOMMU_S_RANGE: u8 = 5;
128 #[allow(unused)]
129 const VIRTIO_IOMMU_S_NOENT: u8 = 6;
130 #[allow(unused)]
131 const VIRTIO_IOMMU_S_FAULT: u8 = 7;
132 #[allow(unused)]
133 const VIRTIO_IOMMU_S_NOMEM: u8 = 8;
134
135 #[derive(Copy, Clone, Debug, Default)]
136 #[repr(C, packed)]
137 #[allow(dead_code)]
138 struct VirtioIommuReqTail {
139 status: u8,
140 _reserved: [u8; 3],
141 }
142
143 /// ATTACH request
144 #[derive(Copy, Clone, Debug, Default)]
145 #[repr(C, packed)]
146 struct VirtioIommuReqAttach {
147 domain: u32,
148 endpoint: u32,
149 flags: u32,
150 _reserved: [u8; 4],
151 }
152
153 const VIRTIO_IOMMU_ATTACH_F_BYPASS: u32 = 1;
154
155 /// DETACH request
156 #[derive(Copy, Clone, Debug, Default)]
157 #[repr(C, packed)]
158 struct VirtioIommuReqDetach {
159 domain: u32,
160 endpoint: u32,
161 _reserved: [u8; 8],
162 }
163
164 /// Virtio IOMMU request MAP flags
165 #[allow(unused)]
166 const VIRTIO_IOMMU_MAP_F_READ: u32 = 1;
167 #[allow(unused)]
168 const VIRTIO_IOMMU_MAP_F_WRITE: u32 = 1 << 1;
169 #[allow(unused)]
170 const VIRTIO_IOMMU_MAP_F_MMIO: u32 = 1 << 2;
171 #[allow(unused)]
172 const VIRTIO_IOMMU_MAP_F_MASK: u32 =
173 VIRTIO_IOMMU_MAP_F_READ | VIRTIO_IOMMU_MAP_F_WRITE | VIRTIO_IOMMU_MAP_F_MMIO;
174
175 /// MAP request
176 #[derive(Copy, Clone, Debug, Default)]
177 #[repr(C, packed)]
178 struct VirtioIommuReqMap {
179 domain: u32,
180 virt_start: u64,
181 virt_end: u64,
182 phys_start: u64,
183 _flags: u32,
184 }
185
186 /// UNMAP request
187 #[derive(Copy, Clone, Debug, Default)]
188 #[repr(C, packed)]
189 struct VirtioIommuReqUnmap {
190 domain: u32,
191 virt_start: u64,
192 virt_end: u64,
193 _reserved: [u8; 4],
194 }
195
196 /// Virtio IOMMU request PROBE types
197 #[allow(unused)]
198 const VIRTIO_IOMMU_PROBE_T_NONE: u16 = 0;
199 const VIRTIO_IOMMU_PROBE_T_RESV_MEM: u16 = 1;
200 #[allow(unused)]
201 const VIRTIO_IOMMU_PROBE_T_MASK: u16 = 0xfff;
202
203 /// PROBE request
204 #[derive(Copy, Clone, Debug, Default)]
205 #[repr(C, packed)]
206 #[allow(dead_code)]
207 struct VirtioIommuReqProbe {
208 endpoint: u32,
209 _reserved: [u64; 8],
210 }
211
212 #[derive(Copy, Clone, Debug, Default)]
213 #[repr(C, packed)]
214 #[allow(dead_code)]
215 struct VirtioIommuProbeProperty {
216 type_: u16,
217 length: u16,
218 }
219
220 /// Virtio IOMMU request PROBE property RESV_MEM subtypes
221 #[allow(unused)]
222 const VIRTIO_IOMMU_RESV_MEM_T_RESERVED: u8 = 0;
223 const VIRTIO_IOMMU_RESV_MEM_T_MSI: u8 = 1;
224
225 #[derive(Copy, Clone, Debug, Default)]
226 #[repr(C, packed)]
227 #[allow(dead_code)]
228 struct VirtioIommuProbeResvMem {
229 subtype: u8,
230 _reserved: [u8; 3],
231 start: u64,
232 end: u64,
233 }
234
235 /// Virtio IOMMU fault flags
236 #[allow(unused)]
237 const VIRTIO_IOMMU_FAULT_F_READ: u32 = 1;
238 #[allow(unused)]
239 const VIRTIO_IOMMU_FAULT_F_WRITE: u32 = 1 << 1;
240 #[allow(unused)]
241 const VIRTIO_IOMMU_FAULT_F_EXEC: u32 = 1 << 2;
242 #[allow(unused)]
243 const VIRTIO_IOMMU_FAULT_F_ADDRESS: u32 = 1 << 8;
244
245 /// Virtio IOMMU fault reasons
246 #[allow(unused)]
247 const VIRTIO_IOMMU_FAULT_R_UNKNOWN: u32 = 0;
248 #[allow(unused)]
249 const VIRTIO_IOMMU_FAULT_R_DOMAIN: u32 = 1;
250 #[allow(unused)]
251 const VIRTIO_IOMMU_FAULT_R_MAPPING: u32 = 2;
252
253 /// Fault reporting through eventq
254 #[allow(unused)]
255 #[derive(Copy, Clone, Debug, Default)]
256 #[repr(C, packed)]
257 struct VirtioIommuFault {
258 reason: u8,
259 reserved: [u8; 3],
260 flags: u32,
261 endpoint: u32,
262 reserved2: [u8; 4],
263 address: u64,
264 }
265
266 // SAFETY: data structure only contain integers and have no implicit padding
267 unsafe impl ByteValued for VirtioIommuRange32 {}
268 // SAFETY: data structure only contain integers and have no implicit padding
269 unsafe impl ByteValued for VirtioIommuRange64 {}
270 // SAFETY: data structure only contain integers and have no implicit padding
271 unsafe impl ByteValued for VirtioIommuConfig {}
272 // SAFETY: data structure only contain integers and have no implicit padding
273 unsafe impl ByteValued for VirtioIommuReqHead {}
274 // SAFETY: data structure only contain integers and have no implicit padding
275 unsafe impl ByteValued for VirtioIommuReqTail {}
276 // SAFETY: data structure only contain integers and have no implicit padding
277 unsafe impl ByteValued for VirtioIommuReqAttach {}
278 // SAFETY: data structure only contain integers and have no implicit padding
279 unsafe impl ByteValued for VirtioIommuReqDetach {}
280 // SAFETY: data structure only contain integers and have no implicit padding
281 unsafe impl ByteValued for VirtioIommuReqMap {}
282 // SAFETY: data structure only contain integers and have no implicit padding
283 unsafe impl ByteValued for VirtioIommuReqUnmap {}
284 // SAFETY: data structure only contain integers and have no implicit padding
285 unsafe impl ByteValued for VirtioIommuReqProbe {}
286 // SAFETY: data structure only contain integers and have no implicit padding
287 unsafe impl ByteValued for VirtioIommuProbeProperty {}
288 // SAFETY: data structure only contain integers and have no implicit padding
289 unsafe impl ByteValued for VirtioIommuProbeResvMem {}
290 // SAFETY: data structure only contain integers and have no implicit padding
291 unsafe impl ByteValued for VirtioIommuFault {}
292
293 #[derive(Error, Debug)]
294 enum Error {
295 #[error("Guest gave us bad memory addresses")]
296 GuestMemory(#[source] GuestMemoryError),
297 #[error("Guest gave us a write only descriptor that protocol says to read from")]
298 UnexpectedWriteOnlyDescriptor,
299 #[error("Guest gave us a read only descriptor that protocol says to write to")]
300 UnexpectedReadOnlyDescriptor,
301 #[error("Guest gave us too few descriptors in a descriptor chain")]
302 DescriptorChainTooShort,
303 #[error("Guest gave us a buffer that was too short to use")]
304 BufferLengthTooSmall,
305 #[error("Guest sent us invalid request")]
306 InvalidRequest,
307 #[error("Guest sent us invalid ATTACH request")]
308 InvalidAttachRequest,
309 #[error("Guest sent us invalid DETACH request")]
310 InvalidDetachRequest,
311 #[error("Guest sent us invalid MAP request")]
312 InvalidMapRequest,
313 #[error("Invalid to map because the domain is in bypass mode")]
314 InvalidMapRequestBypassDomain,
315 #[error("Invalid to map because the domain is missing")]
316 InvalidMapRequestMissingDomain,
317 #[error("Guest sent us invalid UNMAP request")]
318 InvalidUnmapRequest,
319 #[error("Invalid to unmap because the domain is in bypass mode")]
320 InvalidUnmapRequestBypassDomain,
321 #[error("Invalid to unmap because the domain is missing")]
322 InvalidUnmapRequestMissingDomain,
323 #[error("Guest sent us invalid PROBE request")]
324 InvalidProbeRequest,
325 #[error("Failed to performing external mapping")]
326 ExternalMapping(#[source] io::Error),
327 #[error("Failed to performing external unmapping")]
328 ExternalUnmapping(#[source] io::Error),
329 #[error("Failed adding used index")]
330 QueueAddUsed(#[source] virtio_queue::Error),
331 }
332
333 struct Request {}
334
335 impl Request {
336 // Parse the available vring buffer. Based on the hashmap table of external
337 // mappings required from various devices such as VFIO or vhost-user ones,
338 // this function might update the hashmap table of external mappings per
339 // domain.
340 // Basically, the VMM knows about the device_id <=> mapping relationship
341 // before running the VM, but at runtime, a new domain <=> mapping hashmap
342 // is created based on the information provided from the guest driver for
343 // virtio-iommu (giving the link device_id <=> domain).
parse( desc_chain: &mut DescriptorChain<GuestMemoryLoadGuard<GuestMemoryMmap>>, mapping: &Arc<IommuMapping>, ext_mapping: &BTreeMap<u32, Arc<dyn ExternalDmaMapping>>, msi_iova_space: (u64, u64), ) -> result::Result<usize, Error>344 fn parse(
345 desc_chain: &mut DescriptorChain<GuestMemoryLoadGuard<GuestMemoryMmap>>,
346 mapping: &Arc<IommuMapping>,
347 ext_mapping: &BTreeMap<u32, Arc<dyn ExternalDmaMapping>>,
348 msi_iova_space: (u64, u64),
349 ) -> result::Result<usize, Error> {
350 let desc = desc_chain
351 .next()
352 .ok_or(Error::DescriptorChainTooShort)
353 .inspect_err(|_| {
354 error!("Missing head descriptor");
355 })?;
356
357 // The descriptor contains the request type which MUST be readable.
358 if desc.is_write_only() {
359 return Err(Error::UnexpectedWriteOnlyDescriptor);
360 }
361
362 if (desc.len() as usize) < size_of::<VirtioIommuReqHead>() {
363 return Err(Error::InvalidRequest);
364 }
365
366 let req_head: VirtioIommuReqHead = desc_chain
367 .memory()
368 .read_obj(desc.addr())
369 .map_err(Error::GuestMemory)?;
370 let req_offset = size_of::<VirtioIommuReqHead>();
371 let desc_size_left = (desc.len() as usize) - req_offset;
372 let req_addr = if let Some(addr) = desc.addr().checked_add(req_offset as u64) {
373 addr
374 } else {
375 return Err(Error::InvalidRequest);
376 };
377
378 let (msi_iova_start, msi_iova_end) = msi_iova_space;
379
380 // Create the reply
381 let mut reply: Vec<u8> = Vec::new();
382 let mut status = VIRTIO_IOMMU_S_OK;
383 let mut hdr_len = 0;
384
385 let result = (|| {
386 match req_head.type_ {
387 VIRTIO_IOMMU_T_ATTACH => {
388 if desc_size_left != size_of::<VirtioIommuReqAttach>() {
389 status = VIRTIO_IOMMU_S_INVAL;
390 return Err(Error::InvalidAttachRequest);
391 }
392
393 let req: VirtioIommuReqAttach = desc_chain
394 .memory()
395 .read_obj(req_addr as GuestAddress)
396 .map_err(Error::GuestMemory)?;
397 debug!("Attach request 0x{:x?}", req);
398
399 // Copy the value to use it as a proper reference.
400 let domain_id = req.domain;
401 let endpoint = req.endpoint;
402 let bypass =
403 (req.flags & VIRTIO_IOMMU_ATTACH_F_BYPASS) == VIRTIO_IOMMU_ATTACH_F_BYPASS;
404
405 let mut old_domain_id = domain_id;
406 if let Some(&id) = mapping.endpoints.read().unwrap().get(&endpoint) {
407 old_domain_id = id;
408 }
409
410 if old_domain_id != domain_id {
411 detach_endpoint_from_domain(endpoint, old_domain_id, mapping, ext_mapping)?;
412 }
413
414 // Add endpoint associated with specific domain
415 mapping
416 .endpoints
417 .write()
418 .unwrap()
419 .insert(endpoint, domain_id);
420
421 // If any other mappings exist in the domain for other containers,
422 // make sure to issue these mappings for the new endpoint/container
423 if let Some(domain_mappings) = &mapping.domains.read().unwrap().get(&domain_id)
424 {
425 if let Some(ext_map) = ext_mapping.get(&endpoint) {
426 for (virt_start, addr_map) in &domain_mappings.mappings {
427 ext_map
428 .map(*virt_start, addr_map.gpa, addr_map.size)
429 .map_err(Error::ExternalUnmapping)?;
430 }
431 }
432 }
433
434 // Add new domain with no mapping if the entry didn't exist yet
435 let mut domains = mapping.domains.write().unwrap();
436 let domain = Domain {
437 mappings: BTreeMap::new(),
438 bypass,
439 };
440 domains.entry(domain_id).or_insert_with(|| domain);
441 }
442 VIRTIO_IOMMU_T_DETACH => {
443 if desc_size_left != size_of::<VirtioIommuReqDetach>() {
444 status = VIRTIO_IOMMU_S_INVAL;
445 return Err(Error::InvalidDetachRequest);
446 }
447
448 let req: VirtioIommuReqDetach = desc_chain
449 .memory()
450 .read_obj(req_addr as GuestAddress)
451 .map_err(Error::GuestMemory)?;
452 debug!("Detach request 0x{:x?}", req);
453
454 // Copy the value to use it as a proper reference.
455 let domain_id = req.domain;
456 let endpoint = req.endpoint;
457
458 // Remove endpoint associated with specific domain
459 detach_endpoint_from_domain(endpoint, domain_id, mapping, ext_mapping)?;
460 }
461 VIRTIO_IOMMU_T_MAP => {
462 if desc_size_left != size_of::<VirtioIommuReqMap>() {
463 status = VIRTIO_IOMMU_S_INVAL;
464 return Err(Error::InvalidMapRequest);
465 }
466
467 let req: VirtioIommuReqMap = desc_chain
468 .memory()
469 .read_obj(req_addr as GuestAddress)
470 .map_err(Error::GuestMemory)?;
471 debug!("Map request 0x{:x?}", req);
472
473 // Copy the value to use it as a proper reference.
474 let domain_id = req.domain;
475
476 if let Some(domain) = mapping.domains.read().unwrap().get(&domain_id) {
477 if domain.bypass {
478 status = VIRTIO_IOMMU_S_INVAL;
479 return Err(Error::InvalidMapRequestBypassDomain);
480 }
481 } else {
482 status = VIRTIO_IOMMU_S_INVAL;
483 return Err(Error::InvalidMapRequestMissingDomain);
484 }
485
486 // Find the list of endpoints attached to the given domain.
487 let endpoints: Vec<u32> = mapping
488 .endpoints
489 .write()
490 .unwrap()
491 .iter()
492 .filter(|(_, &d)| d == domain_id)
493 .map(|(&e, _)| e)
494 .collect();
495
496 // For viommu all endpoints receive their own VFIO container, as a result
497 // Each endpoint within the domain needs to be separately mapped, as the
498 // mapping is done on a per-container level, not a per-domain level
499 for endpoint in endpoints {
500 if let Some(ext_map) = ext_mapping.get(&endpoint) {
501 let size = req.virt_end - req.virt_start + 1;
502 ext_map
503 .map(req.virt_start, req.phys_start, size)
504 .map_err(Error::ExternalMapping)?;
505 }
506 }
507
508 // Add new mapping associated with the domain
509 mapping
510 .domains
511 .write()
512 .unwrap()
513 .get_mut(&domain_id)
514 .unwrap()
515 .mappings
516 .insert(
517 req.virt_start,
518 Mapping {
519 gpa: req.phys_start,
520 size: req.virt_end - req.virt_start + 1,
521 },
522 );
523 }
524 VIRTIO_IOMMU_T_UNMAP => {
525 if desc_size_left != size_of::<VirtioIommuReqUnmap>() {
526 status = VIRTIO_IOMMU_S_INVAL;
527 return Err(Error::InvalidUnmapRequest);
528 }
529
530 let req: VirtioIommuReqUnmap = desc_chain
531 .memory()
532 .read_obj(req_addr as GuestAddress)
533 .map_err(Error::GuestMemory)?;
534 debug!("Unmap request 0x{:x?}", req);
535
536 // Copy the value to use it as a proper reference.
537 let domain_id = req.domain;
538 let virt_start = req.virt_start;
539
540 if let Some(domain) = mapping.domains.read().unwrap().get(&domain_id) {
541 if domain.bypass {
542 status = VIRTIO_IOMMU_S_INVAL;
543 return Err(Error::InvalidUnmapRequestBypassDomain);
544 }
545 } else {
546 status = VIRTIO_IOMMU_S_INVAL;
547 return Err(Error::InvalidUnmapRequestMissingDomain);
548 }
549
550 // Find the list of endpoints attached to the given domain.
551 let endpoints: Vec<u32> = mapping
552 .endpoints
553 .write()
554 .unwrap()
555 .iter()
556 .filter(|(_, &d)| d == domain_id)
557 .map(|(&e, _)| e)
558 .collect();
559
560 // Trigger external unmapping if necessary.
561 for endpoint in endpoints {
562 if let Some(ext_map) = ext_mapping.get(&endpoint) {
563 let size = req.virt_end - virt_start + 1;
564 ext_map
565 .unmap(virt_start, size)
566 .map_err(Error::ExternalUnmapping)?;
567 }
568 }
569
570 // Remove all mappings associated with the domain within the requested range
571 mapping
572 .domains
573 .write()
574 .unwrap()
575 .get_mut(&domain_id)
576 .unwrap()
577 .mappings
578 .retain(|&x, _| (x < req.virt_start || x > req.virt_end));
579 }
580 VIRTIO_IOMMU_T_PROBE => {
581 if desc_size_left != size_of::<VirtioIommuReqProbe>() {
582 status = VIRTIO_IOMMU_S_INVAL;
583 return Err(Error::InvalidProbeRequest);
584 }
585
586 let req: VirtioIommuReqProbe = desc_chain
587 .memory()
588 .read_obj(req_addr as GuestAddress)
589 .map_err(Error::GuestMemory)?;
590 debug!("Probe request 0x{:x?}", req);
591
592 let probe_prop = VirtioIommuProbeProperty {
593 type_: VIRTIO_IOMMU_PROBE_T_RESV_MEM,
594 length: size_of::<VirtioIommuProbeResvMem>() as u16,
595 };
596 reply.extend_from_slice(probe_prop.as_slice());
597
598 let resv_mem = VirtioIommuProbeResvMem {
599 subtype: VIRTIO_IOMMU_RESV_MEM_T_MSI,
600 start: msi_iova_start,
601 end: msi_iova_end,
602 ..Default::default()
603 };
604 reply.extend_from_slice(resv_mem.as_slice());
605
606 hdr_len = PROBE_PROP_SIZE;
607 }
608 _ => {
609 status = VIRTIO_IOMMU_S_INVAL;
610 return Err(Error::InvalidRequest);
611 }
612 }
613 Ok(())
614 })();
615
616 let status_desc = desc_chain.next().ok_or(Error::DescriptorChainTooShort)?;
617
618 // The status MUST always be writable
619 if !status_desc.is_write_only() {
620 return Err(Error::UnexpectedReadOnlyDescriptor);
621 }
622
623 if status_desc.len() < hdr_len + size_of::<VirtioIommuReqTail>() as u32 {
624 return Err(Error::BufferLengthTooSmall);
625 }
626
627 let tail = VirtioIommuReqTail {
628 status,
629 ..Default::default()
630 };
631 reply.extend_from_slice(tail.as_slice());
632
633 // Make sure we return the result of the request to the guest before
634 // we return a potential error internally.
635 desc_chain
636 .memory()
637 .write_slice(reply.as_slice(), status_desc.addr())
638 .map_err(Error::GuestMemory)?;
639
640 // Return the error if the result was not Ok().
641 result?;
642
643 Ok((hdr_len as usize) + size_of::<VirtioIommuReqTail>())
644 }
645 }
646
detach_endpoint_from_domain( endpoint: u32, domain_id: u32, mapping: &Arc<IommuMapping>, ext_mapping: &BTreeMap<u32, Arc<dyn ExternalDmaMapping>>, ) -> result::Result<(), Error>647 fn detach_endpoint_from_domain(
648 endpoint: u32,
649 domain_id: u32,
650 mapping: &Arc<IommuMapping>,
651 ext_mapping: &BTreeMap<u32, Arc<dyn ExternalDmaMapping>>,
652 ) -> result::Result<(), Error> {
653 // Remove endpoint associated with specific domain
654 mapping.endpoints.write().unwrap().remove(&endpoint);
655
656 // Trigger external unmapping for the endpoint if necessary.
657 if let Some(domain_mappings) = &mapping.domains.read().unwrap().get(&domain_id) {
658 if let Some(ext_map) = ext_mapping.get(&endpoint) {
659 for (virt_start, addr_map) in &domain_mappings.mappings {
660 ext_map
661 .unmap(*virt_start, addr_map.size)
662 .map_err(Error::ExternalUnmapping)?;
663 }
664 }
665 }
666
667 if mapping
668 .endpoints
669 .write()
670 .unwrap()
671 .iter()
672 .filter(|(_, &d)| d == domain_id)
673 .count()
674 == 0
675 {
676 mapping.domains.write().unwrap().remove(&domain_id);
677 }
678
679 Ok(())
680 }
681
682 struct IommuEpollHandler {
683 mem: GuestMemoryAtomic<GuestMemoryMmap>,
684 request_queue: Queue,
685 _event_queue: Queue,
686 interrupt_cb: Arc<dyn VirtioInterrupt>,
687 request_queue_evt: EventFd,
688 _event_queue_evt: EventFd,
689 kill_evt: EventFd,
690 pause_evt: EventFd,
691 mapping: Arc<IommuMapping>,
692 ext_mapping: Arc<Mutex<BTreeMap<u32, Arc<dyn ExternalDmaMapping>>>>,
693 msi_iova_space: (u64, u64),
694 }
695
696 impl IommuEpollHandler {
request_queue(&mut self) -> Result<bool, Error>697 fn request_queue(&mut self) -> Result<bool, Error> {
698 let mut used_descs = false;
699 while let Some(mut desc_chain) = self.request_queue.pop_descriptor_chain(self.mem.memory())
700 {
701 let len = Request::parse(
702 &mut desc_chain,
703 &self.mapping,
704 &self.ext_mapping.lock().unwrap(),
705 self.msi_iova_space,
706 )?;
707
708 self.request_queue
709 .add_used(desc_chain.memory(), desc_chain.head_index(), len as u32)
710 .map_err(Error::QueueAddUsed)?;
711
712 used_descs = true;
713 }
714
715 Ok(used_descs)
716 }
717
signal_used_queue(&self, queue_index: u16) -> result::Result<(), DeviceError>718 fn signal_used_queue(&self, queue_index: u16) -> result::Result<(), DeviceError> {
719 self.interrupt_cb
720 .trigger(VirtioInterruptType::Queue(queue_index))
721 .map_err(|e| {
722 error!("Failed to signal used queue: {:?}", e);
723 DeviceError::FailedSignalingUsedQueue(e)
724 })
725 }
726
run( &mut self, paused: Arc<AtomicBool>, paused_sync: Arc<Barrier>, ) -> result::Result<(), EpollHelperError>727 fn run(
728 &mut self,
729 paused: Arc<AtomicBool>,
730 paused_sync: Arc<Barrier>,
731 ) -> result::Result<(), EpollHelperError> {
732 let mut helper = EpollHelper::new(&self.kill_evt, &self.pause_evt)?;
733 helper.add_event(self.request_queue_evt.as_raw_fd(), REQUEST_Q_EVENT)?;
734 helper.run(paused, paused_sync, self)?;
735
736 Ok(())
737 }
738 }
739
740 impl EpollHelperHandler for IommuEpollHandler {
handle_event( &mut self, _helper: &mut EpollHelper, event: &epoll::Event, ) -> result::Result<(), EpollHelperError>741 fn handle_event(
742 &mut self,
743 _helper: &mut EpollHelper,
744 event: &epoll::Event,
745 ) -> result::Result<(), EpollHelperError> {
746 let ev_type = event.data as u16;
747 match ev_type {
748 REQUEST_Q_EVENT => {
749 self.request_queue_evt.read().map_err(|e| {
750 EpollHelperError::HandleEvent(anyhow!("Failed to get queue event: {:?}", e))
751 })?;
752
753 let needs_notification = self.request_queue().map_err(|e| {
754 EpollHelperError::HandleEvent(anyhow!(
755 "Failed to process request queue : {:?}",
756 e
757 ))
758 })?;
759 if needs_notification {
760 self.signal_used_queue(0).map_err(|e| {
761 EpollHelperError::HandleEvent(anyhow!(
762 "Failed to signal used queue: {:?}",
763 e
764 ))
765 })?;
766 }
767 }
768 _ => {
769 return Err(EpollHelperError::HandleEvent(anyhow!(
770 "Unexpected event: {}",
771 ev_type
772 )));
773 }
774 }
775 Ok(())
776 }
777 }
778
779 #[derive(Clone, Copy, Debug, Serialize, Deserialize)]
780 struct Mapping {
781 gpa: u64,
782 size: u64,
783 }
784
785 #[derive(Clone, Debug)]
786 struct Domain {
787 mappings: BTreeMap<u64, Mapping>,
788 bypass: bool,
789 }
790
791 #[derive(Debug)]
792 pub struct IommuMapping {
793 // Domain related to an endpoint.
794 endpoints: Arc<RwLock<BTreeMap<u32, u32>>>,
795 // Information related to each domain.
796 domains: Arc<RwLock<BTreeMap<u32, Domain>>>,
797 // Global flag indicating if endpoints that are not attached to any domain
798 // are in bypass mode.
799 bypass: AtomicBool,
800 }
801
802 impl DmaRemapping for IommuMapping {
translate_gva(&self, id: u32, addr: u64) -> std::result::Result<u64, std::io::Error>803 fn translate_gva(&self, id: u32, addr: u64) -> std::result::Result<u64, std::io::Error> {
804 debug!("Translate GVA addr 0x{:x}", addr);
805 if let Some(domain_id) = self.endpoints.read().unwrap().get(&id) {
806 if let Some(domain) = self.domains.read().unwrap().get(domain_id) {
807 // Directly return identity mapping in case the domain is in
808 // bypass mode.
809 if domain.bypass {
810 return Ok(addr);
811 }
812
813 for (&key, &value) in domain.mappings.iter() {
814 if addr >= key && addr < key + value.size {
815 let new_addr = addr - key + value.gpa;
816 debug!("Into GPA addr 0x{:x}", new_addr);
817 return Ok(new_addr);
818 }
819 }
820 }
821 } else if self.bypass.load(Ordering::Acquire) {
822 return Ok(addr);
823 }
824
825 Err(io::Error::other(format!(
826 "failed to translate GVA addr 0x{addr:x}"
827 )))
828 }
829
translate_gpa(&self, id: u32, addr: u64) -> std::result::Result<u64, std::io::Error>830 fn translate_gpa(&self, id: u32, addr: u64) -> std::result::Result<u64, std::io::Error> {
831 debug!("Translate GPA addr 0x{:x}", addr);
832 if let Some(domain_id) = self.endpoints.read().unwrap().get(&id) {
833 if let Some(domain) = self.domains.read().unwrap().get(domain_id) {
834 // Directly return identity mapping in case the domain is in
835 // bypass mode.
836 if domain.bypass {
837 return Ok(addr);
838 }
839
840 for (&key, &value) in domain.mappings.iter() {
841 if addr >= value.gpa && addr < value.gpa + value.size {
842 let new_addr = addr - value.gpa + key;
843 debug!("Into GVA addr 0x{:x}", new_addr);
844 return Ok(new_addr);
845 }
846 }
847 }
848 } else if self.bypass.load(Ordering::Acquire) {
849 return Ok(addr);
850 }
851
852 Err(io::Error::other(format!(
853 "failed to translate GPA addr 0x{addr:x}"
854 )))
855 }
856 }
857
858 #[derive(Debug)]
859 pub struct AccessPlatformMapping {
860 id: u32,
861 mapping: Arc<IommuMapping>,
862 }
863
864 impl AccessPlatformMapping {
new(id: u32, mapping: Arc<IommuMapping>) -> Self865 pub fn new(id: u32, mapping: Arc<IommuMapping>) -> Self {
866 AccessPlatformMapping { id, mapping }
867 }
868 }
869
870 impl AccessPlatform for AccessPlatformMapping {
translate_gva(&self, base: u64, _size: u64) -> std::result::Result<u64, std::io::Error>871 fn translate_gva(&self, base: u64, _size: u64) -> std::result::Result<u64, std::io::Error> {
872 self.mapping.translate_gva(self.id, base)
873 }
translate_gpa(&self, base: u64, _size: u64) -> std::result::Result<u64, std::io::Error>874 fn translate_gpa(&self, base: u64, _size: u64) -> std::result::Result<u64, std::io::Error> {
875 self.mapping.translate_gpa(self.id, base)
876 }
877 }
878
879 pub struct Iommu {
880 common: VirtioCommon,
881 id: String,
882 config: VirtioIommuConfig,
883 mapping: Arc<IommuMapping>,
884 ext_mapping: Arc<Mutex<BTreeMap<u32, Arc<dyn ExternalDmaMapping>>>>,
885 seccomp_action: SeccompAction,
886 exit_evt: EventFd,
887 msi_iova_space: (u64, u64),
888 }
889
890 type EndpointsState = Vec<(u32, u32)>;
891 type DomainsState = Vec<(u32, (Vec<(u64, Mapping)>, bool))>;
892
893 #[derive(Serialize, Deserialize)]
894 pub struct IommuState {
895 avail_features: u64,
896 acked_features: u64,
897 endpoints: EndpointsState,
898 domains: DomainsState,
899 }
900
901 impl Iommu {
new( id: String, seccomp_action: SeccompAction, exit_evt: EventFd, msi_iova_space: (u64, u64), address_width_bits: u8, state: Option<IommuState>, ) -> io::Result<(Self, Arc<IommuMapping>)>902 pub fn new(
903 id: String,
904 seccomp_action: SeccompAction,
905 exit_evt: EventFd,
906 msi_iova_space: (u64, u64),
907 address_width_bits: u8,
908 state: Option<IommuState>,
909 ) -> io::Result<(Self, Arc<IommuMapping>)> {
910 let (mut avail_features, acked_features, endpoints, domains, paused) =
911 if let Some(state) = state {
912 info!("Restoring virtio-iommu {}", id);
913 (
914 state.avail_features,
915 state.acked_features,
916 state.endpoints.into_iter().collect(),
917 state
918 .domains
919 .into_iter()
920 .map(|(k, v)| {
921 (
922 k,
923 Domain {
924 mappings: v.0.into_iter().collect(),
925 bypass: v.1,
926 },
927 )
928 })
929 .collect(),
930 true,
931 )
932 } else {
933 let avail_features = (1u64 << VIRTIO_F_VERSION_1)
934 | (1u64 << VIRTIO_IOMMU_F_MAP_UNMAP)
935 | (1u64 << VIRTIO_IOMMU_F_PROBE)
936 | (1u64 << VIRTIO_IOMMU_F_BYPASS_CONFIG);
937
938 (avail_features, 0, BTreeMap::new(), BTreeMap::new(), false)
939 };
940
941 let mut config = VirtioIommuConfig {
942 page_size_mask: VIRTIO_IOMMU_PAGE_SIZE_MASK,
943 probe_size: PROBE_PROP_SIZE,
944 ..Default::default()
945 };
946
947 if address_width_bits < 64 {
948 avail_features |= 1u64 << VIRTIO_IOMMU_F_INPUT_RANGE;
949 config.input_range = VirtioIommuRange64 {
950 start: 0,
951 end: (1u64 << address_width_bits) - 1,
952 }
953 }
954
955 let mapping = Arc::new(IommuMapping {
956 endpoints: Arc::new(RwLock::new(endpoints)),
957 domains: Arc::new(RwLock::new(domains)),
958 bypass: AtomicBool::new(true),
959 });
960
961 Ok((
962 Iommu {
963 id,
964 common: VirtioCommon {
965 device_type: VirtioDeviceType::Iommu as u32,
966 queue_sizes: QUEUE_SIZES.to_vec(),
967 avail_features,
968 acked_features,
969 paused_sync: Some(Arc::new(Barrier::new(2))),
970 min_queues: NUM_QUEUES as u16,
971 paused: Arc::new(AtomicBool::new(paused)),
972 ..Default::default()
973 },
974 config,
975 mapping: mapping.clone(),
976 ext_mapping: Arc::new(Mutex::new(BTreeMap::new())),
977 seccomp_action,
978 exit_evt,
979 msi_iova_space,
980 },
981 mapping,
982 ))
983 }
984
state(&self) -> IommuState985 fn state(&self) -> IommuState {
986 IommuState {
987 avail_features: self.common.avail_features,
988 acked_features: self.common.acked_features,
989 endpoints: self
990 .mapping
991 .endpoints
992 .read()
993 .unwrap()
994 .clone()
995 .into_iter()
996 .collect(),
997 domains: self
998 .mapping
999 .domains
1000 .read()
1001 .unwrap()
1002 .clone()
1003 .into_iter()
1004 .map(|(k, v)| (k, (v.mappings.into_iter().collect(), v.bypass)))
1005 .collect(),
1006 }
1007 }
1008
update_bypass(&mut self)1009 fn update_bypass(&mut self) {
1010 // Use bypass from config if VIRTIO_IOMMU_F_BYPASS_CONFIG has been negotiated
1011 if !self
1012 .common
1013 .feature_acked(VIRTIO_IOMMU_F_BYPASS_CONFIG.into())
1014 {
1015 return;
1016 }
1017
1018 let bypass = self.config.bypass == 1;
1019 info!("Updating bypass mode to {}", bypass);
1020 self.mapping.bypass.store(bypass, Ordering::Release);
1021 }
1022
add_external_mapping(&mut self, device_id: u32, mapping: Arc<dyn ExternalDmaMapping>)1023 pub fn add_external_mapping(&mut self, device_id: u32, mapping: Arc<dyn ExternalDmaMapping>) {
1024 self.ext_mapping.lock().unwrap().insert(device_id, mapping);
1025 }
1026
1027 #[cfg(fuzzing)]
wait_for_epoll_threads(&mut self)1028 pub fn wait_for_epoll_threads(&mut self) {
1029 self.common.wait_for_epoll_threads();
1030 }
1031 }
1032
1033 impl Drop for Iommu {
drop(&mut self)1034 fn drop(&mut self) {
1035 if let Some(kill_evt) = self.common.kill_evt.take() {
1036 // Ignore the result because there is nothing we can do about it.
1037 let _ = kill_evt.write(1);
1038 }
1039 self.common.wait_for_epoll_threads();
1040 }
1041 }
1042
1043 impl VirtioDevice for Iommu {
device_type(&self) -> u321044 fn device_type(&self) -> u32 {
1045 self.common.device_type
1046 }
1047
queue_max_sizes(&self) -> &[u16]1048 fn queue_max_sizes(&self) -> &[u16] {
1049 &self.common.queue_sizes
1050 }
1051
features(&self) -> u641052 fn features(&self) -> u64 {
1053 self.common.avail_features
1054 }
1055
ack_features(&mut self, value: u64)1056 fn ack_features(&mut self, value: u64) {
1057 self.common.ack_features(value)
1058 }
1059
read_config(&self, offset: u64, data: &mut [u8])1060 fn read_config(&self, offset: u64, data: &mut [u8]) {
1061 self.read_config_from_slice(self.config.as_slice(), offset, data);
1062 }
1063
write_config(&mut self, offset: u64, data: &[u8])1064 fn write_config(&mut self, offset: u64, data: &[u8]) {
1065 // The "bypass" field is the only mutable field
1066 let bypass_offset =
1067 (&self.config.bypass as *const _ as u64) - (&self.config as *const _ as u64);
1068 if offset != bypass_offset || data.len() != std::mem::size_of_val(&self.config.bypass) {
1069 error!(
1070 "Attempt to write to read-only field: offset {:x} length {}",
1071 offset,
1072 data.len()
1073 );
1074 return;
1075 }
1076
1077 self.config.bypass = data[0];
1078
1079 self.update_bypass();
1080 }
1081
activate( &mut self, mem: GuestMemoryAtomic<GuestMemoryMmap>, interrupt_cb: Arc<dyn VirtioInterrupt>, mut queues: Vec<(usize, Queue, EventFd)>, ) -> ActivateResult1082 fn activate(
1083 &mut self,
1084 mem: GuestMemoryAtomic<GuestMemoryMmap>,
1085 interrupt_cb: Arc<dyn VirtioInterrupt>,
1086 mut queues: Vec<(usize, Queue, EventFd)>,
1087 ) -> ActivateResult {
1088 self.common.activate(&queues, &interrupt_cb)?;
1089 let (kill_evt, pause_evt) = self.common.dup_eventfds();
1090
1091 let (_, request_queue, request_queue_evt) = queues.remove(0);
1092 let (_, _event_queue, _event_queue_evt) = queues.remove(0);
1093
1094 let mut handler = IommuEpollHandler {
1095 mem,
1096 request_queue,
1097 _event_queue,
1098 interrupt_cb,
1099 request_queue_evt,
1100 _event_queue_evt,
1101 kill_evt,
1102 pause_evt,
1103 mapping: self.mapping.clone(),
1104 ext_mapping: self.ext_mapping.clone(),
1105 msi_iova_space: self.msi_iova_space,
1106 };
1107
1108 let paused = self.common.paused.clone();
1109 let paused_sync = self.common.paused_sync.clone();
1110 let mut epoll_threads = Vec::new();
1111 spawn_virtio_thread(
1112 &self.id,
1113 &self.seccomp_action,
1114 Thread::VirtioIommu,
1115 &mut epoll_threads,
1116 &self.exit_evt,
1117 move || handler.run(paused, paused_sync.unwrap()),
1118 )?;
1119
1120 self.common.epoll_threads = Some(epoll_threads);
1121
1122 event!("virtio-device", "activated", "id", &self.id);
1123 Ok(())
1124 }
1125
reset(&mut self) -> Option<Arc<dyn VirtioInterrupt>>1126 fn reset(&mut self) -> Option<Arc<dyn VirtioInterrupt>> {
1127 let result = self.common.reset();
1128 event!("virtio-device", "reset", "id", &self.id);
1129 result
1130 }
1131 }
1132
1133 impl Pausable for Iommu {
pause(&mut self) -> result::Result<(), MigratableError>1134 fn pause(&mut self) -> result::Result<(), MigratableError> {
1135 self.common.pause()
1136 }
1137
resume(&mut self) -> result::Result<(), MigratableError>1138 fn resume(&mut self) -> result::Result<(), MigratableError> {
1139 self.common.resume()
1140 }
1141 }
1142
1143 impl Snapshottable for Iommu {
id(&self) -> String1144 fn id(&self) -> String {
1145 self.id.clone()
1146 }
1147
snapshot(&mut self) -> std::result::Result<Snapshot, MigratableError>1148 fn snapshot(&mut self) -> std::result::Result<Snapshot, MigratableError> {
1149 Snapshot::new_from_state(&self.state())
1150 }
1151 }
1152 impl Transportable for Iommu {}
1153 impl Migratable for Iommu {}
1154