xref: /cloud-hypervisor/virtio-devices/src/vsock/unix/muxer.rs (revision 19d36c765fdf00be749d95b3e61028bc302d6d73)
1 // Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 // SPDX-License-Identifier: Apache-2.0
3 //
4 
5 //! `VsockMuxer` is the device-facing component of the Unix domain sockets vsock backend. I.e.
6 //! by implementing the `VsockBackend` trait, it abstracts away the gory details of translating
7 //! between AF_VSOCK and AF_UNIX, and presents a clean interface to the rest of the vsock
8 //! device model.
9 //!
10 //! The vsock muxer has two main roles:
11 //!
12 //! ## Vsock connection multiplexer
13 //!
14 //! It's the muxer's job to create, manage, and terminate `VsockConnection` objects. The
15 //! muxer also routes packets to their owning connections. It does so via a connection
16 //! `HashMap`, keyed by what is basically a (host_port, guest_port) tuple.
17 //!
18 //! Vsock packet traffic needs to be inspected, in order to detect connection request
19 //! packets (leading to the creation of a new connection), and connection reset packets
20 //! (leading to the termination of an existing connection). All other packets, though, must
21 //! belong to an existing connection and, as such, the muxer simply forwards them.
22 //!
23 //! ## Event dispatcher
24 //!
25 //! There are three event categories that the vsock backend is interested it:
26 //! 1. A new host-initiated connection is ready to be accepted from the listening host Unix
27 //!    socket;
28 //! 2. Data is available for reading from a newly-accepted host-initiated connection (i.e.
29 //!    the host is ready to issue a vsock connection request, informing us of the
30 //!    destination port to which it wants to connect);
31 //! 3. Some event was triggered for a connected Unix socket, that belongs to a
32 //!    `VsockConnection`.
33 //!
34 //! The muxer gets notified about all of these events, because, as a `VsockEpollListener`
35 //! implementor, it gets to register a nested epoll FD into the main VMM epoll()ing loop. All
36 //! other pollable FDs are then registered under this nested epoll FD.
37 //!
38 //! To route all these events to their handlers, the muxer uses another `HashMap` object,
39 //! mapping `RawFd`s to `EpollListener`s.
40 
41 use std::collections::{HashMap, HashSet};
42 use std::fs::File;
43 use std::io::{self, ErrorKind, Read};
44 use std::os::unix::io::{AsRawFd, FromRawFd, RawFd};
45 use std::os::unix::net::{UnixListener, UnixStream};
46 
47 use super::super::csm::ConnState;
48 use super::super::defs::uapi;
49 use super::super::packet::VsockPacket;
50 use super::super::{
51     Result as VsockResult, VsockBackend, VsockChannel, VsockEpollListener, VsockError,
52 };
53 use super::muxer_killq::MuxerKillQ;
54 use super::muxer_rxq::MuxerRxQ;
55 use super::{defs, Error, MuxerConnection, Result};
56 
57 /// A unique identifier of a `MuxerConnection` object. Connections are stored in a hash map,
58 /// keyed by a `ConnMapKey` object.
59 ///
60 #[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
61 pub struct ConnMapKey {
62     local_port: u32,
63     peer_port: u32,
64 }
65 
66 /// A muxer RX queue item.
67 ///
68 #[derive(Clone, Copy, Debug)]
69 pub enum MuxerRx {
70     /// The packet must be fetched from the connection identified by `ConnMapKey`.
71     ConnRx(ConnMapKey),
72     /// The muxer must produce an RST packet.
73     RstPkt { local_port: u32, peer_port: u32 },
74 }
75 
76 /// An epoll listener, registered under the muxer's nested epoll FD.
77 ///
78 enum EpollListener {
79     /// The listener is a `MuxerConnection`, identified by `key`, and interested in the events
80     /// in `evset`. Since `MuxerConnection` implements `VsockEpollListener`, notifications will
81     /// be forwarded to the listener via `VsockEpollListener::notify()`.
82     Connection {
83         key: ConnMapKey,
84         evset: epoll::Events,
85     },
86     /// A listener interested in new host-initiated connections.
87     HostSock,
88     /// A listener interested in reading host "connect \<port>" commands from a freshly
89     /// connected host socket.
90     LocalStream(UnixStream),
91 }
92 
93 /// A partially read "CONNECT" command.
94 #[derive(Default)]
95 struct PartiallyReadCommand {
96     /// The bytes of the command that have been read so far.
97     buf: [u8; 32],
98     /// How much of `buf` has been used.
99     len: usize,
100 }
101 
102 /// The vsock connection multiplexer.
103 ///
104 pub struct VsockMuxer {
105     /// Guest CID.
106     cid: u64,
107     /// A hash map used to store the active connections.
108     conn_map: HashMap<ConnMapKey, MuxerConnection>,
109     /// A hash map used to store epoll event listeners / handlers.
110     listener_map: HashMap<RawFd, EpollListener>,
111     /// A hash map used to store partially read "connect" commands.
112     partial_command_map: HashMap<RawFd, PartiallyReadCommand>,
113     /// The RX queue. Items in this queue are consumed by `VsockMuxer::recv_pkt()`, and
114     /// produced
115     /// - by `VsockMuxer::send_pkt()` (e.g. RST in response to a connection request packet);
116     ///   and
117     /// - in response to EPOLLIN events (e.g. data available to be read from an AF_UNIX
118     ///   socket).
119     rxq: MuxerRxQ,
120     /// A queue used for terminating connections that are taking too long to shut down.
121     killq: MuxerKillQ,
122     /// The Unix socket, through which host-initiated connections are accepted.
123     host_sock: UnixListener,
124     /// The file system path of the host-side Unix socket. This is used to figure out the path
125     /// to Unix sockets listening on specific ports. I.e. "\<this path>_\<port number>".
126     host_sock_path: String,
127     /// The nested epoll File, used to register epoll listeners.
128     epoll_file: File,
129     /// A hash set used to keep track of used host-side (local) ports, in order to assign local
130     /// ports to host-initiated connections.
131     local_port_set: HashSet<u32>,
132     /// The last used host-side port.
133     local_port_last: u32,
134 }
135 
136 impl VsockChannel for VsockMuxer {
137     /// Deliver a vsock packet to the guest vsock driver.
138     ///
139     /// Returns:
140     /// - `Ok(())`: `pkt` has been successfully filled in; or
141     /// - `Err(VsockError::NoData)`: there was no available data with which to fill in the
142     ///   packet.
143     ///
144     fn recv_pkt(&mut self, pkt: &mut VsockPacket) -> VsockResult<()> {
145         // We'll look for instructions on how to build the RX packet in the RX queue. If the
146         // queue is empty, that doesn't necessarily mean we don't have any pending RX, since
147         // the queue might be out-of-sync. If that's the case, we'll attempt to sync it first,
148         // and then try to pop something out again.
149         if self.rxq.is_empty() && !self.rxq.is_synced() {
150             self.rxq = MuxerRxQ::from_conn_map(&self.conn_map);
151         }
152 
153         while let Some(rx) = self.rxq.peek() {
154             let res = match rx {
155                 // We need to build an RST packet, going from `local_port` to `peer_port`.
156                 MuxerRx::RstPkt {
157                     local_port,
158                     peer_port,
159                 } => {
160                     pkt.set_op(uapi::VSOCK_OP_RST)
161                         .set_src_cid(uapi::VSOCK_HOST_CID)
162                         .set_dst_cid(self.cid)
163                         .set_src_port(local_port)
164                         .set_dst_port(peer_port)
165                         .set_len(0)
166                         .set_type(uapi::VSOCK_TYPE_STREAM)
167                         .set_flags(0)
168                         .set_buf_alloc(0)
169                         .set_fwd_cnt(0);
170                     self.rxq.pop().unwrap();
171                     return Ok(());
172                 }
173 
174                 // We'll defer building the packet to this connection, since it has something
175                 // to say.
176                 MuxerRx::ConnRx(key) => {
177                     let mut conn_res = Err(VsockError::NoData);
178                     let mut do_pop = true;
179                     self.apply_conn_mutation(key, |conn| {
180                         conn_res = conn.recv_pkt(pkt);
181                         do_pop = !conn.has_pending_rx();
182                     });
183                     if do_pop {
184                         self.rxq.pop().unwrap();
185                     }
186                     conn_res
187                 }
188             };
189 
190             if res.is_ok() {
191                 // Inspect traffic, looking for RST packets, since that means we have to
192                 // terminate and remove this connection from the active connection pool.
193                 //
194                 if pkt.op() == uapi::VSOCK_OP_RST {
195                     self.remove_connection(ConnMapKey {
196                         local_port: pkt.src_port(),
197                         peer_port: pkt.dst_port(),
198                     });
199                 }
200 
201                 debug!("vsock muxer: RX pkt: {:?}", pkt.hdr());
202                 return Ok(());
203             }
204         }
205 
206         Err(VsockError::NoData)
207     }
208 
209     /// Deliver a guest-generated packet to its destination in the vsock backend.
210     ///
211     /// This absorbs unexpected packets, handles RSTs (by dropping connections), and forwards
212     /// all the rest to their owning `MuxerConnection`.
213     ///
214     /// Returns:
215     /// always `Ok(())` - the packet has been consumed, and its virtio TX buffers can be
216     /// returned to the guest vsock driver.
217     ///
218     fn send_pkt(&mut self, pkt: &VsockPacket) -> VsockResult<()> {
219         let conn_key = ConnMapKey {
220             local_port: pkt.dst_port(),
221             peer_port: pkt.src_port(),
222         };
223 
224         debug!(
225             "vsock: muxer.send[rxq.len={}]: {:?}",
226             self.rxq.len(),
227             pkt.hdr()
228         );
229 
230         // If this packet has an unsupported type (!=stream), we must send back an RST.
231         //
232         if pkt.type_() != uapi::VSOCK_TYPE_STREAM {
233             self.enq_rst(pkt.dst_port(), pkt.src_port());
234             return Ok(());
235         }
236 
237         // We don't know how to handle packets addressed to other CIDs. We only handle the host
238         // part of the guest - host communication here.
239         if pkt.dst_cid() != uapi::VSOCK_HOST_CID {
240             info!(
241                 "vsock: dropping guest packet for unknown CID: {:?}",
242                 pkt.hdr()
243             );
244             return Ok(());
245         }
246 
247         if !self.conn_map.contains_key(&conn_key) {
248             // This packet can't be routed to any active connection (based on its src and dst
249             // ports).  The only orphan / unroutable packets we know how to handle are
250             // connection requests.
251             if pkt.op() == uapi::VSOCK_OP_REQUEST {
252                 // Oh, this is a connection request!
253                 self.handle_peer_request_pkt(pkt);
254             } else {
255                 // Send back an RST, to let the drive know we weren't expecting this packet.
256                 self.enq_rst(pkt.dst_port(), pkt.src_port());
257             }
258             return Ok(());
259         }
260 
261         // Right, we know where to send this packet, then (to `conn_key`).
262         // However, if this is an RST, we have to forcefully terminate the connection, so
263         // there's no point in forwarding it the packet.
264         if pkt.op() == uapi::VSOCK_OP_RST {
265             self.remove_connection(conn_key);
266             return Ok(());
267         }
268 
269         // Alright, everything looks in order - forward this packet to its owning connection.
270         let mut res: VsockResult<()> = Ok(());
271         self.apply_conn_mutation(conn_key, |conn| {
272             res = conn.send_pkt(pkt);
273         });
274 
275         res
276     }
277 
278     /// Check if the muxer has any pending RX data, with which to fill a guest-provided RX
279     /// buffer.
280     ///
281     fn has_pending_rx(&self) -> bool {
282         !self.rxq.is_empty() || !self.rxq.is_synced()
283     }
284 }
285 
286 impl VsockEpollListener for VsockMuxer {
287     /// Get the FD to be registered for polling upstream (in the main VMM epoll loop, in this
288     /// case).
289     ///
290     /// This will be the muxer's nested epoll FD.
291     ///
292     fn get_polled_fd(&self) -> RawFd {
293         self.epoll_file.as_raw_fd()
294     }
295 
296     /// Get the epoll events to be polled upstream.
297     ///
298     /// Since the polled FD is a nested epoll FD, we're only interested in EPOLLIN events (i.e.
299     /// some event occurred on one of the FDs registered under our epoll FD).
300     ///
301     fn get_polled_evset(&self) -> epoll::Events {
302         epoll::Events::EPOLLIN
303     }
304 
305     /// Notify the muxer about a pending event having occurred under its nested epoll FD.
306     ///
307     fn notify(&mut self, _: epoll::Events) {
308         debug!("vsock: muxer received kick");
309 
310         let mut epoll_events = vec![epoll::Event::new(epoll::Events::empty(), 0); 32];
311         'epoll: loop {
312             match epoll::wait(self.epoll_file.as_raw_fd(), 0, epoll_events.as_mut_slice()) {
313                 Ok(ev_cnt) => {
314                     for evt in epoll_events.iter().take(ev_cnt) {
315                         self.handle_event(
316                             evt.data as RawFd,
317                             // It's ok to unwrap here, since the `evt.events` is filled
318                             // in by `epoll::wait()`, and therefore contains only valid epoll
319                             // flags.
320                             epoll::Events::from_bits(evt.events).unwrap(),
321                         );
322                     }
323                 }
324                 Err(e) => {
325                     if e.kind() == io::ErrorKind::Interrupted {
326                         // It's well defined from the epoll_wait() syscall
327                         // documentation that the epoll loop can be interrupted
328                         // before any of the requested events occurred or the
329                         // timeout expired. In both those cases, epoll_wait()
330                         // returns an error of type EINTR, but this should not
331                         // be considered as a regular error. Instead it is more
332                         // appropriate to retry, by calling into epoll_wait().
333                         continue;
334                     }
335                     warn!("vsock: failed to consume muxer epoll event: {}", e);
336                 }
337             }
338             break 'epoll;
339         }
340     }
341 }
342 
343 impl VsockBackend for VsockMuxer {}
344 
345 impl VsockMuxer {
346     /// Muxer constructor.
347     ///
348     pub fn new(cid: u32, host_sock_path: String) -> Result<Self> {
349         // Create the nested epoll FD. This FD will be added to the VMM `EpollContext`, at
350         // device activation time.
351         let epoll_fd = epoll::create(true).map_err(Error::EpollFdCreate)?;
352         // Use 'File' to enforce closing on 'epoll_fd'
353         // SAFETY: epoll_fd is a valid fd
354         let epoll_file = unsafe { File::from_raw_fd(epoll_fd) };
355 
356         // Open/bind/listen on the host Unix socket, so we can accept host-initiated
357         // connections.
358         let host_sock = UnixListener::bind(&host_sock_path)
359             .and_then(|sock| sock.set_nonblocking(true).map(|_| sock))
360             .map_err(Error::UnixBind)?;
361 
362         let mut muxer = Self {
363             cid: cid.into(),
364             host_sock,
365             host_sock_path,
366             epoll_file,
367             rxq: MuxerRxQ::new(),
368             conn_map: HashMap::with_capacity(defs::MAX_CONNECTIONS),
369             listener_map: HashMap::with_capacity(defs::MAX_CONNECTIONS + 1),
370             partial_command_map: Default::default(),
371             killq: MuxerKillQ::new(),
372             local_port_last: (1u32 << 30) - 1,
373             local_port_set: HashSet::with_capacity(defs::MAX_CONNECTIONS),
374         };
375 
376         muxer.add_listener(muxer.host_sock.as_raw_fd(), EpollListener::HostSock)?;
377         Ok(muxer)
378     }
379 
380     /// Handle/dispatch an epoll event to its listener.
381     ///
382     fn handle_event(&mut self, fd: RawFd, event_set: epoll::Events) {
383         debug!(
384             "vsock: muxer processing event: fd={}, event_set={:?}",
385             fd, event_set
386         );
387 
388         match self.listener_map.get_mut(&fd) {
389             // This event needs to be forwarded to a `MuxerConnection` that is listening for
390             // it.
391             //
392             Some(EpollListener::Connection { key, evset: _ }) => {
393                 let key_copy = *key;
394                 // The handling of this event will most probably mutate the state of the
395                 // receiving connection. We'll need to check for new pending RX, event set
396                 // mutation, and all that, so we're wrapping the event delivery inside those
397                 // checks.
398                 self.apply_conn_mutation(key_copy, |conn| {
399                     conn.notify(event_set);
400                 });
401             }
402 
403             // A new host-initiated connection is ready to be accepted.
404             //
405             Some(EpollListener::HostSock) => {
406                 if self.conn_map.len() == defs::MAX_CONNECTIONS {
407                     // If we're already maxed-out on connections, we'll just accept and
408                     // immediately discard this potentially new one.
409                     warn!("vsock: connection limit reached; refusing new host connection");
410                     self.host_sock.accept().map(|_| 0).unwrap_or(0);
411                     return;
412                 }
413                 self.host_sock
414                     .accept()
415                     .map_err(Error::UnixAccept)
416                     .and_then(|(stream, _)| {
417                         stream
418                             .set_nonblocking(true)
419                             .map(|_| stream)
420                             .map_err(Error::UnixAccept)
421                     })
422                     .and_then(|stream| {
423                         // Before forwarding this connection to a listening AF_VSOCK socket on
424                         // the guest side, we need to know the destination port. We'll read
425                         // that port from a "connect" command received on this socket, so the
426                         // next step is to ask to be notified the moment we can read from it.
427                         self.add_listener(stream.as_raw_fd(), EpollListener::LocalStream(stream))
428                     })
429                     .unwrap_or_else(|err| {
430                         warn!("vsock: unable to accept local connection: {:?}", err);
431                     });
432             }
433 
434             // Data is ready to be read from a host-initiated connection. That would be the
435             // "connect" command that we're expecting.
436             Some(EpollListener::LocalStream(_)) => {
437                 if let Some(EpollListener::LocalStream(stream)) = self.listener_map.get_mut(&fd) {
438                     let port = Self::read_local_stream_port(&mut self.partial_command_map, stream);
439 
440                     if let Err(Error::UnixRead(ref e)) = port {
441                         if e.kind() == ErrorKind::WouldBlock {
442                             return;
443                         }
444                     }
445 
446                     let stream = match self.remove_listener(fd) {
447                         Some(EpollListener::LocalStream(s)) => s,
448                         _ => unreachable!(),
449                     };
450 
451                     port.and_then(|peer_port| {
452                         let local_port = self.allocate_local_port();
453 
454                         self.add_connection(
455                             ConnMapKey {
456                                 local_port,
457                                 peer_port,
458                             },
459                             MuxerConnection::new_local_init(
460                                 stream,
461                                 uapi::VSOCK_HOST_CID,
462                                 self.cid,
463                                 local_port,
464                                 peer_port,
465                             ),
466                         )
467                     })
468                     .unwrap_or_else(|err| {
469                         info!("vsock: error adding local-init connection: {:?}", err);
470                     })
471                 }
472             }
473 
474             _ => {
475                 info!(
476                     "vsock: unexpected event: fd={:?}, event_set={:?}",
477                     fd, event_set
478                 );
479             }
480         }
481     }
482 
483     /// Parse a host "connect" command, and extract the destination vsock port.
484     ///
485     fn read_local_stream_port(
486         partial_command_map: &mut HashMap<RawFd, PartiallyReadCommand>,
487         stream: &mut UnixStream,
488     ) -> Result<u32> {
489         let command = partial_command_map.entry(stream.as_raw_fd()).or_default();
490 
491         // This is the minimum number of bytes that we should be able to read, when parsing a
492         // valid connection request. I.e. `b"connect 0\n".len()`.
493         const MIN_COMMAND_LEN: usize = 10;
494 
495         // Bring in the minimum number of bytes that we should be able to read.
496         stream
497             .read_exact(&mut command.buf[command.len..MIN_COMMAND_LEN])
498             .map_err(Error::UnixRead)?;
499         command.len = MIN_COMMAND_LEN;
500 
501         // Now, finish reading the destination port number, by bringing in one byte at a time,
502         // until we reach an EOL terminator (or our buffer space runs out).  Yeah, not
503         // particularly proud of this approach, but it will have to do for now.
504         while command.buf[command.len - 1] != b'\n' && command.len < command.buf.len() {
505             command.len += stream
506                 .read(&mut command.buf[command.len..=command.len])
507                 .map_err(Error::UnixRead)?;
508         }
509 
510         let command = partial_command_map.remove(&stream.as_raw_fd()).unwrap();
511 
512         let mut word_iter = std::str::from_utf8(&command.buf[..command.len])
513             .map_err(Error::ConvertFromUtf8)?
514             .split_whitespace();
515 
516         word_iter
517             .next()
518             .ok_or(Error::InvalidPortRequest)
519             .and_then(|word| {
520                 if word.to_lowercase() == "connect" {
521                     Ok(())
522                 } else {
523                     Err(Error::InvalidPortRequest)
524                 }
525             })
526             .and_then(|_| word_iter.next().ok_or(Error::InvalidPortRequest))
527             .and_then(|word| word.parse::<u32>().map_err(Error::ParseInteger))
528             .map_err(|e| Error::ReadStreamPort(Box::new(e)))
529     }
530 
531     /// Add a new connection to the active connection pool.
532     ///
533     fn add_connection(&mut self, key: ConnMapKey, conn: MuxerConnection) -> Result<()> {
534         // We might need to make room for this new connection, so let's sweep the kill queue
535         // first.  It's fine to do this here because:
536         // - unless the kill queue is out of sync, this is a pretty inexpensive operation; and
537         // - we are under no pressure to respect any accurate timing for connection
538         //   termination.
539         self.sweep_killq();
540 
541         if self.conn_map.len() >= defs::MAX_CONNECTIONS {
542             info!(
543                 "vsock: muxer connection limit reached ({})",
544                 defs::MAX_CONNECTIONS
545             );
546             return Err(Error::TooManyConnections);
547         }
548 
549         self.add_listener(
550             conn.get_polled_fd(),
551             EpollListener::Connection {
552                 key,
553                 evset: conn.get_polled_evset(),
554             },
555         )
556         .map(|_| {
557             if conn.has_pending_rx() {
558                 // We can safely ignore any error in adding a connection RX indication. Worst
559                 // case scenario, the RX queue will get desynchronized, but we'll handle that
560                 // the next time we need to yield an RX packet.
561                 self.rxq.push(MuxerRx::ConnRx(key));
562             }
563             self.conn_map.insert(key, conn);
564         })
565     }
566 
567     /// Remove a connection from the active connection poll.
568     ///
569     fn remove_connection(&mut self, key: ConnMapKey) {
570         if let Some(conn) = self.conn_map.remove(&key) {
571             self.remove_listener(conn.get_polled_fd());
572         }
573         self.free_local_port(key.local_port);
574     }
575 
576     /// Schedule a connection for immediate termination.
577     /// I.e. as soon as we can also let our peer know we're dropping the connection, by sending
578     /// it an RST packet.
579     ///
580     fn kill_connection(&mut self, key: ConnMapKey) {
581         let mut had_rx = false;
582         self.conn_map.entry(key).and_modify(|conn| {
583             had_rx = conn.has_pending_rx();
584             conn.kill();
585         });
586         // This connection will now have an RST packet to yield, so we need to add it to the RX
587         // queue.  However, there's no point in doing that if it was already in the queue.
588         if !had_rx {
589             // We can safely ignore any error in adding a connection RX indication. Worst case
590             // scenario, the RX queue will get desynchronized, but we'll handle that the next
591             // time we need to yield an RX packet.
592             self.rxq.push(MuxerRx::ConnRx(key));
593         }
594     }
595 
596     /// Register a new epoll listener under the muxer's nested epoll FD.
597     ///
598     fn add_listener(&mut self, fd: RawFd, listener: EpollListener) -> Result<()> {
599         let evset = match listener {
600             EpollListener::Connection { evset, .. } => evset,
601             EpollListener::LocalStream(_) => epoll::Events::EPOLLIN,
602             EpollListener::HostSock => epoll::Events::EPOLLIN,
603         };
604 
605         epoll::ctl(
606             self.epoll_file.as_raw_fd(),
607             epoll::ControlOptions::EPOLL_CTL_ADD,
608             fd,
609             epoll::Event::new(evset, fd as u64),
610         )
611         .map(|_| {
612             self.listener_map.insert(fd, listener);
613         })
614         .map_err(Error::EpollAdd)?;
615 
616         Ok(())
617     }
618 
619     /// Remove (and return) a previously registered epoll listener.
620     ///
621     fn remove_listener(&mut self, fd: RawFd) -> Option<EpollListener> {
622         let maybe_listener = self.listener_map.remove(&fd);
623 
624         if maybe_listener.is_some() {
625             epoll::ctl(
626                 self.epoll_file.as_raw_fd(),
627                 epoll::ControlOptions::EPOLL_CTL_DEL,
628                 fd,
629                 epoll::Event::new(epoll::Events::empty(), 0),
630             )
631             .unwrap_or_else(|err| {
632                 warn!(
633                     "vosck muxer: error removing epoll listener for fd {:?}: {:?}",
634                     fd, err
635                 );
636             });
637         }
638 
639         maybe_listener
640     }
641 
642     /// Allocate a host-side port to be assigned to a new host-initiated connection.
643     ///
644     ///
645     fn allocate_local_port(&mut self) -> u32 {
646         // TODO: this doesn't seem very space-efficient.
647         // Maybe rewrite this to limit port range and use a bitmap?
648         //
649 
650         loop {
651             self.local_port_last = (self.local_port_last + 1) & !(1 << 31) | (1 << 30);
652             if self.local_port_set.insert(self.local_port_last) {
653                 break;
654             }
655         }
656         self.local_port_last
657     }
658 
659     /// Mark a previously used host-side port as free.
660     ///
661     fn free_local_port(&mut self, port: u32) {
662         self.local_port_set.remove(&port);
663     }
664 
665     /// Handle a new connection request coming from our peer (the guest vsock driver).
666     ///
667     /// This will attempt to connect to a host-side Unix socket, expected to be listening at
668     /// the file system path corresponding to the destination port. If successful, a new
669     /// connection object will be created and added to the connection pool. On failure, a new
670     /// RST packet will be scheduled for delivery to the guest.
671     ///
672     fn handle_peer_request_pkt(&mut self, pkt: &VsockPacket) {
673         let port_path = format!("{}_{}", self.host_sock_path, pkt.dst_port());
674 
675         UnixStream::connect(port_path)
676             .and_then(|stream| stream.set_nonblocking(true).map(|_| stream))
677             .map_err(Error::UnixConnect)
678             .and_then(|stream| {
679                 self.add_connection(
680                     ConnMapKey {
681                         local_port: pkt.dst_port(),
682                         peer_port: pkt.src_port(),
683                     },
684                     MuxerConnection::new_peer_init(
685                         stream,
686                         uapi::VSOCK_HOST_CID,
687                         self.cid,
688                         pkt.dst_port(),
689                         pkt.src_port(),
690                         pkt.buf_alloc(),
691                     ),
692                 )
693             })
694             .unwrap_or_else(|_| self.enq_rst(pkt.dst_port(), pkt.src_port()));
695     }
696 
697     /// Perform an action that might mutate a connection's state.
698     ///
699     /// This is used as shorthand for repetitive tasks that need to be performed after a
700     /// connection object mutates. E.g.
701     /// - update the connection's epoll listener;
702     /// - schedule the connection to be queried for RX data;
703     /// - kill the connection if an unrecoverable error occurs.
704     ///
705     fn apply_conn_mutation<F>(&mut self, key: ConnMapKey, mut_fn: F)
706     where
707         F: FnOnce(&mut MuxerConnection),
708     {
709         if let Some(conn) = self.conn_map.get_mut(&key) {
710             let had_rx = conn.has_pending_rx();
711             let was_expiring = conn.will_expire();
712             let prev_state = conn.state();
713 
714             mut_fn(conn);
715 
716             // If this is a host-initiated connection that has just become established, we'll have
717             // to send an ack message to the host end.
718             if prev_state == ConnState::LocalInit && conn.state() == ConnState::Established {
719                 let msg = format!("OK {}\n", key.local_port);
720                 match conn.send_bytes_raw(msg.as_bytes()) {
721                     Ok(written) if written == msg.len() => (),
722                     Ok(_) => {
723                         // If we can't write a dozen bytes to a pristine connection something
724                         // must be really wrong. Killing it.
725                         conn.kill();
726                         warn!("vsock: unable to fully write connection ack msg.");
727                     }
728                     Err(err) => {
729                         conn.kill();
730                         warn!("vsock: unable to ack host connection: {:?}", err);
731                     }
732                 };
733             }
734 
735             // If the connection wasn't previously scheduled for RX, add it to our RX queue.
736             if !had_rx && conn.has_pending_rx() {
737                 self.rxq.push(MuxerRx::ConnRx(key));
738             }
739 
740             // If the connection wasn't previously scheduled for termination, add it to the
741             // kill queue.
742             if !was_expiring && conn.will_expire() {
743                 // It's safe to unwrap here, since `conn.will_expire()` already guaranteed that
744                 // an `conn.expiry` is available.
745                 self.killq.push(key, conn.expiry().unwrap());
746             }
747 
748             let fd = conn.get_polled_fd();
749             let new_evset = conn.get_polled_evset();
750             if new_evset.is_empty() {
751                 // If the connection no longer needs epoll notifications, remove its listener
752                 // from our list.
753                 self.remove_listener(fd);
754                 return;
755             }
756             if let Some(EpollListener::Connection { evset, .. }) = self.listener_map.get_mut(&fd) {
757                 if *evset != new_evset {
758                     // If the set of events that the connection is interested in has changed,
759                     // we need to update its epoll listener.
760                     debug!(
761                         "vsock: updating listener for (lp={}, pp={}): old={:?}, new={:?}",
762                         key.local_port, key.peer_port, *evset, new_evset
763                     );
764 
765                     *evset = new_evset;
766                     epoll::ctl(
767                         self.epoll_file.as_raw_fd(),
768                         epoll::ControlOptions::EPOLL_CTL_MOD,
769                         fd,
770                         epoll::Event::new(new_evset, fd as u64),
771                     )
772                     .unwrap_or_else(|err| {
773                         // This really shouldn't happen, like, ever. However, "famous last
774                         // words" and all that, so let's just kill it with fire, and walk away.
775                         self.kill_connection(key);
776                         error!(
777                             "vsock: error updating epoll listener for (lp={}, pp={}): {:?}",
778                             key.local_port, key.peer_port, err
779                         );
780                     });
781                 }
782             } else {
783                 // The connection had previously asked to be removed from the listener map (by
784                 // returning an empty event set via `get_polled_fd()`), but now wants back in.
785                 self.add_listener(
786                     fd,
787                     EpollListener::Connection {
788                         key,
789                         evset: new_evset,
790                     },
791                 )
792                 .unwrap_or_else(|err| {
793                     self.kill_connection(key);
794                     error!(
795                         "vsock: error updating epoll listener for (lp={}, pp={}): {:?}",
796                         key.local_port, key.peer_port, err
797                     );
798                 });
799             }
800         }
801     }
802 
803     /// Check if any connections have timed out, and if so, schedule them for immediate
804     /// termination.
805     ///
806     fn sweep_killq(&mut self) {
807         while let Some(key) = self.killq.pop() {
808             // Connections don't get removed from the kill queue when their kill timer is
809             // disarmed, since that would be a costly operation. This means we must check if
810             // the connection has indeed expired, prior to killing it.
811             let mut kill = false;
812             self.conn_map
813                 .entry(key)
814                 .and_modify(|conn| kill = conn.has_expired());
815             if kill {
816                 self.kill_connection(key);
817             }
818         }
819 
820         if self.killq.is_empty() && !self.killq.is_synced() {
821             self.killq = MuxerKillQ::from_conn_map(&self.conn_map);
822             // If we've just re-created the kill queue, we can sweep it again; maybe there's
823             // more to kill.
824             self.sweep_killq();
825         }
826     }
827 
828     /// Enqueue an RST packet into `self.rxq`.
829     ///
830     /// Enqueue errors aren't propagated up the call chain, since there is nothing we can do to
831     /// handle them. We do, however, log a warning, since not being able to enqueue an RST
832     /// packet means we have to drop it, which is not normal operation.
833     ///
834     fn enq_rst(&mut self, local_port: u32, peer_port: u32) {
835         let pushed = self.rxq.push(MuxerRx::RstPkt {
836             local_port,
837             peer_port,
838         });
839         if !pushed {
840             warn!(
841                 "vsock: muxer.rxq full; dropping RST packet for lp={}, pp={}",
842                 local_port, peer_port
843             );
844         }
845     }
846 }
847 
848 #[cfg(test)]
849 mod tests {
850     use std::io::Write;
851     use std::path::{Path, PathBuf};
852 
853     use virtio_queue::QueueOwnedT;
854 
855     use super::super::super::csm::defs as csm_defs;
856     use super::super::super::tests::TestContext as VsockTestContext;
857     use super::*;
858 
859     const PEER_CID: u32 = 3;
860     const PEER_BUF_ALLOC: u32 = 64 * 1024;
861 
862     struct MuxerTestContext {
863         _vsock_test_ctx: VsockTestContext,
864         pkt: VsockPacket,
865         muxer: VsockMuxer,
866     }
867 
868     impl Drop for MuxerTestContext {
869         fn drop(&mut self) {
870             std::fs::remove_file(self.muxer.host_sock_path.as_str()).unwrap();
871         }
872     }
873 
874     impl MuxerTestContext {
875         fn new(name: &str) -> Self {
876             let vsock_test_ctx = VsockTestContext::new();
877             let mut handler_ctx = vsock_test_ctx.create_epoll_handler_context();
878             let pkt = VsockPacket::from_rx_virtq_head(
879                 &mut handler_ctx.handler.queues[0]
880                     .iter(&vsock_test_ctx.mem)
881                     .unwrap()
882                     .next()
883                     .unwrap(),
884                 None,
885             )
886             .unwrap();
887             let uds_path = format!("test_vsock_{name}.sock");
888             let muxer = VsockMuxer::new(PEER_CID, uds_path).unwrap();
889 
890             Self {
891                 _vsock_test_ctx: vsock_test_ctx,
892                 pkt,
893                 muxer,
894             }
895         }
896 
897         fn init_pkt(&mut self, local_port: u32, peer_port: u32, op: u16) -> &mut VsockPacket {
898             for b in self.pkt.hdr_mut() {
899                 *b = 0;
900             }
901             self.pkt
902                 .set_type(uapi::VSOCK_TYPE_STREAM)
903                 .set_src_cid(PEER_CID.into())
904                 .set_dst_cid(uapi::VSOCK_HOST_CID)
905                 .set_src_port(peer_port)
906                 .set_dst_port(local_port)
907                 .set_op(op)
908                 .set_buf_alloc(PEER_BUF_ALLOC)
909         }
910 
911         fn init_data_pkt(
912             &mut self,
913             local_port: u32,
914             peer_port: u32,
915             data: &[u8],
916         ) -> &mut VsockPacket {
917             assert!(data.len() <= self.pkt.buf().unwrap().len());
918             self.init_pkt(local_port, peer_port, uapi::VSOCK_OP_RW)
919                 .set_len(data.len() as u32);
920             self.pkt.buf_mut().unwrap()[..data.len()].copy_from_slice(data);
921             &mut self.pkt
922         }
923 
924         fn send(&mut self) {
925             self.muxer.send_pkt(&self.pkt).unwrap();
926         }
927 
928         fn recv(&mut self) {
929             self.muxer.recv_pkt(&mut self.pkt).unwrap();
930         }
931 
932         fn notify_muxer(&mut self) {
933             self.muxer.notify(epoll::Events::EPOLLIN);
934         }
935 
936         fn count_epoll_listeners(&self) -> (usize, usize) {
937             let mut local_lsn_count = 0usize;
938             let mut conn_lsn_count = 0usize;
939             for key in self.muxer.listener_map.values() {
940                 match key {
941                     EpollListener::LocalStream(_) => local_lsn_count += 1,
942                     EpollListener::Connection { .. } => conn_lsn_count += 1,
943                     _ => (),
944                 };
945             }
946             (local_lsn_count, conn_lsn_count)
947         }
948 
949         fn create_local_listener(&self, port: u32) -> LocalListener {
950             LocalListener::new(format!("{}_{}", self.muxer.host_sock_path, port))
951         }
952 
953         fn local_connect(&mut self, peer_port: u32) -> (UnixStream, u32) {
954             let (init_local_lsn_count, init_conn_lsn_count) = self.count_epoll_listeners();
955 
956             let mut stream = UnixStream::connect(self.muxer.host_sock_path.clone()).unwrap();
957             stream.set_nonblocking(true).unwrap();
958             // The muxer would now get notified of a new connection having arrived at its Unix
959             // socket, so it can accept it.
960             self.notify_muxer();
961 
962             // Just after having accepted a new local connection, the muxer should've added a new
963             // `LocalStream` listener to its `listener_map`.
964             let (local_lsn_count, _) = self.count_epoll_listeners();
965             assert_eq!(local_lsn_count, init_local_lsn_count + 1);
966 
967             let buf = format!("CONNECT {peer_port}\n");
968             stream.write_all(buf.as_bytes()).unwrap();
969             // The muxer would now get notified that data is available for reading from the locally
970             // initiated connection.
971             self.notify_muxer();
972 
973             // Successfully reading and parsing the connection request should have removed the
974             // LocalStream epoll listener and added a Connection epoll listener.
975             let (local_lsn_count, conn_lsn_count) = self.count_epoll_listeners();
976             assert_eq!(local_lsn_count, init_local_lsn_count);
977             assert_eq!(conn_lsn_count, init_conn_lsn_count + 1);
978 
979             // A LocalInit connection should've been added to the muxer connection map.  A new
980             // local port should also have been allocated for the new LocalInit connection.
981             let local_port = self.muxer.local_port_last;
982             let key = ConnMapKey {
983                 local_port,
984                 peer_port,
985             };
986             assert!(self.muxer.conn_map.contains_key(&key));
987             assert!(self.muxer.local_port_set.contains(&local_port));
988 
989             // A connection request for the peer should now be available from the muxer.
990             assert!(self.muxer.has_pending_rx());
991             self.recv();
992             assert_eq!(self.pkt.op(), uapi::VSOCK_OP_REQUEST);
993             assert_eq!(self.pkt.dst_port(), peer_port);
994             assert_eq!(self.pkt.src_port(), local_port);
995 
996             self.init_pkt(local_port, peer_port, uapi::VSOCK_OP_RESPONSE);
997             self.send();
998 
999             let mut buf = [0u8; 32];
1000             let len = stream.read(&mut buf[..]).unwrap();
1001             assert_eq!(&buf[..len], format!("OK {local_port}\n").as_bytes());
1002 
1003             (stream, local_port)
1004         }
1005     }
1006 
1007     struct LocalListener {
1008         path: PathBuf,
1009         sock: UnixListener,
1010     }
1011     impl LocalListener {
1012         fn new<P: AsRef<Path> + Clone>(path: P) -> Self {
1013             let path_buf = path.as_ref().to_path_buf();
1014             let sock = UnixListener::bind(path).unwrap();
1015             sock.set_nonblocking(true).unwrap();
1016             Self {
1017                 path: path_buf,
1018                 sock,
1019             }
1020         }
1021         fn accept(&mut self) -> UnixStream {
1022             let (stream, _) = self.sock.accept().unwrap();
1023             stream.set_nonblocking(true).unwrap();
1024             stream
1025         }
1026     }
1027     impl Drop for LocalListener {
1028         fn drop(&mut self) {
1029             std::fs::remove_file(&self.path).unwrap();
1030         }
1031     }
1032 
1033     #[test]
1034     fn test_muxer_epoll_listener() {
1035         let ctx = MuxerTestContext::new("muxer_epoll_listener");
1036         assert_eq!(ctx.muxer.get_polled_fd(), ctx.muxer.epoll_file.as_raw_fd());
1037         assert_eq!(ctx.muxer.get_polled_evset(), epoll::Events::EPOLLIN);
1038     }
1039 
1040     #[test]
1041     fn test_bad_peer_pkt() {
1042         const LOCAL_PORT: u32 = 1026;
1043         const PEER_PORT: u32 = 1025;
1044         const SOCK_DGRAM: u16 = 2;
1045 
1046         let mut ctx = MuxerTestContext::new("bad_peer_pkt");
1047         ctx.init_pkt(LOCAL_PORT, PEER_PORT, uapi::VSOCK_OP_REQUEST)
1048             .set_type(SOCK_DGRAM);
1049         ctx.send();
1050 
1051         // The guest sent a SOCK_DGRAM packet. Per the vsock spec, we need to reply with an RST
1052         // packet, since vsock only supports stream sockets.
1053         assert!(ctx.muxer.has_pending_rx());
1054         ctx.recv();
1055         assert_eq!(ctx.pkt.op(), uapi::VSOCK_OP_RST);
1056         assert_eq!(ctx.pkt.src_cid(), uapi::VSOCK_HOST_CID);
1057         assert_eq!(ctx.pkt.dst_cid(), PEER_CID as u64);
1058         assert_eq!(ctx.pkt.src_port(), LOCAL_PORT);
1059         assert_eq!(ctx.pkt.dst_port(), PEER_PORT);
1060 
1061         // Any orphan (i.e. without a connection), non-RST packet, should be replied to with an
1062         // RST.
1063         let bad_ops = [
1064             uapi::VSOCK_OP_RESPONSE,
1065             uapi::VSOCK_OP_CREDIT_REQUEST,
1066             uapi::VSOCK_OP_CREDIT_UPDATE,
1067             uapi::VSOCK_OP_SHUTDOWN,
1068             uapi::VSOCK_OP_RW,
1069         ];
1070         for op in bad_ops.iter() {
1071             ctx.init_pkt(LOCAL_PORT, PEER_PORT, *op);
1072             ctx.send();
1073             assert!(ctx.muxer.has_pending_rx());
1074             ctx.recv();
1075             assert_eq!(ctx.pkt.op(), uapi::VSOCK_OP_RST);
1076             assert_eq!(ctx.pkt.src_port(), LOCAL_PORT);
1077             assert_eq!(ctx.pkt.dst_port(), PEER_PORT);
1078         }
1079 
1080         // Any packet addressed to anything other than VSOCK_VHOST_CID should get dropped.
1081         assert!(!ctx.muxer.has_pending_rx());
1082         ctx.init_pkt(LOCAL_PORT, PEER_PORT, uapi::VSOCK_OP_REQUEST)
1083             .set_dst_cid(uapi::VSOCK_HOST_CID + 1);
1084         ctx.send();
1085         assert!(!ctx.muxer.has_pending_rx());
1086     }
1087 
1088     #[test]
1089     fn test_peer_connection() {
1090         const LOCAL_PORT: u32 = 1026;
1091         const PEER_PORT: u32 = 1025;
1092 
1093         let mut ctx = MuxerTestContext::new("peer_connection");
1094 
1095         // Test peer connection refused.
1096         ctx.init_pkt(LOCAL_PORT, PEER_PORT, uapi::VSOCK_OP_REQUEST);
1097         ctx.send();
1098         ctx.recv();
1099         assert_eq!(ctx.pkt.op(), uapi::VSOCK_OP_RST);
1100         assert_eq!(ctx.pkt.len(), 0);
1101         assert_eq!(ctx.pkt.src_cid(), uapi::VSOCK_HOST_CID);
1102         assert_eq!(ctx.pkt.dst_cid(), PEER_CID as u64);
1103         assert_eq!(ctx.pkt.src_port(), LOCAL_PORT);
1104         assert_eq!(ctx.pkt.dst_port(), PEER_PORT);
1105 
1106         // Test peer connection accepted.
1107         let mut listener = ctx.create_local_listener(LOCAL_PORT);
1108         ctx.init_pkt(LOCAL_PORT, PEER_PORT, uapi::VSOCK_OP_REQUEST);
1109         ctx.send();
1110         assert_eq!(ctx.muxer.conn_map.len(), 1);
1111         let mut stream = listener.accept();
1112         ctx.recv();
1113         assert_eq!(ctx.pkt.op(), uapi::VSOCK_OP_RESPONSE);
1114         assert_eq!(ctx.pkt.len(), 0);
1115         assert_eq!(ctx.pkt.src_cid(), uapi::VSOCK_HOST_CID);
1116         assert_eq!(ctx.pkt.dst_cid(), PEER_CID as u64);
1117         assert_eq!(ctx.pkt.src_port(), LOCAL_PORT);
1118         assert_eq!(ctx.pkt.dst_port(), PEER_PORT);
1119         let key = ConnMapKey {
1120             local_port: LOCAL_PORT,
1121             peer_port: PEER_PORT,
1122         };
1123         assert!(ctx.muxer.conn_map.contains_key(&key));
1124 
1125         // Test guest -> host data flow.
1126         let data = [1, 2, 3, 4];
1127         ctx.init_data_pkt(LOCAL_PORT, PEER_PORT, &data);
1128         ctx.send();
1129         let mut buf = vec![0; data.len()];
1130         stream.read_exact(buf.as_mut_slice()).unwrap();
1131         assert_eq!(buf.as_slice(), data);
1132 
1133         // Test host -> guest data flow.
1134         let data = [5u8, 6, 7, 8];
1135         stream.write_all(&data).unwrap();
1136 
1137         // When data is available on the local stream, an EPOLLIN event would normally be delivered
1138         // to the muxer's nested epoll FD. For testing only, we can fake that event notification
1139         // here.
1140         ctx.notify_muxer();
1141         // After being notified, the muxer should've figured out that RX data was available for one
1142         // of its connections, so it should now be reporting that it can fill in an RX packet.
1143         assert!(ctx.muxer.has_pending_rx());
1144         ctx.recv();
1145         assert_eq!(ctx.pkt.op(), uapi::VSOCK_OP_RW);
1146         assert_eq!(ctx.pkt.buf().unwrap()[..data.len()], data);
1147         assert_eq!(ctx.pkt.src_port(), LOCAL_PORT);
1148         assert_eq!(ctx.pkt.dst_port(), PEER_PORT);
1149 
1150         assert!(!ctx.muxer.has_pending_rx());
1151     }
1152 
1153     #[test]
1154     fn test_local_connection() {
1155         let mut ctx = MuxerTestContext::new("local_connection");
1156         let peer_port = 1025;
1157         let (mut stream, local_port) = ctx.local_connect(peer_port);
1158 
1159         // Test guest -> host data flow.
1160         let data = [1, 2, 3, 4];
1161         ctx.init_data_pkt(local_port, peer_port, &data);
1162         ctx.send();
1163 
1164         let mut buf = vec![0u8; data.len()];
1165         stream.read_exact(buf.as_mut_slice()).unwrap();
1166         assert_eq!(buf.as_slice(), &data);
1167 
1168         // Test host -> guest data flow.
1169         let data = [5, 6, 7, 8];
1170         stream.write_all(&data).unwrap();
1171         ctx.notify_muxer();
1172 
1173         assert!(ctx.muxer.has_pending_rx());
1174         ctx.recv();
1175         assert_eq!(ctx.pkt.op(), uapi::VSOCK_OP_RW);
1176         assert_eq!(ctx.pkt.src_port(), local_port);
1177         assert_eq!(ctx.pkt.dst_port(), peer_port);
1178         assert_eq!(ctx.pkt.buf().unwrap()[..data.len()], data);
1179     }
1180 
1181     #[test]
1182     fn test_local_close() {
1183         let peer_port = 1025;
1184         let mut ctx = MuxerTestContext::new("local_close");
1185         let local_port;
1186         {
1187             let (_stream, local_port_) = ctx.local_connect(peer_port);
1188             local_port = local_port_;
1189         }
1190         // Local var `_stream` was now dropped, thus closing the local stream. After the muxer gets
1191         // notified via EPOLLIN, it should attempt to gracefully shutdown the connection, issuing a
1192         // VSOCK_OP_SHUTDOWN with both no-more-send and no-more-recv indications set.
1193         ctx.notify_muxer();
1194         assert!(ctx.muxer.has_pending_rx());
1195         ctx.recv();
1196         assert_eq!(ctx.pkt.op(), uapi::VSOCK_OP_SHUTDOWN);
1197         assert_ne!(ctx.pkt.flags() & uapi::VSOCK_FLAGS_SHUTDOWN_SEND, 0);
1198         assert_ne!(ctx.pkt.flags() & uapi::VSOCK_FLAGS_SHUTDOWN_RCV, 0);
1199         assert_eq!(ctx.pkt.src_port(), local_port);
1200         assert_eq!(ctx.pkt.dst_port(), peer_port);
1201 
1202         // The connection should get removed (and its local port freed), after the peer replies
1203         // with an RST.
1204         ctx.init_pkt(local_port, peer_port, uapi::VSOCK_OP_RST);
1205         ctx.send();
1206         let key = ConnMapKey {
1207             local_port,
1208             peer_port,
1209         };
1210         assert!(!ctx.muxer.conn_map.contains_key(&key));
1211         assert!(!ctx.muxer.local_port_set.contains(&local_port));
1212     }
1213 
1214     #[test]
1215     fn test_peer_close() {
1216         let peer_port = 1025;
1217         let local_port = 1026;
1218         let mut ctx = MuxerTestContext::new("peer_close");
1219 
1220         let mut sock = ctx.create_local_listener(local_port);
1221         ctx.init_pkt(local_port, peer_port, uapi::VSOCK_OP_REQUEST);
1222         ctx.send();
1223         let mut stream = sock.accept();
1224 
1225         assert!(ctx.muxer.has_pending_rx());
1226         ctx.recv();
1227         assert_eq!(ctx.pkt.op(), uapi::VSOCK_OP_RESPONSE);
1228         assert_eq!(ctx.pkt.src_port(), local_port);
1229         assert_eq!(ctx.pkt.dst_port(), peer_port);
1230         let key = ConnMapKey {
1231             local_port,
1232             peer_port,
1233         };
1234         assert!(ctx.muxer.conn_map.contains_key(&key));
1235 
1236         // Emulate a full shutdown from the peer (no-more-send + no-more-recv).
1237         ctx.init_pkt(local_port, peer_port, uapi::VSOCK_OP_SHUTDOWN)
1238             .set_flag(uapi::VSOCK_FLAGS_SHUTDOWN_SEND)
1239             .set_flag(uapi::VSOCK_FLAGS_SHUTDOWN_RCV);
1240         ctx.send();
1241 
1242         // Now, the muxer should remove the connection from its map, and reply with an RST.
1243         assert!(ctx.muxer.has_pending_rx());
1244         ctx.recv();
1245         assert_eq!(ctx.pkt.op(), uapi::VSOCK_OP_RST);
1246         assert_eq!(ctx.pkt.src_port(), local_port);
1247         assert_eq!(ctx.pkt.dst_port(), peer_port);
1248         let key = ConnMapKey {
1249             local_port,
1250             peer_port,
1251         };
1252         assert!(!ctx.muxer.conn_map.contains_key(&key));
1253 
1254         // The muxer should also drop / close the local Unix socket for this connection.
1255         let mut buf = vec![0u8; 16];
1256         assert_eq!(stream.read(buf.as_mut_slice()).unwrap(), 0);
1257     }
1258 
1259     #[test]
1260     fn test_muxer_rxq() {
1261         let mut ctx = MuxerTestContext::new("muxer_rxq");
1262         let local_port = 1026;
1263         let peer_port_first = 1025;
1264         let mut listener = ctx.create_local_listener(local_port);
1265         let mut streams: Vec<UnixStream> = Vec::new();
1266 
1267         for peer_port in peer_port_first..peer_port_first + defs::MUXER_RXQ_SIZE {
1268             ctx.init_pkt(local_port, peer_port as u32, uapi::VSOCK_OP_REQUEST);
1269             ctx.send();
1270             streams.push(listener.accept());
1271         }
1272 
1273         // The muxer RX queue should now be full (with connection responses), but still
1274         // synchronized.
1275         assert!(ctx.muxer.rxq.is_synced());
1276 
1277         // One more queued reply should desync the RX queue.
1278         ctx.init_pkt(
1279             local_port,
1280             (peer_port_first + defs::MUXER_RXQ_SIZE) as u32,
1281             uapi::VSOCK_OP_REQUEST,
1282         );
1283         ctx.send();
1284         assert!(!ctx.muxer.rxq.is_synced());
1285 
1286         // With an out-of-sync queue, an RST should evict any non-RST packet from the queue, and
1287         // take its place. We'll check that by making sure that the last packet popped from the
1288         // queue is an RST.
1289         ctx.init_pkt(
1290             local_port + 1,
1291             peer_port_first as u32,
1292             uapi::VSOCK_OP_REQUEST,
1293         );
1294         ctx.send();
1295 
1296         for peer_port in peer_port_first..peer_port_first + defs::MUXER_RXQ_SIZE - 1 {
1297             ctx.recv();
1298             assert_eq!(ctx.pkt.op(), uapi::VSOCK_OP_RESPONSE);
1299             // The response order should hold. The evicted response should have been the last
1300             // enqueued.
1301             assert_eq!(ctx.pkt.dst_port(), peer_port as u32);
1302         }
1303         // There should be one more packet in the queue: the RST.
1304         assert_eq!(ctx.muxer.rxq.len(), 1);
1305         ctx.recv();
1306         assert_eq!(ctx.pkt.op(), uapi::VSOCK_OP_RST);
1307 
1308         // The queue should now be empty, but out-of-sync, so the muxer should report it has some
1309         // pending RX.
1310         assert!(ctx.muxer.rxq.is_empty());
1311         assert!(!ctx.muxer.rxq.is_synced());
1312         assert!(ctx.muxer.has_pending_rx());
1313 
1314         // The next recv should sync the queue back up. It should also yield one of the two
1315         // responses that are still left:
1316         // - the one that desynchronized the queue; and
1317         // - the one that got evicted by the RST.
1318         ctx.recv();
1319         assert!(ctx.muxer.rxq.is_synced());
1320         assert_eq!(ctx.pkt.op(), uapi::VSOCK_OP_RESPONSE);
1321 
1322         assert!(ctx.muxer.has_pending_rx());
1323         ctx.recv();
1324         assert_eq!(ctx.pkt.op(), uapi::VSOCK_OP_RESPONSE);
1325     }
1326 
1327     #[test]
1328     fn test_muxer_killq() {
1329         let mut ctx = MuxerTestContext::new("muxer_killq");
1330         let local_port = 1026;
1331         let peer_port_first = 1025;
1332         let peer_port_last = peer_port_first + defs::MUXER_KILLQ_SIZE;
1333         let mut listener = ctx.create_local_listener(local_port);
1334 
1335         for peer_port in peer_port_first..=peer_port_last {
1336             ctx.init_pkt(local_port, peer_port as u32, uapi::VSOCK_OP_REQUEST);
1337             ctx.send();
1338             ctx.notify_muxer();
1339             ctx.recv();
1340             assert_eq!(ctx.pkt.op(), uapi::VSOCK_OP_RESPONSE);
1341             assert_eq!(ctx.pkt.src_port(), local_port);
1342             assert_eq!(ctx.pkt.dst_port(), peer_port as u32);
1343             {
1344                 let _stream = listener.accept();
1345             }
1346             ctx.notify_muxer();
1347             ctx.recv();
1348             assert_eq!(ctx.pkt.op(), uapi::VSOCK_OP_SHUTDOWN);
1349             assert_eq!(ctx.pkt.src_port(), local_port);
1350             assert_eq!(ctx.pkt.dst_port(), peer_port as u32);
1351             // The kill queue should be synchronized, up until the `defs::MUXER_KILLQ_SIZE`th
1352             // connection we schedule for termination.
1353             assert_eq!(
1354                 ctx.muxer.killq.is_synced(),
1355                 peer_port < peer_port_first + defs::MUXER_KILLQ_SIZE
1356             );
1357         }
1358 
1359         assert!(!ctx.muxer.killq.is_synced());
1360         assert!(!ctx.muxer.has_pending_rx());
1361 
1362         // Wait for the kill timers to expire.
1363         std::thread::sleep(std::time::Duration::from_millis(
1364             csm_defs::CONN_SHUTDOWN_TIMEOUT_MS,
1365         ));
1366 
1367         // Trigger a kill queue sweep, by requesting a new connection.
1368         ctx.init_pkt(
1369             local_port,
1370             peer_port_last as u32 + 1,
1371             uapi::VSOCK_OP_REQUEST,
1372         );
1373         ctx.send();
1374 
1375         // After sweeping the kill queue, it should now be synced (assuming the RX queue is larger
1376         // than the kill queue, since an RST packet will be queued for each killed connection).
1377         assert!(ctx.muxer.killq.is_synced());
1378         assert!(ctx.muxer.has_pending_rx());
1379         // There should be `defs::MUXER_KILLQ_SIZE` RSTs in the RX queue, from terminating the
1380         // dying connections in the recent killq sweep.
1381         for _p in peer_port_first..peer_port_last {
1382             ctx.recv();
1383             assert_eq!(ctx.pkt.op(), uapi::VSOCK_OP_RST);
1384             assert_eq!(ctx.pkt.src_port(), local_port);
1385         }
1386 
1387         // There should be one more packet in the RX queue: the connection response our request
1388         // that triggered the kill queue sweep.
1389         ctx.recv();
1390         assert_eq!(ctx.pkt.op(), uapi::VSOCK_OP_RESPONSE);
1391         assert_eq!(ctx.pkt.dst_port(), peer_port_last as u32 + 1);
1392 
1393         assert!(!ctx.muxer.has_pending_rx());
1394     }
1395 
1396     #[test]
1397     fn test_regression_handshake() {
1398         // Address one of the issues found while fixing the following issue:
1399         // https://github.com/firecracker-microvm/firecracker/issues/1751
1400         // This test checks that the handshake message is not accounted for
1401         let mut ctx = MuxerTestContext::new("regression_handshake");
1402         let peer_port = 1025;
1403 
1404         // Create a local connection.
1405         let (_, local_port) = ctx.local_connect(peer_port);
1406 
1407         // Get the connection from the connection map.
1408         let key = ConnMapKey {
1409             local_port,
1410             peer_port,
1411         };
1412         let conn = ctx.muxer.conn_map.get_mut(&key).unwrap();
1413 
1414         // Check that fwd_cnt is 0 - "OK ..." was not accounted for.
1415         assert_eq!(conn.fwd_cnt().0, 0);
1416     }
1417 
1418     #[test]
1419     fn test_regression_rxq_pop() {
1420         // Address one of the issues found while fixing the following issue:
1421         // https://github.com/firecracker-microvm/firecracker/issues/1751
1422         // This test checks that a connection is not popped out of the muxer
1423         // rxq when multiple flags are set
1424         let mut ctx = MuxerTestContext::new("regression_rxq_pop");
1425         let peer_port = 1025;
1426         let (mut stream, local_port) = ctx.local_connect(peer_port);
1427 
1428         // Send some data.
1429         let data = [5u8, 6, 7, 8];
1430         stream.write_all(&data).unwrap();
1431         ctx.notify_muxer();
1432 
1433         // Get the connection from the connection map.
1434         let key = ConnMapKey {
1435             local_port,
1436             peer_port,
1437         };
1438         let conn = ctx.muxer.conn_map.get_mut(&key).unwrap();
1439 
1440         // Forcefully insert another flag.
1441         conn.insert_credit_update();
1442 
1443         // Call recv twice in order to check that the connection is still
1444         // in the rxq.
1445         assert!(ctx.muxer.has_pending_rx());
1446         ctx.recv();
1447         assert!(ctx.muxer.has_pending_rx());
1448         ctx.recv();
1449 
1450         // Since initially the connection had two flags set, now there should
1451         // not be any pending RX in the muxer.
1452         assert!(!ctx.muxer.has_pending_rx());
1453     }
1454 }
1455