xref: /cloud-hypervisor/virtio-devices/src/vhost_user/vu_common_ctrl.rs (revision 4d7a4c598ac247aaf770b00dfb057cdac891f67d)
1 // Copyright 2019 Intel Corporation. All Rights Reserved.
2 // SPDX-License-Identifier: Apache-2.0
3 
4 use super::{Error, Result};
5 use crate::vhost_user::Inflight;
6 use crate::{
7     get_host_address_range, GuestMemoryMmap, GuestRegionMmap, MmapRegion, VirtioInterrupt,
8     VirtioInterruptType,
9 };
10 use std::ffi;
11 use std::fs::File;
12 use std::os::unix::io::{AsRawFd, FromRawFd, RawFd};
13 use std::os::unix::net::UnixListener;
14 use std::sync::atomic::Ordering;
15 use std::sync::Arc;
16 use std::thread::sleep;
17 use std::time::{Duration, Instant};
18 use std::vec::Vec;
19 use vhost::vhost_kern::vhost_binding::{VHOST_F_LOG_ALL, VHOST_VRING_F_LOG};
20 use vhost::vhost_user::message::{
21     VhostUserHeaderFlag, VhostUserInflight, VhostUserProtocolFeatures, VhostUserVirtioFeatures,
22 };
23 use vhost::vhost_user::{
24     Frontend, FrontendReqHandler, VhostUserFrontend, VhostUserFrontendReqHandler,
25 };
26 use vhost::{VhostBackend, VhostUserDirtyLogRegion, VhostUserMemoryRegionInfo, VringConfigData};
27 use virtio_queue::{Descriptor, Queue, QueueT};
28 use vm_memory::{
29     Address, Error as MmapError, FileOffset, GuestAddress, GuestMemory, GuestMemoryRegion,
30 };
31 use vm_migration::protocol::MemoryRangeTable;
32 use vmm_sys_util::eventfd::EventFd;
33 
34 // Size of a dirty page for vhost-user.
35 const VHOST_LOG_PAGE: u64 = 0x1000;
36 
37 #[derive(Debug, Clone)]
38 pub struct VhostUserConfig {
39     pub socket: String,
40     pub num_queues: usize,
41     pub queue_size: u16,
42 }
43 
44 #[derive(Clone)]
45 struct VringInfo {
46     config_data: VringConfigData,
47     used_guest_addr: u64,
48 }
49 
50 #[derive(Clone)]
51 pub struct VhostUserHandle {
52     vu: Frontend,
53     ready: bool,
54     supports_migration: bool,
55     shm_log: Option<Arc<MmapRegion>>,
56     acked_features: u64,
57     vrings_info: Option<Vec<VringInfo>>,
58     queue_indexes: Vec<usize>,
59 }
60 
61 impl VhostUserHandle {
62     pub fn update_mem_table(&mut self, mem: &GuestMemoryMmap) -> Result<()> {
63         let mut regions: Vec<VhostUserMemoryRegionInfo> = Vec::new();
64         for region in mem.iter() {
65             let (mmap_handle, mmap_offset) = match region.file_offset() {
66                 Some(_file_offset) => (_file_offset.file().as_raw_fd(), _file_offset.start()),
67                 None => return Err(Error::VhostUserMemoryRegion(MmapError::NoMemoryRegion)),
68             };
69 
70             let vhost_user_net_reg = VhostUserMemoryRegionInfo {
71                 guest_phys_addr: region.start_addr().raw_value(),
72                 memory_size: region.len(),
73                 userspace_addr: region.as_ptr() as u64,
74                 mmap_offset,
75                 mmap_handle,
76             };
77 
78             regions.push(vhost_user_net_reg);
79         }
80 
81         self.vu
82             .set_mem_table(regions.as_slice())
83             .map_err(Error::VhostUserSetMemTable)?;
84 
85         Ok(())
86     }
87 
88     pub fn add_memory_region(&mut self, region: &Arc<GuestRegionMmap>) -> Result<()> {
89         let (mmap_handle, mmap_offset) = match region.file_offset() {
90             Some(file_offset) => (file_offset.file().as_raw_fd(), file_offset.start()),
91             None => return Err(Error::MissingRegionFd),
92         };
93 
94         let region = VhostUserMemoryRegionInfo {
95             guest_phys_addr: region.start_addr().raw_value(),
96             memory_size: region.len(),
97             userspace_addr: region.as_ptr() as u64,
98             mmap_offset,
99             mmap_handle,
100         };
101 
102         self.vu
103             .add_mem_region(&region)
104             .map_err(Error::VhostUserAddMemReg)
105     }
106 
107     pub fn negotiate_features_vhost_user(
108         &mut self,
109         avail_features: u64,
110         avail_protocol_features: VhostUserProtocolFeatures,
111     ) -> Result<(u64, u64)> {
112         // Set vhost-user owner.
113         self.vu.set_owner().map_err(Error::VhostUserSetOwner)?;
114 
115         // Get features from backend, do negotiation to get a feature collection which
116         // both VMM and backend support.
117         let backend_features = self
118             .vu
119             .get_features()
120             .map_err(Error::VhostUserGetFeatures)?;
121         let acked_features = avail_features & backend_features;
122 
123         let acked_protocol_features =
124             if acked_features & VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits() != 0 {
125                 let backend_protocol_features = self
126                     .vu
127                     .get_protocol_features()
128                     .map_err(Error::VhostUserGetProtocolFeatures)?;
129 
130                 let acked_protocol_features = avail_protocol_features & backend_protocol_features;
131 
132                 self.vu
133                     .set_protocol_features(acked_protocol_features)
134                     .map_err(Error::VhostUserSetProtocolFeatures)?;
135 
136                 acked_protocol_features
137             } else {
138                 VhostUserProtocolFeatures::empty()
139             };
140 
141         if avail_protocol_features.contains(VhostUserProtocolFeatures::REPLY_ACK)
142             && acked_protocol_features.contains(VhostUserProtocolFeatures::REPLY_ACK)
143         {
144             self.vu.set_hdr_flags(VhostUserHeaderFlag::NEED_REPLY);
145         }
146 
147         self.update_supports_migration(acked_features, acked_protocol_features.bits());
148 
149         Ok((acked_features, acked_protocol_features.bits()))
150     }
151 
152     #[allow(clippy::too_many_arguments)]
153     pub fn setup_vhost_user<S: VhostUserFrontendReqHandler>(
154         &mut self,
155         mem: &GuestMemoryMmap,
156         queues: Vec<(usize, Queue, EventFd)>,
157         virtio_interrupt: &Arc<dyn VirtioInterrupt>,
158         acked_features: u64,
159         backend_req_handler: &Option<FrontendReqHandler<S>>,
160         inflight: Option<&mut Inflight>,
161     ) -> Result<()> {
162         self.vu
163             .set_features(acked_features)
164             .map_err(Error::VhostUserSetFeatures)?;
165 
166         // Update internal value after it's been sent to the backend.
167         self.acked_features = acked_features;
168 
169         // Let's first provide the memory table to the backend.
170         self.update_mem_table(mem)?;
171 
172         // Send set_vring_num here, since it could tell backends, like SPDK,
173         // how many virt queues to be handled, which backend required to know
174         // at early stage.
175         for (queue_index, queue, _) in queues.iter() {
176             self.vu
177                 .set_vring_num(*queue_index, queue.size())
178                 .map_err(Error::VhostUserSetVringNum)?;
179         }
180 
181         // Setup for inflight I/O tracking shared memory.
182         if let Some(inflight) = inflight {
183             if inflight.fd.is_none() {
184                 let inflight_req_info = VhostUserInflight {
185                     mmap_size: 0,
186                     mmap_offset: 0,
187                     num_queues: queues.len() as u16,
188                     queue_size: queues[0].1.size(),
189                 };
190                 let (info, fd) = self
191                     .vu
192                     .get_inflight_fd(&inflight_req_info)
193                     .map_err(Error::VhostUserGetInflight)?;
194                 inflight.info = info;
195                 inflight.fd = Some(fd);
196             }
197             // Unwrapping the inflight fd is safe here since we know it can't be None.
198             self.vu
199                 .set_inflight_fd(&inflight.info, inflight.fd.as_ref().unwrap().as_raw_fd())
200                 .map_err(Error::VhostUserSetInflight)?;
201         }
202 
203         let mut vrings_info = Vec::new();
204         for (queue_index, queue, queue_evt) in queues.iter() {
205             let actual_size: usize = queue.size().into();
206 
207             let config_data = VringConfigData {
208                 queue_max_size: queue.max_size(),
209                 queue_size: queue.size(),
210                 flags: 0u32,
211                 desc_table_addr: get_host_address_range(
212                     mem,
213                     GuestAddress(queue.desc_table()),
214                     actual_size * std::mem::size_of::<Descriptor>(),
215                 )
216                 .ok_or(Error::DescriptorTableAddress)? as u64,
217                 // The used ring is {flags: u16; idx: u16; virtq_used_elem [{id: u16, len: u16}; actual_size]},
218                 // i.e. 4 + (4 + 4) * actual_size.
219                 used_ring_addr: get_host_address_range(
220                     mem,
221                     GuestAddress(queue.used_ring()),
222                     4 + actual_size * 8,
223                 )
224                 .ok_or(Error::UsedAddress)? as u64,
225                 // The used ring is {flags: u16; idx: u16; elem [u16; actual_size]},
226                 // i.e. 4 + (2) * actual_size.
227                 avail_ring_addr: get_host_address_range(
228                     mem,
229                     GuestAddress(queue.avail_ring()),
230                     4 + actual_size * 2,
231                 )
232                 .ok_or(Error::AvailAddress)? as u64,
233                 log_addr: None,
234             };
235 
236             vrings_info.push(VringInfo {
237                 config_data,
238                 used_guest_addr: queue.used_ring(),
239             });
240 
241             self.vu
242                 .set_vring_addr(*queue_index, &config_data)
243                 .map_err(Error::VhostUserSetVringAddr)?;
244             self.vu
245                 .set_vring_base(
246                     *queue_index,
247                     queue
248                         .avail_idx(mem, Ordering::Acquire)
249                         .map_err(Error::GetAvailableIndex)?
250                         .0,
251                 )
252                 .map_err(Error::VhostUserSetVringBase)?;
253 
254             if let Some(eventfd) =
255                 virtio_interrupt.notifier(VirtioInterruptType::Queue(*queue_index as u16))
256             {
257                 self.vu
258                     .set_vring_call(*queue_index, &eventfd)
259                     .map_err(Error::VhostUserSetVringCall)?;
260             }
261 
262             self.vu
263                 .set_vring_kick(*queue_index, queue_evt)
264                 .map_err(Error::VhostUserSetVringKick)?;
265 
266             self.queue_indexes.push(*queue_index);
267         }
268 
269         self.enable_vhost_user_vrings(self.queue_indexes.clone(), true)?;
270 
271         if let Some(backend_req_handler) = backend_req_handler {
272             self.vu
273                 .set_backend_request_fd(&backend_req_handler.get_tx_raw_fd())
274                 .map_err(Error::VhostUserSetBackendRequestFd)?;
275         }
276 
277         self.vrings_info = Some(vrings_info);
278         self.ready = true;
279 
280         Ok(())
281     }
282 
283     fn enable_vhost_user_vrings(&mut self, queue_indexes: Vec<usize>, enable: bool) -> Result<()> {
284         for queue_index in queue_indexes {
285             self.vu
286                 .set_vring_enable(queue_index, enable)
287                 .map_err(Error::VhostUserSetVringEnable)?;
288         }
289 
290         Ok(())
291     }
292 
293     pub fn reset_vhost_user(&mut self) -> Result<()> {
294         for queue_index in self.queue_indexes.drain(..) {
295             self.vu
296                 .set_vring_enable(queue_index, false)
297                 .map_err(Error::VhostUserSetVringEnable)?;
298 
299             let _ = self
300                 .vu
301                 .get_vring_base(queue_index)
302                 .map_err(Error::VhostUserGetVringBase)?;
303         }
304 
305         Ok(())
306     }
307 
308     pub fn set_protocol_features_vhost_user(
309         &mut self,
310         acked_features: u64,
311         acked_protocol_features: u64,
312     ) -> Result<()> {
313         self.vu.set_owner().map_err(Error::VhostUserSetOwner)?;
314         self.vu
315             .get_features()
316             .map_err(Error::VhostUserGetFeatures)?;
317 
318         if acked_features & VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits() != 0 {
319             if let Some(acked_protocol_features) =
320                 VhostUserProtocolFeatures::from_bits(acked_protocol_features)
321             {
322                 self.vu
323                     .set_protocol_features(acked_protocol_features)
324                     .map_err(Error::VhostUserSetProtocolFeatures)?;
325 
326                 if acked_protocol_features.contains(VhostUserProtocolFeatures::REPLY_ACK) {
327                     self.vu.set_hdr_flags(VhostUserHeaderFlag::NEED_REPLY);
328                 }
329             }
330         }
331 
332         self.update_supports_migration(acked_features, acked_protocol_features);
333 
334         Ok(())
335     }
336 
337     #[allow(clippy::too_many_arguments)]
338     pub fn reinitialize_vhost_user<S: VhostUserFrontendReqHandler>(
339         &mut self,
340         mem: &GuestMemoryMmap,
341         queues: Vec<(usize, Queue, EventFd)>,
342         virtio_interrupt: &Arc<dyn VirtioInterrupt>,
343         acked_features: u64,
344         acked_protocol_features: u64,
345         backend_req_handler: &Option<FrontendReqHandler<S>>,
346         inflight: Option<&mut Inflight>,
347     ) -> Result<()> {
348         self.set_protocol_features_vhost_user(acked_features, acked_protocol_features)?;
349 
350         self.setup_vhost_user(
351             mem,
352             queues,
353             virtio_interrupt,
354             acked_features,
355             backend_req_handler,
356             inflight,
357         )
358     }
359 
360     pub fn connect_vhost_user(
361         server: bool,
362         socket_path: &str,
363         num_queues: u64,
364         unlink_socket: bool,
365     ) -> Result<Self> {
366         if server {
367             if unlink_socket {
368                 std::fs::remove_file(socket_path).map_err(Error::RemoveSocketPath)?;
369             }
370 
371             info!("Binding vhost-user listener...");
372             let listener = UnixListener::bind(socket_path).map_err(Error::BindSocket)?;
373             info!("Waiting for incoming vhost-user connection...");
374             let (stream, _) = listener.accept().map_err(Error::AcceptConnection)?;
375 
376             Ok(VhostUserHandle {
377                 vu: Frontend::from_stream(stream, num_queues),
378                 ready: false,
379                 supports_migration: false,
380                 shm_log: None,
381                 acked_features: 0,
382                 vrings_info: None,
383                 queue_indexes: Vec::new(),
384             })
385         } else {
386             let now = Instant::now();
387 
388             // Retry connecting for a full minute
389             let err = loop {
390                 let err = match Frontend::connect(socket_path, num_queues) {
391                     Ok(m) => {
392                         return Ok(VhostUserHandle {
393                             vu: m,
394                             ready: false,
395                             supports_migration: false,
396                             shm_log: None,
397                             acked_features: 0,
398                             vrings_info: None,
399                             queue_indexes: Vec::new(),
400                         })
401                     }
402                     Err(e) => e,
403                 };
404                 sleep(Duration::from_millis(100));
405 
406                 if now.elapsed().as_secs() >= 60 {
407                     break err;
408                 }
409             };
410 
411             error!(
412                 "Failed connecting the backend after trying for 1 minute: {:?}",
413                 err
414             );
415             Err(Error::VhostUserConnect)
416         }
417     }
418 
419     pub fn socket_handle(&mut self) -> &mut Frontend {
420         &mut self.vu
421     }
422 
423     pub fn pause_vhost_user(&mut self) -> Result<()> {
424         if self.ready {
425             self.enable_vhost_user_vrings(self.queue_indexes.clone(), false)?;
426         }
427 
428         Ok(())
429     }
430 
431     pub fn resume_vhost_user(&mut self) -> Result<()> {
432         if self.ready {
433             self.enable_vhost_user_vrings(self.queue_indexes.clone(), true)?;
434         }
435 
436         Ok(())
437     }
438 
439     fn update_supports_migration(&mut self, acked_features: u64, acked_protocol_features: u64) {
440         if (acked_features & u64::from(vhost::vhost_kern::vhost_binding::VHOST_F_LOG_ALL) != 0)
441             && (acked_protocol_features & VhostUserProtocolFeatures::LOG_SHMFD.bits() != 0)
442         {
443             self.supports_migration = true;
444         }
445     }
446 
447     fn update_log_base(&mut self, last_ram_addr: u64) -> Result<Option<Arc<MmapRegion>>> {
448         // Create the memfd
449         let fd = memfd_create(
450             &ffi::CString::new("vhost_user_dirty_log").unwrap(),
451             libc::MFD_CLOEXEC | libc::MFD_ALLOW_SEALING,
452         )
453         .map_err(Error::MemfdCreate)?;
454 
455         // SAFETY: we checked the file descriptor is valid
456         let file = unsafe { File::from_raw_fd(fd) };
457         // The size of the memory mapping corresponds to the size of a bitmap
458         // covering all guest pages for addresses from 0 to the last physical
459         // address in guest RAM.
460         // A page is always 4kiB from a vhost-user perspective, and each bit is
461         // a page. That's how we can compute mmap_size from the last address.
462         let mmap_size = (last_ram_addr / (VHOST_LOG_PAGE * 8)) + 1;
463         let mmap_handle = file.as_raw_fd();
464 
465         // Set shm_log region size
466         file.set_len(mmap_size).map_err(Error::SetFileSize)?;
467 
468         // Set the seals
469         // SAFETY: FFI call with valid arguments
470         let res = unsafe {
471             libc::fcntl(
472                 file.as_raw_fd(),
473                 libc::F_ADD_SEALS,
474                 libc::F_SEAL_GROW | libc::F_SEAL_SHRINK | libc::F_SEAL_SEAL,
475             )
476         };
477         if res < 0 {
478             return Err(Error::SetSeals(std::io::Error::last_os_error()));
479         }
480 
481         // Mmap shm_log region
482         let region = MmapRegion::build(
483             Some(FileOffset::new(file, 0)),
484             mmap_size as usize,
485             libc::PROT_READ | libc::PROT_WRITE,
486             libc::MAP_SHARED,
487         )
488         .map_err(Error::NewMmapRegion)?;
489 
490         // Make sure we hold onto the region to prevent the mapping from being
491         // released.
492         let old_region = self.shm_log.replace(Arc::new(region));
493 
494         // Send the shm_log fd over to the backend
495         let log = VhostUserDirtyLogRegion {
496             mmap_size,
497             mmap_offset: 0,
498             mmap_handle,
499         };
500         self.vu
501             .set_log_base(0, Some(log))
502             .map_err(Error::VhostUserSetLogBase)?;
503 
504         Ok(old_region)
505     }
506 
507     fn set_vring_logging(&mut self, enable: bool) -> Result<()> {
508         if let Some(vrings_info) = &self.vrings_info {
509             for (i, vring_info) in vrings_info.iter().enumerate() {
510                 let mut config_data = vring_info.config_data;
511                 config_data.flags = if enable { 1 << VHOST_VRING_F_LOG } else { 0 };
512                 config_data.log_addr = if enable {
513                     Some(vring_info.used_guest_addr)
514                 } else {
515                     None
516                 };
517 
518                 self.vu
519                     .set_vring_addr(i, &config_data)
520                     .map_err(Error::VhostUserSetVringAddr)?;
521             }
522         }
523 
524         Ok(())
525     }
526 
527     pub fn start_dirty_log(&mut self, last_ram_addr: u64) -> Result<()> {
528         if !self.supports_migration {
529             return Err(Error::MigrationNotSupported);
530         }
531 
532         // Set the shm log region
533         self.update_log_base(last_ram_addr)?;
534 
535         // Enable VHOST_F_LOG_ALL feature
536         let features = self.acked_features | (1 << VHOST_F_LOG_ALL);
537         self.vu
538             .set_features(features)
539             .map_err(Error::VhostUserSetFeatures)?;
540 
541         // Enable dirty page logging of used ring for all queues
542         self.set_vring_logging(true)
543     }
544 
545     pub fn stop_dirty_log(&mut self) -> Result<()> {
546         if !self.supports_migration {
547             return Err(Error::MigrationNotSupported);
548         }
549 
550         // Disable dirty page logging of used ring for all queues
551         self.set_vring_logging(false)?;
552 
553         // Disable VHOST_F_LOG_ALL feature
554         self.vu
555             .set_features(self.acked_features)
556             .map_err(Error::VhostUserSetFeatures)?;
557 
558         // This is important here since the log region goes out of scope,
559         // invoking the Drop trait, hence unmapping the memory.
560         self.shm_log = None;
561 
562         Ok(())
563     }
564 
565     pub fn dirty_log(&mut self, last_ram_addr: u64) -> Result<MemoryRangeTable> {
566         // The log region is updated by creating a new region that is sent to
567         // the backend. This ensures the backend stops logging to the previous
568         // region. The previous region is returned and processed to create the
569         // bitmap representing the dirty pages.
570         if let Some(region) = self.update_log_base(last_ram_addr)? {
571             // Be careful with the size, as it was based on u8, meaning we must
572             // divide it by 8.
573             let len = region.size() / 8;
574             // SAFETY: region is of size len
575             let bitmap = unsafe {
576                 // Cast the pointer to u64
577                 let ptr = region.as_ptr() as *const u64;
578                 std::slice::from_raw_parts(ptr, len).to_vec()
579             };
580             Ok(MemoryRangeTable::from_bitmap(bitmap, 0, 4096))
581         } else {
582             Err(Error::MissingShmLogRegion)
583         }
584     }
585 }
586 
587 fn memfd_create(name: &ffi::CStr, flags: u32) -> std::result::Result<RawFd, std::io::Error> {
588     // SAFETY: FFI call with valid arguments
589     let res = unsafe { libc::syscall(libc::SYS_memfd_create, name.as_ptr(), flags) };
590 
591     if res < 0 {
592         Err(std::io::Error::last_os_error())
593     } else {
594         Ok(res as RawFd)
595     }
596 }
597