xref: /cloud-hypervisor/vmm/src/device_tree.rs (revision b440cb7d2330770cd415b63544a371d4caa2db3a)
1 // Copyright © 2020 Intel Corporation
2 //
3 // SPDX-License-Identifier: Apache-2.0
4 
5 use crate::device_manager::PciDeviceHandle;
6 use pci::PciBdf;
7 use serde::{Deserialize, Serialize};
8 use std::collections::HashMap;
9 use std::sync::{Arc, Mutex};
10 use vm_device::Resource;
11 use vm_migration::Migratable;
12 
13 #[derive(Clone, Serialize, Deserialize)]
14 pub struct DeviceNode {
15     pub id: String,
16     pub resources: Vec<Resource>,
17     pub parent: Option<String>,
18     pub children: Vec<String>,
19     #[serde(skip)]
20     pub migratable: Option<Arc<Mutex<dyn Migratable>>>,
21     pub pci_bdf: Option<PciBdf>,
22     #[serde(skip)]
23     pub pci_device_handle: Option<PciDeviceHandle>,
24 }
25 
26 impl DeviceNode {
27     pub fn new(id: String, migratable: Option<Arc<Mutex<dyn Migratable>>>) -> Self {
28         DeviceNode {
29             id,
30             resources: Vec::new(),
31             parent: None,
32             children: Vec::new(),
33             migratable,
34             pci_bdf: None,
35             pci_device_handle: None,
36         }
37     }
38 }
39 
40 #[macro_export]
41 macro_rules! device_node {
42     ($id:ident) => {
43         DeviceNode::new($id.clone(), None)
44     };
45     ($id:ident, $device:ident) => {
46         DeviceNode::new(
47             $id.clone(),
48             Some(Arc::clone(&$device) as Arc<Mutex<dyn Migratable>>),
49         )
50     };
51 }
52 
53 #[derive(Clone, Default, Serialize, Deserialize)]
54 pub struct DeviceTree(HashMap<String, DeviceNode>);
55 
56 impl DeviceTree {
57     pub fn new() -> Self {
58         DeviceTree(HashMap::new())
59     }
60     pub fn contains_key(&self, k: &str) -> bool {
61         self.0.contains_key(k)
62     }
63     pub fn get(&self, k: &str) -> Option<&DeviceNode> {
64         self.0.get(k)
65     }
66     pub fn get_mut(&mut self, k: &str) -> Option<&mut DeviceNode> {
67         self.0.get_mut(k)
68     }
69     pub fn insert(&mut self, k: String, v: DeviceNode) -> Option<DeviceNode> {
70         self.0.insert(k, v)
71     }
72     pub fn remove(&mut self, k: &str) -> Option<DeviceNode> {
73         self.0.remove(k)
74     }
75     pub fn iter(&self) -> std::collections::hash_map::Iter<String, DeviceNode> {
76         self.0.iter()
77     }
78     pub fn breadth_first_traversal(&self) -> BftIter {
79         BftIter::new(&self.0)
80     }
81     pub fn pci_devices(&self) -> Vec<&DeviceNode> {
82         self.0
83             .values()
84             .filter(|v| v.pci_bdf.is_some() && v.pci_device_handle.is_some())
85             .collect()
86     }
87 
88     pub fn remove_node_by_pci_bdf(&mut self, pci_bdf: PciBdf) -> Option<DeviceNode> {
89         let mut id = None;
90         for (k, v) in self.0.iter() {
91             if v.pci_bdf == Some(pci_bdf) {
92                 id = Some(k.clone());
93                 break;
94             }
95         }
96 
97         if let Some(id) = &id {
98             self.0.remove(id)
99         } else {
100             None
101         }
102     }
103 }
104 
105 // Breadth first traversal iterator.
106 pub struct BftIter<'a> {
107     nodes: Vec<&'a DeviceNode>,
108 }
109 
110 impl<'a> BftIter<'a> {
111     fn new(hash_map: &'a HashMap<String, DeviceNode>) -> Self {
112         let mut nodes = Vec::new();
113 
114         for (_, node) in hash_map.iter() {
115             if node.parent.is_none() {
116                 nodes.push(node);
117             }
118         }
119 
120         let mut node_layer = nodes.as_slice();
121         loop {
122             let mut next_node_layer = Vec::new();
123 
124             for node in node_layer.iter() {
125                 for child_node_id in node.children.iter() {
126                     if let Some(child_node) = hash_map.get(child_node_id) {
127                         next_node_layer.push(child_node);
128                     }
129                 }
130             }
131 
132             if next_node_layer.is_empty() {
133                 break;
134             }
135 
136             let pos = nodes.len();
137             nodes.extend(next_node_layer);
138 
139             node_layer = &nodes[pos..];
140         }
141 
142         BftIter { nodes }
143     }
144 }
145 
146 impl<'a> Iterator for BftIter<'a> {
147     type Item = &'a DeviceNode;
148 
149     fn next(&mut self) -> Option<Self::Item> {
150         if self.nodes.is_empty() {
151             None
152         } else {
153             Some(self.nodes.remove(0))
154         }
155     }
156 }
157 
158 impl<'a> DoubleEndedIterator for BftIter<'a> {
159     fn next_back(&mut self) -> Option<Self::Item> {
160         self.nodes.pop()
161     }
162 }
163 
164 #[cfg(test)]
165 mod tests {
166     use super::{DeviceNode, DeviceTree};
167 
168     #[test]
169     fn test_device_tree() {
170         // Check new()
171         let mut device_tree = DeviceTree::new();
172         assert_eq!(device_tree.0.len(), 0);
173 
174         // Check insert()
175         let id = String::from("id1");
176         device_tree.insert(id.clone(), DeviceNode::new(id.clone(), None));
177         assert_eq!(device_tree.0.len(), 1);
178         let node = device_tree.0.get(&id);
179         assert!(node.is_some());
180         let node = node.unwrap();
181         assert_eq!(node.id, id);
182 
183         // Check get()
184         let id2 = String::from("id2");
185         assert!(device_tree.get(&id).is_some());
186         assert!(device_tree.get(&id2).is_none());
187 
188         // Check get_mut()
189         let node = device_tree.get_mut(&id).unwrap();
190         node.id = id2.clone();
191         let node = device_tree.0.get(&id).unwrap();
192         assert_eq!(node.id, id2);
193 
194         // Check remove()
195         let node = device_tree.remove(&id).unwrap();
196         assert_eq!(node.id, id2);
197         assert_eq!(device_tree.0.len(), 0);
198 
199         // Check iter()
200         let disk_id = String::from("disk0");
201         let net_id = String::from("net0");
202         let rng_id = String::from("rng0");
203         let device_list = vec![
204             (disk_id.clone(), device_node!(disk_id)),
205             (net_id.clone(), device_node!(net_id)),
206             (rng_id.clone(), device_node!(rng_id)),
207         ];
208         device_tree.0.extend(device_list);
209         for (id, node) in device_tree.iter() {
210             if id == &disk_id {
211                 assert_eq!(node.id, disk_id);
212             } else if id == &net_id {
213                 assert_eq!(node.id, net_id);
214             } else if id == &rng_id {
215                 assert_eq!(node.id, rng_id);
216             } else {
217                 unreachable!()
218             }
219         }
220 
221         // Check breadth_first_traversal() based on the following hierarchy
222         //
223         // 0
224         // | \
225         // 1  2
226         // |  | \
227         // 3  4  5
228         //
229         let mut device_tree = DeviceTree::new();
230         let child_1_id = String::from("child1");
231         let child_2_id = String::from("child2");
232         let child_3_id = String::from("child3");
233         let parent_1_id = String::from("parent1");
234         let parent_2_id = String::from("parent2");
235         let root_id = String::from("root");
236         let mut child_1_node = device_node!(child_1_id);
237         let mut child_2_node = device_node!(child_2_id);
238         let mut child_3_node = device_node!(child_3_id);
239         let mut parent_1_node = device_node!(parent_1_id);
240         let mut parent_2_node = device_node!(parent_2_id);
241         let mut root_node = device_node!(root_id);
242         child_1_node.parent = Some(parent_1_id.clone());
243         child_2_node.parent = Some(parent_2_id.clone());
244         child_3_node.parent = Some(parent_2_id.clone());
245         parent_1_node.children = vec![child_1_id.clone()];
246         parent_1_node.parent = Some(root_id.clone());
247         parent_2_node.children = vec![child_2_id.clone(), child_3_id.clone()];
248         parent_2_node.parent = Some(root_id.clone());
249         root_node.children = vec![parent_1_id.clone(), parent_2_id.clone()];
250         let device_list = vec![
251             (child_1_id.clone(), child_1_node),
252             (child_2_id.clone(), child_2_node),
253             (child_3_id.clone(), child_3_node),
254             (parent_1_id.clone(), parent_1_node),
255             (parent_2_id.clone(), parent_2_node),
256             (root_id.clone(), root_node),
257         ];
258         device_tree.0.extend(device_list);
259 
260         let iter_vec = device_tree
261             .breadth_first_traversal()
262             .collect::<Vec<&DeviceNode>>();
263         assert_eq!(iter_vec.len(), 6);
264         assert_eq!(iter_vec[0].id, root_id);
265         assert_eq!(iter_vec[1].id, parent_1_id);
266         assert_eq!(iter_vec[2].id, parent_2_id);
267         assert_eq!(iter_vec[3].id, child_1_id);
268         assert_eq!(iter_vec[4].id, child_2_id);
269         assert_eq!(iter_vec[5].id, child_3_id);
270 
271         let iter_vec = device_tree
272             .breadth_first_traversal()
273             .rev()
274             .collect::<Vec<&DeviceNode>>();
275         assert_eq!(iter_vec.len(), 6);
276         assert_eq!(iter_vec[5].id, root_id);
277         assert_eq!(iter_vec[4].id, parent_1_id);
278         assert_eq!(iter_vec[3].id, parent_2_id);
279         assert_eq!(iter_vec[2].id, child_1_id);
280         assert_eq!(iter_vec[1].id, child_2_id);
281         assert_eq!(iter_vec[0].id, child_3_id);
282     }
283 }
284