xref: /cloud-hypervisor/virtio-devices/src/vsock/unix/muxer.rs (revision 3ce0fef7fd546467398c914dbc74d8542e45cf6f)
1 // Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 // SPDX-License-Identifier: Apache-2.0
3 //
4 
5 //! `VsockMuxer` is the device-facing component of the Unix domain sockets vsock backend. I.e.
6 //! by implementing the `VsockBackend` trait, it abstracts away the gory details of translating
7 //! between AF_VSOCK and AF_UNIX, and presents a clean interface to the rest of the vsock
8 //! device model.
9 //!
10 //! The vsock muxer has two main roles:
11 //!
12 //! ## Vsock connection multiplexer
13 //!
14 //! It's the muxer's job to create, manage, and terminate `VsockConnection` objects. The
15 //! muxer also routes packets to their owning connections. It does so via a connection
16 //! `HashMap`, keyed by what is basically a (host_port, guest_port) tuple.
17 //!
18 //! Vsock packet traffic needs to be inspected, in order to detect connection request
19 //! packets (leading to the creation of a new connection), and connection reset packets
20 //! (leading to the termination of an existing connection). All other packets, though, must
21 //! belong to an existing connection and, as such, the muxer simply forwards them.
22 //!
23 //! ## Event dispatcher
24 //!
25 //! There are three event categories that the vsock backend is interested it:
26 //! 1. A new host-initiated connection is ready to be accepted from the listening host Unix
27 //!    socket;
28 //! 2. Data is available for reading from a newly-accepted host-initiated connection (i.e.
29 //!    the host is ready to issue a vsock connection request, informing us of the
30 //!    destination port to which it wants to connect);
31 //! 3. Some event was triggered for a connected Unix socket, that belongs to a
32 //!    `VsockConnection`.
33 //!
34 //! The muxer gets notified about all of these events, because, as a `VsockEpollListener`
35 //! implementor, it gets to register a nested epoll FD into the main VMM epolling loop. All
36 //! other pollable FDs are then registered under this nested epoll FD.
37 //!
38 //! To route all these events to their handlers, the muxer uses another `HashMap` object,
39 //! mapping `RawFd`s to `EpollListener`s.
40 
41 use std::collections::{HashMap, HashSet};
42 use std::fs::File;
43 use std::io::{self, ErrorKind, Read};
44 use std::os::unix::io::{AsRawFd, FromRawFd, RawFd};
45 use std::os::unix::net::{UnixListener, UnixStream};
46 
47 use super::super::csm::ConnState;
48 use super::super::defs::uapi;
49 use super::super::packet::VsockPacket;
50 use super::super::{
51     Result as VsockResult, VsockBackend, VsockChannel, VsockEpollListener, VsockError,
52 };
53 use super::defs;
54 use super::muxer_killq::MuxerKillQ;
55 use super::muxer_rxq::MuxerRxQ;
56 use super::MuxerConnection;
57 use super::{Error, Result};
58 
59 /// A unique identifier of a `MuxerConnection` object. Connections are stored in a hash map,
60 /// keyed by a `ConnMapKey` object.
61 ///
62 #[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
63 pub struct ConnMapKey {
64     local_port: u32,
65     peer_port: u32,
66 }
67 
68 /// A muxer RX queue item.
69 ///
70 #[derive(Clone, Copy, Debug)]
71 pub enum MuxerRx {
72     /// The packet must be fetched from the connection identified by `ConnMapKey`.
73     ConnRx(ConnMapKey),
74     /// The muxer must produce an RST packet.
75     RstPkt { local_port: u32, peer_port: u32 },
76 }
77 
78 /// An epoll listener, registered under the muxer's nested epoll FD.
79 ///
80 enum EpollListener {
81     /// The listener is a `MuxerConnection`, identified by `key`, and interested in the events
82     /// in `evset`. Since `MuxerConnection` implements `VsockEpollListener`, notifications will
83     /// be forwarded to the listener via `VsockEpollListener::notify()`.
84     Connection {
85         key: ConnMapKey,
86         evset: epoll::Events,
87     },
88     /// A listener interested in new host-initiated connections.
89     HostSock,
90     /// A listener interested in reading host "connect \<port>" commands from a freshly
91     /// connected host socket.
92     LocalStream(UnixStream),
93 }
94 
95 /// A partially read "CONNECT" command.
96 #[derive(Default)]
97 struct PartiallyReadCommand {
98     /// The bytes of the command that have been read so far.
99     buf: [u8; 32],
100     /// How much of `buf` has been used.
101     len: usize,
102 }
103 
104 /// The vsock connection multiplexer.
105 ///
106 pub struct VsockMuxer {
107     /// Guest CID.
108     cid: u64,
109     /// A hash map used to store the active connections.
110     conn_map: HashMap<ConnMapKey, MuxerConnection>,
111     /// A hash map used to store epoll event listeners / handlers.
112     listener_map: HashMap<RawFd, EpollListener>,
113     /// A hash map used to store partially read "connect" commands.
114     partial_command_map: HashMap<RawFd, PartiallyReadCommand>,
115     /// The RX queue. Items in this queue are consumed by `VsockMuxer::recv_pkt()`, and
116     /// produced
117     /// - by `VsockMuxer::send_pkt()` (e.g. RST in response to a connection request packet);
118     ///   and
119     /// - in response to EPOLLIN events (e.g. data available to be read from an AF_UNIX
120     ///   socket).
121     rxq: MuxerRxQ,
122     /// A queue used for terminating connections that are taking too long to shut down.
123     killq: MuxerKillQ,
124     /// The Unix socket, through which host-initiated connections are accepted.
125     host_sock: UnixListener,
126     /// The file system path of the host-side Unix socket. This is used to figure out the path
127     /// to Unix sockets listening on specific ports. I.e. "\<this path>_\<port number>".
128     host_sock_path: String,
129     /// The nested epoll File, used to register epoll listeners.
130     epoll_file: File,
131     /// A hash set used to keep track of used host-side (local) ports, in order to assign local
132     /// ports to host-initiated connections.
133     local_port_set: HashSet<u32>,
134     /// The last used host-side port.
135     local_port_last: u32,
136 }
137 
138 impl VsockChannel for VsockMuxer {
139     /// Deliver a vsock packet to the guest vsock driver.
140     ///
141     /// Returns:
142     /// - `Ok(())`: `pkt` has been successfully filled in; or
143     /// - `Err(VsockError::NoData)`: there was no available data with which to fill in the
144     ///   packet.
145     ///
146     fn recv_pkt(&mut self, pkt: &mut VsockPacket) -> VsockResult<()> {
147         // We'll look for instructions on how to build the RX packet in the RX queue. If the
148         // queue is empty, that doesn't necessarily mean we don't have any pending RX, since
149         // the queue might be out-of-sync. If that's the case, we'll attempt to sync it first,
150         // and then try to pop something out again.
151         if self.rxq.is_empty() && !self.rxq.is_synced() {
152             self.rxq = MuxerRxQ::from_conn_map(&self.conn_map);
153         }
154 
155         while let Some(rx) = self.rxq.peek() {
156             let res = match rx {
157                 // We need to build an RST packet, going from `local_port` to `peer_port`.
158                 MuxerRx::RstPkt {
159                     local_port,
160                     peer_port,
161                 } => {
162                     pkt.set_op(uapi::VSOCK_OP_RST)
163                         .set_src_cid(uapi::VSOCK_HOST_CID)
164                         .set_dst_cid(self.cid)
165                         .set_src_port(local_port)
166                         .set_dst_port(peer_port)
167                         .set_len(0)
168                         .set_type(uapi::VSOCK_TYPE_STREAM)
169                         .set_flags(0)
170                         .set_buf_alloc(0)
171                         .set_fwd_cnt(0);
172                     self.rxq.pop().unwrap();
173                     return Ok(());
174                 }
175 
176                 // We'll defer building the packet to this connection, since it has something
177                 // to say.
178                 MuxerRx::ConnRx(key) => {
179                     let mut conn_res = Err(VsockError::NoData);
180                     let mut do_pop = true;
181                     self.apply_conn_mutation(key, |conn| {
182                         conn_res = conn.recv_pkt(pkt);
183                         do_pop = !conn.has_pending_rx();
184                     });
185                     if do_pop {
186                         self.rxq.pop().unwrap();
187                     }
188                     conn_res
189                 }
190             };
191 
192             if res.is_ok() {
193                 // Inspect traffic, looking for RST packets, since that means we have to
194                 // terminate and remove this connection from the active connection pool.
195                 //
196                 if pkt.op() == uapi::VSOCK_OP_RST {
197                     self.remove_connection(ConnMapKey {
198                         local_port: pkt.src_port(),
199                         peer_port: pkt.dst_port(),
200                     });
201                 }
202 
203                 debug!("vsock muxer: RX pkt: {:?}", pkt.hdr());
204                 return Ok(());
205             }
206         }
207 
208         Err(VsockError::NoData)
209     }
210 
211     /// Deliver a guest-generated packet to its destination in the vsock backend.
212     ///
213     /// This absorbs unexpected packets, handles RSTs (by dropping connections), and forwards
214     /// all the rest to their owning `MuxerConnection`.
215     ///
216     /// Returns:
217     /// always `Ok(())` - the packet has been consumed, and its virtio TX buffers can be
218     /// returned to the guest vsock driver.
219     ///
220     fn send_pkt(&mut self, pkt: &VsockPacket) -> VsockResult<()> {
221         let conn_key = ConnMapKey {
222             local_port: pkt.dst_port(),
223             peer_port: pkt.src_port(),
224         };
225 
226         debug!(
227             "vsock: muxer.send[rxq.len={}]: {:?}",
228             self.rxq.len(),
229             pkt.hdr()
230         );
231 
232         // If this packet has an unsupported type (!=stream), we must send back an RST.
233         //
234         if pkt.type_() != uapi::VSOCK_TYPE_STREAM {
235             self.enq_rst(pkt.dst_port(), pkt.src_port());
236             return Ok(());
237         }
238 
239         // We don't know how to handle packets addressed to other CIDs. We only handle the host
240         // part of the guest - host communication here.
241         if pkt.dst_cid() != uapi::VSOCK_HOST_CID {
242             info!(
243                 "vsock: dropping guest packet for unknown CID: {:?}",
244                 pkt.hdr()
245             );
246             return Ok(());
247         }
248 
249         if !self.conn_map.contains_key(&conn_key) {
250             // This packet can't be routed to any active connection (based on its src and dst
251             // ports).  The only orphan / unroutable packets we know how to handle are
252             // connection requests.
253             if pkt.op() == uapi::VSOCK_OP_REQUEST {
254                 // Oh, this is a connection request!
255                 self.handle_peer_request_pkt(pkt);
256             } else {
257                 // Send back an RST, to let the drive know we weren't expecting this packet.
258                 self.enq_rst(pkt.dst_port(), pkt.src_port());
259             }
260             return Ok(());
261         }
262 
263         // Right, we know where to send this packet, then (to `conn_key`).
264         // However, if this is an RST, we have to forcefully terminate the connection, so
265         // there's no point in forwarding it the packet.
266         if pkt.op() == uapi::VSOCK_OP_RST {
267             self.remove_connection(conn_key);
268             return Ok(());
269         }
270 
271         // Alright, everything looks in order - forward this packet to its owning connection.
272         let mut res: VsockResult<()> = Ok(());
273         self.apply_conn_mutation(conn_key, |conn| {
274             res = conn.send_pkt(pkt);
275         });
276 
277         res
278     }
279 
280     /// Check if the muxer has any pending RX data, with which to fill a guest-provided RX
281     /// buffer.
282     ///
283     fn has_pending_rx(&self) -> bool {
284         !self.rxq.is_empty() || !self.rxq.is_synced()
285     }
286 }
287 
288 impl VsockEpollListener for VsockMuxer {
289     /// Get the FD to be registered for polling upstream (in the main VMM epoll loop, in this
290     /// case).
291     ///
292     /// This will be the muxer's nested epoll FD.
293     ///
294     fn get_polled_fd(&self) -> RawFd {
295         self.epoll_file.as_raw_fd()
296     }
297 
298     /// Get the epoll events to be polled upstream.
299     ///
300     /// Since the polled FD is a nested epoll FD, we're only interested in EPOLLIN events (i.e.
301     /// some event occurred on one of the FDs registered under our epoll FD).
302     ///
303     fn get_polled_evset(&self) -> epoll::Events {
304         epoll::Events::EPOLLIN
305     }
306 
307     /// Notify the muxer about a pending event having occurred under its nested epoll FD.
308     ///
309     fn notify(&mut self, _: epoll::Events) {
310         debug!("vsock: muxer received kick");
311 
312         let mut epoll_events = vec![epoll::Event::new(epoll::Events::empty(), 0); 32];
313         'epoll: loop {
314             match epoll::wait(self.epoll_file.as_raw_fd(), 0, epoll_events.as_mut_slice()) {
315                 Ok(ev_cnt) => {
316                     for evt in epoll_events.iter().take(ev_cnt) {
317                         self.handle_event(
318                             evt.data as RawFd,
319                             // It's ok to unwrap here, since the `evt.events` is filled
320                             // in by `epoll::wait()`, and therefore contains only valid epoll
321                             // flags.
322                             epoll::Events::from_bits(evt.events).unwrap(),
323                         );
324                     }
325                 }
326                 Err(e) => {
327                     if e.kind() == io::ErrorKind::Interrupted {
328                         // It's well defined from the epoll_wait() syscall
329                         // documentation that the epoll loop can be interrupted
330                         // before any of the requested events occurred or the
331                         // timeout expired. In both those cases, epoll_wait()
332                         // returns an error of type EINTR, but this should not
333                         // be considered as a regular error. Instead it is more
334                         // appropriate to retry, by calling into epoll_wait().
335                         continue;
336                     }
337                     warn!("vsock: failed to consume muxer epoll event: {}", e);
338                 }
339             }
340             break 'epoll;
341         }
342     }
343 }
344 
345 impl VsockBackend for VsockMuxer {}
346 
347 impl VsockMuxer {
348     /// Muxer constructor.
349     ///
350     pub fn new(cid: u32, host_sock_path: String) -> Result<Self> {
351         // Create the nested epoll FD. This FD will be added to the VMM `EpollContext`, at
352         // device activation time.
353         let epoll_fd = epoll::create(true).map_err(Error::EpollFdCreate)?;
354         // Use 'File' to enforce closing on 'epoll_fd'
355         // SAFETY: epoll_fd is a valid fd
356         let epoll_file = unsafe { File::from_raw_fd(epoll_fd) };
357 
358         // Open/bind/listen on the host Unix socket, so we can accept host-initiated
359         // connections.
360         let host_sock = UnixListener::bind(&host_sock_path)
361             .and_then(|sock| sock.set_nonblocking(true).map(|_| sock))
362             .map_err(Error::UnixBind)?;
363 
364         let mut muxer = Self {
365             cid: cid.into(),
366             host_sock,
367             host_sock_path,
368             epoll_file,
369             rxq: MuxerRxQ::new(),
370             conn_map: HashMap::with_capacity(defs::MAX_CONNECTIONS),
371             listener_map: HashMap::with_capacity(defs::MAX_CONNECTIONS + 1),
372             partial_command_map: Default::default(),
373             killq: MuxerKillQ::new(),
374             local_port_last: (1u32 << 30) - 1,
375             local_port_set: HashSet::with_capacity(defs::MAX_CONNECTIONS),
376         };
377 
378         muxer.add_listener(muxer.host_sock.as_raw_fd(), EpollListener::HostSock)?;
379         Ok(muxer)
380     }
381 
382     /// Handle/dispatch an epoll event to its listener.
383     ///
384     fn handle_event(&mut self, fd: RawFd, event_set: epoll::Events) {
385         debug!(
386             "vsock: muxer processing event: fd={}, event_set={:?}",
387             fd, event_set
388         );
389 
390         match self.listener_map.get_mut(&fd) {
391             // This event needs to be forwarded to a `MuxerConnection` that is listening for
392             // it.
393             //
394             Some(EpollListener::Connection { key, evset: _ }) => {
395                 let key_copy = *key;
396                 // The handling of this event will most probably mutate the state of the
397                 // receiving connection. We'll need to check for new pending RX, event set
398                 // mutation, and all that, so we're wrapping the event delivery inside those
399                 // checks.
400                 self.apply_conn_mutation(key_copy, |conn| {
401                     conn.notify(event_set);
402                 });
403             }
404 
405             // A new host-initiated connection is ready to be accepted.
406             //
407             Some(EpollListener::HostSock) => {
408                 if self.conn_map.len() == defs::MAX_CONNECTIONS {
409                     // If we're already maxed-out on connections, we'll just accept and
410                     // immediately discard this potentially new one.
411                     warn!("vsock: connection limit reached; refusing new host connection");
412                     self.host_sock.accept().map(|_| 0).unwrap_or(0);
413                     return;
414                 }
415                 self.host_sock
416                     .accept()
417                     .map_err(Error::UnixAccept)
418                     .and_then(|(stream, _)| {
419                         stream
420                             .set_nonblocking(true)
421                             .map(|_| stream)
422                             .map_err(Error::UnixAccept)
423                     })
424                     .and_then(|stream| {
425                         // Before forwarding this connection to a listening AF_VSOCK socket on
426                         // the guest side, we need to know the destination port. We'll read
427                         // that port from a "connect" command received on this socket, so the
428                         // next step is to ask to be notified the moment we can read from it.
429                         self.add_listener(stream.as_raw_fd(), EpollListener::LocalStream(stream))
430                     })
431                     .unwrap_or_else(|err| {
432                         warn!("vsock: unable to accept local connection: {:?}", err);
433                     });
434             }
435 
436             // Data is ready to be read from a host-initiated connection. That would be the
437             // "connect" command that we're expecting.
438             Some(EpollListener::LocalStream(_)) => {
439                 if let Some(EpollListener::LocalStream(stream)) = self.listener_map.get_mut(&fd) {
440                     let port = Self::read_local_stream_port(&mut self.partial_command_map, stream);
441 
442                     if let Err(Error::UnixRead(ref e)) = port {
443                         if e.kind() == ErrorKind::WouldBlock {
444                             return;
445                         }
446                     }
447 
448                     let stream = match self.remove_listener(fd) {
449                         Some(EpollListener::LocalStream(s)) => s,
450                         _ => unreachable!(),
451                     };
452 
453                     port.and_then(|peer_port| {
454                         let local_port = self.allocate_local_port();
455 
456                         self.add_connection(
457                             ConnMapKey {
458                                 local_port,
459                                 peer_port,
460                             },
461                             MuxerConnection::new_local_init(
462                                 stream,
463                                 uapi::VSOCK_HOST_CID,
464                                 self.cid,
465                                 local_port,
466                                 peer_port,
467                             ),
468                         )
469                     })
470                     .unwrap_or_else(|err| {
471                         info!("vsock: error adding local-init connection: {:?}", err);
472                     })
473                 }
474             }
475 
476             _ => {
477                 info!(
478                     "vsock: unexpected event: fd={:?}, event_set={:?}",
479                     fd, event_set
480                 );
481             }
482         }
483     }
484 
485     /// Parse a host "connect" command, and extract the destination vsock port.
486     ///
487     fn read_local_stream_port(
488         partial_command_map: &mut HashMap<RawFd, PartiallyReadCommand>,
489         stream: &mut UnixStream,
490     ) -> Result<u32> {
491         let command = partial_command_map.entry(stream.as_raw_fd()).or_default();
492 
493         // This is the minimum number of bytes that we should be able to read, when parsing a
494         // valid connection request. I.e. `b"connect 0\n".len()`.
495         const MIN_COMMAND_LEN: usize = 10;
496 
497         // Bring in the minimum number of bytes that we should be able to read.
498         if command.len < MIN_COMMAND_LEN {
499             command.len += stream
500                 .read(&mut command.buf[command.len..MIN_COMMAND_LEN])
501                 .map_err(Error::UnixRead)?;
502         }
503 
504         // Now, finish reading the destination port number, by bringing in one byte at a time,
505         // until we reach an EOL terminator (or our buffer space runs out).  Yeah, not
506         // particularly proud of this approach, but it will have to do for now.
507         while command.buf[command.len - 1] != b'\n' && command.len < command.buf.len() {
508             command.len += stream
509                 .read(&mut command.buf[command.len..=command.len])
510                 .map_err(Error::UnixRead)?;
511         }
512 
513         let _ = command;
514         let command = partial_command_map.remove(&stream.as_raw_fd()).unwrap();
515 
516         let mut word_iter = std::str::from_utf8(&command.buf[..command.len])
517             .map_err(Error::ConvertFromUtf8)?
518             .split_whitespace();
519 
520         word_iter
521             .next()
522             .ok_or(Error::InvalidPortRequest)
523             .and_then(|word| {
524                 if word.to_lowercase() == "connect" {
525                     Ok(())
526                 } else {
527                     Err(Error::InvalidPortRequest)
528                 }
529             })
530             .and_then(|_| word_iter.next().ok_or(Error::InvalidPortRequest))
531             .and_then(|word| word.parse::<u32>().map_err(Error::ParseInteger))
532             .map_err(|e| Error::ReadStreamPort(Box::new(e)))
533     }
534 
535     /// Add a new connection to the active connection pool.
536     ///
537     fn add_connection(&mut self, key: ConnMapKey, conn: MuxerConnection) -> Result<()> {
538         // We might need to make room for this new connection, so let's sweep the kill queue
539         // first.  It's fine to do this here because:
540         // - unless the kill queue is out of sync, this is a pretty inexpensive operation; and
541         // - we are under no pressure to respect any accurate timing for connection
542         //   termination.
543         self.sweep_killq();
544 
545         if self.conn_map.len() >= defs::MAX_CONNECTIONS {
546             info!(
547                 "vsock: muxer connection limit reached ({})",
548                 defs::MAX_CONNECTIONS
549             );
550             return Err(Error::TooManyConnections);
551         }
552 
553         self.add_listener(
554             conn.get_polled_fd(),
555             EpollListener::Connection {
556                 key,
557                 evset: conn.get_polled_evset(),
558             },
559         )
560         .map(|_| {
561             if conn.has_pending_rx() {
562                 // We can safely ignore any error in adding a connection RX indication. Worst
563                 // case scenario, the RX queue will get desynchronized, but we'll handle that
564                 // the next time we need to yield an RX packet.
565                 self.rxq.push(MuxerRx::ConnRx(key));
566             }
567             self.conn_map.insert(key, conn);
568         })
569     }
570 
571     /// Remove a connection from the active connection poll.
572     ///
573     fn remove_connection(&mut self, key: ConnMapKey) {
574         if let Some(conn) = self.conn_map.remove(&key) {
575             self.remove_listener(conn.get_polled_fd());
576         }
577         self.free_local_port(key.local_port);
578     }
579 
580     /// Schedule a connection for immediate termination.
581     /// I.e. as soon as we can also let our peer know we're dropping the connection, by sending
582     /// it an RST packet.
583     ///
584     fn kill_connection(&mut self, key: ConnMapKey) {
585         let mut had_rx = false;
586         self.conn_map.entry(key).and_modify(|conn| {
587             had_rx = conn.has_pending_rx();
588             conn.kill();
589         });
590         // This connection will now have an RST packet to yield, so we need to add it to the RX
591         // queue.  However, there's no point in doing that if it was already in the queue.
592         if !had_rx {
593             // We can safely ignore any error in adding a connection RX indication. Worst case
594             // scenario, the RX queue will get desynchronized, but we'll handle that the next
595             // time we need to yield an RX packet.
596             self.rxq.push(MuxerRx::ConnRx(key));
597         }
598     }
599 
600     /// Register a new epoll listener under the muxer's nested epoll FD.
601     ///
602     fn add_listener(&mut self, fd: RawFd, listener: EpollListener) -> Result<()> {
603         let evset = match listener {
604             EpollListener::Connection { evset, .. } => evset,
605             EpollListener::LocalStream(_) => epoll::Events::EPOLLIN,
606             EpollListener::HostSock => epoll::Events::EPOLLIN,
607         };
608 
609         epoll::ctl(
610             self.epoll_file.as_raw_fd(),
611             epoll::ControlOptions::EPOLL_CTL_ADD,
612             fd,
613             epoll::Event::new(evset, fd as u64),
614         )
615         .map(|_| {
616             self.listener_map.insert(fd, listener);
617         })
618         .map_err(Error::EpollAdd)?;
619 
620         Ok(())
621     }
622 
623     /// Remove (and return) a previously registered epoll listener.
624     ///
625     fn remove_listener(&mut self, fd: RawFd) -> Option<EpollListener> {
626         let maybe_listener = self.listener_map.remove(&fd);
627 
628         if maybe_listener.is_some() {
629             epoll::ctl(
630                 self.epoll_file.as_raw_fd(),
631                 epoll::ControlOptions::EPOLL_CTL_DEL,
632                 fd,
633                 epoll::Event::new(epoll::Events::empty(), 0),
634             )
635             .unwrap_or_else(|err| {
636                 warn!(
637                     "vosck muxer: error removing epoll listener for fd {:?}: {:?}",
638                     fd, err
639                 );
640             });
641         }
642 
643         maybe_listener
644     }
645 
646     /// Allocate a host-side port to be assigned to a new host-initiated connection.
647     ///
648     ///
649     fn allocate_local_port(&mut self) -> u32 {
650         // TODO: this doesn't seem very space-efficient.
651         // Maybe rewrite this to limit port range and use a bitmap?
652         //
653 
654         loop {
655             self.local_port_last = (self.local_port_last + 1) & !(1 << 31) | (1 << 30);
656             if self.local_port_set.insert(self.local_port_last) {
657                 break;
658             }
659         }
660         self.local_port_last
661     }
662 
663     /// Mark a previously used host-side port as free.
664     ///
665     fn free_local_port(&mut self, port: u32) {
666         self.local_port_set.remove(&port);
667     }
668 
669     /// Handle a new connection request coming from our peer (the guest vsock driver).
670     ///
671     /// This will attempt to connect to a host-side Unix socket, expected to be listening at
672     /// the file system path corresponding to the destination port. If successful, a new
673     /// connection object will be created and added to the connection pool. On failure, a new
674     /// RST packet will be scheduled for delivery to the guest.
675     ///
676     fn handle_peer_request_pkt(&mut self, pkt: &VsockPacket) {
677         let port_path = format!("{}_{}", self.host_sock_path, pkt.dst_port());
678 
679         UnixStream::connect(port_path)
680             .and_then(|stream| stream.set_nonblocking(true).map(|_| stream))
681             .map_err(Error::UnixConnect)
682             .and_then(|stream| {
683                 self.add_connection(
684                     ConnMapKey {
685                         local_port: pkt.dst_port(),
686                         peer_port: pkt.src_port(),
687                     },
688                     MuxerConnection::new_peer_init(
689                         stream,
690                         uapi::VSOCK_HOST_CID,
691                         self.cid,
692                         pkt.dst_port(),
693                         pkt.src_port(),
694                         pkt.buf_alloc(),
695                     ),
696                 )
697             })
698             .unwrap_or_else(|_| self.enq_rst(pkt.dst_port(), pkt.src_port()));
699     }
700 
701     /// Perform an action that might mutate a connection's state.
702     ///
703     /// This is used as shorthand for repetitive tasks that need to be performed after a
704     /// connection object mutates. E.g.
705     /// - update the connection's epoll listener;
706     /// - schedule the connection to be queried for RX data;
707     /// - kill the connection if an unrecoverable error occurs.
708     ///
709     fn apply_conn_mutation<F>(&mut self, key: ConnMapKey, mut_fn: F)
710     where
711         F: FnOnce(&mut MuxerConnection),
712     {
713         if let Some(conn) = self.conn_map.get_mut(&key) {
714             let had_rx = conn.has_pending_rx();
715             let was_expiring = conn.will_expire();
716             let prev_state = conn.state();
717 
718             mut_fn(conn);
719 
720             // If this is a host-initiated connection that has just become established, we'll have
721             // to send an ack message to the host end.
722             if prev_state == ConnState::LocalInit && conn.state() == ConnState::Established {
723                 let msg = format!("OK {}\n", key.local_port);
724                 match conn.send_bytes_raw(msg.as_bytes()) {
725                     Ok(written) if written == msg.len() => (),
726                     Ok(_) => {
727                         // If we can't write a dozen bytes to a pristine connection something
728                         // must be really wrong. Killing it.
729                         conn.kill();
730                         warn!("vsock: unable to fully write connection ack msg.");
731                     }
732                     Err(err) => {
733                         conn.kill();
734                         warn!("vsock: unable to ack host connection: {:?}", err);
735                     }
736                 };
737             }
738 
739             // If the connection wasn't previously scheduled for RX, add it to our RX queue.
740             if !had_rx && conn.has_pending_rx() {
741                 self.rxq.push(MuxerRx::ConnRx(key));
742             }
743 
744             // If the connection wasn't previously scheduled for termination, add it to the
745             // kill queue.
746             if !was_expiring && conn.will_expire() {
747                 // It's safe to unwrap here, since `conn.will_expire()` already guaranteed that
748                 // an `conn.expiry` is available.
749                 self.killq.push(key, conn.expiry().unwrap());
750             }
751 
752             let fd = conn.get_polled_fd();
753             let new_evset = conn.get_polled_evset();
754             if new_evset.is_empty() {
755                 // If the connection no longer needs epoll notifications, remove its listener
756                 // from our list.
757                 self.remove_listener(fd);
758                 return;
759             }
760             if let Some(EpollListener::Connection { evset, .. }) = self.listener_map.get_mut(&fd) {
761                 if *evset != new_evset {
762                     // If the set of events that the connection is interested in has changed,
763                     // we need to update its epoll listener.
764                     debug!(
765                         "vsock: updating listener for (lp={}, pp={}): old={:?}, new={:?}",
766                         key.local_port, key.peer_port, *evset, new_evset
767                     );
768 
769                     *evset = new_evset;
770                     epoll::ctl(
771                         self.epoll_file.as_raw_fd(),
772                         epoll::ControlOptions::EPOLL_CTL_MOD,
773                         fd,
774                         epoll::Event::new(new_evset, fd as u64),
775                     )
776                     .unwrap_or_else(|err| {
777                         // This really shouldn't happen, like, ever. However, "famous last
778                         // words" and all that, so let's just kill it with fire, and walk away.
779                         self.kill_connection(key);
780                         error!(
781                             "vsock: error updating epoll listener for (lp={}, pp={}): {:?}",
782                             key.local_port, key.peer_port, err
783                         );
784                     });
785                 }
786             } else {
787                 // The connection had previously asked to be removed from the listener map (by
788                 // returning an empty event set via `get_polled_fd()`), but now wants back in.
789                 self.add_listener(
790                     fd,
791                     EpollListener::Connection {
792                         key,
793                         evset: new_evset,
794                     },
795                 )
796                 .unwrap_or_else(|err| {
797                     self.kill_connection(key);
798                     error!(
799                         "vsock: error updating epoll listener for (lp={}, pp={}): {:?}",
800                         key.local_port, key.peer_port, err
801                     );
802                 });
803             }
804         }
805     }
806 
807     /// Check if any connections have timed out, and if so, schedule them for immediate
808     /// termination.
809     ///
810     fn sweep_killq(&mut self) {
811         while let Some(key) = self.killq.pop() {
812             // Connections don't get removed from the kill queue when their kill timer is
813             // disarmed, since that would be a costly operation. This means we must check if
814             // the connection has indeed expired, prior to killing it.
815             let mut kill = false;
816             self.conn_map
817                 .entry(key)
818                 .and_modify(|conn| kill = conn.has_expired());
819             if kill {
820                 self.kill_connection(key);
821             }
822         }
823 
824         if self.killq.is_empty() && !self.killq.is_synced() {
825             self.killq = MuxerKillQ::from_conn_map(&self.conn_map);
826             // If we've just re-created the kill queue, we can sweep it again; maybe there's
827             // more to kill.
828             self.sweep_killq();
829         }
830     }
831 
832     /// Enqueue an RST packet into `self.rxq`.
833     ///
834     /// Enqueue errors aren't propagated up the call chain, since there is nothing we can do to
835     /// handle them. We do, however, log a warning, since not being able to enqueue an RST
836     /// packet means we have to drop it, which is not normal operation.
837     ///
838     fn enq_rst(&mut self, local_port: u32, peer_port: u32) {
839         let pushed = self.rxq.push(MuxerRx::RstPkt {
840             local_port,
841             peer_port,
842         });
843         if !pushed {
844             warn!(
845                 "vsock: muxer.rxq full; dropping RST packet for lp={}, pp={}",
846                 local_port, peer_port
847             );
848         }
849     }
850 }
851 
852 #[cfg(test)]
853 mod tests {
854     use std::io::{Read, Write};
855     use std::ops::Drop;
856     use std::os::unix::net::{UnixListener, UnixStream};
857     use std::path::{Path, PathBuf};
858 
859     use virtio_queue::QueueOwnedT;
860 
861     use super::super::super::csm::defs as csm_defs;
862     use super::super::super::tests::TestContext as VsockTestContext;
863     use super::*;
864 
865     const PEER_CID: u32 = 3;
866     const PEER_BUF_ALLOC: u32 = 64 * 1024;
867 
868     struct MuxerTestContext {
869         _vsock_test_ctx: VsockTestContext,
870         pkt: VsockPacket,
871         muxer: VsockMuxer,
872     }
873 
874     impl Drop for MuxerTestContext {
875         fn drop(&mut self) {
876             std::fs::remove_file(self.muxer.host_sock_path.as_str()).unwrap();
877         }
878     }
879 
880     impl MuxerTestContext {
881         fn new(name: &str) -> Self {
882             let vsock_test_ctx = VsockTestContext::new();
883             let mut handler_ctx = vsock_test_ctx.create_epoll_handler_context();
884             let pkt = VsockPacket::from_rx_virtq_head(
885                 &mut handler_ctx.handler.queues[0]
886                     .iter(&vsock_test_ctx.mem)
887                     .unwrap()
888                     .next()
889                     .unwrap(),
890                 None,
891             )
892             .unwrap();
893             let uds_path = format!("test_vsock_{name}.sock");
894             let muxer = VsockMuxer::new(PEER_CID, uds_path).unwrap();
895 
896             Self {
897                 _vsock_test_ctx: vsock_test_ctx,
898                 pkt,
899                 muxer,
900             }
901         }
902 
903         fn init_pkt(&mut self, local_port: u32, peer_port: u32, op: u16) -> &mut VsockPacket {
904             for b in self.pkt.hdr_mut() {
905                 *b = 0;
906             }
907             self.pkt
908                 .set_type(uapi::VSOCK_TYPE_STREAM)
909                 .set_src_cid(PEER_CID.into())
910                 .set_dst_cid(uapi::VSOCK_HOST_CID)
911                 .set_src_port(peer_port)
912                 .set_dst_port(local_port)
913                 .set_op(op)
914                 .set_buf_alloc(PEER_BUF_ALLOC)
915         }
916 
917         fn init_data_pkt(
918             &mut self,
919             local_port: u32,
920             peer_port: u32,
921             data: &[u8],
922         ) -> &mut VsockPacket {
923             assert!(data.len() <= self.pkt.buf().unwrap().len());
924             self.init_pkt(local_port, peer_port, uapi::VSOCK_OP_RW)
925                 .set_len(data.len() as u32);
926             self.pkt.buf_mut().unwrap()[..data.len()].copy_from_slice(data);
927             &mut self.pkt
928         }
929 
930         fn send(&mut self) {
931             self.muxer.send_pkt(&self.pkt).unwrap();
932         }
933 
934         fn recv(&mut self) {
935             self.muxer.recv_pkt(&mut self.pkt).unwrap();
936         }
937 
938         fn notify_muxer(&mut self) {
939             self.muxer.notify(epoll::Events::EPOLLIN);
940         }
941 
942         fn count_epoll_listeners(&self) -> (usize, usize) {
943             let mut local_lsn_count = 0usize;
944             let mut conn_lsn_count = 0usize;
945             for key in self.muxer.listener_map.values() {
946                 match key {
947                     EpollListener::LocalStream(_) => local_lsn_count += 1,
948                     EpollListener::Connection { .. } => conn_lsn_count += 1,
949                     _ => (),
950                 };
951             }
952             (local_lsn_count, conn_lsn_count)
953         }
954 
955         fn create_local_listener(&self, port: u32) -> LocalListener {
956             LocalListener::new(format!("{}_{}", self.muxer.host_sock_path, port))
957         }
958 
959         fn local_connect(&mut self, peer_port: u32) -> (UnixStream, u32) {
960             let (init_local_lsn_count, init_conn_lsn_count) = self.count_epoll_listeners();
961 
962             let mut stream = UnixStream::connect(self.muxer.host_sock_path.clone()).unwrap();
963             stream.set_nonblocking(true).unwrap();
964             // The muxer would now get notified of a new connection having arrived at its Unix
965             // socket, so it can accept it.
966             self.notify_muxer();
967 
968             // Just after having accepted a new local connection, the muxer should've added a new
969             // `LocalStream` listener to its `listener_map`.
970             let (local_lsn_count, _) = self.count_epoll_listeners();
971             assert_eq!(local_lsn_count, init_local_lsn_count + 1);
972 
973             let buf = format!("CONNECT {peer_port}\n");
974             stream.write_all(buf.as_bytes()).unwrap();
975             // The muxer would now get notified that data is available for reading from the locally
976             // initiated connection.
977             self.notify_muxer();
978 
979             // Successfully reading and parsing the connection request should have removed the
980             // LocalStream epoll listener and added a Connection epoll listener.
981             let (local_lsn_count, conn_lsn_count) = self.count_epoll_listeners();
982             assert_eq!(local_lsn_count, init_local_lsn_count);
983             assert_eq!(conn_lsn_count, init_conn_lsn_count + 1);
984 
985             // A LocalInit connection should've been added to the muxer connection map.  A new
986             // local port should also have been allocated for the new LocalInit connection.
987             let local_port = self.muxer.local_port_last;
988             let key = ConnMapKey {
989                 local_port,
990                 peer_port,
991             };
992             assert!(self.muxer.conn_map.contains_key(&key));
993             assert!(self.muxer.local_port_set.contains(&local_port));
994 
995             // A connection request for the peer should now be available from the muxer.
996             assert!(self.muxer.has_pending_rx());
997             self.recv();
998             assert_eq!(self.pkt.op(), uapi::VSOCK_OP_REQUEST);
999             assert_eq!(self.pkt.dst_port(), peer_port);
1000             assert_eq!(self.pkt.src_port(), local_port);
1001 
1002             self.init_pkt(local_port, peer_port, uapi::VSOCK_OP_RESPONSE);
1003             self.send();
1004 
1005             let mut buf = [0u8; 32];
1006             let len = stream.read(&mut buf[..]).unwrap();
1007             assert_eq!(&buf[..len], format!("OK {local_port}\n").as_bytes());
1008 
1009             (stream, local_port)
1010         }
1011     }
1012 
1013     struct LocalListener {
1014         path: PathBuf,
1015         sock: UnixListener,
1016     }
1017     impl LocalListener {
1018         fn new<P: AsRef<Path> + Clone>(path: P) -> Self {
1019             let path_buf = path.as_ref().to_path_buf();
1020             let sock = UnixListener::bind(path).unwrap();
1021             sock.set_nonblocking(true).unwrap();
1022             Self {
1023                 path: path_buf,
1024                 sock,
1025             }
1026         }
1027         fn accept(&mut self) -> UnixStream {
1028             let (stream, _) = self.sock.accept().unwrap();
1029             stream.set_nonblocking(true).unwrap();
1030             stream
1031         }
1032     }
1033     impl Drop for LocalListener {
1034         fn drop(&mut self) {
1035             std::fs::remove_file(&self.path).unwrap();
1036         }
1037     }
1038 
1039     #[test]
1040     fn test_muxer_epoll_listener() {
1041         let ctx = MuxerTestContext::new("muxer_epoll_listener");
1042         assert_eq!(ctx.muxer.get_polled_fd(), ctx.muxer.epoll_file.as_raw_fd());
1043         assert_eq!(ctx.muxer.get_polled_evset(), epoll::Events::EPOLLIN);
1044     }
1045 
1046     #[test]
1047     fn test_bad_peer_pkt() {
1048         const LOCAL_PORT: u32 = 1026;
1049         const PEER_PORT: u32 = 1025;
1050         const SOCK_DGRAM: u16 = 2;
1051 
1052         let mut ctx = MuxerTestContext::new("bad_peer_pkt");
1053         ctx.init_pkt(LOCAL_PORT, PEER_PORT, uapi::VSOCK_OP_REQUEST)
1054             .set_type(SOCK_DGRAM);
1055         ctx.send();
1056 
1057         // The guest sent a SOCK_DGRAM packet. Per the vsock spec, we need to reply with an RST
1058         // packet, since vsock only supports stream sockets.
1059         assert!(ctx.muxer.has_pending_rx());
1060         ctx.recv();
1061         assert_eq!(ctx.pkt.op(), uapi::VSOCK_OP_RST);
1062         assert_eq!(ctx.pkt.src_cid(), uapi::VSOCK_HOST_CID);
1063         assert_eq!(ctx.pkt.dst_cid(), PEER_CID as u64);
1064         assert_eq!(ctx.pkt.src_port(), LOCAL_PORT);
1065         assert_eq!(ctx.pkt.dst_port(), PEER_PORT);
1066 
1067         // Any orphan (i.e. without a connection), non-RST packet, should be replied to with an
1068         // RST.
1069         let bad_ops = [
1070             uapi::VSOCK_OP_RESPONSE,
1071             uapi::VSOCK_OP_CREDIT_REQUEST,
1072             uapi::VSOCK_OP_CREDIT_UPDATE,
1073             uapi::VSOCK_OP_SHUTDOWN,
1074             uapi::VSOCK_OP_RW,
1075         ];
1076         for op in bad_ops.iter() {
1077             ctx.init_pkt(LOCAL_PORT, PEER_PORT, *op);
1078             ctx.send();
1079             assert!(ctx.muxer.has_pending_rx());
1080             ctx.recv();
1081             assert_eq!(ctx.pkt.op(), uapi::VSOCK_OP_RST);
1082             assert_eq!(ctx.pkt.src_port(), LOCAL_PORT);
1083             assert_eq!(ctx.pkt.dst_port(), PEER_PORT);
1084         }
1085 
1086         // Any packet addressed to anything other than VSOCK_VHOST_CID should get dropped.
1087         assert!(!ctx.muxer.has_pending_rx());
1088         ctx.init_pkt(LOCAL_PORT, PEER_PORT, uapi::VSOCK_OP_REQUEST)
1089             .set_dst_cid(uapi::VSOCK_HOST_CID + 1);
1090         ctx.send();
1091         assert!(!ctx.muxer.has_pending_rx());
1092     }
1093 
1094     #[test]
1095     fn test_peer_connection() {
1096         const LOCAL_PORT: u32 = 1026;
1097         const PEER_PORT: u32 = 1025;
1098 
1099         let mut ctx = MuxerTestContext::new("peer_connection");
1100 
1101         // Test peer connection refused.
1102         ctx.init_pkt(LOCAL_PORT, PEER_PORT, uapi::VSOCK_OP_REQUEST);
1103         ctx.send();
1104         ctx.recv();
1105         assert_eq!(ctx.pkt.op(), uapi::VSOCK_OP_RST);
1106         assert_eq!(ctx.pkt.len(), 0);
1107         assert_eq!(ctx.pkt.src_cid(), uapi::VSOCK_HOST_CID);
1108         assert_eq!(ctx.pkt.dst_cid(), PEER_CID as u64);
1109         assert_eq!(ctx.pkt.src_port(), LOCAL_PORT);
1110         assert_eq!(ctx.pkt.dst_port(), PEER_PORT);
1111 
1112         // Test peer connection accepted.
1113         let mut listener = ctx.create_local_listener(LOCAL_PORT);
1114         ctx.init_pkt(LOCAL_PORT, PEER_PORT, uapi::VSOCK_OP_REQUEST);
1115         ctx.send();
1116         assert_eq!(ctx.muxer.conn_map.len(), 1);
1117         let mut stream = listener.accept();
1118         ctx.recv();
1119         assert_eq!(ctx.pkt.op(), uapi::VSOCK_OP_RESPONSE);
1120         assert_eq!(ctx.pkt.len(), 0);
1121         assert_eq!(ctx.pkt.src_cid(), uapi::VSOCK_HOST_CID);
1122         assert_eq!(ctx.pkt.dst_cid(), PEER_CID as u64);
1123         assert_eq!(ctx.pkt.src_port(), LOCAL_PORT);
1124         assert_eq!(ctx.pkt.dst_port(), PEER_PORT);
1125         let key = ConnMapKey {
1126             local_port: LOCAL_PORT,
1127             peer_port: PEER_PORT,
1128         };
1129         assert!(ctx.muxer.conn_map.contains_key(&key));
1130 
1131         // Test guest -> host data flow.
1132         let data = [1, 2, 3, 4];
1133         ctx.init_data_pkt(LOCAL_PORT, PEER_PORT, &data);
1134         ctx.send();
1135         let mut buf = vec![0; data.len()];
1136         stream.read_exact(buf.as_mut_slice()).unwrap();
1137         assert_eq!(buf.as_slice(), data);
1138 
1139         // Test host -> guest data flow.
1140         let data = [5u8, 6, 7, 8];
1141         stream.write_all(&data).unwrap();
1142 
1143         // When data is available on the local stream, an EPOLLIN event would normally be delivered
1144         // to the muxer's nested epoll FD. For testing only, we can fake that event notification
1145         // here.
1146         ctx.notify_muxer();
1147         // After being notified, the muxer should've figured out that RX data was available for one
1148         // of its connections, so it should now be reporting that it can fill in an RX packet.
1149         assert!(ctx.muxer.has_pending_rx());
1150         ctx.recv();
1151         assert_eq!(ctx.pkt.op(), uapi::VSOCK_OP_RW);
1152         assert_eq!(ctx.pkt.buf().unwrap()[..data.len()], data);
1153         assert_eq!(ctx.pkt.src_port(), LOCAL_PORT);
1154         assert_eq!(ctx.pkt.dst_port(), PEER_PORT);
1155 
1156         assert!(!ctx.muxer.has_pending_rx());
1157     }
1158 
1159     #[test]
1160     fn test_local_connection() {
1161         let mut ctx = MuxerTestContext::new("local_connection");
1162         let peer_port = 1025;
1163         let (mut stream, local_port) = ctx.local_connect(peer_port);
1164 
1165         // Test guest -> host data flow.
1166         let data = [1, 2, 3, 4];
1167         ctx.init_data_pkt(local_port, peer_port, &data);
1168         ctx.send();
1169 
1170         let mut buf = vec![0u8; data.len()];
1171         stream.read_exact(buf.as_mut_slice()).unwrap();
1172         assert_eq!(buf.as_slice(), &data);
1173 
1174         // Test host -> guest data flow.
1175         let data = [5, 6, 7, 8];
1176         stream.write_all(&data).unwrap();
1177         ctx.notify_muxer();
1178 
1179         assert!(ctx.muxer.has_pending_rx());
1180         ctx.recv();
1181         assert_eq!(ctx.pkt.op(), uapi::VSOCK_OP_RW);
1182         assert_eq!(ctx.pkt.src_port(), local_port);
1183         assert_eq!(ctx.pkt.dst_port(), peer_port);
1184         assert_eq!(ctx.pkt.buf().unwrap()[..data.len()], data);
1185     }
1186 
1187     #[test]
1188     fn test_local_close() {
1189         let peer_port = 1025;
1190         let mut ctx = MuxerTestContext::new("local_close");
1191         let local_port;
1192         {
1193             let (_stream, local_port_) = ctx.local_connect(peer_port);
1194             local_port = local_port_;
1195         }
1196         // Local var `_stream` was now dropped, thus closing the local stream. After the muxer gets
1197         // notified via EPOLLIN, it should attempt to gracefully shutdown the connection, issuing a
1198         // VSOCK_OP_SHUTDOWN with both no-more-send and no-more-recv indications set.
1199         ctx.notify_muxer();
1200         assert!(ctx.muxer.has_pending_rx());
1201         ctx.recv();
1202         assert_eq!(ctx.pkt.op(), uapi::VSOCK_OP_SHUTDOWN);
1203         assert_ne!(ctx.pkt.flags() & uapi::VSOCK_FLAGS_SHUTDOWN_SEND, 0);
1204         assert_ne!(ctx.pkt.flags() & uapi::VSOCK_FLAGS_SHUTDOWN_RCV, 0);
1205         assert_eq!(ctx.pkt.src_port(), local_port);
1206         assert_eq!(ctx.pkt.dst_port(), peer_port);
1207 
1208         // The connection should get removed (and its local port freed), after the peer replies
1209         // with an RST.
1210         ctx.init_pkt(local_port, peer_port, uapi::VSOCK_OP_RST);
1211         ctx.send();
1212         let key = ConnMapKey {
1213             local_port,
1214             peer_port,
1215         };
1216         assert!(!ctx.muxer.conn_map.contains_key(&key));
1217         assert!(!ctx.muxer.local_port_set.contains(&local_port));
1218     }
1219 
1220     #[test]
1221     fn test_peer_close() {
1222         let peer_port = 1025;
1223         let local_port = 1026;
1224         let mut ctx = MuxerTestContext::new("peer_close");
1225 
1226         let mut sock = ctx.create_local_listener(local_port);
1227         ctx.init_pkt(local_port, peer_port, uapi::VSOCK_OP_REQUEST);
1228         ctx.send();
1229         let mut stream = sock.accept();
1230 
1231         assert!(ctx.muxer.has_pending_rx());
1232         ctx.recv();
1233         assert_eq!(ctx.pkt.op(), uapi::VSOCK_OP_RESPONSE);
1234         assert_eq!(ctx.pkt.src_port(), local_port);
1235         assert_eq!(ctx.pkt.dst_port(), peer_port);
1236         let key = ConnMapKey {
1237             local_port,
1238             peer_port,
1239         };
1240         assert!(ctx.muxer.conn_map.contains_key(&key));
1241 
1242         // Emulate a full shutdown from the peer (no-more-send + no-more-recv).
1243         ctx.init_pkt(local_port, peer_port, uapi::VSOCK_OP_SHUTDOWN)
1244             .set_flag(uapi::VSOCK_FLAGS_SHUTDOWN_SEND)
1245             .set_flag(uapi::VSOCK_FLAGS_SHUTDOWN_RCV);
1246         ctx.send();
1247 
1248         // Now, the muxer should remove the connection from its map, and reply with an RST.
1249         assert!(ctx.muxer.has_pending_rx());
1250         ctx.recv();
1251         assert_eq!(ctx.pkt.op(), uapi::VSOCK_OP_RST);
1252         assert_eq!(ctx.pkt.src_port(), local_port);
1253         assert_eq!(ctx.pkt.dst_port(), peer_port);
1254         let key = ConnMapKey {
1255             local_port,
1256             peer_port,
1257         };
1258         assert!(!ctx.muxer.conn_map.contains_key(&key));
1259 
1260         // The muxer should also drop / close the local Unix socket for this connection.
1261         let mut buf = vec![0u8; 16];
1262         assert_eq!(stream.read(buf.as_mut_slice()).unwrap(), 0);
1263     }
1264 
1265     #[test]
1266     fn test_muxer_rxq() {
1267         let mut ctx = MuxerTestContext::new("muxer_rxq");
1268         let local_port = 1026;
1269         let peer_port_first = 1025;
1270         let mut listener = ctx.create_local_listener(local_port);
1271         let mut streams: Vec<UnixStream> = Vec::new();
1272 
1273         for peer_port in peer_port_first..peer_port_first + defs::MUXER_RXQ_SIZE {
1274             ctx.init_pkt(local_port, peer_port as u32, uapi::VSOCK_OP_REQUEST);
1275             ctx.send();
1276             streams.push(listener.accept());
1277         }
1278 
1279         // The muxer RX queue should now be full (with connection responses), but still
1280         // synchronized.
1281         assert!(ctx.muxer.rxq.is_synced());
1282 
1283         // One more queued reply should desync the RX queue.
1284         ctx.init_pkt(
1285             local_port,
1286             (peer_port_first + defs::MUXER_RXQ_SIZE) as u32,
1287             uapi::VSOCK_OP_REQUEST,
1288         );
1289         ctx.send();
1290         assert!(!ctx.muxer.rxq.is_synced());
1291 
1292         // With an out-of-sync queue, an RST should evict any non-RST packet from the queue, and
1293         // take its place. We'll check that by making sure that the last packet popped from the
1294         // queue is an RST.
1295         ctx.init_pkt(
1296             local_port + 1,
1297             peer_port_first as u32,
1298             uapi::VSOCK_OP_REQUEST,
1299         );
1300         ctx.send();
1301 
1302         for peer_port in peer_port_first..peer_port_first + defs::MUXER_RXQ_SIZE - 1 {
1303             ctx.recv();
1304             assert_eq!(ctx.pkt.op(), uapi::VSOCK_OP_RESPONSE);
1305             // The response order should hold. The evicted response should have been the last
1306             // enqueued.
1307             assert_eq!(ctx.pkt.dst_port(), peer_port as u32);
1308         }
1309         // There should be one more packet in the queue: the RST.
1310         assert_eq!(ctx.muxer.rxq.len(), 1);
1311         ctx.recv();
1312         assert_eq!(ctx.pkt.op(), uapi::VSOCK_OP_RST);
1313 
1314         // The queue should now be empty, but out-of-sync, so the muxer should report it has some
1315         // pending RX.
1316         assert!(ctx.muxer.rxq.is_empty());
1317         assert!(!ctx.muxer.rxq.is_synced());
1318         assert!(ctx.muxer.has_pending_rx());
1319 
1320         // The next recv should sync the queue back up. It should also yield one of the two
1321         // responses that are still left:
1322         // - the one that desynchronized the queue; and
1323         // - the one that got evicted by the RST.
1324         ctx.recv();
1325         assert!(ctx.muxer.rxq.is_synced());
1326         assert_eq!(ctx.pkt.op(), uapi::VSOCK_OP_RESPONSE);
1327 
1328         assert!(ctx.muxer.has_pending_rx());
1329         ctx.recv();
1330         assert_eq!(ctx.pkt.op(), uapi::VSOCK_OP_RESPONSE);
1331     }
1332 
1333     #[test]
1334     fn test_muxer_killq() {
1335         let mut ctx = MuxerTestContext::new("muxer_killq");
1336         let local_port = 1026;
1337         let peer_port_first = 1025;
1338         let peer_port_last = peer_port_first + defs::MUXER_KILLQ_SIZE;
1339         let mut listener = ctx.create_local_listener(local_port);
1340 
1341         for peer_port in peer_port_first..=peer_port_last {
1342             ctx.init_pkt(local_port, peer_port as u32, uapi::VSOCK_OP_REQUEST);
1343             ctx.send();
1344             ctx.notify_muxer();
1345             ctx.recv();
1346             assert_eq!(ctx.pkt.op(), uapi::VSOCK_OP_RESPONSE);
1347             assert_eq!(ctx.pkt.src_port(), local_port);
1348             assert_eq!(ctx.pkt.dst_port(), peer_port as u32);
1349             {
1350                 let _stream = listener.accept();
1351             }
1352             ctx.notify_muxer();
1353             ctx.recv();
1354             assert_eq!(ctx.pkt.op(), uapi::VSOCK_OP_SHUTDOWN);
1355             assert_eq!(ctx.pkt.src_port(), local_port);
1356             assert_eq!(ctx.pkt.dst_port(), peer_port as u32);
1357             // The kill queue should be synchronized, up until the `defs::MUXER_KILLQ_SIZE`th
1358             // connection we schedule for termination.
1359             assert_eq!(
1360                 ctx.muxer.killq.is_synced(),
1361                 peer_port < peer_port_first + defs::MUXER_KILLQ_SIZE
1362             );
1363         }
1364 
1365         assert!(!ctx.muxer.killq.is_synced());
1366         assert!(!ctx.muxer.has_pending_rx());
1367 
1368         // Wait for the kill timers to expire.
1369         std::thread::sleep(std::time::Duration::from_millis(
1370             csm_defs::CONN_SHUTDOWN_TIMEOUT_MS,
1371         ));
1372 
1373         // Trigger a kill queue sweep, by requesting a new connection.
1374         ctx.init_pkt(
1375             local_port,
1376             peer_port_last as u32 + 1,
1377             uapi::VSOCK_OP_REQUEST,
1378         );
1379         ctx.send();
1380 
1381         // After sweeping the kill queue, it should now be synced (assuming the RX queue is larger
1382         // than the kill queue, since an RST packet will be queued for each killed connection).
1383         assert!(ctx.muxer.killq.is_synced());
1384         assert!(ctx.muxer.has_pending_rx());
1385         // There should be `defs::MUXER_KILLQ_SIZE` RSTs in the RX queue, from terminating the
1386         // dying connections in the recent killq sweep.
1387         for _p in peer_port_first..peer_port_last {
1388             ctx.recv();
1389             assert_eq!(ctx.pkt.op(), uapi::VSOCK_OP_RST);
1390             assert_eq!(ctx.pkt.src_port(), local_port);
1391         }
1392 
1393         // There should be one more packet in the RX queue: the connection response our request
1394         // that triggered the kill queue sweep.
1395         ctx.recv();
1396         assert_eq!(ctx.pkt.op(), uapi::VSOCK_OP_RESPONSE);
1397         assert_eq!(ctx.pkt.dst_port(), peer_port_last as u32 + 1);
1398 
1399         assert!(!ctx.muxer.has_pending_rx());
1400     }
1401 
1402     #[test]
1403     fn test_regression_handshake() {
1404         // Address one of the issues found while fixing the following issue:
1405         // https://github.com/firecracker-microvm/firecracker/issues/1751
1406         // This test checks that the handshake message is not accounted for
1407         let mut ctx = MuxerTestContext::new("regression_handshake");
1408         let peer_port = 1025;
1409 
1410         // Create a local connection.
1411         let (_, local_port) = ctx.local_connect(peer_port);
1412 
1413         // Get the connection from the connection map.
1414         let key = ConnMapKey {
1415             local_port,
1416             peer_port,
1417         };
1418         let conn = ctx.muxer.conn_map.get_mut(&key).unwrap();
1419 
1420         // Check that fwd_cnt is 0 - "OK ..." was not accounted for.
1421         assert_eq!(conn.fwd_cnt().0, 0);
1422     }
1423 
1424     #[test]
1425     fn test_regression_rxq_pop() {
1426         // Address one of the issues found while fixing the following issue:
1427         // https://github.com/firecracker-microvm/firecracker/issues/1751
1428         // This test checks that a connection is not popped out of the muxer
1429         // rxq when multiple flags are set
1430         let mut ctx = MuxerTestContext::new("regression_rxq_pop");
1431         let peer_port = 1025;
1432         let (mut stream, local_port) = ctx.local_connect(peer_port);
1433 
1434         // Send some data.
1435         let data = [5u8, 6, 7, 8];
1436         stream.write_all(&data).unwrap();
1437         ctx.notify_muxer();
1438 
1439         // Get the connection from the connection map.
1440         let key = ConnMapKey {
1441             local_port,
1442             peer_port,
1443         };
1444         let conn = ctx.muxer.conn_map.get_mut(&key).unwrap();
1445 
1446         // Forcefully insert another flag.
1447         conn.insert_credit_update();
1448 
1449         // Call recv twice in order to check that the connection is still
1450         // in the rxq.
1451         assert!(ctx.muxer.has_pending_rx());
1452         ctx.recv();
1453         assert!(ctx.muxer.has_pending_rx());
1454         ctx.recv();
1455 
1456         // Since initially the connection had two flags set, now there should
1457         // not be any pending RX in the muxer.
1458         assert!(!ctx.muxer.has_pending_rx());
1459     }
1460 }
1461