xref: /cloud-hypervisor/virtio-devices/src/vhost_user/vu_common_ctrl.rs (revision f67b3f79ea19c9a66e04074cbbf5d292f6529e43)
1 // Copyright 2019 Intel Corporation. All Rights Reserved.
2 // SPDX-License-Identifier: Apache-2.0
3 
4 use super::super::{Descriptor, Queue};
5 use super::{Error, Result};
6 use crate::vhost_user::Inflight;
7 use crate::{
8     get_host_address_range, GuestMemoryMmap, GuestRegionMmap, MmapRegion, VirtioInterrupt,
9     VirtioInterruptType,
10 };
11 use std::convert::TryInto;
12 use std::ffi;
13 use std::fs::File;
14 use std::os::unix::io::{AsRawFd, FromRawFd, RawFd};
15 use std::os::unix::net::UnixListener;
16 use std::sync::Arc;
17 use std::thread::sleep;
18 use std::time::{Duration, Instant};
19 use std::vec::Vec;
20 use vhost::vhost_kern::vhost_binding::{VHOST_F_LOG_ALL, VHOST_VRING_F_LOG};
21 use vhost::vhost_user::message::{
22     VhostUserHeaderFlag, VhostUserInflight, VhostUserProtocolFeatures, VhostUserVirtioFeatures,
23 };
24 use vhost::vhost_user::{Master, MasterReqHandler, VhostUserMaster, VhostUserMasterReqHandler};
25 use vhost::{VhostBackend, VhostUserDirtyLogRegion, VhostUserMemoryRegionInfo, VringConfigData};
26 use vm_memory::{Address, Error as MmapError, FileOffset, GuestMemory, GuestMemoryRegion};
27 use vm_migration::protocol::MemoryRangeTable;
28 use vmm_sys_util::eventfd::EventFd;
29 
30 // Size of a dirty page for vhost-user.
31 const VHOST_LOG_PAGE: u64 = 0x1000;
32 
33 #[derive(Debug, Clone)]
34 pub struct VhostUserConfig {
35     pub socket: String,
36     pub num_queues: usize,
37     pub queue_size: u16,
38 }
39 
40 #[derive(Clone)]
41 struct VringInfo {
42     config_data: VringConfigData,
43     used_guest_addr: u64,
44 }
45 
46 #[derive(Clone)]
47 pub struct VhostUserHandle {
48     vu: Master,
49     ready: bool,
50     supports_migration: bool,
51     shm_log: Option<Arc<MmapRegion>>,
52     acked_features: u64,
53     vrings_info: Option<Vec<VringInfo>>,
54 }
55 
56 impl VhostUserHandle {
57     pub fn update_mem_table(&mut self, mem: &GuestMemoryMmap) -> Result<()> {
58         let mut regions: Vec<VhostUserMemoryRegionInfo> = Vec::new();
59         for region in mem.iter() {
60             let (mmap_handle, mmap_offset) = match region.file_offset() {
61                 Some(_file_offset) => (_file_offset.file().as_raw_fd(), _file_offset.start()),
62                 None => return Err(Error::VhostUserMemoryRegion(MmapError::NoMemoryRegion)),
63             };
64 
65             let vhost_user_net_reg = VhostUserMemoryRegionInfo {
66                 guest_phys_addr: region.start_addr().raw_value(),
67                 memory_size: region.len() as u64,
68                 userspace_addr: region.as_ptr() as u64,
69                 mmap_offset,
70                 mmap_handle,
71             };
72 
73             regions.push(vhost_user_net_reg);
74         }
75 
76         self.vu
77             .set_mem_table(regions.as_slice())
78             .map_err(Error::VhostUserSetMemTable)?;
79 
80         Ok(())
81     }
82 
83     pub fn add_memory_region(&mut self, region: &Arc<GuestRegionMmap>) -> Result<()> {
84         let (mmap_handle, mmap_offset) = match region.file_offset() {
85             Some(file_offset) => (file_offset.file().as_raw_fd(), file_offset.start()),
86             None => return Err(Error::MissingRegionFd),
87         };
88 
89         let region = VhostUserMemoryRegionInfo {
90             guest_phys_addr: region.start_addr().raw_value(),
91             memory_size: region.len() as u64,
92             userspace_addr: region.as_ptr() as u64,
93             mmap_offset,
94             mmap_handle,
95         };
96 
97         self.vu
98             .add_mem_region(&region)
99             .map_err(Error::VhostUserAddMemReg)
100     }
101 
102     pub fn negotiate_features_vhost_user(
103         &mut self,
104         avail_features: u64,
105         avail_protocol_features: VhostUserProtocolFeatures,
106     ) -> Result<(u64, u64)> {
107         // Set vhost-user owner.
108         self.vu.set_owner().map_err(Error::VhostUserSetOwner)?;
109 
110         // Get features from backend, do negotiation to get a feature collection which
111         // both VMM and backend support.
112         let backend_features = self
113             .vu
114             .get_features()
115             .map_err(Error::VhostUserGetFeatures)?;
116         let acked_features = avail_features & backend_features;
117 
118         let acked_protocol_features =
119             if acked_features & VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits() != 0 {
120                 let backend_protocol_features = self
121                     .vu
122                     .get_protocol_features()
123                     .map_err(Error::VhostUserGetProtocolFeatures)?;
124 
125                 let acked_protocol_features = avail_protocol_features & backend_protocol_features;
126 
127                 self.vu
128                     .set_protocol_features(acked_protocol_features)
129                     .map_err(Error::VhostUserSetProtocolFeatures)?;
130 
131                 acked_protocol_features
132             } else {
133                 VhostUserProtocolFeatures::empty()
134             };
135 
136         if avail_protocol_features.contains(VhostUserProtocolFeatures::REPLY_ACK)
137             && acked_protocol_features.contains(VhostUserProtocolFeatures::REPLY_ACK)
138         {
139             self.vu.set_hdr_flags(VhostUserHeaderFlag::NEED_REPLY);
140         }
141 
142         self.update_supports_migration(acked_features, acked_protocol_features.bits());
143 
144         Ok((acked_features, acked_protocol_features.bits()))
145     }
146 
147     #[allow(clippy::too_many_arguments)]
148     pub fn setup_vhost_user<S: VhostUserMasterReqHandler>(
149         &mut self,
150         mem: &GuestMemoryMmap,
151         queues: Vec<Queue>,
152         queue_evts: Vec<EventFd>,
153         virtio_interrupt: &Arc<dyn VirtioInterrupt>,
154         acked_features: u64,
155         slave_req_handler: &Option<MasterReqHandler<S>>,
156         inflight: Option<&mut Inflight>,
157     ) -> Result<()> {
158         self.vu
159             .set_features(acked_features)
160             .map_err(Error::VhostUserSetFeatures)?;
161 
162         // Update internal value after it's been sent to the backend.
163         self.acked_features = acked_features;
164 
165         // Let's first provide the memory table to the backend.
166         self.update_mem_table(mem)?;
167 
168         // Send set_vring_num here, since it could tell backends, like SPDK,
169         // how many virt queues to be handled, which backend required to know
170         // at early stage.
171         for (queue_index, queue) in queues.iter().enumerate() {
172             self.vu
173                 .set_vring_num(queue_index, queue.actual_size())
174                 .map_err(Error::VhostUserSetVringNum)?;
175         }
176 
177         // Setup for inflight I/O tracking shared memory.
178         if let Some(inflight) = inflight {
179             if inflight.fd.is_none() {
180                 let inflight_req_info = VhostUserInflight {
181                     mmap_size: 0,
182                     mmap_offset: 0,
183                     num_queues: queues.len() as u16,
184                     queue_size: queues[0].actual_size(),
185                 };
186                 let (info, fd) = self
187                     .vu
188                     .get_inflight_fd(&inflight_req_info)
189                     .map_err(Error::VhostUserGetInflight)?;
190                 inflight.info = info;
191                 inflight.fd = Some(fd);
192             }
193             // Unwrapping the inflight fd is safe here since we know it can't be None.
194             self.vu
195                 .set_inflight_fd(&inflight.info, inflight.fd.as_ref().unwrap().as_raw_fd())
196                 .map_err(Error::VhostUserSetInflight)?;
197         }
198 
199         let num_queues = queues.len() as usize;
200 
201         let mut vrings_info = Vec::new();
202         for (queue_index, queue) in queues.into_iter().enumerate() {
203             let actual_size: usize = queue.actual_size().try_into().unwrap();
204 
205             let config_data = VringConfigData {
206                 queue_max_size: queue.get_max_size(),
207                 queue_size: queue.actual_size(),
208                 flags: 0u32,
209                 desc_table_addr: get_host_address_range(
210                     mem,
211                     queue.desc_table,
212                     actual_size * std::mem::size_of::<Descriptor>(),
213                 )
214                 .ok_or(Error::DescriptorTableAddress)? as u64,
215                 // The used ring is {flags: u16; idx: u16; virtq_used_elem [{id: u16, len: u16}; actual_size]},
216                 // i.e. 4 + (4 + 4) * actual_size.
217                 used_ring_addr: get_host_address_range(mem, queue.used_ring, 4 + actual_size * 8)
218                     .ok_or(Error::UsedAddress)? as u64,
219                 // The used ring is {flags: u16; idx: u16; elem [u16; actual_size]},
220                 // i.e. 4 + (2) * actual_size.
221                 avail_ring_addr: get_host_address_range(mem, queue.avail_ring, 4 + actual_size * 2)
222                     .ok_or(Error::AvailAddress)? as u64,
223                 log_addr: None,
224             };
225 
226             vrings_info.push(VringInfo {
227                 config_data,
228                 used_guest_addr: queue.used_ring.raw_value(),
229             });
230 
231             self.vu
232                 .set_vring_addr(queue_index, &config_data)
233                 .map_err(Error::VhostUserSetVringAddr)?;
234             self.vu
235                 .set_vring_base(
236                     queue_index,
237                     queue
238                         .used_index_from_memory(mem)
239                         .map_err(Error::GetAvailableIndex)?,
240                 )
241                 .map_err(Error::VhostUserSetVringBase)?;
242 
243             if let Some(eventfd) =
244                 virtio_interrupt.notifier(&VirtioInterruptType::Queue, Some(&queue))
245             {
246                 self.vu
247                     .set_vring_call(queue_index, &eventfd)
248                     .map_err(Error::VhostUserSetVringCall)?;
249             }
250 
251             self.vu
252                 .set_vring_kick(queue_index, &queue_evts[queue_index])
253                 .map_err(Error::VhostUserSetVringKick)?;
254         }
255 
256         self.enable_vhost_user_vrings(num_queues, true)?;
257 
258         if let Some(slave_req_handler) = slave_req_handler {
259             self.vu
260                 .set_slave_request_fd(&slave_req_handler.get_tx_raw_fd())
261                 .map_err(Error::VhostUserSetSlaveRequestFd)?;
262         }
263 
264         self.vrings_info = Some(vrings_info);
265         self.ready = true;
266 
267         Ok(())
268     }
269 
270     fn enable_vhost_user_vrings(&mut self, num_queues: usize, enable: bool) -> Result<()> {
271         for queue_index in 0..num_queues {
272             self.vu
273                 .set_vring_enable(queue_index, enable)
274                 .map_err(Error::VhostUserSetVringEnable)?;
275         }
276 
277         Ok(())
278     }
279 
280     pub fn reset_vhost_user(&mut self, num_queues: usize) -> Result<()> {
281         self.enable_vhost_user_vrings(num_queues, false)?;
282 
283         // Reset the owner.
284         self.vu.reset_owner().map_err(Error::VhostUserResetOwner)
285     }
286 
287     pub fn set_protocol_features_vhost_user(
288         &mut self,
289         acked_features: u64,
290         acked_protocol_features: u64,
291     ) -> Result<()> {
292         self.vu.set_owner().map_err(Error::VhostUserSetOwner)?;
293         self.vu
294             .get_features()
295             .map_err(Error::VhostUserGetFeatures)?;
296 
297         if acked_features & VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits() != 0 {
298             if let Some(acked_protocol_features) =
299                 VhostUserProtocolFeatures::from_bits(acked_protocol_features)
300             {
301                 self.vu
302                     .set_protocol_features(acked_protocol_features)
303                     .map_err(Error::VhostUserSetProtocolFeatures)?;
304 
305                 if acked_protocol_features.contains(VhostUserProtocolFeatures::REPLY_ACK) {
306                     self.vu.set_hdr_flags(VhostUserHeaderFlag::NEED_REPLY);
307                 }
308             }
309         }
310 
311         self.update_supports_migration(acked_features, acked_protocol_features);
312 
313         Ok(())
314     }
315 
316     #[allow(clippy::too_many_arguments)]
317     pub fn reinitialize_vhost_user<S: VhostUserMasterReqHandler>(
318         &mut self,
319         mem: &GuestMemoryMmap,
320         queues: Vec<Queue>,
321         queue_evts: Vec<EventFd>,
322         virtio_interrupt: &Arc<dyn VirtioInterrupt>,
323         acked_features: u64,
324         acked_protocol_features: u64,
325         slave_req_handler: &Option<MasterReqHandler<S>>,
326         inflight: Option<&mut Inflight>,
327     ) -> Result<()> {
328         self.set_protocol_features_vhost_user(acked_features, acked_protocol_features)?;
329 
330         self.setup_vhost_user(
331             mem,
332             queues,
333             queue_evts,
334             virtio_interrupt,
335             acked_features,
336             slave_req_handler,
337             inflight,
338         )
339     }
340 
341     pub fn connect_vhost_user(
342         server: bool,
343         socket_path: &str,
344         num_queues: u64,
345         unlink_socket: bool,
346     ) -> Result<Self> {
347         if server {
348             if unlink_socket {
349                 std::fs::remove_file(socket_path).map_err(Error::RemoveSocketPath)?;
350             }
351 
352             info!("Binding vhost-user listener...");
353             let listener = UnixListener::bind(socket_path).map_err(Error::BindSocket)?;
354             info!("Waiting for incoming vhost-user connection...");
355             let (stream, _) = listener.accept().map_err(Error::AcceptConnection)?;
356 
357             Ok(VhostUserHandle {
358                 vu: Master::from_stream(stream, num_queues),
359                 ready: false,
360                 supports_migration: false,
361                 shm_log: None,
362                 acked_features: 0,
363                 vrings_info: None,
364             })
365         } else {
366             let now = Instant::now();
367 
368             // Retry connecting for a full minute
369             let err = loop {
370                 let err = match Master::connect(socket_path, num_queues) {
371                     Ok(m) => {
372                         return Ok(VhostUserHandle {
373                             vu: m,
374                             ready: false,
375                             supports_migration: false,
376                             shm_log: None,
377                             acked_features: 0,
378                             vrings_info: None,
379                         })
380                     }
381                     Err(e) => e,
382                 };
383                 sleep(Duration::from_millis(100));
384 
385                 if now.elapsed().as_secs() >= 60 {
386                     break err;
387                 }
388             };
389 
390             error!(
391                 "Failed connecting the backend after trying for 1 minute: {:?}",
392                 err
393             );
394             Err(Error::VhostUserConnect)
395         }
396     }
397 
398     pub fn socket_handle(&mut self) -> &mut Master {
399         &mut self.vu
400     }
401 
402     pub fn pause_vhost_user(&mut self, num_queues: usize) -> Result<()> {
403         if self.ready {
404             self.enable_vhost_user_vrings(num_queues, false)?;
405         }
406 
407         Ok(())
408     }
409 
410     pub fn resume_vhost_user(&mut self, num_queues: usize) -> Result<()> {
411         if self.ready {
412             self.enable_vhost_user_vrings(num_queues, true)?;
413         }
414 
415         Ok(())
416     }
417 
418     fn update_supports_migration(&mut self, acked_features: u64, acked_protocol_features: u64) {
419         if (acked_features & u64::from(vhost::vhost_kern::vhost_binding::VHOST_F_LOG_ALL) != 0)
420             && (acked_protocol_features & VhostUserProtocolFeatures::LOG_SHMFD.bits() != 0)
421         {
422             self.supports_migration = true;
423         }
424     }
425 
426     fn update_log_base(&mut self, last_ram_addr: u64) -> Result<Option<Arc<MmapRegion>>> {
427         // Create the memfd
428         let fd = memfd_create(
429             &ffi::CString::new("vhost_user_dirty_log").unwrap(),
430             libc::MFD_CLOEXEC | libc::MFD_ALLOW_SEALING,
431         )
432         .map_err(Error::MemfdCreate)?;
433 
434         // Safe because we checked the file descriptor is valid
435         let file = unsafe { File::from_raw_fd(fd) };
436         // The size of the memory mapping corresponds to the size of a bitmap
437         // covering all guest pages for addresses from 0 to the last physical
438         // address in guest RAM.
439         // A page is always 4kiB from a vhost-user perspective, and each bit is
440         // a page. That's how we can compute mmap_size from the last address.
441         let mmap_size = (last_ram_addr / (VHOST_LOG_PAGE * 8)) + 1;
442         let mmap_handle = file.as_raw_fd();
443 
444         // Set shm_log region size
445         file.set_len(mmap_size).map_err(Error::SetFileSize)?;
446 
447         // Set the seals
448         let res = unsafe {
449             libc::fcntl(
450                 file.as_raw_fd(),
451                 libc::F_ADD_SEALS,
452                 libc::F_SEAL_GROW | libc::F_SEAL_SHRINK | libc::F_SEAL_SEAL,
453             )
454         };
455         if res < 0 {
456             return Err(Error::SetSeals(std::io::Error::last_os_error()));
457         }
458 
459         // Mmap shm_log region
460         let region = MmapRegion::build(
461             Some(FileOffset::new(file, 0)),
462             mmap_size as usize,
463             libc::PROT_READ | libc::PROT_WRITE,
464             libc::MAP_SHARED,
465         )
466         .map_err(Error::NewMmapRegion)?;
467 
468         // Make sure we hold onto the region to prevent the mapping from being
469         // released.
470         let old_region = self.shm_log.replace(Arc::new(region));
471 
472         // Send the shm_log fd over to the backend
473         let log = VhostUserDirtyLogRegion {
474             mmap_size,
475             mmap_offset: 0,
476             mmap_handle,
477         };
478         self.vu
479             .set_log_base(0, Some(log))
480             .map_err(Error::VhostUserSetLogBase)?;
481 
482         Ok(old_region)
483     }
484 
485     fn set_vring_logging(&mut self, enable: bool) -> Result<()> {
486         if let Some(vrings_info) = &self.vrings_info {
487             for (i, vring_info) in vrings_info.iter().enumerate() {
488                 let mut config_data = vring_info.config_data;
489                 config_data.flags = if enable { 1 << VHOST_VRING_F_LOG } else { 0 };
490                 config_data.log_addr = if enable {
491                     Some(vring_info.used_guest_addr)
492                 } else {
493                     None
494                 };
495 
496                 self.vu
497                     .set_vring_addr(i, &config_data)
498                     .map_err(Error::VhostUserSetVringAddr)?;
499             }
500         }
501 
502         Ok(())
503     }
504 
505     pub fn start_dirty_log(&mut self, last_ram_addr: u64) -> Result<()> {
506         if !self.supports_migration {
507             return Err(Error::MigrationNotSupported);
508         }
509 
510         // Set the shm log region
511         self.update_log_base(last_ram_addr)?;
512 
513         // Enable VHOST_F_LOG_ALL feature
514         let features = self.acked_features | (1 << VHOST_F_LOG_ALL);
515         self.vu
516             .set_features(features)
517             .map_err(Error::VhostUserSetFeatures)?;
518 
519         // Enable dirty page logging of used ring for all queues
520         self.set_vring_logging(true)
521     }
522 
523     pub fn stop_dirty_log(&mut self) -> Result<()> {
524         if !self.supports_migration {
525             return Err(Error::MigrationNotSupported);
526         }
527 
528         // Disable dirty page logging of used ring for all queues
529         self.set_vring_logging(false)?;
530 
531         // Disable VHOST_F_LOG_ALL feature
532         self.vu
533             .set_features(self.acked_features)
534             .map_err(Error::VhostUserSetFeatures)?;
535 
536         // This is important here since the log region goes out of scope,
537         // invoking the Drop trait, hence unmapping the memory.
538         self.shm_log = None;
539 
540         Ok(())
541     }
542 
543     pub fn dirty_log(&mut self, last_ram_addr: u64) -> Result<MemoryRangeTable> {
544         // The log region is updated by creating a new region that is sent to
545         // the backend. This ensures the backend stops logging to the previous
546         // region. The previous region is returned and processed to create the
547         // bitmap representing the dirty pages.
548         if let Some(region) = self.update_log_base(last_ram_addr)? {
549             // Be careful with the size, as it was based on u8, meaning we must
550             // divide it by 8.
551             let len = region.size() / 8;
552             let bitmap = unsafe {
553                 // Cast the pointer to u64
554                 let ptr = region.as_ptr() as *const u64;
555                 std::slice::from_raw_parts(ptr, len).to_vec()
556             };
557             Ok(MemoryRangeTable::from_bitmap(bitmap, 0))
558         } else {
559             Err(Error::MissingShmLogRegion)
560         }
561     }
562 }
563 
564 fn memfd_create(name: &ffi::CStr, flags: u32) -> std::result::Result<RawFd, std::io::Error> {
565     let res = unsafe { libc::syscall(libc::SYS_memfd_create, name.as_ptr(), flags) };
566 
567     if res < 0 {
568         Err(std::io::Error::last_os_error())
569     } else {
570         Ok(res as RawFd)
571     }
572 }
573