xref: /cloud-hypervisor/virtio-devices/src/vhost_user/vu_common_ctrl.rs (revision 9af2968a7dc47b89bf07ea9dc5e735084efcfa3a)
1 // Copyright 2019 Intel Corporation. All Rights Reserved.
2 // SPDX-License-Identifier: Apache-2.0
3 
4 use super::super::{Descriptor, Queue};
5 use super::{Error, Result};
6 use crate::vhost_user::Inflight;
7 use crate::{get_host_address_range, VirtioInterrupt, VirtioInterruptType};
8 use crate::{GuestMemoryMmap, GuestRegionMmap};
9 use std::convert::TryInto;
10 use std::os::unix::io::AsRawFd;
11 use std::os::unix::net::UnixListener;
12 use std::sync::Arc;
13 use std::thread::sleep;
14 use std::time::{Duration, Instant};
15 use std::vec::Vec;
16 use vhost::vhost_user::message::{
17     VhostUserHeaderFlag, VhostUserInflight, VhostUserProtocolFeatures, VhostUserVirtioFeatures,
18 };
19 use vhost::vhost_user::{Master, MasterReqHandler, VhostUserMaster, VhostUserMasterReqHandler};
20 use vhost::{VhostBackend, VhostUserMemoryRegionInfo, VringConfigData};
21 use vm_memory::{Address, Error as MmapError, GuestMemory, GuestMemoryRegion};
22 use vmm_sys_util::eventfd::EventFd;
23 
24 #[derive(Debug, Clone)]
25 pub struct VhostUserConfig {
26     pub socket: String,
27     pub num_queues: usize,
28     pub queue_size: u16,
29 }
30 
31 pub fn update_mem_table(vu: &mut Master, mem: &GuestMemoryMmap) -> Result<()> {
32     let mut regions: Vec<VhostUserMemoryRegionInfo> = Vec::new();
33     for region in mem.iter() {
34         let (mmap_handle, mmap_offset) = match region.file_offset() {
35             Some(_file_offset) => (_file_offset.file().as_raw_fd(), _file_offset.start()),
36             None => return Err(Error::VhostUserMemoryRegion(MmapError::NoMemoryRegion)),
37         };
38 
39         let vhost_user_net_reg = VhostUserMemoryRegionInfo {
40             guest_phys_addr: region.start_addr().raw_value(),
41             memory_size: region.len() as u64,
42             userspace_addr: region.as_ptr() as u64,
43             mmap_offset,
44             mmap_handle,
45         };
46 
47         regions.push(vhost_user_net_reg);
48     }
49 
50     vu.set_mem_table(regions.as_slice())
51         .map_err(Error::VhostUserSetMemTable)?;
52 
53     Ok(())
54 }
55 
56 pub fn add_memory_region(vu: &mut Master, region: &Arc<GuestRegionMmap>) -> Result<()> {
57     let (mmap_handle, mmap_offset) = match region.file_offset() {
58         Some(file_offset) => (file_offset.file().as_raw_fd(), file_offset.start()),
59         None => return Err(Error::MissingRegionFd),
60     };
61 
62     let region = VhostUserMemoryRegionInfo {
63         guest_phys_addr: region.start_addr().raw_value(),
64         memory_size: region.len() as u64,
65         userspace_addr: region.as_ptr() as u64,
66         mmap_offset,
67         mmap_handle,
68     };
69 
70     vu.add_mem_region(&region)
71         .map_err(Error::VhostUserAddMemReg)
72 }
73 
74 pub fn negotiate_features_vhost_user(
75     vu: &mut Master,
76     avail_features: u64,
77     avail_protocol_features: VhostUserProtocolFeatures,
78 ) -> Result<(u64, u64)> {
79     // Set vhost-user owner.
80     vu.set_owner().map_err(Error::VhostUserSetOwner)?;
81 
82     // Get features from backend, do negotiation to get a feature collection which
83     // both VMM and backend support.
84     let backend_features = vu.get_features().map_err(Error::VhostUserGetFeatures)?;
85     let acked_features = avail_features & backend_features;
86 
87     let acked_protocol_features =
88         if acked_features & VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits() != 0 {
89             let backend_protocol_features = vu
90                 .get_protocol_features()
91                 .map_err(Error::VhostUserGetProtocolFeatures)?;
92 
93             let acked_protocol_features = avail_protocol_features & backend_protocol_features;
94 
95             vu.set_protocol_features(acked_protocol_features)
96                 .map_err(Error::VhostUserSetProtocolFeatures)?;
97 
98             acked_protocol_features
99         } else {
100             VhostUserProtocolFeatures::empty()
101         };
102 
103     if avail_protocol_features.contains(VhostUserProtocolFeatures::REPLY_ACK)
104         && acked_protocol_features.contains(VhostUserProtocolFeatures::REPLY_ACK)
105     {
106         vu.set_hdr_flags(VhostUserHeaderFlag::NEED_REPLY);
107     }
108 
109     Ok((acked_features, acked_protocol_features.bits()))
110 }
111 
112 #[allow(clippy::too_many_arguments)]
113 pub fn setup_vhost_user<S: VhostUserMasterReqHandler>(
114     vu: &mut Master,
115     mem: &GuestMemoryMmap,
116     queues: Vec<Queue>,
117     queue_evts: Vec<EventFd>,
118     virtio_interrupt: &Arc<dyn VirtioInterrupt>,
119     acked_features: u64,
120     slave_req_handler: &Option<MasterReqHandler<S>>,
121     inflight: Option<&mut Inflight>,
122 ) -> Result<()> {
123     vu.set_features(acked_features)
124         .map_err(Error::VhostUserSetFeatures)?;
125 
126     // Let's first provide the memory table to the backend.
127     update_mem_table(vu, mem)?;
128 
129     // Setup for inflight I/O tracking shared memory.
130     if let Some(inflight) = inflight {
131         if inflight.fd.is_none() {
132             let inflight_req_info = VhostUserInflight {
133                 mmap_size: 0,
134                 mmap_offset: 0,
135                 num_queues: queues.len() as u16,
136                 queue_size: queues[0].actual_size(),
137             };
138             let (info, fd) = vu
139                 .get_inflight_fd(&inflight_req_info)
140                 .map_err(Error::VhostUserGetInflight)?;
141             inflight.info = info;
142             inflight.fd = Some(fd);
143         }
144         // Unwrapping the inflight fd is safe here since we know it can't be None.
145         vu.set_inflight_fd(&inflight.info, inflight.fd.as_ref().unwrap().as_raw_fd())
146             .map_err(Error::VhostUserSetInflight)?;
147     }
148 
149     for (queue_index, queue) in queues.into_iter().enumerate() {
150         let actual_size: usize = queue.actual_size().try_into().unwrap();
151 
152         vu.set_vring_num(queue_index, queue.actual_size())
153             .map_err(Error::VhostUserSetVringNum)?;
154 
155         let config_data = VringConfigData {
156             queue_max_size: queue.get_max_size(),
157             queue_size: queue.actual_size(),
158             flags: 0u32,
159             desc_table_addr: get_host_address_range(
160                 mem,
161                 queue.desc_table,
162                 actual_size * std::mem::size_of::<Descriptor>(),
163             )
164             .ok_or(Error::DescriptorTableAddress)? as u64,
165             // The used ring is {flags: u16; idx: u16; virtq_used_elem [{id: u16, len: u16}; actual_size]},
166             // i.e. 4 + (4 + 4) * actual_size.
167             used_ring_addr: get_host_address_range(mem, queue.used_ring, 4 + actual_size * 8)
168                 .ok_or(Error::UsedAddress)? as u64,
169             // The used ring is {flags: u16; idx: u16; elem [u16; actual_size]},
170             // i.e. 4 + (2) * actual_size.
171             avail_ring_addr: get_host_address_range(mem, queue.avail_ring, 4 + actual_size * 2)
172                 .ok_or(Error::AvailAddress)? as u64,
173             log_addr: None,
174         };
175 
176         vu.set_vring_addr(queue_index, &config_data)
177             .map_err(Error::VhostUserSetVringAddr)?;
178         vu.set_vring_base(
179             queue_index,
180             queue
181                 .avail_index_from_memory(mem)
182                 .map_err(Error::GetAvailableIndex)?,
183         )
184         .map_err(Error::VhostUserSetVringBase)?;
185 
186         if let Some(eventfd) = virtio_interrupt.notifier(&VirtioInterruptType::Queue, Some(&queue))
187         {
188             vu.set_vring_call(queue_index, &eventfd)
189                 .map_err(Error::VhostUserSetVringCall)?;
190         }
191 
192         vu.set_vring_kick(queue_index, &queue_evts[queue_index])
193             .map_err(Error::VhostUserSetVringKick)?;
194 
195         vu.set_vring_enable(queue_index, true)
196             .map_err(Error::VhostUserSetVringEnable)?;
197     }
198 
199     if let Some(slave_req_handler) = slave_req_handler {
200         vu.set_slave_request_fd(&slave_req_handler.get_tx_raw_fd())
201             .map_err(Error::VhostUserSetSlaveRequestFd)
202     } else {
203         Ok(())
204     }
205 }
206 
207 pub fn reset_vhost_user(vu: &mut Master, num_queues: usize) -> Result<()> {
208     for queue_index in 0..num_queues {
209         // Disable the vrings.
210         vu.set_vring_enable(queue_index, false)
211             .map_err(Error::VhostUserSetVringEnable)?;
212     }
213 
214     // Reset the owner.
215     vu.reset_owner().map_err(Error::VhostUserResetOwner)
216 }
217 
218 #[allow(clippy::too_many_arguments)]
219 pub fn reinitialize_vhost_user<S: VhostUserMasterReqHandler>(
220     vu: &mut Master,
221     mem: &GuestMemoryMmap,
222     queues: Vec<Queue>,
223     queue_evts: Vec<EventFd>,
224     virtio_interrupt: &Arc<dyn VirtioInterrupt>,
225     acked_features: u64,
226     acked_protocol_features: u64,
227     slave_req_handler: &Option<MasterReqHandler<S>>,
228     inflight: Option<&mut Inflight>,
229 ) -> Result<()> {
230     vu.set_owner().map_err(Error::VhostUserSetOwner)?;
231     vu.get_features().map_err(Error::VhostUserGetFeatures)?;
232 
233     if acked_features & VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits() != 0 {
234         if let Some(acked_protocol_features) =
235             VhostUserProtocolFeatures::from_bits(acked_protocol_features)
236         {
237             vu.set_protocol_features(acked_protocol_features)
238                 .map_err(Error::VhostUserSetProtocolFeatures)?;
239 
240             if acked_protocol_features.contains(VhostUserProtocolFeatures::REPLY_ACK) {
241                 vu.set_hdr_flags(VhostUserHeaderFlag::NEED_REPLY);
242             }
243         }
244     }
245 
246     setup_vhost_user(
247         vu,
248         mem,
249         queues,
250         queue_evts,
251         virtio_interrupt,
252         acked_features,
253         slave_req_handler,
254         inflight,
255     )
256 }
257 
258 pub fn connect_vhost_user(
259     server: bool,
260     socket_path: &str,
261     num_queues: u64,
262     unlink_socket: bool,
263 ) -> Result<Master> {
264     if server {
265         if unlink_socket {
266             std::fs::remove_file(socket_path).map_err(Error::RemoveSocketPath)?;
267         }
268 
269         info!("Binding vhost-user listener...");
270         let listener = UnixListener::bind(socket_path).map_err(Error::BindSocket)?;
271         info!("Waiting for incoming vhost-user connection...");
272         let (stream, _) = listener.accept().map_err(Error::AcceptConnection)?;
273 
274         Ok(Master::from_stream(stream, num_queues))
275     } else {
276         let now = Instant::now();
277 
278         // Retry connecting for a full minute
279         let err = loop {
280             let err = match Master::connect(socket_path, num_queues) {
281                 Ok(m) => return Ok(m),
282                 Err(e) => e,
283             };
284             sleep(Duration::from_millis(100));
285 
286             if now.elapsed().as_secs() >= 60 {
287                 break err;
288             }
289         };
290 
291         error!(
292             "Failed connecting the backend after trying for 1 minute: {:?}",
293             err
294         );
295         Err(Error::VhostUserConnect)
296     }
297 }
298