xref: /cloud-hypervisor/virtio-devices/src/vhost_user/vu_common_ctrl.rs (revision 88a9f799449c04180c6b9a21d3b9c0c4b57e2bd6)
1 // Copyright 2019 Intel Corporation. All Rights Reserved.
2 // SPDX-License-Identifier: Apache-2.0
3 
4 use std::ffi;
5 use std::fs::File;
6 use std::os::unix::io::{AsRawFd, FromRawFd, RawFd};
7 use std::os::unix::net::UnixListener;
8 use std::sync::atomic::Ordering;
9 use std::sync::Arc;
10 use std::thread::sleep;
11 use std::time::{Duration, Instant};
12 
13 use vhost::vhost_kern::vhost_binding::{VHOST_F_LOG_ALL, VHOST_VRING_F_LOG};
14 use vhost::vhost_user::message::{
15     VhostUserHeaderFlag, VhostUserInflight, VhostUserProtocolFeatures, VhostUserVirtioFeatures,
16 };
17 use vhost::vhost_user::{
18     Frontend, FrontendReqHandler, VhostUserFrontend, VhostUserFrontendReqHandler,
19 };
20 use vhost::{VhostBackend, VhostUserDirtyLogRegion, VhostUserMemoryRegionInfo, VringConfigData};
21 use virtio_queue::{Descriptor, Queue, QueueT};
22 use vm_memory::{
23     Address, Error as MmapError, FileOffset, GuestAddress, GuestMemory, GuestMemoryRegion,
24 };
25 use vm_migration::protocol::MemoryRangeTable;
26 use vmm_sys_util::eventfd::EventFd;
27 
28 use super::{Error, Result};
29 use crate::vhost_user::Inflight;
30 use crate::{
31     get_host_address_range, GuestMemoryMmap, GuestRegionMmap, MmapRegion, VirtioInterrupt,
32     VirtioInterruptType,
33 };
34 
35 // Size of a dirty page for vhost-user.
36 const VHOST_LOG_PAGE: u64 = 0x1000;
37 
38 #[derive(Debug, Clone)]
39 pub struct VhostUserConfig {
40     pub socket: String,
41     pub num_queues: usize,
42     pub queue_size: u16,
43 }
44 
45 #[derive(Clone)]
46 struct VringInfo {
47     config_data: VringConfigData,
48     used_guest_addr: u64,
49 }
50 
51 #[derive(Clone)]
52 pub struct VhostUserHandle {
53     vu: Frontend,
54     ready: bool,
55     supports_migration: bool,
56     shm_log: Option<Arc<MmapRegion>>,
57     acked_features: u64,
58     vrings_info: Option<Vec<VringInfo>>,
59     queue_indexes: Vec<usize>,
60 }
61 
62 impl VhostUserHandle {
63     pub fn update_mem_table(&mut self, mem: &GuestMemoryMmap) -> Result<()> {
64         let mut regions: Vec<VhostUserMemoryRegionInfo> = Vec::new();
65         for region in mem.iter() {
66             let (mmap_handle, mmap_offset) = match region.file_offset() {
67                 Some(_file_offset) => (_file_offset.file().as_raw_fd(), _file_offset.start()),
68                 None => return Err(Error::VhostUserMemoryRegion(MmapError::NoMemoryRegion)),
69             };
70 
71             let vhost_user_net_reg = VhostUserMemoryRegionInfo {
72                 guest_phys_addr: region.start_addr().raw_value(),
73                 memory_size: region.len(),
74                 userspace_addr: region.as_ptr() as u64,
75                 mmap_offset,
76                 mmap_handle,
77             };
78 
79             regions.push(vhost_user_net_reg);
80         }
81 
82         self.vu
83             .set_mem_table(regions.as_slice())
84             .map_err(Error::VhostUserSetMemTable)?;
85 
86         Ok(())
87     }
88 
89     pub fn add_memory_region(&mut self, region: &Arc<GuestRegionMmap>) -> Result<()> {
90         let (mmap_handle, mmap_offset) = match region.file_offset() {
91             Some(file_offset) => (file_offset.file().as_raw_fd(), file_offset.start()),
92             None => return Err(Error::MissingRegionFd),
93         };
94 
95         let region = VhostUserMemoryRegionInfo {
96             guest_phys_addr: region.start_addr().raw_value(),
97             memory_size: region.len(),
98             userspace_addr: region.as_ptr() as u64,
99             mmap_offset,
100             mmap_handle,
101         };
102 
103         self.vu
104             .add_mem_region(&region)
105             .map_err(Error::VhostUserAddMemReg)
106     }
107 
108     pub fn negotiate_features_vhost_user(
109         &mut self,
110         avail_features: u64,
111         avail_protocol_features: VhostUserProtocolFeatures,
112     ) -> Result<(u64, u64)> {
113         // Set vhost-user owner.
114         self.vu.set_owner().map_err(Error::VhostUserSetOwner)?;
115 
116         // Get features from backend, do negotiation to get a feature collection which
117         // both VMM and backend support.
118         let backend_features = self
119             .vu
120             .get_features()
121             .map_err(Error::VhostUserGetFeatures)?;
122         let acked_features = avail_features & backend_features;
123 
124         let acked_protocol_features =
125             if acked_features & VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits() != 0 {
126                 let backend_protocol_features = self
127                     .vu
128                     .get_protocol_features()
129                     .map_err(Error::VhostUserGetProtocolFeatures)?;
130 
131                 let acked_protocol_features = avail_protocol_features & backend_protocol_features;
132 
133                 self.vu
134                     .set_protocol_features(acked_protocol_features)
135                     .map_err(Error::VhostUserSetProtocolFeatures)?;
136 
137                 acked_protocol_features
138             } else {
139                 VhostUserProtocolFeatures::empty()
140             };
141 
142         if avail_protocol_features.contains(VhostUserProtocolFeatures::REPLY_ACK)
143             && acked_protocol_features.contains(VhostUserProtocolFeatures::REPLY_ACK)
144         {
145             self.vu.set_hdr_flags(VhostUserHeaderFlag::NEED_REPLY);
146         }
147 
148         self.update_supports_migration(acked_features, acked_protocol_features.bits());
149 
150         Ok((acked_features, acked_protocol_features.bits()))
151     }
152 
153     #[allow(clippy::too_many_arguments)]
154     pub fn setup_vhost_user<S: VhostUserFrontendReqHandler>(
155         &mut self,
156         mem: &GuestMemoryMmap,
157         queues: Vec<(usize, Queue, EventFd)>,
158         virtio_interrupt: &Arc<dyn VirtioInterrupt>,
159         acked_features: u64,
160         backend_req_handler: &Option<FrontendReqHandler<S>>,
161         inflight: Option<&mut Inflight>,
162     ) -> Result<()> {
163         self.vu
164             .set_features(acked_features)
165             .map_err(Error::VhostUserSetFeatures)?;
166 
167         // Update internal value after it's been sent to the backend.
168         self.acked_features = acked_features;
169 
170         // Let's first provide the memory table to the backend.
171         self.update_mem_table(mem)?;
172 
173         // Send set_vring_num here, since it could tell backends, like SPDK,
174         // how many virt queues to be handled, which backend required to know
175         // at early stage.
176         for (queue_index, queue, _) in queues.iter() {
177             self.vu
178                 .set_vring_num(*queue_index, queue.size())
179                 .map_err(Error::VhostUserSetVringNum)?;
180         }
181 
182         // Setup for inflight I/O tracking shared memory.
183         if let Some(inflight) = inflight {
184             if inflight.fd.is_none() {
185                 let inflight_req_info = VhostUserInflight {
186                     mmap_size: 0,
187                     mmap_offset: 0,
188                     num_queues: queues.len() as u16,
189                     queue_size: queues[0].1.size(),
190                 };
191                 let (info, fd) = self
192                     .vu
193                     .get_inflight_fd(&inflight_req_info)
194                     .map_err(Error::VhostUserGetInflight)?;
195                 inflight.info = info;
196                 inflight.fd = Some(fd);
197             }
198             // Unwrapping the inflight fd is safe here since we know it can't be None.
199             self.vu
200                 .set_inflight_fd(&inflight.info, inflight.fd.as_ref().unwrap().as_raw_fd())
201                 .map_err(Error::VhostUserSetInflight)?;
202         }
203 
204         let mut vrings_info = Vec::new();
205         for (queue_index, queue, queue_evt) in queues.iter() {
206             let actual_size: usize = queue.size().into();
207 
208             let config_data = VringConfigData {
209                 queue_max_size: queue.max_size(),
210                 queue_size: queue.size(),
211                 flags: 0u32,
212                 desc_table_addr: get_host_address_range(
213                     mem,
214                     GuestAddress(queue.desc_table()),
215                     actual_size * std::mem::size_of::<Descriptor>(),
216                 )
217                 .ok_or(Error::DescriptorTableAddress)? as u64,
218                 // The used ring is {flags: u16; idx: u16; virtq_used_elem [{id: u16, len: u16}; actual_size]},
219                 // i.e. 4 + (4 + 4) * actual_size.
220                 used_ring_addr: get_host_address_range(
221                     mem,
222                     GuestAddress(queue.used_ring()),
223                     4 + actual_size * 8,
224                 )
225                 .ok_or(Error::UsedAddress)? as u64,
226                 // The used ring is {flags: u16; idx: u16; elem [u16; actual_size]},
227                 // i.e. 4 + (2) * actual_size.
228                 avail_ring_addr: get_host_address_range(
229                     mem,
230                     GuestAddress(queue.avail_ring()),
231                     4 + actual_size * 2,
232                 )
233                 .ok_or(Error::AvailAddress)? as u64,
234                 log_addr: None,
235             };
236 
237             vrings_info.push(VringInfo {
238                 config_data,
239                 used_guest_addr: queue.used_ring(),
240             });
241 
242             self.vu
243                 .set_vring_addr(*queue_index, &config_data)
244                 .map_err(Error::VhostUserSetVringAddr)?;
245             self.vu
246                 .set_vring_base(
247                     *queue_index,
248                     queue
249                         .avail_idx(mem, Ordering::Acquire)
250                         .map_err(Error::GetAvailableIndex)?
251                         .0,
252                 )
253                 .map_err(Error::VhostUserSetVringBase)?;
254 
255             if let Some(eventfd) =
256                 virtio_interrupt.notifier(VirtioInterruptType::Queue(*queue_index as u16))
257             {
258                 self.vu
259                     .set_vring_call(*queue_index, &eventfd)
260                     .map_err(Error::VhostUserSetVringCall)?;
261             }
262 
263             self.vu
264                 .set_vring_kick(*queue_index, queue_evt)
265                 .map_err(Error::VhostUserSetVringKick)?;
266 
267             self.queue_indexes.push(*queue_index);
268         }
269 
270         self.enable_vhost_user_vrings(self.queue_indexes.clone(), true)?;
271 
272         if let Some(backend_req_handler) = backend_req_handler {
273             self.vu
274                 .set_backend_request_fd(&backend_req_handler.get_tx_raw_fd())
275                 .map_err(Error::VhostUserSetBackendRequestFd)?;
276         }
277 
278         self.vrings_info = Some(vrings_info);
279         self.ready = true;
280 
281         Ok(())
282     }
283 
284     fn enable_vhost_user_vrings(&mut self, queue_indexes: Vec<usize>, enable: bool) -> Result<()> {
285         for queue_index in queue_indexes {
286             self.vu
287                 .set_vring_enable(queue_index, enable)
288                 .map_err(Error::VhostUserSetVringEnable)?;
289         }
290 
291         Ok(())
292     }
293 
294     pub fn reset_vhost_user(&mut self) -> Result<()> {
295         for queue_index in self.queue_indexes.drain(..) {
296             self.vu
297                 .set_vring_enable(queue_index, false)
298                 .map_err(Error::VhostUserSetVringEnable)?;
299 
300             let _ = self
301                 .vu
302                 .get_vring_base(queue_index)
303                 .map_err(Error::VhostUserGetVringBase)?;
304         }
305 
306         Ok(())
307     }
308 
309     pub fn set_protocol_features_vhost_user(
310         &mut self,
311         acked_features: u64,
312         acked_protocol_features: u64,
313     ) -> Result<()> {
314         self.vu.set_owner().map_err(Error::VhostUserSetOwner)?;
315         self.vu
316             .get_features()
317             .map_err(Error::VhostUserGetFeatures)?;
318 
319         if acked_features & VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits() != 0 {
320             if let Some(acked_protocol_features) =
321                 VhostUserProtocolFeatures::from_bits(acked_protocol_features)
322             {
323                 self.vu
324                     .set_protocol_features(acked_protocol_features)
325                     .map_err(Error::VhostUserSetProtocolFeatures)?;
326 
327                 if acked_protocol_features.contains(VhostUserProtocolFeatures::REPLY_ACK) {
328                     self.vu.set_hdr_flags(VhostUserHeaderFlag::NEED_REPLY);
329                 }
330             }
331         }
332 
333         self.update_supports_migration(acked_features, acked_protocol_features);
334 
335         Ok(())
336     }
337 
338     #[allow(clippy::too_many_arguments)]
339     pub fn reinitialize_vhost_user<S: VhostUserFrontendReqHandler>(
340         &mut self,
341         mem: &GuestMemoryMmap,
342         queues: Vec<(usize, Queue, EventFd)>,
343         virtio_interrupt: &Arc<dyn VirtioInterrupt>,
344         acked_features: u64,
345         acked_protocol_features: u64,
346         backend_req_handler: &Option<FrontendReqHandler<S>>,
347         inflight: Option<&mut Inflight>,
348     ) -> Result<()> {
349         self.set_protocol_features_vhost_user(acked_features, acked_protocol_features)?;
350 
351         self.setup_vhost_user(
352             mem,
353             queues,
354             virtio_interrupt,
355             acked_features,
356             backend_req_handler,
357             inflight,
358         )
359     }
360 
361     pub fn connect_vhost_user(
362         server: bool,
363         socket_path: &str,
364         num_queues: u64,
365         unlink_socket: bool,
366     ) -> Result<Self> {
367         if server {
368             if unlink_socket {
369                 std::fs::remove_file(socket_path).map_err(Error::RemoveSocketPath)?;
370             }
371 
372             info!("Binding vhost-user listener...");
373             let listener = UnixListener::bind(socket_path).map_err(Error::BindSocket)?;
374             info!("Waiting for incoming vhost-user connection...");
375             let (stream, _) = listener.accept().map_err(Error::AcceptConnection)?;
376 
377             Ok(VhostUserHandle {
378                 vu: Frontend::from_stream(stream, num_queues),
379                 ready: false,
380                 supports_migration: false,
381                 shm_log: None,
382                 acked_features: 0,
383                 vrings_info: None,
384                 queue_indexes: Vec::new(),
385             })
386         } else {
387             let now = Instant::now();
388 
389             // Retry connecting for a full minute
390             let err = loop {
391                 let err = match Frontend::connect(socket_path, num_queues) {
392                     Ok(m) => {
393                         return Ok(VhostUserHandle {
394                             vu: m,
395                             ready: false,
396                             supports_migration: false,
397                             shm_log: None,
398                             acked_features: 0,
399                             vrings_info: None,
400                             queue_indexes: Vec::new(),
401                         })
402                     }
403                     Err(e) => e,
404                 };
405                 sleep(Duration::from_millis(100));
406 
407                 if now.elapsed().as_secs() >= 60 {
408                     break err;
409                 }
410             };
411 
412             error!(
413                 "Failed connecting the backend after trying for 1 minute: {:?}",
414                 err
415             );
416             Err(Error::VhostUserConnect)
417         }
418     }
419 
420     pub fn socket_handle(&mut self) -> &mut Frontend {
421         &mut self.vu
422     }
423 
424     pub fn pause_vhost_user(&mut self) -> Result<()> {
425         if self.ready {
426             self.enable_vhost_user_vrings(self.queue_indexes.clone(), false)?;
427         }
428 
429         Ok(())
430     }
431 
432     pub fn resume_vhost_user(&mut self) -> Result<()> {
433         if self.ready {
434             self.enable_vhost_user_vrings(self.queue_indexes.clone(), true)?;
435         }
436 
437         Ok(())
438     }
439 
440     fn update_supports_migration(&mut self, acked_features: u64, acked_protocol_features: u64) {
441         if (acked_features & u64::from(vhost::vhost_kern::vhost_binding::VHOST_F_LOG_ALL) != 0)
442             && (acked_protocol_features & VhostUserProtocolFeatures::LOG_SHMFD.bits() != 0)
443         {
444             self.supports_migration = true;
445         }
446     }
447 
448     fn update_log_base(&mut self, last_ram_addr: u64) -> Result<Option<Arc<MmapRegion>>> {
449         // Create the memfd
450         let fd = memfd_create(
451             &ffi::CString::new("vhost_user_dirty_log").unwrap(),
452             libc::MFD_CLOEXEC | libc::MFD_ALLOW_SEALING,
453         )
454         .map_err(Error::MemfdCreate)?;
455 
456         // SAFETY: we checked the file descriptor is valid
457         let file = unsafe { File::from_raw_fd(fd) };
458         // The size of the memory mapping corresponds to the size of a bitmap
459         // covering all guest pages for addresses from 0 to the last physical
460         // address in guest RAM.
461         // A page is always 4kiB from a vhost-user perspective, and each bit is
462         // a page. That's how we can compute mmap_size from the last address.
463         let mmap_size = (last_ram_addr / (VHOST_LOG_PAGE * 8)) + 1;
464         let mmap_handle = file.as_raw_fd();
465 
466         // Set shm_log region size
467         file.set_len(mmap_size).map_err(Error::SetFileSize)?;
468 
469         // Set the seals
470         // SAFETY: FFI call with valid arguments
471         let res = unsafe {
472             libc::fcntl(
473                 file.as_raw_fd(),
474                 libc::F_ADD_SEALS,
475                 libc::F_SEAL_GROW | libc::F_SEAL_SHRINK | libc::F_SEAL_SEAL,
476             )
477         };
478         if res < 0 {
479             return Err(Error::SetSeals(std::io::Error::last_os_error()));
480         }
481 
482         // Mmap shm_log region
483         let region = MmapRegion::build(
484             Some(FileOffset::new(file, 0)),
485             mmap_size as usize,
486             libc::PROT_READ | libc::PROT_WRITE,
487             libc::MAP_SHARED,
488         )
489         .map_err(Error::NewMmapRegion)?;
490 
491         // Make sure we hold onto the region to prevent the mapping from being
492         // released.
493         let old_region = self.shm_log.replace(Arc::new(region));
494 
495         // Send the shm_log fd over to the backend
496         let log = VhostUserDirtyLogRegion {
497             mmap_size,
498             mmap_offset: 0,
499             mmap_handle,
500         };
501         self.vu
502             .set_log_base(0, Some(log))
503             .map_err(Error::VhostUserSetLogBase)?;
504 
505         Ok(old_region)
506     }
507 
508     fn set_vring_logging(&mut self, enable: bool) -> Result<()> {
509         if let Some(vrings_info) = &self.vrings_info {
510             for (i, vring_info) in vrings_info.iter().enumerate() {
511                 let mut config_data = vring_info.config_data;
512                 config_data.flags = if enable { 1 << VHOST_VRING_F_LOG } else { 0 };
513                 config_data.log_addr = if enable {
514                     Some(vring_info.used_guest_addr)
515                 } else {
516                     None
517                 };
518 
519                 self.vu
520                     .set_vring_addr(i, &config_data)
521                     .map_err(Error::VhostUserSetVringAddr)?;
522             }
523         }
524 
525         Ok(())
526     }
527 
528     pub fn start_dirty_log(&mut self, last_ram_addr: u64) -> Result<()> {
529         if !self.supports_migration {
530             return Err(Error::MigrationNotSupported);
531         }
532 
533         // Set the shm log region
534         self.update_log_base(last_ram_addr)?;
535 
536         // Enable VHOST_F_LOG_ALL feature
537         let features = self.acked_features | (1 << VHOST_F_LOG_ALL);
538         self.vu
539             .set_features(features)
540             .map_err(Error::VhostUserSetFeatures)?;
541 
542         // Enable dirty page logging of used ring for all queues
543         self.set_vring_logging(true)
544     }
545 
546     pub fn stop_dirty_log(&mut self) -> Result<()> {
547         if !self.supports_migration {
548             return Err(Error::MigrationNotSupported);
549         }
550 
551         // Disable dirty page logging of used ring for all queues
552         self.set_vring_logging(false)?;
553 
554         // Disable VHOST_F_LOG_ALL feature
555         self.vu
556             .set_features(self.acked_features)
557             .map_err(Error::VhostUserSetFeatures)?;
558 
559         // This is important here since the log region goes out of scope,
560         // invoking the Drop trait, hence unmapping the memory.
561         self.shm_log = None;
562 
563         Ok(())
564     }
565 
566     pub fn dirty_log(&mut self, last_ram_addr: u64) -> Result<MemoryRangeTable> {
567         // The log region is updated by creating a new region that is sent to
568         // the backend. This ensures the backend stops logging to the previous
569         // region. The previous region is returned and processed to create the
570         // bitmap representing the dirty pages.
571         if let Some(region) = self.update_log_base(last_ram_addr)? {
572             // Be careful with the size, as it was based on u8, meaning we must
573             // divide it by 8.
574             let len = region.size() / 8;
575             // SAFETY: region is of size len
576             let bitmap = unsafe {
577                 // Cast the pointer to u64
578                 let ptr = region.as_ptr() as *const u64;
579                 std::slice::from_raw_parts(ptr, len).to_vec()
580             };
581             Ok(MemoryRangeTable::from_bitmap(bitmap, 0, 4096))
582         } else {
583             Err(Error::MissingShmLogRegion)
584         }
585     }
586 }
587 
588 fn memfd_create(name: &ffi::CStr, flags: u32) -> std::result::Result<RawFd, std::io::Error> {
589     // SAFETY: FFI call with valid arguments
590     let res = unsafe { libc::syscall(libc::SYS_memfd_create, name.as_ptr(), flags) };
591 
592     if res < 0 {
593         Err(std::io::Error::last_os_error())
594     } else {
595         Ok(res as RawFd)
596     }
597 }
598