1 // Copyright © 2020 Intel Corporation 2 // 3 // SPDX-License-Identifier: Apache-2.0 4 5 use crate::device_manager::PciDeviceHandle; 6 use pci::PciBdf; 7 use serde::{Deserialize, Serialize}; 8 use std::collections::HashMap; 9 use std::sync::{Arc, Mutex}; 10 use vm_device::Resource; 11 use vm_migration::Migratable; 12 13 #[derive(Clone, Serialize, Deserialize)] 14 pub struct DeviceNode { 15 pub id: String, 16 pub resources: Vec<Resource>, 17 pub parent: Option<String>, 18 pub children: Vec<String>, 19 #[serde(skip)] 20 pub migratable: Option<Arc<Mutex<dyn Migratable>>>, 21 pub pci_bdf: Option<PciBdf>, 22 #[serde(skip)] 23 pub pci_device_handle: Option<PciDeviceHandle>, 24 } 25 26 impl DeviceNode { 27 pub fn new(id: String, migratable: Option<Arc<Mutex<dyn Migratable>>>) -> Self { 28 DeviceNode { 29 id, 30 resources: Vec::new(), 31 parent: None, 32 children: Vec::new(), 33 migratable, 34 pci_bdf: None, 35 pci_device_handle: None, 36 } 37 } 38 } 39 40 #[macro_export] 41 macro_rules! device_node { 42 ($id:ident) => { 43 DeviceNode::new($id.clone(), None) 44 }; 45 ($id:ident, $device:ident) => { 46 DeviceNode::new( 47 $id.clone(), 48 Some(Arc::clone(&$device) as Arc<Mutex<dyn Migratable>>), 49 ) 50 }; 51 } 52 53 #[derive(Clone, Default, Serialize, Deserialize)] 54 pub struct DeviceTree(HashMap<String, DeviceNode>); 55 56 impl DeviceTree { 57 pub fn new() -> Self { 58 DeviceTree(HashMap::new()) 59 } 60 pub fn contains_key(&self, k: &str) -> bool { 61 self.0.contains_key(k) 62 } 63 pub fn get(&self, k: &str) -> Option<&DeviceNode> { 64 self.0.get(k) 65 } 66 pub fn get_mut(&mut self, k: &str) -> Option<&mut DeviceNode> { 67 self.0.get_mut(k) 68 } 69 pub fn insert(&mut self, k: String, v: DeviceNode) -> Option<DeviceNode> { 70 self.0.insert(k, v) 71 } 72 pub fn remove(&mut self, k: &str) -> Option<DeviceNode> { 73 self.0.remove(k) 74 } 75 pub fn iter(&self) -> std::collections::hash_map::Iter<String, DeviceNode> { 76 self.0.iter() 77 } 78 pub fn breadth_first_traversal(&self) -> BftIter { 79 BftIter::new(&self.0) 80 } 81 pub fn pci_devices(&self) -> Vec<&DeviceNode> { 82 self.0 83 .values() 84 .filter(|v| v.pci_bdf.is_some() && v.pci_device_handle.is_some()) 85 .collect() 86 } 87 88 pub fn remove_node_by_pci_bdf(&mut self, pci_bdf: PciBdf) -> Option<DeviceNode> { 89 let mut id = None; 90 for (k, v) in self.0.iter() { 91 if v.pci_bdf == Some(pci_bdf) { 92 id = Some(k.clone()); 93 break; 94 } 95 } 96 97 if let Some(id) = &id { 98 self.0.remove(id) 99 } else { 100 None 101 } 102 } 103 } 104 105 // Breadth first traversal iterator. 106 pub struct BftIter<'a> { 107 nodes: Vec<&'a DeviceNode>, 108 } 109 110 impl<'a> BftIter<'a> { 111 fn new(hash_map: &'a HashMap<String, DeviceNode>) -> Self { 112 let mut nodes = Vec::new(); 113 114 for (_, node) in hash_map.iter() { 115 if node.parent.is_none() { 116 nodes.push(node); 117 } 118 } 119 120 let mut node_layer = nodes.as_slice(); 121 loop { 122 let mut next_node_layer = Vec::new(); 123 124 for node in node_layer.iter() { 125 for child_node_id in node.children.iter() { 126 if let Some(child_node) = hash_map.get(child_node_id) { 127 next_node_layer.push(child_node); 128 } 129 } 130 } 131 132 if next_node_layer.is_empty() { 133 break; 134 } 135 136 let pos = nodes.len(); 137 nodes.extend(next_node_layer); 138 139 node_layer = &nodes[pos..]; 140 } 141 142 BftIter { nodes } 143 } 144 } 145 146 impl<'a> Iterator for BftIter<'a> { 147 type Item = &'a DeviceNode; 148 149 fn next(&mut self) -> Option<Self::Item> { 150 if self.nodes.is_empty() { 151 None 152 } else { 153 Some(self.nodes.remove(0)) 154 } 155 } 156 } 157 158 impl<'a> DoubleEndedIterator for BftIter<'a> { 159 fn next_back(&mut self) -> Option<Self::Item> { 160 self.nodes.pop() 161 } 162 } 163 164 #[cfg(test)] 165 mod tests { 166 use super::{DeviceNode, DeviceTree}; 167 168 #[test] 169 fn test_device_tree() { 170 // Check new() 171 let mut device_tree = DeviceTree::new(); 172 assert_eq!(device_tree.0.len(), 0); 173 174 // Check insert() 175 let id = String::from("id1"); 176 device_tree.insert(id.clone(), DeviceNode::new(id.clone(), None)); 177 assert_eq!(device_tree.0.len(), 1); 178 let node = device_tree.0.get(&id); 179 assert!(node.is_some()); 180 let node = node.unwrap(); 181 assert_eq!(node.id, id); 182 183 // Check get() 184 let id2 = String::from("id2"); 185 assert!(device_tree.get(&id).is_some()); 186 assert!(device_tree.get(&id2).is_none()); 187 188 // Check get_mut() 189 let node = device_tree.get_mut(&id).unwrap(); 190 node.id = id2.clone(); 191 let node = device_tree.0.get(&id).unwrap(); 192 assert_eq!(node.id, id2); 193 194 // Check remove() 195 let node = device_tree.remove(&id).unwrap(); 196 assert_eq!(node.id, id2); 197 assert_eq!(device_tree.0.len(), 0); 198 199 // Check iter() 200 let disk_id = String::from("disk0"); 201 let net_id = String::from("net0"); 202 let rng_id = String::from("rng0"); 203 let device_list = vec![ 204 (disk_id.clone(), device_node!(disk_id)), 205 (net_id.clone(), device_node!(net_id)), 206 (rng_id.clone(), device_node!(rng_id)), 207 ]; 208 device_tree.0.extend(device_list); 209 for (id, node) in device_tree.iter() { 210 if id == &disk_id { 211 assert_eq!(node.id, disk_id); 212 } else if id == &net_id { 213 assert_eq!(node.id, net_id); 214 } else if id == &rng_id { 215 assert_eq!(node.id, rng_id); 216 } else { 217 unreachable!() 218 } 219 } 220 221 // Check breadth_first_traversal() based on the following hierarchy 222 // 223 // 0 224 // | \ 225 // 1 2 226 // | | \ 227 // 3 4 5 228 // 229 let mut device_tree = DeviceTree::new(); 230 let child_1_id = String::from("child1"); 231 let child_2_id = String::from("child2"); 232 let child_3_id = String::from("child3"); 233 let parent_1_id = String::from("parent1"); 234 let parent_2_id = String::from("parent2"); 235 let root_id = String::from("root"); 236 let mut child_1_node = device_node!(child_1_id); 237 let mut child_2_node = device_node!(child_2_id); 238 let mut child_3_node = device_node!(child_3_id); 239 let mut parent_1_node = device_node!(parent_1_id); 240 let mut parent_2_node = device_node!(parent_2_id); 241 let mut root_node = device_node!(root_id); 242 child_1_node.parent = Some(parent_1_id.clone()); 243 child_2_node.parent = Some(parent_2_id.clone()); 244 child_3_node.parent = Some(parent_2_id.clone()); 245 parent_1_node.children = vec![child_1_id.clone()]; 246 parent_1_node.parent = Some(root_id.clone()); 247 parent_2_node.children = vec![child_2_id.clone(), child_3_id.clone()]; 248 parent_2_node.parent = Some(root_id.clone()); 249 root_node.children = vec![parent_1_id.clone(), parent_2_id.clone()]; 250 let device_list = vec![ 251 (child_1_id.clone(), child_1_node), 252 (child_2_id.clone(), child_2_node), 253 (child_3_id.clone(), child_3_node), 254 (parent_1_id.clone(), parent_1_node), 255 (parent_2_id.clone(), parent_2_node), 256 (root_id.clone(), root_node), 257 ]; 258 device_tree.0.extend(device_list); 259 260 let iter_vec = device_tree 261 .breadth_first_traversal() 262 .collect::<Vec<&DeviceNode>>(); 263 assert_eq!(iter_vec.len(), 6); 264 assert_eq!(iter_vec[0].id, root_id); 265 assert_eq!(iter_vec[1].id, parent_1_id); 266 assert_eq!(iter_vec[2].id, parent_2_id); 267 assert_eq!(iter_vec[3].id, child_1_id); 268 assert_eq!(iter_vec[4].id, child_2_id); 269 assert_eq!(iter_vec[5].id, child_3_id); 270 271 let iter_vec = device_tree 272 .breadth_first_traversal() 273 .rev() 274 .collect::<Vec<&DeviceNode>>(); 275 assert_eq!(iter_vec.len(), 6); 276 assert_eq!(iter_vec[5].id, root_id); 277 assert_eq!(iter_vec[4].id, parent_1_id); 278 assert_eq!(iter_vec[3].id, parent_2_id); 279 assert_eq!(iter_vec[2].id, child_1_id); 280 assert_eq!(iter_vec[1].id, child_2_id); 281 assert_eq!(iter_vec[0].id, child_3_id); 282 } 283 } 284