1 // Copyright © 2020 Intel Corporation 2 // 3 // SPDX-License-Identifier: Apache-2.0 4 5 use crate::device_manager::PciDeviceHandle; 6 use pci::PciBdf; 7 use std::collections::HashMap; 8 use std::sync::{Arc, Mutex}; 9 use vm_device::Resource; 10 use vm_migration::Migratable; 11 12 #[derive(Clone, Serialize, Deserialize)] 13 pub struct DeviceNode { 14 pub id: String, 15 pub resources: Vec<Resource>, 16 pub parent: Option<String>, 17 pub children: Vec<String>, 18 #[serde(skip)] 19 pub migratable: Option<Arc<Mutex<dyn Migratable>>>, 20 pub pci_bdf: Option<PciBdf>, 21 #[serde(skip)] 22 pub pci_device_handle: Option<PciDeviceHandle>, 23 } 24 25 impl DeviceNode { 26 pub fn new(id: String, migratable: Option<Arc<Mutex<dyn Migratable>>>) -> Self { 27 DeviceNode { 28 id, 29 resources: Vec::new(), 30 parent: None, 31 children: Vec::new(), 32 migratable, 33 pci_bdf: None, 34 pci_device_handle: None, 35 } 36 } 37 } 38 39 #[macro_export] 40 macro_rules! device_node { 41 ($id:ident) => { 42 DeviceNode::new($id.clone(), None) 43 }; 44 ($id:ident, $device:ident) => { 45 DeviceNode::new( 46 $id.clone(), 47 Some(Arc::clone(&$device) as Arc<Mutex<dyn Migratable>>), 48 ) 49 }; 50 } 51 52 #[derive(Clone, Default, Serialize, Deserialize)] 53 pub struct DeviceTree(HashMap<String, DeviceNode>); 54 55 impl DeviceTree { 56 pub fn new() -> Self { 57 DeviceTree(HashMap::new()) 58 } 59 pub fn contains_key(&self, k: &str) -> bool { 60 self.0.contains_key(k) 61 } 62 pub fn get(&self, k: &str) -> Option<&DeviceNode> { 63 self.0.get(k) 64 } 65 pub fn get_mut(&mut self, k: &str) -> Option<&mut DeviceNode> { 66 self.0.get_mut(k) 67 } 68 pub fn insert(&mut self, k: String, v: DeviceNode) -> Option<DeviceNode> { 69 self.0.insert(k, v) 70 } 71 pub fn remove(&mut self, k: &str) -> Option<DeviceNode> { 72 self.0.remove(k) 73 } 74 pub fn iter(&self) -> std::collections::hash_map::Iter<String, DeviceNode> { 75 self.0.iter() 76 } 77 pub fn breadth_first_traversal(&self) -> BftIter { 78 BftIter::new(&self.0) 79 } 80 pub fn pci_devices(&self) -> Vec<&DeviceNode> { 81 self.0 82 .values() 83 .filter(|v| v.pci_bdf.is_some() && v.pci_device_handle.is_some()) 84 .collect() 85 } 86 87 pub fn remove_node_by_pci_bdf(&mut self, pci_bdf: PciBdf) -> Option<DeviceNode> { 88 let mut id = None; 89 for (k, v) in self.0.iter() { 90 if v.pci_bdf == Some(pci_bdf) { 91 id = Some(k.clone()); 92 break; 93 } 94 } 95 96 if let Some(id) = &id { 97 self.0.remove(id) 98 } else { 99 None 100 } 101 } 102 } 103 104 // Breadth first traversal iterator. 105 pub struct BftIter<'a> { 106 nodes: Vec<&'a DeviceNode>, 107 } 108 109 impl<'a> BftIter<'a> { 110 fn new(hash_map: &'a HashMap<String, DeviceNode>) -> Self { 111 let mut nodes = Vec::new(); 112 113 for (_, node) in hash_map.iter() { 114 if node.parent.is_none() { 115 nodes.push(node); 116 } 117 } 118 119 let mut node_layer = nodes.as_slice(); 120 loop { 121 let mut next_node_layer = Vec::new(); 122 123 for node in node_layer.iter() { 124 for child_node_id in node.children.iter() { 125 if let Some(child_node) = hash_map.get(child_node_id) { 126 next_node_layer.push(child_node); 127 } 128 } 129 } 130 131 if next_node_layer.is_empty() { 132 break; 133 } 134 135 let pos = nodes.len(); 136 nodes.extend(next_node_layer); 137 138 node_layer = &nodes[pos..]; 139 } 140 141 BftIter { nodes } 142 } 143 } 144 145 impl<'a> Iterator for BftIter<'a> { 146 type Item = &'a DeviceNode; 147 148 fn next(&mut self) -> Option<Self::Item> { 149 if self.nodes.is_empty() { 150 None 151 } else { 152 Some(self.nodes.remove(0)) 153 } 154 } 155 } 156 157 impl<'a> DoubleEndedIterator for BftIter<'a> { 158 fn next_back(&mut self) -> Option<Self::Item> { 159 self.nodes.pop() 160 } 161 } 162 163 #[cfg(test)] 164 mod tests { 165 use super::{DeviceNode, DeviceTree}; 166 167 #[test] 168 fn test_device_tree() { 169 // Check new() 170 let mut device_tree = DeviceTree::new(); 171 assert_eq!(device_tree.0.len(), 0); 172 173 // Check insert() 174 let id = String::from("id1"); 175 device_tree.insert(id.clone(), DeviceNode::new(id.clone(), None)); 176 assert_eq!(device_tree.0.len(), 1); 177 let node = device_tree.0.get(&id); 178 assert!(node.is_some()); 179 let node = node.unwrap(); 180 assert_eq!(node.id, id); 181 182 // Check get() 183 let id2 = String::from("id2"); 184 assert!(device_tree.get(&id).is_some()); 185 assert!(device_tree.get(&id2).is_none()); 186 187 // Check get_mut() 188 let node = device_tree.get_mut(&id).unwrap(); 189 node.id = id2.clone(); 190 let node = device_tree.0.get(&id).unwrap(); 191 assert_eq!(node.id, id2); 192 193 // Check remove() 194 let node = device_tree.remove(&id).unwrap(); 195 assert_eq!(node.id, id2); 196 assert_eq!(device_tree.0.len(), 0); 197 198 // Check iter() 199 let disk_id = String::from("disk0"); 200 let net_id = String::from("net0"); 201 let rng_id = String::from("rng0"); 202 let device_list = vec![ 203 (disk_id.clone(), device_node!(disk_id)), 204 (net_id.clone(), device_node!(net_id)), 205 (rng_id.clone(), device_node!(rng_id)), 206 ]; 207 device_tree.0.extend(device_list); 208 for (id, node) in device_tree.iter() { 209 if id == &disk_id { 210 assert_eq!(node.id, disk_id); 211 } else if id == &net_id { 212 assert_eq!(node.id, net_id); 213 } else if id == &rng_id { 214 assert_eq!(node.id, rng_id); 215 } else { 216 unreachable!() 217 } 218 } 219 220 // Check breadth_first_traversal() based on the following hierarchy 221 // 222 // 0 223 // | \ 224 // 1 2 225 // | | \ 226 // 3 4 5 227 // 228 let mut device_tree = DeviceTree::new(); 229 let child_1_id = String::from("child1"); 230 let child_2_id = String::from("child2"); 231 let child_3_id = String::from("child3"); 232 let parent_1_id = String::from("parent1"); 233 let parent_2_id = String::from("parent2"); 234 let root_id = String::from("root"); 235 let mut child_1_node = device_node!(child_1_id); 236 let mut child_2_node = device_node!(child_2_id); 237 let mut child_3_node = device_node!(child_3_id); 238 let mut parent_1_node = device_node!(parent_1_id); 239 let mut parent_2_node = device_node!(parent_2_id); 240 let mut root_node = device_node!(root_id); 241 child_1_node.parent = Some(parent_1_id.clone()); 242 child_2_node.parent = Some(parent_2_id.clone()); 243 child_3_node.parent = Some(parent_2_id.clone()); 244 parent_1_node.children = vec![child_1_id.clone()]; 245 parent_1_node.parent = Some(root_id.clone()); 246 parent_2_node.children = vec![child_2_id.clone(), child_3_id.clone()]; 247 parent_2_node.parent = Some(root_id.clone()); 248 root_node.children = vec![parent_1_id.clone(), parent_2_id.clone()]; 249 let device_list = vec![ 250 (child_1_id.clone(), child_1_node), 251 (child_2_id.clone(), child_2_node), 252 (child_3_id.clone(), child_3_node), 253 (parent_1_id.clone(), parent_1_node), 254 (parent_2_id.clone(), parent_2_node), 255 (root_id.clone(), root_node), 256 ]; 257 device_tree.0.extend(device_list); 258 259 let iter_vec = device_tree 260 .breadth_first_traversal() 261 .collect::<Vec<&DeviceNode>>(); 262 assert_eq!(iter_vec.len(), 6); 263 assert_eq!(iter_vec[0].id, root_id); 264 assert_eq!(iter_vec[1].id, parent_1_id); 265 assert_eq!(iter_vec[2].id, parent_2_id); 266 assert_eq!(iter_vec[3].id, child_1_id); 267 assert_eq!(iter_vec[4].id, child_2_id); 268 assert_eq!(iter_vec[5].id, child_3_id); 269 270 let iter_vec = device_tree 271 .breadth_first_traversal() 272 .rev() 273 .collect::<Vec<&DeviceNode>>(); 274 assert_eq!(iter_vec.len(), 6); 275 assert_eq!(iter_vec[5].id, root_id); 276 assert_eq!(iter_vec[4].id, parent_1_id); 277 assert_eq!(iter_vec[3].id, parent_2_id); 278 assert_eq!(iter_vec[2].id, child_1_id); 279 assert_eq!(iter_vec[1].id, child_2_id); 280 assert_eq!(iter_vec[0].id, child_3_id); 281 } 282 } 283