xref: /cloud-hypervisor/vmm/src/device_tree.rs (revision 894a4dee6ec35bdd9df848a3a88bf5b96771c943)
1 // Copyright © 2020 Intel Corporation
2 //
3 // SPDX-License-Identifier: Apache-2.0
4 
5 use std::collections::HashMap;
6 use std::sync::{Arc, Mutex};
7 
8 use pci::PciBdf;
9 use serde::{Deserialize, Serialize};
10 use vm_device::Resource;
11 use vm_migration::Migratable;
12 
13 use crate::device_manager::PciDeviceHandle;
14 
15 #[derive(Clone, Serialize, Deserialize)]
16 pub struct DeviceNode {
17     pub id: String,
18     pub resources: Vec<Resource>,
19     pub parent: Option<String>,
20     pub children: Vec<String>,
21     #[serde(skip)]
22     pub migratable: Option<Arc<Mutex<dyn Migratable>>>,
23     pub pci_bdf: Option<PciBdf>,
24     #[serde(skip)]
25     pub pci_device_handle: Option<PciDeviceHandle>,
26 }
27 
28 impl DeviceNode {
29     pub fn new(id: String, migratable: Option<Arc<Mutex<dyn Migratable>>>) -> Self {
30         DeviceNode {
31             id,
32             resources: Vec::new(),
33             parent: None,
34             children: Vec::new(),
35             migratable,
36             pci_bdf: None,
37             pci_device_handle: None,
38         }
39     }
40 }
41 
42 #[macro_export]
43 macro_rules! device_node {
44     ($id:ident) => {
45         DeviceNode::new($id.clone(), None)
46     };
47     ($id:ident, $device:ident) => {
48         DeviceNode::new(
49             $id.clone(),
50             Some(Arc::clone(&$device) as Arc<Mutex<dyn Migratable>>),
51         )
52     };
53 }
54 
55 #[derive(Clone, Default, Serialize, Deserialize)]
56 pub struct DeviceTree(HashMap<String, DeviceNode>);
57 
58 impl DeviceTree {
59     pub fn new() -> Self {
60         DeviceTree(HashMap::new())
61     }
62     pub fn contains_key(&self, k: &str) -> bool {
63         self.0.contains_key(k)
64     }
65     pub fn get(&self, k: &str) -> Option<&DeviceNode> {
66         self.0.get(k)
67     }
68     pub fn get_mut(&mut self, k: &str) -> Option<&mut DeviceNode> {
69         self.0.get_mut(k)
70     }
71     pub fn insert(&mut self, k: String, v: DeviceNode) -> Option<DeviceNode> {
72         self.0.insert(k, v)
73     }
74     pub fn remove(&mut self, k: &str) -> Option<DeviceNode> {
75         self.0.remove(k)
76     }
77     pub fn iter(&self) -> std::collections::hash_map::Iter<String, DeviceNode> {
78         self.0.iter()
79     }
80     pub fn breadth_first_traversal(&self) -> BftIter {
81         BftIter::new(&self.0)
82     }
83     pub fn pci_devices(&self) -> Vec<&DeviceNode> {
84         self.0
85             .values()
86             .filter(|v| v.pci_bdf.is_some() && v.pci_device_handle.is_some())
87             .collect()
88     }
89 
90     pub fn remove_node_by_pci_bdf(&mut self, pci_bdf: PciBdf) -> Option<DeviceNode> {
91         let mut id = None;
92         for (k, v) in self.0.iter() {
93             if v.pci_bdf == Some(pci_bdf) {
94                 id = Some(k.clone());
95                 break;
96             }
97         }
98 
99         if let Some(id) = &id {
100             self.0.remove(id)
101         } else {
102             None
103         }
104     }
105 }
106 
107 // Breadth first traversal iterator.
108 pub struct BftIter<'a> {
109     nodes: Vec<&'a DeviceNode>,
110 }
111 
112 impl<'a> BftIter<'a> {
113     fn new(hash_map: &'a HashMap<String, DeviceNode>) -> Self {
114         let mut nodes = Vec::with_capacity(hash_map.len());
115         let mut i = 0;
116 
117         for (_, node) in hash_map.iter() {
118             if node.parent.is_none() {
119                 nodes.push(node);
120             }
121         }
122 
123         while i < nodes.len() {
124             for child_node_id in nodes[i].children.iter() {
125                 if let Some(child_node) = hash_map.get(child_node_id) {
126                     nodes.push(child_node);
127                 }
128             }
129             i += 1;
130         }
131 
132         BftIter { nodes }
133     }
134 }
135 
136 impl<'a> Iterator for BftIter<'a> {
137     type Item = &'a DeviceNode;
138 
139     fn next(&mut self) -> Option<Self::Item> {
140         if self.nodes.is_empty() {
141             None
142         } else {
143             Some(self.nodes.remove(0))
144         }
145     }
146 }
147 
148 impl DoubleEndedIterator for BftIter<'_> {
149     fn next_back(&mut self) -> Option<Self::Item> {
150         self.nodes.pop()
151     }
152 }
153 
154 #[cfg(test)]
155 mod tests {
156     use super::{DeviceNode, DeviceTree};
157 
158     #[test]
159     fn test_device_tree() {
160         // Check new()
161         let mut device_tree = DeviceTree::new();
162         assert_eq!(device_tree.0.len(), 0);
163 
164         // Check insert()
165         let id = String::from("id1");
166         device_tree.insert(id.clone(), DeviceNode::new(id.clone(), None));
167         assert_eq!(device_tree.0.len(), 1);
168         let node = device_tree.0.get(&id);
169         assert!(node.is_some());
170         let node = node.unwrap();
171         assert_eq!(node.id, id);
172 
173         // Check get()
174         let id2 = String::from("id2");
175         assert!(device_tree.get(&id).is_some());
176         assert!(device_tree.get(&id2).is_none());
177 
178         // Check get_mut()
179         let node = device_tree.get_mut(&id).unwrap();
180         node.id.clone_from(&id2);
181         let node = device_tree.0.get(&id).unwrap();
182         assert_eq!(node.id, id2);
183 
184         // Check remove()
185         let node = device_tree.remove(&id).unwrap();
186         assert_eq!(node.id, id2);
187         assert_eq!(device_tree.0.len(), 0);
188 
189         // Check iter()
190         let disk_id = String::from("disk0");
191         let net_id = String::from("net0");
192         let rng_id = String::from("rng0");
193         let device_list = vec![
194             (disk_id.clone(), device_node!(disk_id)),
195             (net_id.clone(), device_node!(net_id)),
196             (rng_id.clone(), device_node!(rng_id)),
197         ];
198         device_tree.0.extend(device_list);
199         for (id, node) in device_tree.iter() {
200             if id == &disk_id {
201                 assert_eq!(node.id, disk_id);
202             } else if id == &net_id {
203                 assert_eq!(node.id, net_id);
204             } else if id == &rng_id {
205                 assert_eq!(node.id, rng_id);
206             } else {
207                 unreachable!()
208             }
209         }
210 
211         // Check breadth_first_traversal() based on the following hierarchy
212         //
213         // 0
214         // | \
215         // 1  2
216         // |  | \
217         // 3  4  5
218         //
219         let mut device_tree = DeviceTree::new();
220         let child_1_id = String::from("child1");
221         let child_2_id = String::from("child2");
222         let child_3_id = String::from("child3");
223         let parent_1_id = String::from("parent1");
224         let parent_2_id = String::from("parent2");
225         let root_id = String::from("root");
226         let mut child_1_node = device_node!(child_1_id);
227         let mut child_2_node = device_node!(child_2_id);
228         let mut child_3_node = device_node!(child_3_id);
229         let mut parent_1_node = device_node!(parent_1_id);
230         let mut parent_2_node = device_node!(parent_2_id);
231         let mut root_node = device_node!(root_id);
232         child_1_node.parent = Some(parent_1_id.clone());
233         child_2_node.parent = Some(parent_2_id.clone());
234         child_3_node.parent = Some(parent_2_id.clone());
235         parent_1_node.children = vec![child_1_id.clone()];
236         parent_1_node.parent = Some(root_id.clone());
237         parent_2_node.children = vec![child_2_id.clone(), child_3_id.clone()];
238         parent_2_node.parent = Some(root_id.clone());
239         root_node.children = vec![parent_1_id.clone(), parent_2_id.clone()];
240         let device_list = vec![
241             (child_1_id.clone(), child_1_node),
242             (child_2_id.clone(), child_2_node),
243             (child_3_id.clone(), child_3_node),
244             (parent_1_id.clone(), parent_1_node),
245             (parent_2_id.clone(), parent_2_node),
246             (root_id.clone(), root_node),
247         ];
248         device_tree.0.extend(device_list);
249 
250         let iter_vec = device_tree
251             .breadth_first_traversal()
252             .collect::<Vec<&DeviceNode>>();
253         assert_eq!(iter_vec.len(), 6);
254         assert_eq!(iter_vec[0].id, root_id);
255         assert_eq!(iter_vec[1].id, parent_1_id);
256         assert_eq!(iter_vec[2].id, parent_2_id);
257         assert_eq!(iter_vec[3].id, child_1_id);
258         assert_eq!(iter_vec[4].id, child_2_id);
259         assert_eq!(iter_vec[5].id, child_3_id);
260 
261         let iter_vec = device_tree
262             .breadth_first_traversal()
263             .rev()
264             .collect::<Vec<&DeviceNode>>();
265         assert_eq!(iter_vec.len(), 6);
266         assert_eq!(iter_vec[5].id, root_id);
267         assert_eq!(iter_vec[4].id, parent_1_id);
268         assert_eq!(iter_vec[3].id, parent_2_id);
269         assert_eq!(iter_vec[2].id, child_1_id);
270         assert_eq!(iter_vec[1].id, child_2_id);
271         assert_eq!(iter_vec[0].id, child_3_id);
272     }
273 }
274