xref: /cloud-hypervisor/net_util/src/queue_pair.rs (revision 6f8bd27cf7629733582d930519e98d19e90afb16)
1 // Copyright (c) 2020 Intel Corporation. All rights reserved.
2 //
3 // SPDX-License-Identifier: Apache-2.0 AND BSD-3-Clause
4 
5 use super::{register_listener, unregister_listener, vnet_hdr_len, Tap};
6 use crate::GuestMemoryMmap;
7 use rate_limiter::{RateLimiter, TokenType};
8 use std::io;
9 use std::num::Wrapping;
10 use std::os::unix::io::{AsRawFd, RawFd};
11 use std::sync::atomic::{AtomicU64, Ordering};
12 use std::sync::Arc;
13 use thiserror::Error;
14 use virtio_queue::{Queue, QueueOwnedT, QueueT};
15 use vm_memory::{Bytes, GuestMemory};
16 use vm_virtio::{AccessPlatform, Translatable};
17 
18 #[derive(Clone)]
19 pub struct TxVirtio {
20     pub counter_bytes: Wrapping<u64>,
21     pub counter_frames: Wrapping<u64>,
22 }
23 
24 impl Default for TxVirtio {
25     fn default() -> Self {
26         Self::new()
27     }
28 }
29 
30 impl TxVirtio {
31     pub fn new() -> Self {
32         TxVirtio {
33             counter_bytes: Wrapping(0),
34             counter_frames: Wrapping(0),
35         }
36     }
37 
38     pub fn process_desc_chain(
39         &mut self,
40         mem: &GuestMemoryMmap,
41         tap: &mut Tap,
42         queue: &mut Queue,
43         rate_limiter: &mut Option<RateLimiter>,
44         access_platform: Option<&Arc<dyn AccessPlatform>>,
45     ) -> Result<bool, NetQueuePairError> {
46         let mut retry_write = false;
47         let mut rate_limit_reached = false;
48 
49         while let Some(mut desc_chain) = queue.pop_descriptor_chain(mem) {
50             if rate_limit_reached {
51                 queue.go_to_previous_position();
52                 break;
53             }
54 
55             let mut next_desc = desc_chain.next();
56 
57             let mut iovecs = Vec::new();
58             while let Some(desc) = next_desc {
59                 let desc_addr = desc
60                     .addr()
61                     .translate_gva(access_platform, desc.len() as usize);
62                 if !desc.is_write_only() && desc.len() > 0 {
63                     let buf = desc_chain
64                         .memory()
65                         .get_slice(desc_addr, desc.len() as usize)
66                         .map_err(NetQueuePairError::GuestMemory)?
67                         .as_ptr();
68                     let iovec = libc::iovec {
69                         iov_base: buf as *mut libc::c_void,
70                         iov_len: desc.len() as libc::size_t,
71                     };
72                     iovecs.push(iovec);
73                 } else {
74                     error!(
75                         "Invalid descriptor chain: address = 0x{:x} length = {} write_only = {}",
76                         desc_addr.0,
77                         desc.len(),
78                         desc.is_write_only()
79                     );
80                     return Err(NetQueuePairError::DescriptorChainInvalid);
81                 }
82                 next_desc = desc_chain.next();
83             }
84 
85             let len = if !iovecs.is_empty() {
86                 // SAFETY: FFI call with correct arguments
87                 let result = unsafe {
88                     libc::writev(
89                         tap.as_raw_fd() as libc::c_int,
90                         iovecs.as_ptr() as *const libc::iovec,
91                         iovecs.len() as libc::c_int,
92                     )
93                 };
94 
95                 if result < 0 {
96                     let e = std::io::Error::last_os_error();
97 
98                     /* EAGAIN */
99                     if e.kind() == std::io::ErrorKind::WouldBlock {
100                         queue.go_to_previous_position();
101                         retry_write = true;
102                         break;
103                     }
104                     error!("net: tx: failed writing to tap: {}", e);
105                     return Err(NetQueuePairError::WriteTap(e));
106                 }
107 
108                 self.counter_bytes += Wrapping(result as u64 - vnet_hdr_len() as u64);
109                 self.counter_frames += Wrapping(1);
110 
111                 result as u32
112             } else {
113                 0
114             };
115 
116             // For the sake of simplicity (similar to the RX rate limiting), we always
117             // let the 'last' descriptor chain go-through even if it was over the rate
118             // limit, and simply stop processing oncoming `avail_desc` if any.
119             if let Some(rate_limiter) = rate_limiter {
120                 rate_limit_reached = !rate_limiter.consume(1, TokenType::Ops)
121                     || !rate_limiter.consume(len as u64, TokenType::Bytes);
122             }
123 
124             queue
125                 .add_used(desc_chain.memory(), desc_chain.head_index(), len)
126                 .map_err(NetQueuePairError::QueueAddUsed)?;
127 
128             if !queue
129                 .enable_notification(mem)
130                 .map_err(NetQueuePairError::QueueEnableNotification)?
131             {
132                 break;
133             }
134         }
135 
136         Ok(retry_write)
137     }
138 }
139 
140 #[derive(Clone)]
141 pub struct RxVirtio {
142     pub counter_bytes: Wrapping<u64>,
143     pub counter_frames: Wrapping<u64>,
144 }
145 
146 impl Default for RxVirtio {
147     fn default() -> Self {
148         Self::new()
149     }
150 }
151 
152 impl RxVirtio {
153     pub fn new() -> Self {
154         RxVirtio {
155             counter_bytes: Wrapping(0),
156             counter_frames: Wrapping(0),
157         }
158     }
159 
160     pub fn process_desc_chain(
161         &mut self,
162         mem: &GuestMemoryMmap,
163         tap: &mut Tap,
164         queue: &mut Queue,
165         rate_limiter: &mut Option<RateLimiter>,
166         access_platform: Option<&Arc<dyn AccessPlatform>>,
167     ) -> Result<bool, NetQueuePairError> {
168         let mut exhausted_descs = true;
169         let mut rate_limit_reached = false;
170 
171         while let Some(mut desc_chain) = queue.pop_descriptor_chain(mem) {
172             if rate_limit_reached {
173                 exhausted_descs = false;
174                 queue.go_to_previous_position();
175                 break;
176             }
177 
178             let desc = desc_chain
179                 .next()
180                 .ok_or(NetQueuePairError::DescriptorChainTooShort)?;
181 
182             let num_buffers_addr = desc_chain
183                 .memory()
184                 .checked_offset(
185                     desc.addr()
186                         .translate_gva(access_platform, desc.len() as usize),
187                     10,
188                 )
189                 .ok_or(NetQueuePairError::DescriptorInvalidHeader)?;
190             let mut next_desc = Some(desc);
191 
192             let mut iovecs = Vec::new();
193             while let Some(desc) = next_desc {
194                 let desc_addr = desc
195                     .addr()
196                     .translate_gva(access_platform, desc.len() as usize);
197                 if desc.is_write_only() && desc.len() > 0 {
198                     let buf = desc_chain
199                         .memory()
200                         .get_slice(desc_addr, desc.len() as usize)
201                         .map_err(NetQueuePairError::GuestMemory)?
202                         .as_ptr();
203                     let iovec = libc::iovec {
204                         iov_base: buf as *mut libc::c_void,
205                         iov_len: desc.len() as libc::size_t,
206                     };
207                     iovecs.push(iovec);
208                 } else {
209                     error!(
210                         "Invalid descriptor chain: address = 0x{:x} length = {} write_only = {}",
211                         desc_addr.0,
212                         desc.len(),
213                         desc.is_write_only()
214                     );
215                     return Err(NetQueuePairError::DescriptorChainInvalid);
216                 }
217                 next_desc = desc_chain.next();
218             }
219 
220             let len = if !iovecs.is_empty() {
221                 // SAFETY: FFI call with correct arguments
222                 let result = unsafe {
223                     libc::readv(
224                         tap.as_raw_fd() as libc::c_int,
225                         iovecs.as_ptr() as *const libc::iovec,
226                         iovecs.len() as libc::c_int,
227                     )
228                 };
229                 if result < 0 {
230                     let e = std::io::Error::last_os_error();
231                     exhausted_descs = false;
232                     queue.go_to_previous_position();
233 
234                     /* EAGAIN */
235                     if e.kind() == std::io::ErrorKind::WouldBlock {
236                         break;
237                     }
238 
239                     error!("net: rx: failed reading from tap: {}", e);
240                     return Err(NetQueuePairError::ReadTap(e));
241                 }
242 
243                 // Write num_buffers to guest memory. We simply write 1 as we
244                 // never spread the frame over more than one descriptor chain.
245                 desc_chain
246                     .memory()
247                     .write_obj(1u16, num_buffers_addr)
248                     .map_err(NetQueuePairError::GuestMemory)?;
249 
250                 self.counter_bytes += Wrapping(result as u64 - vnet_hdr_len() as u64);
251                 self.counter_frames += Wrapping(1);
252 
253                 result as u32
254             } else {
255                 0
256             };
257 
258             // For the sake of simplicity (keeping the handling of RX_QUEUE_EVENT and
259             // RX_TAP_EVENT totally asynchronous), we always let the 'last' descriptor
260             // chain go-through even if it was over the rate limit, and simply stop
261             // processing oncoming `avail_desc` if any.
262             if let Some(rate_limiter) = rate_limiter {
263                 rate_limit_reached = !rate_limiter.consume(1, TokenType::Ops)
264                     || !rate_limiter.consume(len as u64, TokenType::Bytes);
265             }
266 
267             queue
268                 .add_used(desc_chain.memory(), desc_chain.head_index(), len)
269                 .map_err(NetQueuePairError::QueueAddUsed)?;
270 
271             if !queue
272                 .enable_notification(mem)
273                 .map_err(NetQueuePairError::QueueEnableNotification)?
274             {
275                 break;
276             }
277         }
278 
279         Ok(exhausted_descs)
280     }
281 }
282 
283 #[derive(Default, Clone)]
284 pub struct NetCounters {
285     pub tx_bytes: Arc<AtomicU64>,
286     pub tx_frames: Arc<AtomicU64>,
287     pub rx_bytes: Arc<AtomicU64>,
288     pub rx_frames: Arc<AtomicU64>,
289 }
290 
291 #[derive(Error, Debug)]
292 pub enum NetQueuePairError {
293     #[error("No memory configured")]
294     NoMemoryConfigured,
295     #[error("Error registering listener: {0}")]
296     RegisterListener(io::Error),
297     #[error("Error unregistering listener: {0}")]
298     UnregisterListener(io::Error),
299     #[error("Error writing to the TAP device: {0}")]
300     WriteTap(io::Error),
301     #[error("Error reading from the TAP device: {0}")]
302     ReadTap(io::Error),
303     #[error("Error related to guest memory: {0}")]
304     GuestMemory(vm_memory::GuestMemoryError),
305     #[error("Returned an error while iterating through the queue: {0}")]
306     QueueIteratorFailed(virtio_queue::Error),
307     #[error("Descriptor chain is too short")]
308     DescriptorChainTooShort,
309     #[error("Descriptor chain does not contain valid descriptors")]
310     DescriptorChainInvalid,
311     #[error("Failed to determine if queue needed notification: {0}")]
312     QueueNeedsNotification(virtio_queue::Error),
313     #[error("Failed to enable notification on the queue: {0}")]
314     QueueEnableNotification(virtio_queue::Error),
315     #[error("Failed to add used index to the queue: {0}")]
316     QueueAddUsed(virtio_queue::Error),
317     #[error("Descriptor with invalid virtio-net header")]
318     DescriptorInvalidHeader,
319 }
320 
321 pub struct NetQueuePair {
322     pub tap: Tap,
323     // With epoll each FD must be unique. So in order to filter the
324     // events we need to get a second FD responding to the original
325     // device so that we can send EPOLLOUT and EPOLLIN to separate
326     // events.
327     pub tap_for_write_epoll: Tap,
328     pub rx: RxVirtio,
329     pub tx: TxVirtio,
330     pub epoll_fd: Option<RawFd>,
331     pub rx_tap_listening: bool,
332     pub tx_tap_listening: bool,
333     pub counters: NetCounters,
334     pub tap_rx_event_id: u16,
335     pub tap_tx_event_id: u16,
336     pub rx_desc_avail: bool,
337     pub rx_rate_limiter: Option<RateLimiter>,
338     pub tx_rate_limiter: Option<RateLimiter>,
339     pub access_platform: Option<Arc<dyn AccessPlatform>>,
340 }
341 
342 impl NetQueuePair {
343     pub fn process_tx(
344         &mut self,
345         mem: &GuestMemoryMmap,
346         queue: &mut Queue,
347     ) -> Result<bool, NetQueuePairError> {
348         let tx_tap_retry = self.tx.process_desc_chain(
349             mem,
350             &mut self.tap,
351             queue,
352             &mut self.tx_rate_limiter,
353             self.access_platform.as_ref(),
354         )?;
355 
356         // We got told to try again when writing to the tap. Wait for the TAP to be writable
357         if tx_tap_retry && !self.tx_tap_listening {
358             register_listener(
359                 self.epoll_fd.unwrap(),
360                 self.tap_for_write_epoll.as_raw_fd(),
361                 epoll::Events::EPOLLOUT,
362                 u64::from(self.tap_tx_event_id),
363             )
364             .map_err(NetQueuePairError::RegisterListener)?;
365             self.tx_tap_listening = true;
366             info!("Writing to TAP returned EAGAIN. Listening for TAP to become writable.");
367         } else if !tx_tap_retry && self.tx_tap_listening {
368             unregister_listener(
369                 self.epoll_fd.unwrap(),
370                 self.tap_for_write_epoll.as_raw_fd(),
371                 epoll::Events::EPOLLOUT,
372                 u64::from(self.tap_tx_event_id),
373             )
374             .map_err(NetQueuePairError::UnregisterListener)?;
375             self.tx_tap_listening = false;
376             info!("Writing to TAP succeeded. No longer listening for TAP to become writable.");
377         }
378 
379         self.counters
380             .tx_bytes
381             .fetch_add(self.tx.counter_bytes.0, Ordering::AcqRel);
382         self.counters
383             .tx_frames
384             .fetch_add(self.tx.counter_frames.0, Ordering::AcqRel);
385         self.tx.counter_bytes = Wrapping(0);
386         self.tx.counter_frames = Wrapping(0);
387 
388         queue
389             .needs_notification(mem)
390             .map_err(NetQueuePairError::QueueNeedsNotification)
391     }
392 
393     pub fn process_rx(
394         &mut self,
395         mem: &GuestMemoryMmap,
396         queue: &mut Queue,
397     ) -> Result<bool, NetQueuePairError> {
398         self.rx_desc_avail = !self.rx.process_desc_chain(
399             mem,
400             &mut self.tap,
401             queue,
402             &mut self.rx_rate_limiter,
403             self.access_platform.as_ref(),
404         )?;
405         let rate_limit_reached = self
406             .rx_rate_limiter
407             .as_ref()
408             .map_or(false, |r| r.is_blocked());
409 
410         // Stop listening on the `RX_TAP_EVENT` when:
411         // 1) there is no available describles, or
412         // 2) the RX rate limit is reached.
413         if self.rx_tap_listening && (!self.rx_desc_avail || rate_limit_reached) {
414             unregister_listener(
415                 self.epoll_fd.unwrap(),
416                 self.tap.as_raw_fd(),
417                 epoll::Events::EPOLLIN,
418                 u64::from(self.tap_rx_event_id),
419             )
420             .map_err(NetQueuePairError::UnregisterListener)?;
421             self.rx_tap_listening = false;
422         }
423 
424         self.counters
425             .rx_bytes
426             .fetch_add(self.rx.counter_bytes.0, Ordering::AcqRel);
427         self.counters
428             .rx_frames
429             .fetch_add(self.rx.counter_frames.0, Ordering::AcqRel);
430         self.rx.counter_bytes = Wrapping(0);
431         self.rx.counter_frames = Wrapping(0);
432 
433         queue
434             .needs_notification(mem)
435             .map_err(NetQueuePairError::QueueNeedsNotification)
436     }
437 }
438