1 // Copyright © 2020 Intel Corporation 2 // 3 // SPDX-License-Identifier: Apache-2.0 4 5 use crate::device_manager::PciDeviceHandle; 6 use pci::PciBdf; 7 use serde::{Deserialize, Serialize}; 8 use std::collections::HashMap; 9 use std::sync::{Arc, Mutex}; 10 use vm_device::Resource; 11 use vm_migration::Migratable; 12 13 #[derive(Clone, Serialize, Deserialize)] 14 pub struct DeviceNode { 15 pub id: String, 16 pub resources: Vec<Resource>, 17 pub parent: Option<String>, 18 pub children: Vec<String>, 19 #[serde(skip)] 20 pub migratable: Option<Arc<Mutex<dyn Migratable>>>, 21 pub pci_bdf: Option<PciBdf>, 22 #[serde(skip)] 23 pub pci_device_handle: Option<PciDeviceHandle>, 24 } 25 26 impl DeviceNode { 27 pub fn new(id: String, migratable: Option<Arc<Mutex<dyn Migratable>>>) -> Self { 28 DeviceNode { 29 id, 30 resources: Vec::new(), 31 parent: None, 32 children: Vec::new(), 33 migratable, 34 pci_bdf: None, 35 pci_device_handle: None, 36 } 37 } 38 } 39 40 #[macro_export] 41 macro_rules! device_node { 42 ($id:ident) => { 43 DeviceNode::new($id.clone(), None) 44 }; 45 ($id:ident, $device:ident) => { 46 DeviceNode::new( 47 $id.clone(), 48 Some(Arc::clone(&$device) as Arc<Mutex<dyn Migratable>>), 49 ) 50 }; 51 } 52 53 #[derive(Clone, Default, Serialize, Deserialize)] 54 pub struct DeviceTree(HashMap<String, DeviceNode>); 55 56 impl DeviceTree { 57 pub fn new() -> Self { 58 DeviceTree(HashMap::new()) 59 } 60 pub fn contains_key(&self, k: &str) -> bool { 61 self.0.contains_key(k) 62 } 63 pub fn get(&self, k: &str) -> Option<&DeviceNode> { 64 self.0.get(k) 65 } 66 pub fn get_mut(&mut self, k: &str) -> Option<&mut DeviceNode> { 67 self.0.get_mut(k) 68 } 69 pub fn insert(&mut self, k: String, v: DeviceNode) -> Option<DeviceNode> { 70 self.0.insert(k, v) 71 } 72 pub fn remove(&mut self, k: &str) -> Option<DeviceNode> { 73 self.0.remove(k) 74 } 75 pub fn iter(&self) -> std::collections::hash_map::Iter<String, DeviceNode> { 76 self.0.iter() 77 } 78 pub fn breadth_first_traversal(&self) -> BftIter { 79 BftIter::new(&self.0) 80 } 81 pub fn pci_devices(&self) -> Vec<&DeviceNode> { 82 self.0 83 .values() 84 .filter(|v| v.pci_bdf.is_some() && v.pci_device_handle.is_some()) 85 .collect() 86 } 87 88 pub fn remove_node_by_pci_bdf(&mut self, pci_bdf: PciBdf) -> Option<DeviceNode> { 89 let mut id = None; 90 for (k, v) in self.0.iter() { 91 if v.pci_bdf == Some(pci_bdf) { 92 id = Some(k.clone()); 93 break; 94 } 95 } 96 97 if let Some(id) = &id { 98 self.0.remove(id) 99 } else { 100 None 101 } 102 } 103 } 104 105 // Breadth first traversal iterator. 106 pub struct BftIter<'a> { 107 nodes: Vec<&'a DeviceNode>, 108 } 109 110 impl<'a> BftIter<'a> { 111 fn new(hash_map: &'a HashMap<String, DeviceNode>) -> Self { 112 let mut nodes = Vec::with_capacity(hash_map.len()); 113 let mut i = 0; 114 115 for (_, node) in hash_map.iter() { 116 if node.parent.is_none() { 117 nodes.push(node); 118 } 119 } 120 121 while i < nodes.len() { 122 for child_node_id in nodes[i].children.iter() { 123 if let Some(child_node) = hash_map.get(child_node_id) { 124 nodes.push(child_node); 125 } 126 } 127 i += 1; 128 } 129 130 BftIter { nodes } 131 } 132 } 133 134 impl<'a> Iterator for BftIter<'a> { 135 type Item = &'a DeviceNode; 136 137 fn next(&mut self) -> Option<Self::Item> { 138 if self.nodes.is_empty() { 139 None 140 } else { 141 Some(self.nodes.remove(0)) 142 } 143 } 144 } 145 146 impl<'a> DoubleEndedIterator for BftIter<'a> { 147 fn next_back(&mut self) -> Option<Self::Item> { 148 self.nodes.pop() 149 } 150 } 151 152 #[cfg(test)] 153 mod tests { 154 use super::{DeviceNode, DeviceTree}; 155 156 #[test] 157 fn test_device_tree() { 158 // Check new() 159 let mut device_tree = DeviceTree::new(); 160 assert_eq!(device_tree.0.len(), 0); 161 162 // Check insert() 163 let id = String::from("id1"); 164 device_tree.insert(id.clone(), DeviceNode::new(id.clone(), None)); 165 assert_eq!(device_tree.0.len(), 1); 166 let node = device_tree.0.get(&id); 167 assert!(node.is_some()); 168 let node = node.unwrap(); 169 assert_eq!(node.id, id); 170 171 // Check get() 172 let id2 = String::from("id2"); 173 assert!(device_tree.get(&id).is_some()); 174 assert!(device_tree.get(&id2).is_none()); 175 176 // Check get_mut() 177 let node = device_tree.get_mut(&id).unwrap(); 178 node.id = id2.clone(); 179 let node = device_tree.0.get(&id).unwrap(); 180 assert_eq!(node.id, id2); 181 182 // Check remove() 183 let node = device_tree.remove(&id).unwrap(); 184 assert_eq!(node.id, id2); 185 assert_eq!(device_tree.0.len(), 0); 186 187 // Check iter() 188 let disk_id = String::from("disk0"); 189 let net_id = String::from("net0"); 190 let rng_id = String::from("rng0"); 191 let device_list = vec![ 192 (disk_id.clone(), device_node!(disk_id)), 193 (net_id.clone(), device_node!(net_id)), 194 (rng_id.clone(), device_node!(rng_id)), 195 ]; 196 device_tree.0.extend(device_list); 197 for (id, node) in device_tree.iter() { 198 if id == &disk_id { 199 assert_eq!(node.id, disk_id); 200 } else if id == &net_id { 201 assert_eq!(node.id, net_id); 202 } else if id == &rng_id { 203 assert_eq!(node.id, rng_id); 204 } else { 205 unreachable!() 206 } 207 } 208 209 // Check breadth_first_traversal() based on the following hierarchy 210 // 211 // 0 212 // | \ 213 // 1 2 214 // | | \ 215 // 3 4 5 216 // 217 let mut device_tree = DeviceTree::new(); 218 let child_1_id = String::from("child1"); 219 let child_2_id = String::from("child2"); 220 let child_3_id = String::from("child3"); 221 let parent_1_id = String::from("parent1"); 222 let parent_2_id = String::from("parent2"); 223 let root_id = String::from("root"); 224 let mut child_1_node = device_node!(child_1_id); 225 let mut child_2_node = device_node!(child_2_id); 226 let mut child_3_node = device_node!(child_3_id); 227 let mut parent_1_node = device_node!(parent_1_id); 228 let mut parent_2_node = device_node!(parent_2_id); 229 let mut root_node = device_node!(root_id); 230 child_1_node.parent = Some(parent_1_id.clone()); 231 child_2_node.parent = Some(parent_2_id.clone()); 232 child_3_node.parent = Some(parent_2_id.clone()); 233 parent_1_node.children = vec![child_1_id.clone()]; 234 parent_1_node.parent = Some(root_id.clone()); 235 parent_2_node.children = vec![child_2_id.clone(), child_3_id.clone()]; 236 parent_2_node.parent = Some(root_id.clone()); 237 root_node.children = vec![parent_1_id.clone(), parent_2_id.clone()]; 238 let device_list = vec![ 239 (child_1_id.clone(), child_1_node), 240 (child_2_id.clone(), child_2_node), 241 (child_3_id.clone(), child_3_node), 242 (parent_1_id.clone(), parent_1_node), 243 (parent_2_id.clone(), parent_2_node), 244 (root_id.clone(), root_node), 245 ]; 246 device_tree.0.extend(device_list); 247 248 let iter_vec = device_tree 249 .breadth_first_traversal() 250 .collect::<Vec<&DeviceNode>>(); 251 assert_eq!(iter_vec.len(), 6); 252 assert_eq!(iter_vec[0].id, root_id); 253 assert_eq!(iter_vec[1].id, parent_1_id); 254 assert_eq!(iter_vec[2].id, parent_2_id); 255 assert_eq!(iter_vec[3].id, child_1_id); 256 assert_eq!(iter_vec[4].id, child_2_id); 257 assert_eq!(iter_vec[5].id, child_3_id); 258 259 let iter_vec = device_tree 260 .breadth_first_traversal() 261 .rev() 262 .collect::<Vec<&DeviceNode>>(); 263 assert_eq!(iter_vec.len(), 6); 264 assert_eq!(iter_vec[5].id, root_id); 265 assert_eq!(iter_vec[4].id, parent_1_id); 266 assert_eq!(iter_vec[3].id, parent_2_id); 267 assert_eq!(iter_vec[2].id, child_1_id); 268 assert_eq!(iter_vec[1].id, child_2_id); 269 assert_eq!(iter_vec[0].id, child_3_id); 270 } 271 } 272