xref: /cloud-hypervisor/virtio-devices/src/vhost_user/vu_common_ctrl.rs (revision fa7a000dbe9637eb256af18ae8c3c4a8d5bf9c8f)
1 // Copyright 2019 Intel Corporation. All Rights Reserved.
2 // SPDX-License-Identifier: Apache-2.0
3 
4 use super::{Error, Result};
5 use crate::vhost_user::Inflight;
6 use crate::{
7     get_host_address_range, GuestMemoryMmap, GuestRegionMmap, MmapRegion, VirtioInterrupt,
8     VirtioInterruptType,
9 };
10 use std::ffi;
11 use std::fs::File;
12 use std::os::unix::io::{AsRawFd, FromRawFd, RawFd};
13 use std::os::unix::net::UnixListener;
14 use std::sync::atomic::Ordering;
15 use std::sync::Arc;
16 use std::thread::sleep;
17 use std::time::{Duration, Instant};
18 use vhost::vhost_kern::vhost_binding::{VHOST_F_LOG_ALL, VHOST_VRING_F_LOG};
19 use vhost::vhost_user::message::{
20     VhostUserHeaderFlag, VhostUserInflight, VhostUserProtocolFeatures, VhostUserVirtioFeatures,
21 };
22 use vhost::vhost_user::{
23     Frontend, FrontendReqHandler, VhostUserFrontend, VhostUserFrontendReqHandler,
24 };
25 use vhost::{VhostBackend, VhostUserDirtyLogRegion, VhostUserMemoryRegionInfo, VringConfigData};
26 use virtio_queue::{Descriptor, Queue, QueueT};
27 use vm_memory::{
28     Address, Error as MmapError, FileOffset, GuestAddress, GuestMemory, GuestMemoryRegion,
29 };
30 use vm_migration::protocol::MemoryRangeTable;
31 use vmm_sys_util::eventfd::EventFd;
32 
33 // Size of a dirty page for vhost-user.
34 const VHOST_LOG_PAGE: u64 = 0x1000;
35 
36 #[derive(Debug, Clone)]
37 pub struct VhostUserConfig {
38     pub socket: String,
39     pub num_queues: usize,
40     pub queue_size: u16,
41 }
42 
43 #[derive(Clone)]
44 struct VringInfo {
45     config_data: VringConfigData,
46     used_guest_addr: u64,
47 }
48 
49 #[derive(Clone)]
50 pub struct VhostUserHandle {
51     vu: Frontend,
52     ready: bool,
53     supports_migration: bool,
54     shm_log: Option<Arc<MmapRegion>>,
55     acked_features: u64,
56     vrings_info: Option<Vec<VringInfo>>,
57     queue_indexes: Vec<usize>,
58 }
59 
60 impl VhostUserHandle {
61     pub fn update_mem_table(&mut self, mem: &GuestMemoryMmap) -> Result<()> {
62         let mut regions: Vec<VhostUserMemoryRegionInfo> = Vec::new();
63         for region in mem.iter() {
64             let (mmap_handle, mmap_offset) = match region.file_offset() {
65                 Some(_file_offset) => (_file_offset.file().as_raw_fd(), _file_offset.start()),
66                 None => return Err(Error::VhostUserMemoryRegion(MmapError::NoMemoryRegion)),
67             };
68 
69             let vhost_user_net_reg = VhostUserMemoryRegionInfo {
70                 guest_phys_addr: region.start_addr().raw_value(),
71                 memory_size: region.len(),
72                 userspace_addr: region.as_ptr() as u64,
73                 mmap_offset,
74                 mmap_handle,
75             };
76 
77             regions.push(vhost_user_net_reg);
78         }
79 
80         self.vu
81             .set_mem_table(regions.as_slice())
82             .map_err(Error::VhostUserSetMemTable)?;
83 
84         Ok(())
85     }
86 
87     pub fn add_memory_region(&mut self, region: &Arc<GuestRegionMmap>) -> Result<()> {
88         let (mmap_handle, mmap_offset) = match region.file_offset() {
89             Some(file_offset) => (file_offset.file().as_raw_fd(), file_offset.start()),
90             None => return Err(Error::MissingRegionFd),
91         };
92 
93         let region = VhostUserMemoryRegionInfo {
94             guest_phys_addr: region.start_addr().raw_value(),
95             memory_size: region.len(),
96             userspace_addr: region.as_ptr() as u64,
97             mmap_offset,
98             mmap_handle,
99         };
100 
101         self.vu
102             .add_mem_region(&region)
103             .map_err(Error::VhostUserAddMemReg)
104     }
105 
106     pub fn negotiate_features_vhost_user(
107         &mut self,
108         avail_features: u64,
109         avail_protocol_features: VhostUserProtocolFeatures,
110     ) -> Result<(u64, u64)> {
111         // Set vhost-user owner.
112         self.vu.set_owner().map_err(Error::VhostUserSetOwner)?;
113 
114         // Get features from backend, do negotiation to get a feature collection which
115         // both VMM and backend support.
116         let backend_features = self
117             .vu
118             .get_features()
119             .map_err(Error::VhostUserGetFeatures)?;
120         let acked_features = avail_features & backend_features;
121 
122         let acked_protocol_features =
123             if acked_features & VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits() != 0 {
124                 let backend_protocol_features = self
125                     .vu
126                     .get_protocol_features()
127                     .map_err(Error::VhostUserGetProtocolFeatures)?;
128 
129                 let acked_protocol_features = avail_protocol_features & backend_protocol_features;
130 
131                 self.vu
132                     .set_protocol_features(acked_protocol_features)
133                     .map_err(Error::VhostUserSetProtocolFeatures)?;
134 
135                 acked_protocol_features
136             } else {
137                 VhostUserProtocolFeatures::empty()
138             };
139 
140         if avail_protocol_features.contains(VhostUserProtocolFeatures::REPLY_ACK)
141             && acked_protocol_features.contains(VhostUserProtocolFeatures::REPLY_ACK)
142         {
143             self.vu.set_hdr_flags(VhostUserHeaderFlag::NEED_REPLY);
144         }
145 
146         self.update_supports_migration(acked_features, acked_protocol_features.bits());
147 
148         Ok((acked_features, acked_protocol_features.bits()))
149     }
150 
151     #[allow(clippy::too_many_arguments)]
152     pub fn setup_vhost_user<S: VhostUserFrontendReqHandler>(
153         &mut self,
154         mem: &GuestMemoryMmap,
155         queues: Vec<(usize, Queue, EventFd)>,
156         virtio_interrupt: &Arc<dyn VirtioInterrupt>,
157         acked_features: u64,
158         backend_req_handler: &Option<FrontendReqHandler<S>>,
159         inflight: Option<&mut Inflight>,
160     ) -> Result<()> {
161         self.vu
162             .set_features(acked_features)
163             .map_err(Error::VhostUserSetFeatures)?;
164 
165         // Update internal value after it's been sent to the backend.
166         self.acked_features = acked_features;
167 
168         // Let's first provide the memory table to the backend.
169         self.update_mem_table(mem)?;
170 
171         // Send set_vring_num here, since it could tell backends, like SPDK,
172         // how many virt queues to be handled, which backend required to know
173         // at early stage.
174         for (queue_index, queue, _) in queues.iter() {
175             self.vu
176                 .set_vring_num(*queue_index, queue.size())
177                 .map_err(Error::VhostUserSetVringNum)?;
178         }
179 
180         // Setup for inflight I/O tracking shared memory.
181         if let Some(inflight) = inflight {
182             if inflight.fd.is_none() {
183                 let inflight_req_info = VhostUserInflight {
184                     mmap_size: 0,
185                     mmap_offset: 0,
186                     num_queues: queues.len() as u16,
187                     queue_size: queues[0].1.size(),
188                 };
189                 let (info, fd) = self
190                     .vu
191                     .get_inflight_fd(&inflight_req_info)
192                     .map_err(Error::VhostUserGetInflight)?;
193                 inflight.info = info;
194                 inflight.fd = Some(fd);
195             }
196             // Unwrapping the inflight fd is safe here since we know it can't be None.
197             self.vu
198                 .set_inflight_fd(&inflight.info, inflight.fd.as_ref().unwrap().as_raw_fd())
199                 .map_err(Error::VhostUserSetInflight)?;
200         }
201 
202         let mut vrings_info = Vec::new();
203         for (queue_index, queue, queue_evt) in queues.iter() {
204             let actual_size: usize = queue.size().into();
205 
206             let config_data = VringConfigData {
207                 queue_max_size: queue.max_size(),
208                 queue_size: queue.size(),
209                 flags: 0u32,
210                 desc_table_addr: get_host_address_range(
211                     mem,
212                     GuestAddress(queue.desc_table()),
213                     actual_size * std::mem::size_of::<Descriptor>(),
214                 )
215                 .ok_or(Error::DescriptorTableAddress)? as u64,
216                 // The used ring is {flags: u16; idx: u16; virtq_used_elem [{id: u16, len: u16}; actual_size]},
217                 // i.e. 4 + (4 + 4) * actual_size.
218                 used_ring_addr: get_host_address_range(
219                     mem,
220                     GuestAddress(queue.used_ring()),
221                     4 + actual_size * 8,
222                 )
223                 .ok_or(Error::UsedAddress)? as u64,
224                 // The used ring is {flags: u16; idx: u16; elem [u16; actual_size]},
225                 // i.e. 4 + (2) * actual_size.
226                 avail_ring_addr: get_host_address_range(
227                     mem,
228                     GuestAddress(queue.avail_ring()),
229                     4 + actual_size * 2,
230                 )
231                 .ok_or(Error::AvailAddress)? as u64,
232                 log_addr: None,
233             };
234 
235             vrings_info.push(VringInfo {
236                 config_data,
237                 used_guest_addr: queue.used_ring(),
238             });
239 
240             self.vu
241                 .set_vring_addr(*queue_index, &config_data)
242                 .map_err(Error::VhostUserSetVringAddr)?;
243             self.vu
244                 .set_vring_base(
245                     *queue_index,
246                     queue
247                         .avail_idx(mem, Ordering::Acquire)
248                         .map_err(Error::GetAvailableIndex)?
249                         .0,
250                 )
251                 .map_err(Error::VhostUserSetVringBase)?;
252 
253             if let Some(eventfd) =
254                 virtio_interrupt.notifier(VirtioInterruptType::Queue(*queue_index as u16))
255             {
256                 self.vu
257                     .set_vring_call(*queue_index, &eventfd)
258                     .map_err(Error::VhostUserSetVringCall)?;
259             }
260 
261             self.vu
262                 .set_vring_kick(*queue_index, queue_evt)
263                 .map_err(Error::VhostUserSetVringKick)?;
264 
265             self.queue_indexes.push(*queue_index);
266         }
267 
268         self.enable_vhost_user_vrings(self.queue_indexes.clone(), true)?;
269 
270         if let Some(backend_req_handler) = backend_req_handler {
271             self.vu
272                 .set_backend_request_fd(&backend_req_handler.get_tx_raw_fd())
273                 .map_err(Error::VhostUserSetBackendRequestFd)?;
274         }
275 
276         self.vrings_info = Some(vrings_info);
277         self.ready = true;
278 
279         Ok(())
280     }
281 
282     fn enable_vhost_user_vrings(&mut self, queue_indexes: Vec<usize>, enable: bool) -> Result<()> {
283         for queue_index in queue_indexes {
284             self.vu
285                 .set_vring_enable(queue_index, enable)
286                 .map_err(Error::VhostUserSetVringEnable)?;
287         }
288 
289         Ok(())
290     }
291 
292     pub fn reset_vhost_user(&mut self) -> Result<()> {
293         for queue_index in self.queue_indexes.drain(..) {
294             self.vu
295                 .set_vring_enable(queue_index, false)
296                 .map_err(Error::VhostUserSetVringEnable)?;
297 
298             let _ = self
299                 .vu
300                 .get_vring_base(queue_index)
301                 .map_err(Error::VhostUserGetVringBase)?;
302         }
303 
304         Ok(())
305     }
306 
307     pub fn set_protocol_features_vhost_user(
308         &mut self,
309         acked_features: u64,
310         acked_protocol_features: u64,
311     ) -> Result<()> {
312         self.vu.set_owner().map_err(Error::VhostUserSetOwner)?;
313         self.vu
314             .get_features()
315             .map_err(Error::VhostUserGetFeatures)?;
316 
317         if acked_features & VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits() != 0 {
318             if let Some(acked_protocol_features) =
319                 VhostUserProtocolFeatures::from_bits(acked_protocol_features)
320             {
321                 self.vu
322                     .set_protocol_features(acked_protocol_features)
323                     .map_err(Error::VhostUserSetProtocolFeatures)?;
324 
325                 if acked_protocol_features.contains(VhostUserProtocolFeatures::REPLY_ACK) {
326                     self.vu.set_hdr_flags(VhostUserHeaderFlag::NEED_REPLY);
327                 }
328             }
329         }
330 
331         self.update_supports_migration(acked_features, acked_protocol_features);
332 
333         Ok(())
334     }
335 
336     #[allow(clippy::too_many_arguments)]
337     pub fn reinitialize_vhost_user<S: VhostUserFrontendReqHandler>(
338         &mut self,
339         mem: &GuestMemoryMmap,
340         queues: Vec<(usize, Queue, EventFd)>,
341         virtio_interrupt: &Arc<dyn VirtioInterrupt>,
342         acked_features: u64,
343         acked_protocol_features: u64,
344         backend_req_handler: &Option<FrontendReqHandler<S>>,
345         inflight: Option<&mut Inflight>,
346     ) -> Result<()> {
347         self.set_protocol_features_vhost_user(acked_features, acked_protocol_features)?;
348 
349         self.setup_vhost_user(
350             mem,
351             queues,
352             virtio_interrupt,
353             acked_features,
354             backend_req_handler,
355             inflight,
356         )
357     }
358 
359     pub fn connect_vhost_user(
360         server: bool,
361         socket_path: &str,
362         num_queues: u64,
363         unlink_socket: bool,
364     ) -> Result<Self> {
365         if server {
366             if unlink_socket {
367                 std::fs::remove_file(socket_path).map_err(Error::RemoveSocketPath)?;
368             }
369 
370             info!("Binding vhost-user listener...");
371             let listener = UnixListener::bind(socket_path).map_err(Error::BindSocket)?;
372             info!("Waiting for incoming vhost-user connection...");
373             let (stream, _) = listener.accept().map_err(Error::AcceptConnection)?;
374 
375             Ok(VhostUserHandle {
376                 vu: Frontend::from_stream(stream, num_queues),
377                 ready: false,
378                 supports_migration: false,
379                 shm_log: None,
380                 acked_features: 0,
381                 vrings_info: None,
382                 queue_indexes: Vec::new(),
383             })
384         } else {
385             let now = Instant::now();
386 
387             // Retry connecting for a full minute
388             let err = loop {
389                 let err = match Frontend::connect(socket_path, num_queues) {
390                     Ok(m) => {
391                         return Ok(VhostUserHandle {
392                             vu: m,
393                             ready: false,
394                             supports_migration: false,
395                             shm_log: None,
396                             acked_features: 0,
397                             vrings_info: None,
398                             queue_indexes: Vec::new(),
399                         })
400                     }
401                     Err(e) => e,
402                 };
403                 sleep(Duration::from_millis(100));
404 
405                 if now.elapsed().as_secs() >= 60 {
406                     break err;
407                 }
408             };
409 
410             error!(
411                 "Failed connecting the backend after trying for 1 minute: {:?}",
412                 err
413             );
414             Err(Error::VhostUserConnect)
415         }
416     }
417 
418     pub fn socket_handle(&mut self) -> &mut Frontend {
419         &mut self.vu
420     }
421 
422     pub fn pause_vhost_user(&mut self) -> Result<()> {
423         if self.ready {
424             self.enable_vhost_user_vrings(self.queue_indexes.clone(), false)?;
425         }
426 
427         Ok(())
428     }
429 
430     pub fn resume_vhost_user(&mut self) -> Result<()> {
431         if self.ready {
432             self.enable_vhost_user_vrings(self.queue_indexes.clone(), true)?;
433         }
434 
435         Ok(())
436     }
437 
438     fn update_supports_migration(&mut self, acked_features: u64, acked_protocol_features: u64) {
439         if (acked_features & u64::from(vhost::vhost_kern::vhost_binding::VHOST_F_LOG_ALL) != 0)
440             && (acked_protocol_features & VhostUserProtocolFeatures::LOG_SHMFD.bits() != 0)
441         {
442             self.supports_migration = true;
443         }
444     }
445 
446     fn update_log_base(&mut self, last_ram_addr: u64) -> Result<Option<Arc<MmapRegion>>> {
447         // Create the memfd
448         let fd = memfd_create(
449             &ffi::CString::new("vhost_user_dirty_log").unwrap(),
450             libc::MFD_CLOEXEC | libc::MFD_ALLOW_SEALING,
451         )
452         .map_err(Error::MemfdCreate)?;
453 
454         // SAFETY: we checked the file descriptor is valid
455         let file = unsafe { File::from_raw_fd(fd) };
456         // The size of the memory mapping corresponds to the size of a bitmap
457         // covering all guest pages for addresses from 0 to the last physical
458         // address in guest RAM.
459         // A page is always 4kiB from a vhost-user perspective, and each bit is
460         // a page. That's how we can compute mmap_size from the last address.
461         let mmap_size = (last_ram_addr / (VHOST_LOG_PAGE * 8)) + 1;
462         let mmap_handle = file.as_raw_fd();
463 
464         // Set shm_log region size
465         file.set_len(mmap_size).map_err(Error::SetFileSize)?;
466 
467         // Set the seals
468         // SAFETY: FFI call with valid arguments
469         let res = unsafe {
470             libc::fcntl(
471                 file.as_raw_fd(),
472                 libc::F_ADD_SEALS,
473                 libc::F_SEAL_GROW | libc::F_SEAL_SHRINK | libc::F_SEAL_SEAL,
474             )
475         };
476         if res < 0 {
477             return Err(Error::SetSeals(std::io::Error::last_os_error()));
478         }
479 
480         // Mmap shm_log region
481         let region = MmapRegion::build(
482             Some(FileOffset::new(file, 0)),
483             mmap_size as usize,
484             libc::PROT_READ | libc::PROT_WRITE,
485             libc::MAP_SHARED,
486         )
487         .map_err(Error::NewMmapRegion)?;
488 
489         // Make sure we hold onto the region to prevent the mapping from being
490         // released.
491         let old_region = self.shm_log.replace(Arc::new(region));
492 
493         // Send the shm_log fd over to the backend
494         let log = VhostUserDirtyLogRegion {
495             mmap_size,
496             mmap_offset: 0,
497             mmap_handle,
498         };
499         self.vu
500             .set_log_base(0, Some(log))
501             .map_err(Error::VhostUserSetLogBase)?;
502 
503         Ok(old_region)
504     }
505 
506     fn set_vring_logging(&mut self, enable: bool) -> Result<()> {
507         if let Some(vrings_info) = &self.vrings_info {
508             for (i, vring_info) in vrings_info.iter().enumerate() {
509                 let mut config_data = vring_info.config_data;
510                 config_data.flags = if enable { 1 << VHOST_VRING_F_LOG } else { 0 };
511                 config_data.log_addr = if enable {
512                     Some(vring_info.used_guest_addr)
513                 } else {
514                     None
515                 };
516 
517                 self.vu
518                     .set_vring_addr(i, &config_data)
519                     .map_err(Error::VhostUserSetVringAddr)?;
520             }
521         }
522 
523         Ok(())
524     }
525 
526     pub fn start_dirty_log(&mut self, last_ram_addr: u64) -> Result<()> {
527         if !self.supports_migration {
528             return Err(Error::MigrationNotSupported);
529         }
530 
531         // Set the shm log region
532         self.update_log_base(last_ram_addr)?;
533 
534         // Enable VHOST_F_LOG_ALL feature
535         let features = self.acked_features | (1 << VHOST_F_LOG_ALL);
536         self.vu
537             .set_features(features)
538             .map_err(Error::VhostUserSetFeatures)?;
539 
540         // Enable dirty page logging of used ring for all queues
541         self.set_vring_logging(true)
542     }
543 
544     pub fn stop_dirty_log(&mut self) -> Result<()> {
545         if !self.supports_migration {
546             return Err(Error::MigrationNotSupported);
547         }
548 
549         // Disable dirty page logging of used ring for all queues
550         self.set_vring_logging(false)?;
551 
552         // Disable VHOST_F_LOG_ALL feature
553         self.vu
554             .set_features(self.acked_features)
555             .map_err(Error::VhostUserSetFeatures)?;
556 
557         // This is important here since the log region goes out of scope,
558         // invoking the Drop trait, hence unmapping the memory.
559         self.shm_log = None;
560 
561         Ok(())
562     }
563 
564     pub fn dirty_log(&mut self, last_ram_addr: u64) -> Result<MemoryRangeTable> {
565         // The log region is updated by creating a new region that is sent to
566         // the backend. This ensures the backend stops logging to the previous
567         // region. The previous region is returned and processed to create the
568         // bitmap representing the dirty pages.
569         if let Some(region) = self.update_log_base(last_ram_addr)? {
570             // Be careful with the size, as it was based on u8, meaning we must
571             // divide it by 8.
572             let len = region.size() / 8;
573             // SAFETY: region is of size len
574             let bitmap = unsafe {
575                 // Cast the pointer to u64
576                 let ptr = region.as_ptr() as *const u64;
577                 std::slice::from_raw_parts(ptr, len).to_vec()
578             };
579             Ok(MemoryRangeTable::from_bitmap(bitmap, 0, 4096))
580         } else {
581             Err(Error::MissingShmLogRegion)
582         }
583     }
584 }
585 
586 fn memfd_create(name: &ffi::CStr, flags: u32) -> std::result::Result<RawFd, std::io::Error> {
587     // SAFETY: FFI call with valid arguments
588     let res = unsafe { libc::syscall(libc::SYS_memfd_create, name.as_ptr(), flags) };
589 
590     if res < 0 {
591         Err(std::io::Error::last_os_error())
592     } else {
593         Ok(res as RawFd)
594     }
595 }
596