1 // Copyright © 2020 Intel Corporation 2 // 3 // SPDX-License-Identifier: Apache-2.0 4 5 use std::collections::HashMap; 6 use std::sync::{Arc, Mutex}; 7 8 use pci::PciBdf; 9 use serde::{Deserialize, Serialize}; 10 use vm_device::Resource; 11 use vm_migration::Migratable; 12 13 use crate::device_manager::PciDeviceHandle; 14 15 #[derive(Clone, Serialize, Deserialize)] 16 pub struct DeviceNode { 17 pub id: String, 18 pub resources: Vec<Resource>, 19 pub parent: Option<String>, 20 pub children: Vec<String>, 21 #[serde(skip)] 22 pub migratable: Option<Arc<Mutex<dyn Migratable>>>, 23 pub pci_bdf: Option<PciBdf>, 24 #[serde(skip)] 25 pub pci_device_handle: Option<PciDeviceHandle>, 26 } 27 28 impl DeviceNode { 29 pub fn new(id: String, migratable: Option<Arc<Mutex<dyn Migratable>>>) -> Self { 30 DeviceNode { 31 id, 32 resources: Vec::new(), 33 parent: None, 34 children: Vec::new(), 35 migratable, 36 pci_bdf: None, 37 pci_device_handle: None, 38 } 39 } 40 } 41 42 #[macro_export] 43 macro_rules! device_node { 44 ($id:ident) => { 45 DeviceNode::new($id.clone(), None) 46 }; 47 ($id:ident, $device:ident) => { 48 DeviceNode::new( 49 $id.clone(), 50 Some(Arc::clone(&$device) as Arc<Mutex<dyn Migratable>>), 51 ) 52 }; 53 } 54 55 #[derive(Clone, Default, Serialize, Deserialize)] 56 pub struct DeviceTree(HashMap<String, DeviceNode>); 57 58 impl DeviceTree { 59 pub fn new() -> Self { 60 DeviceTree(HashMap::new()) 61 } 62 pub fn contains_key(&self, k: &str) -> bool { 63 self.0.contains_key(k) 64 } 65 pub fn get(&self, k: &str) -> Option<&DeviceNode> { 66 self.0.get(k) 67 } 68 pub fn get_mut(&mut self, k: &str) -> Option<&mut DeviceNode> { 69 self.0.get_mut(k) 70 } 71 pub fn insert(&mut self, k: String, v: DeviceNode) -> Option<DeviceNode> { 72 self.0.insert(k, v) 73 } 74 pub fn remove(&mut self, k: &str) -> Option<DeviceNode> { 75 self.0.remove(k) 76 } 77 pub fn iter(&self) -> std::collections::hash_map::Iter<String, DeviceNode> { 78 self.0.iter() 79 } 80 pub fn breadth_first_traversal(&self) -> BftIter { 81 BftIter::new(&self.0) 82 } 83 pub fn pci_devices(&self) -> Vec<&DeviceNode> { 84 self.0 85 .values() 86 .filter(|v| v.pci_bdf.is_some() && v.pci_device_handle.is_some()) 87 .collect() 88 } 89 90 pub fn remove_node_by_pci_bdf(&mut self, pci_bdf: PciBdf) -> Option<DeviceNode> { 91 let mut id = None; 92 for (k, v) in self.0.iter() { 93 if v.pci_bdf == Some(pci_bdf) { 94 id = Some(k.clone()); 95 break; 96 } 97 } 98 99 if let Some(id) = &id { 100 self.0.remove(id) 101 } else { 102 None 103 } 104 } 105 } 106 107 // Breadth first traversal iterator. 108 pub struct BftIter<'a> { 109 nodes: Vec<&'a DeviceNode>, 110 } 111 112 impl<'a> BftIter<'a> { 113 fn new(hash_map: &'a HashMap<String, DeviceNode>) -> Self { 114 let mut nodes = Vec::with_capacity(hash_map.len()); 115 let mut i = 0; 116 117 for (_, node) in hash_map.iter() { 118 if node.parent.is_none() { 119 nodes.push(node); 120 } 121 } 122 123 while i < nodes.len() { 124 for child_node_id in nodes[i].children.iter() { 125 if let Some(child_node) = hash_map.get(child_node_id) { 126 nodes.push(child_node); 127 } 128 } 129 i += 1; 130 } 131 132 BftIter { nodes } 133 } 134 } 135 136 impl<'a> Iterator for BftIter<'a> { 137 type Item = &'a DeviceNode; 138 139 fn next(&mut self) -> Option<Self::Item> { 140 if self.nodes.is_empty() { 141 None 142 } else { 143 Some(self.nodes.remove(0)) 144 } 145 } 146 } 147 148 impl DoubleEndedIterator for BftIter<'_> { 149 fn next_back(&mut self) -> Option<Self::Item> { 150 self.nodes.pop() 151 } 152 } 153 154 #[cfg(test)] 155 mod tests { 156 use super::{DeviceNode, DeviceTree}; 157 158 #[test] 159 fn test_device_tree() { 160 // Check new() 161 let mut device_tree = DeviceTree::new(); 162 assert_eq!(device_tree.0.len(), 0); 163 164 // Check insert() 165 let id = String::from("id1"); 166 device_tree.insert(id.clone(), DeviceNode::new(id.clone(), None)); 167 assert_eq!(device_tree.0.len(), 1); 168 let node = device_tree.0.get(&id); 169 assert!(node.is_some()); 170 let node = node.unwrap(); 171 assert_eq!(node.id, id); 172 173 // Check get() 174 let id2 = String::from("id2"); 175 assert!(device_tree.get(&id).is_some()); 176 assert!(device_tree.get(&id2).is_none()); 177 178 // Check get_mut() 179 let node = device_tree.get_mut(&id).unwrap(); 180 node.id.clone_from(&id2); 181 let node = device_tree.0.get(&id).unwrap(); 182 assert_eq!(node.id, id2); 183 184 // Check remove() 185 let node = device_tree.remove(&id).unwrap(); 186 assert_eq!(node.id, id2); 187 assert_eq!(device_tree.0.len(), 0); 188 189 // Check iter() 190 let disk_id = String::from("disk0"); 191 let net_id = String::from("net0"); 192 let rng_id = String::from("rng0"); 193 let device_list = vec![ 194 (disk_id.clone(), device_node!(disk_id)), 195 (net_id.clone(), device_node!(net_id)), 196 (rng_id.clone(), device_node!(rng_id)), 197 ]; 198 device_tree.0.extend(device_list); 199 for (id, node) in device_tree.iter() { 200 if id == &disk_id { 201 assert_eq!(node.id, disk_id); 202 } else if id == &net_id { 203 assert_eq!(node.id, net_id); 204 } else if id == &rng_id { 205 assert_eq!(node.id, rng_id); 206 } else { 207 unreachable!() 208 } 209 } 210 211 // Check breadth_first_traversal() based on the following hierarchy 212 // 213 // 0 214 // | \ 215 // 1 2 216 // | | \ 217 // 3 4 5 218 // 219 let mut device_tree = DeviceTree::new(); 220 let child_1_id = String::from("child1"); 221 let child_2_id = String::from("child2"); 222 let child_3_id = String::from("child3"); 223 let parent_1_id = String::from("parent1"); 224 let parent_2_id = String::from("parent2"); 225 let root_id = String::from("root"); 226 let mut child_1_node = device_node!(child_1_id); 227 let mut child_2_node = device_node!(child_2_id); 228 let mut child_3_node = device_node!(child_3_id); 229 let mut parent_1_node = device_node!(parent_1_id); 230 let mut parent_2_node = device_node!(parent_2_id); 231 let mut root_node = device_node!(root_id); 232 child_1_node.parent = Some(parent_1_id.clone()); 233 child_2_node.parent = Some(parent_2_id.clone()); 234 child_3_node.parent = Some(parent_2_id.clone()); 235 parent_1_node.children = vec![child_1_id.clone()]; 236 parent_1_node.parent = Some(root_id.clone()); 237 parent_2_node.children = vec![child_2_id.clone(), child_3_id.clone()]; 238 parent_2_node.parent = Some(root_id.clone()); 239 root_node.children = vec![parent_1_id.clone(), parent_2_id.clone()]; 240 let device_list = vec![ 241 (child_1_id.clone(), child_1_node), 242 (child_2_id.clone(), child_2_node), 243 (child_3_id.clone(), child_3_node), 244 (parent_1_id.clone(), parent_1_node), 245 (parent_2_id.clone(), parent_2_node), 246 (root_id.clone(), root_node), 247 ]; 248 device_tree.0.extend(device_list); 249 250 let iter_vec = device_tree 251 .breadth_first_traversal() 252 .collect::<Vec<&DeviceNode>>(); 253 assert_eq!(iter_vec.len(), 6); 254 assert_eq!(iter_vec[0].id, root_id); 255 assert_eq!(iter_vec[1].id, parent_1_id); 256 assert_eq!(iter_vec[2].id, parent_2_id); 257 assert_eq!(iter_vec[3].id, child_1_id); 258 assert_eq!(iter_vec[4].id, child_2_id); 259 assert_eq!(iter_vec[5].id, child_3_id); 260 261 let iter_vec = device_tree 262 .breadth_first_traversal() 263 .rev() 264 .collect::<Vec<&DeviceNode>>(); 265 assert_eq!(iter_vec.len(), 6); 266 assert_eq!(iter_vec[5].id, root_id); 267 assert_eq!(iter_vec[4].id, parent_1_id); 268 assert_eq!(iter_vec[3].id, parent_2_id); 269 assert_eq!(iter_vec[2].id, child_1_id); 270 assert_eq!(iter_vec[1].id, child_2_id); 271 assert_eq!(iter_vec[0].id, child_3_id); 272 } 273 } 274