xref: /cloud-hypervisor/vmm/src/device_tree.rs (revision 7d7bfb2034001d4cb15df2ddc56d2d350c8da30f)
1 // Copyright © 2020 Intel Corporation
2 //
3 // SPDX-License-Identifier: Apache-2.0
4 
5 use crate::device_manager::PciDeviceHandle;
6 use pci::PciBdf;
7 use std::collections::HashMap;
8 use std::sync::{Arc, Mutex};
9 use vm_device::Resource;
10 use vm_migration::Migratable;
11 
12 #[derive(Clone, Serialize, Deserialize)]
13 pub struct DeviceNode {
14     pub id: String,
15     pub resources: Vec<Resource>,
16     pub parent: Option<String>,
17     pub children: Vec<String>,
18     #[serde(skip)]
19     pub migratable: Option<Arc<Mutex<dyn Migratable>>>,
20     pub pci_bdf: Option<PciBdf>,
21     #[serde(skip)]
22     pub pci_device_handle: Option<PciDeviceHandle>,
23 }
24 
25 impl DeviceNode {
26     pub fn new(id: String, migratable: Option<Arc<Mutex<dyn Migratable>>>) -> Self {
27         DeviceNode {
28             id,
29             resources: Vec::new(),
30             parent: None,
31             children: Vec::new(),
32             migratable,
33             pci_bdf: None,
34             pci_device_handle: None,
35         }
36     }
37 }
38 
39 #[macro_export]
40 macro_rules! device_node {
41     ($id:ident) => {
42         DeviceNode::new($id.clone(), None)
43     };
44     ($id:ident, $device:ident) => {
45         DeviceNode::new(
46             $id.clone(),
47             Some(Arc::clone(&$device) as Arc<Mutex<dyn Migratable>>),
48         )
49     };
50 }
51 
52 #[derive(Clone, Default, Serialize, Deserialize)]
53 pub struct DeviceTree(HashMap<String, DeviceNode>);
54 
55 impl DeviceTree {
56     pub fn new() -> Self {
57         DeviceTree(HashMap::new())
58     }
59     pub fn contains_key(&self, k: &str) -> bool {
60         self.0.contains_key(k)
61     }
62     pub fn get(&self, k: &str) -> Option<&DeviceNode> {
63         self.0.get(k)
64     }
65     pub fn get_mut(&mut self, k: &str) -> Option<&mut DeviceNode> {
66         self.0.get_mut(k)
67     }
68     pub fn insert(&mut self, k: String, v: DeviceNode) -> Option<DeviceNode> {
69         self.0.insert(k, v)
70     }
71     pub fn remove(&mut self, k: &str) -> Option<DeviceNode> {
72         self.0.remove(k)
73     }
74     pub fn iter(&self) -> std::collections::hash_map::Iter<String, DeviceNode> {
75         self.0.iter()
76     }
77     pub fn breadth_first_traversal(&self) -> BftIter {
78         BftIter::new(&self.0)
79     }
80     pub fn pci_devices(&self) -> Vec<&DeviceNode> {
81         self.0
82             .values()
83             .filter(|v| v.pci_bdf.is_some() && v.pci_device_handle.is_some())
84             .collect()
85     }
86 
87     pub fn remove_node_by_pci_bdf(&mut self, pci_bdf: PciBdf) -> Option<DeviceNode> {
88         let mut id = None;
89         for (k, v) in self.0.iter() {
90             if v.pci_bdf == Some(pci_bdf) {
91                 id = Some(k.clone());
92                 break;
93             }
94         }
95 
96         if let Some(id) = &id {
97             self.0.remove(id)
98         } else {
99             None
100         }
101     }
102 }
103 
104 // Breadth first traversal iterator.
105 pub struct BftIter<'a> {
106     nodes: Vec<&'a DeviceNode>,
107 }
108 
109 impl<'a> BftIter<'a> {
110     fn new(hash_map: &'a HashMap<String, DeviceNode>) -> Self {
111         let mut nodes = Vec::new();
112 
113         for (_, node) in hash_map.iter() {
114             if node.parent.is_none() {
115                 nodes.push(node);
116             }
117         }
118 
119         let mut node_layer = nodes.as_slice();
120         loop {
121             let mut next_node_layer = Vec::new();
122 
123             for node in node_layer.iter() {
124                 for child_node_id in node.children.iter() {
125                     if let Some(child_node) = hash_map.get(child_node_id) {
126                         next_node_layer.push(child_node);
127                     }
128                 }
129             }
130 
131             if next_node_layer.is_empty() {
132                 break;
133             }
134 
135             let pos = nodes.len();
136             nodes.extend(next_node_layer);
137 
138             node_layer = &nodes[pos..];
139         }
140 
141         BftIter { nodes }
142     }
143 }
144 
145 impl<'a> Iterator for BftIter<'a> {
146     type Item = &'a DeviceNode;
147 
148     fn next(&mut self) -> Option<Self::Item> {
149         if self.nodes.is_empty() {
150             None
151         } else {
152             Some(self.nodes.remove(0))
153         }
154     }
155 }
156 
157 impl<'a> DoubleEndedIterator for BftIter<'a> {
158     fn next_back(&mut self) -> Option<Self::Item> {
159         self.nodes.pop()
160     }
161 }
162 
163 #[cfg(test)]
164 mod tests {
165     use super::{DeviceNode, DeviceTree};
166 
167     #[test]
168     fn test_device_tree() {
169         // Check new()
170         let mut device_tree = DeviceTree::new();
171         assert_eq!(device_tree.0.len(), 0);
172 
173         // Check insert()
174         let id = String::from("id1");
175         device_tree.insert(id.clone(), DeviceNode::new(id.clone(), None));
176         assert_eq!(device_tree.0.len(), 1);
177         let node = device_tree.0.get(&id);
178         assert!(node.is_some());
179         let node = node.unwrap();
180         assert_eq!(node.id, id);
181 
182         // Check get()
183         let id2 = String::from("id2");
184         assert!(device_tree.get(&id).is_some());
185         assert!(device_tree.get(&id2).is_none());
186 
187         // Check get_mut()
188         let node = device_tree.get_mut(&id).unwrap();
189         node.id = id2.clone();
190         let node = device_tree.0.get(&id).unwrap();
191         assert_eq!(node.id, id2);
192 
193         // Check remove()
194         let node = device_tree.remove(&id).unwrap();
195         assert_eq!(node.id, id2);
196         assert_eq!(device_tree.0.len(), 0);
197 
198         // Check iter()
199         let disk_id = String::from("disk0");
200         let net_id = String::from("net0");
201         let rng_id = String::from("rng0");
202         let device_list = vec![
203             (disk_id.clone(), device_node!(disk_id)),
204             (net_id.clone(), device_node!(net_id)),
205             (rng_id.clone(), device_node!(rng_id)),
206         ];
207         device_tree.0.extend(device_list);
208         for (id, node) in device_tree.iter() {
209             if id == &disk_id {
210                 assert_eq!(node.id, disk_id);
211             } else if id == &net_id {
212                 assert_eq!(node.id, net_id);
213             } else if id == &rng_id {
214                 assert_eq!(node.id, rng_id);
215             } else {
216                 unreachable!()
217             }
218         }
219 
220         // Check breadth_first_traversal() based on the following hierarchy
221         //
222         // 0
223         // | \
224         // 1  2
225         // |  | \
226         // 3  4  5
227         //
228         let mut device_tree = DeviceTree::new();
229         let child_1_id = String::from("child1");
230         let child_2_id = String::from("child2");
231         let child_3_id = String::from("child3");
232         let parent_1_id = String::from("parent1");
233         let parent_2_id = String::from("parent2");
234         let root_id = String::from("root");
235         let mut child_1_node = device_node!(child_1_id);
236         let mut child_2_node = device_node!(child_2_id);
237         let mut child_3_node = device_node!(child_3_id);
238         let mut parent_1_node = device_node!(parent_1_id);
239         let mut parent_2_node = device_node!(parent_2_id);
240         let mut root_node = device_node!(root_id);
241         child_1_node.parent = Some(parent_1_id.clone());
242         child_2_node.parent = Some(parent_2_id.clone());
243         child_3_node.parent = Some(parent_2_id.clone());
244         parent_1_node.children = vec![child_1_id.clone()];
245         parent_1_node.parent = Some(root_id.clone());
246         parent_2_node.children = vec![child_2_id.clone(), child_3_id.clone()];
247         parent_2_node.parent = Some(root_id.clone());
248         root_node.children = vec![parent_1_id.clone(), parent_2_id.clone()];
249         let device_list = vec![
250             (child_1_id.clone(), child_1_node),
251             (child_2_id.clone(), child_2_node),
252             (child_3_id.clone(), child_3_node),
253             (parent_1_id.clone(), parent_1_node),
254             (parent_2_id.clone(), parent_2_node),
255             (root_id.clone(), root_node),
256         ];
257         device_tree.0.extend(device_list);
258 
259         let iter_vec = device_tree
260             .breadth_first_traversal()
261             .collect::<Vec<&DeviceNode>>();
262         assert_eq!(iter_vec.len(), 6);
263         assert_eq!(iter_vec[0].id, root_id);
264         assert_eq!(iter_vec[1].id, parent_1_id);
265         assert_eq!(iter_vec[2].id, parent_2_id);
266         assert_eq!(iter_vec[3].id, child_1_id);
267         assert_eq!(iter_vec[4].id, child_2_id);
268         assert_eq!(iter_vec[5].id, child_3_id);
269 
270         let iter_vec = device_tree
271             .breadth_first_traversal()
272             .rev()
273             .collect::<Vec<&DeviceNode>>();
274         assert_eq!(iter_vec.len(), 6);
275         assert_eq!(iter_vec[5].id, root_id);
276         assert_eq!(iter_vec[4].id, parent_1_id);
277         assert_eq!(iter_vec[3].id, parent_2_id);
278         assert_eq!(iter_vec[2].id, child_1_id);
279         assert_eq!(iter_vec[1].id, child_2_id);
280         assert_eq!(iter_vec[0].id, child_3_id);
281     }
282 }
283