1 // Copyright (c) 2020 Intel Corporation. All rights reserved. 2 // 3 // SPDX-License-Identifier: Apache-2.0 AND BSD-3-Clause 4 5 use std::io; 6 use std::num::Wrapping; 7 use std::os::unix::io::{AsRawFd, RawFd}; 8 use std::sync::atomic::{AtomicU64, Ordering}; 9 use std::sync::Arc; 10 11 use rate_limiter::{RateLimiter, TokenType}; 12 use thiserror::Error; 13 use virtio_queue::{Queue, QueueOwnedT, QueueT}; 14 use vm_memory::bitmap::Bitmap; 15 use vm_memory::{Bytes, GuestMemory}; 16 use vm_virtio::{AccessPlatform, Translatable}; 17 18 use super::{register_listener, unregister_listener, vnet_hdr_len, Tap}; 19 20 #[derive(Clone)] 21 pub struct TxVirtio { 22 pub counter_bytes: Wrapping<u64>, 23 pub counter_frames: Wrapping<u64>, 24 iovecs: IovecBuffer, 25 } 26 27 impl Default for TxVirtio { 28 fn default() -> Self { 29 Self::new() 30 } 31 } 32 33 impl TxVirtio { 34 pub fn new() -> Self { 35 TxVirtio { 36 counter_bytes: Wrapping(0), 37 counter_frames: Wrapping(0), 38 iovecs: IovecBuffer::new(), 39 } 40 } 41 42 pub fn process_desc_chain<B: Bitmap + 'static>( 43 &mut self, 44 mem: &vm_memory::GuestMemoryMmap<B>, 45 tap: &Tap, 46 queue: &mut Queue, 47 rate_limiter: &mut Option<RateLimiter>, 48 access_platform: Option<&Arc<dyn AccessPlatform>>, 49 ) -> Result<bool, NetQueuePairError> { 50 let mut retry_write = false; 51 let mut rate_limit_reached = false; 52 53 while let Some(mut desc_chain) = queue.pop_descriptor_chain(mem) { 54 if rate_limit_reached { 55 queue.go_to_previous_position(); 56 break; 57 } 58 59 let mut next_desc = desc_chain.next(); 60 61 let mut iovecs = self.iovecs.borrow(); 62 while let Some(desc) = next_desc { 63 let desc_addr = desc 64 .addr() 65 .translate_gva(access_platform, desc.len() as usize); 66 if !desc.is_write_only() && desc.len() > 0 { 67 let buf = desc_chain 68 .memory() 69 .get_slice(desc_addr, desc.len() as usize) 70 .map_err(NetQueuePairError::GuestMemory)? 71 .ptr_guard_mut(); 72 let iovec = libc::iovec { 73 iov_base: buf.as_ptr() as *mut libc::c_void, 74 iov_len: desc.len() as libc::size_t, 75 }; 76 iovecs.push(iovec); 77 } else { 78 error!( 79 "Invalid descriptor chain: address = 0x{:x} length = {} write_only = {}", 80 desc_addr.0, 81 desc.len(), 82 desc.is_write_only() 83 ); 84 return Err(NetQueuePairError::DescriptorChainInvalid); 85 } 86 next_desc = desc_chain.next(); 87 } 88 89 let len = if !iovecs.is_empty() { 90 // SAFETY: FFI call with correct arguments 91 let result = unsafe { 92 libc::writev( 93 tap.as_raw_fd() as libc::c_int, 94 iovecs.as_ptr(), 95 iovecs.len() as libc::c_int, 96 ) 97 }; 98 99 if result < 0 { 100 let e = std::io::Error::last_os_error(); 101 102 /* EAGAIN */ 103 if e.kind() == std::io::ErrorKind::WouldBlock { 104 queue.go_to_previous_position(); 105 retry_write = true; 106 break; 107 } 108 error!("net: tx: failed writing to tap: {}", e); 109 return Err(NetQueuePairError::WriteTap(e)); 110 } 111 112 if (result as usize) < vnet_hdr_len() { 113 return Err(NetQueuePairError::InvalidVirtioNetHeader); 114 } 115 116 self.counter_bytes += Wrapping(result as u64 - vnet_hdr_len() as u64); 117 self.counter_frames += Wrapping(1); 118 119 result as u32 120 } else { 121 0 122 }; 123 124 // For the sake of simplicity (similar to the RX rate limiting), we always 125 // let the 'last' descriptor chain go-through even if it was over the rate 126 // limit, and simply stop processing oncoming `avail_desc` if any. 127 if let Some(rate_limiter) = rate_limiter { 128 rate_limit_reached = !rate_limiter.consume(1, TokenType::Ops) 129 || !rate_limiter.consume(len as u64, TokenType::Bytes); 130 } 131 132 queue 133 .add_used(desc_chain.memory(), desc_chain.head_index(), len) 134 .map_err(NetQueuePairError::QueueAddUsed)?; 135 136 if !queue 137 .enable_notification(mem) 138 .map_err(NetQueuePairError::QueueEnableNotification)? 139 { 140 break; 141 } 142 } 143 144 Ok(retry_write) 145 } 146 } 147 148 #[derive(Clone)] 149 pub struct RxVirtio { 150 pub counter_bytes: Wrapping<u64>, 151 pub counter_frames: Wrapping<u64>, 152 iovecs: IovecBuffer, 153 } 154 155 impl Default for RxVirtio { 156 fn default() -> Self { 157 Self::new() 158 } 159 } 160 161 impl RxVirtio { 162 pub fn new() -> Self { 163 RxVirtio { 164 counter_bytes: Wrapping(0), 165 counter_frames: Wrapping(0), 166 iovecs: IovecBuffer::new(), 167 } 168 } 169 170 pub fn process_desc_chain<B: Bitmap + 'static>( 171 &mut self, 172 mem: &vm_memory::GuestMemoryMmap<B>, 173 tap: &Tap, 174 queue: &mut Queue, 175 rate_limiter: &mut Option<RateLimiter>, 176 access_platform: Option<&Arc<dyn AccessPlatform>>, 177 ) -> Result<bool, NetQueuePairError> { 178 let mut exhausted_descs = true; 179 let mut rate_limit_reached = false; 180 181 while let Some(mut desc_chain) = queue.pop_descriptor_chain(mem) { 182 if rate_limit_reached { 183 exhausted_descs = false; 184 queue.go_to_previous_position(); 185 break; 186 } 187 188 let desc = desc_chain 189 .next() 190 .ok_or(NetQueuePairError::DescriptorChainTooShort)?; 191 192 let num_buffers_addr = desc_chain 193 .memory() 194 .checked_offset( 195 desc.addr() 196 .translate_gva(access_platform, desc.len() as usize), 197 10, 198 ) 199 .ok_or(NetQueuePairError::DescriptorInvalidHeader)?; 200 let mut next_desc = Some(desc); 201 202 let mut iovecs = self.iovecs.borrow(); 203 while let Some(desc) = next_desc { 204 let desc_addr = desc 205 .addr() 206 .translate_gva(access_platform, desc.len() as usize); 207 if desc.is_write_only() && desc.len() > 0 { 208 let buf = desc_chain 209 .memory() 210 .get_slice(desc_addr, desc.len() as usize) 211 .map_err(NetQueuePairError::GuestMemory)? 212 .ptr_guard_mut(); 213 let iovec = libc::iovec { 214 iov_base: buf.as_ptr() as *mut libc::c_void, 215 iov_len: desc.len() as libc::size_t, 216 }; 217 iovecs.push(iovec); 218 } else { 219 error!( 220 "Invalid descriptor chain: address = 0x{:x} length = {} write_only = {}", 221 desc_addr.0, 222 desc.len(), 223 desc.is_write_only() 224 ); 225 return Err(NetQueuePairError::DescriptorChainInvalid); 226 } 227 next_desc = desc_chain.next(); 228 } 229 230 let len = if !iovecs.is_empty() { 231 // SAFETY: FFI call with correct arguments 232 let result = unsafe { 233 libc::readv( 234 tap.as_raw_fd() as libc::c_int, 235 iovecs.as_ptr(), 236 iovecs.len() as libc::c_int, 237 ) 238 }; 239 if result < 0 { 240 let e = std::io::Error::last_os_error(); 241 exhausted_descs = false; 242 queue.go_to_previous_position(); 243 244 /* EAGAIN */ 245 if e.kind() == std::io::ErrorKind::WouldBlock { 246 break; 247 } 248 249 error!("net: rx: failed reading from tap: {}", e); 250 return Err(NetQueuePairError::ReadTap(e)); 251 } 252 253 if (result as usize) < vnet_hdr_len() { 254 return Err(NetQueuePairError::InvalidVirtioNetHeader); 255 } 256 257 // Write num_buffers to guest memory. We simply write 1 as we 258 // never spread the frame over more than one descriptor chain. 259 desc_chain 260 .memory() 261 .write_obj(1u16, num_buffers_addr) 262 .map_err(NetQueuePairError::GuestMemory)?; 263 264 self.counter_bytes += Wrapping(result as u64 - vnet_hdr_len() as u64); 265 self.counter_frames += Wrapping(1); 266 267 result as u32 268 } else { 269 0 270 }; 271 272 // For the sake of simplicity (keeping the handling of RX_QUEUE_EVENT and 273 // RX_TAP_EVENT totally asynchronous), we always let the 'last' descriptor 274 // chain go-through even if it was over the rate limit, and simply stop 275 // processing oncoming `avail_desc` if any. 276 if let Some(rate_limiter) = rate_limiter { 277 rate_limit_reached = !rate_limiter.consume(1, TokenType::Ops) 278 || !rate_limiter.consume(len as u64, TokenType::Bytes); 279 } 280 281 queue 282 .add_used(desc_chain.memory(), desc_chain.head_index(), len) 283 .map_err(NetQueuePairError::QueueAddUsed)?; 284 285 if !queue 286 .enable_notification(mem) 287 .map_err(NetQueuePairError::QueueEnableNotification)? 288 { 289 break; 290 } 291 } 292 293 Ok(exhausted_descs) 294 } 295 } 296 297 #[derive(Default, Clone)] 298 struct IovecBuffer(Vec<libc::iovec>); 299 300 // SAFETY: Implementing Send for IovecBuffer is safe as the pointer inside is iovec. 301 // The iovecs are usually constructed from virtio descriptors, which are safe to send across 302 // threads. 303 unsafe impl Send for IovecBuffer {} 304 // SAFETY: Implementing Sync for IovecBuffer is safe as the pointer inside is iovec. 305 // The iovecs are usually constructed from virtio descriptors, which are safe to access from 306 // multiple threads. 307 unsafe impl Sync for IovecBuffer {} 308 309 impl IovecBuffer { 310 fn new() -> Self { 311 // Here we use 4 as the default capacity because it is enough for most cases. 312 const DEFAULT_CAPACITY: usize = 4; 313 IovecBuffer(Vec::with_capacity(DEFAULT_CAPACITY)) 314 } 315 316 fn borrow(&mut self) -> IovecBufferBorrowed<'_> { 317 IovecBufferBorrowed(&mut self.0) 318 } 319 } 320 321 struct IovecBufferBorrowed<'a>(&'a mut Vec<libc::iovec>); 322 323 impl std::ops::Deref for IovecBufferBorrowed<'_> { 324 type Target = Vec<libc::iovec>; 325 326 fn deref(&self) -> &Self::Target { 327 self.0 328 } 329 } 330 331 impl std::ops::DerefMut for IovecBufferBorrowed<'_> { 332 fn deref_mut(&mut self) -> &mut Self::Target { 333 self.0 334 } 335 } 336 337 impl Drop for IovecBufferBorrowed<'_> { 338 fn drop(&mut self) { 339 // Clear the buffer to make sure old values are not used after 340 self.0.clear(); 341 } 342 } 343 344 #[derive(Default, Clone)] 345 pub struct NetCounters { 346 pub tx_bytes: Arc<AtomicU64>, 347 pub tx_frames: Arc<AtomicU64>, 348 pub rx_bytes: Arc<AtomicU64>, 349 pub rx_frames: Arc<AtomicU64>, 350 } 351 352 #[derive(Error, Debug)] 353 pub enum NetQueuePairError { 354 #[error("No memory configured")] 355 NoMemoryConfigured, 356 #[error("Error registering listener: {0}")] 357 RegisterListener(io::Error), 358 #[error("Error unregistering listener: {0}")] 359 UnregisterListener(io::Error), 360 #[error("Error writing to the TAP device: {0}")] 361 WriteTap(io::Error), 362 #[error("Error reading from the TAP device: {0}")] 363 ReadTap(io::Error), 364 #[error("Error related to guest memory: {0}")] 365 GuestMemory(vm_memory::GuestMemoryError), 366 #[error("Returned an error while iterating through the queue: {0}")] 367 QueueIteratorFailed(virtio_queue::Error), 368 #[error("Descriptor chain is too short")] 369 DescriptorChainTooShort, 370 #[error("Descriptor chain does not contain valid descriptors")] 371 DescriptorChainInvalid, 372 #[error("Failed to determine if queue needed notification: {0}")] 373 QueueNeedsNotification(virtio_queue::Error), 374 #[error("Failed to enable notification on the queue: {0}")] 375 QueueEnableNotification(virtio_queue::Error), 376 #[error("Failed to add used index to the queue: {0}")] 377 QueueAddUsed(virtio_queue::Error), 378 #[error("Descriptor with invalid virtio-net header")] 379 DescriptorInvalidHeader, 380 #[error("Invalid virtio-net header")] 381 InvalidVirtioNetHeader, 382 } 383 384 pub struct NetQueuePair { 385 pub tap: Tap, 386 // With epoll each FD must be unique. So in order to filter the 387 // events we need to get a second FD responding to the original 388 // device so that we can send EPOLLOUT and EPOLLIN to separate 389 // events. 390 pub tap_for_write_epoll: Tap, 391 pub rx: RxVirtio, 392 pub tx: TxVirtio, 393 pub epoll_fd: Option<RawFd>, 394 pub rx_tap_listening: bool, 395 pub tx_tap_listening: bool, 396 pub counters: NetCounters, 397 pub tap_rx_event_id: u16, 398 pub tap_tx_event_id: u16, 399 pub rx_desc_avail: bool, 400 pub rx_rate_limiter: Option<RateLimiter>, 401 pub tx_rate_limiter: Option<RateLimiter>, 402 pub access_platform: Option<Arc<dyn AccessPlatform>>, 403 } 404 405 impl NetQueuePair { 406 pub fn process_tx<B: Bitmap + 'static>( 407 &mut self, 408 mem: &vm_memory::GuestMemoryMmap<B>, 409 queue: &mut Queue, 410 ) -> Result<bool, NetQueuePairError> { 411 let tx_tap_retry = self.tx.process_desc_chain( 412 mem, 413 &self.tap, 414 queue, 415 &mut self.tx_rate_limiter, 416 self.access_platform.as_ref(), 417 )?; 418 419 // We got told to try again when writing to the tap. Wait for the TAP to be writable 420 if tx_tap_retry && !self.tx_tap_listening { 421 register_listener( 422 self.epoll_fd.unwrap(), 423 self.tap_for_write_epoll.as_raw_fd(), 424 epoll::Events::EPOLLOUT, 425 u64::from(self.tap_tx_event_id), 426 ) 427 .map_err(NetQueuePairError::RegisterListener)?; 428 self.tx_tap_listening = true; 429 info!("Writing to TAP returned EAGAIN. Listening for TAP to become writable."); 430 } else if !tx_tap_retry && self.tx_tap_listening { 431 unregister_listener( 432 self.epoll_fd.unwrap(), 433 self.tap_for_write_epoll.as_raw_fd(), 434 epoll::Events::EPOLLOUT, 435 u64::from(self.tap_tx_event_id), 436 ) 437 .map_err(NetQueuePairError::UnregisterListener)?; 438 self.tx_tap_listening = false; 439 info!("Writing to TAP succeeded. No longer listening for TAP to become writable."); 440 } 441 442 self.counters 443 .tx_bytes 444 .fetch_add(self.tx.counter_bytes.0, Ordering::AcqRel); 445 self.counters 446 .tx_frames 447 .fetch_add(self.tx.counter_frames.0, Ordering::AcqRel); 448 self.tx.counter_bytes = Wrapping(0); 449 self.tx.counter_frames = Wrapping(0); 450 451 queue 452 .needs_notification(mem) 453 .map_err(NetQueuePairError::QueueNeedsNotification) 454 } 455 456 pub fn process_rx<B: Bitmap + 'static>( 457 &mut self, 458 mem: &vm_memory::GuestMemoryMmap<B>, 459 queue: &mut Queue, 460 ) -> Result<bool, NetQueuePairError> { 461 self.rx_desc_avail = !self.rx.process_desc_chain( 462 mem, 463 &self.tap, 464 queue, 465 &mut self.rx_rate_limiter, 466 self.access_platform.as_ref(), 467 )?; 468 let rate_limit_reached = self 469 .rx_rate_limiter 470 .as_ref() 471 .is_some_and(|r| r.is_blocked()); 472 473 // Stop listening on the `RX_TAP_EVENT` when: 474 // 1) there is no available describes, or 475 // 2) the RX rate limit is reached. 476 if self.rx_tap_listening && (!self.rx_desc_avail || rate_limit_reached) { 477 unregister_listener( 478 self.epoll_fd.unwrap(), 479 self.tap.as_raw_fd(), 480 epoll::Events::EPOLLIN, 481 u64::from(self.tap_rx_event_id), 482 ) 483 .map_err(NetQueuePairError::UnregisterListener)?; 484 self.rx_tap_listening = false; 485 } 486 487 self.counters 488 .rx_bytes 489 .fetch_add(self.rx.counter_bytes.0, Ordering::AcqRel); 490 self.counters 491 .rx_frames 492 .fetch_add(self.rx.counter_frames.0, Ordering::AcqRel); 493 self.rx.counter_bytes = Wrapping(0); 494 self.rx.counter_frames = Wrapping(0); 495 496 queue 497 .needs_notification(mem) 498 .map_err(NetQueuePairError::QueueNeedsNotification) 499 } 500 } 501