xref: /cloud-hypervisor/virtio-devices/src/vhost_user/vu_common_ctrl.rs (revision f7f2f25a574b1b2dba22c094fc8226d404157d15)
1 // Copyright 2019 Intel Corporation. All Rights Reserved.
2 // SPDX-License-Identifier: Apache-2.0
3 
4 use super::super::{Descriptor, Queue};
5 use super::{Error, Result};
6 use crate::vhost_user::Inflight;
7 use crate::{
8     get_host_address_range, GuestMemoryMmap, GuestRegionMmap, MmapRegion, VirtioInterrupt,
9     VirtioInterruptType,
10 };
11 use std::convert::TryInto;
12 use std::ffi;
13 use std::fs::File;
14 use std::os::unix::io::{AsRawFd, FromRawFd, RawFd};
15 use std::os::unix::net::UnixListener;
16 use std::sync::Arc;
17 use std::thread::sleep;
18 use std::time::{Duration, Instant};
19 use std::vec::Vec;
20 use vhost::vhost_kern::vhost_binding::{VHOST_F_LOG_ALL, VHOST_VRING_F_LOG};
21 use vhost::vhost_user::message::{
22     VhostUserHeaderFlag, VhostUserInflight, VhostUserProtocolFeatures, VhostUserVirtioFeatures,
23 };
24 use vhost::vhost_user::{Master, MasterReqHandler, VhostUserMaster, VhostUserMasterReqHandler};
25 use vhost::{VhostBackend, VhostUserDirtyLogRegion, VhostUserMemoryRegionInfo, VringConfigData};
26 use vm_memory::{Address, Error as MmapError, FileOffset, GuestMemory, GuestMemoryRegion};
27 use vm_migration::protocol::MemoryRangeTable;
28 use vmm_sys_util::eventfd::EventFd;
29 
30 // Size of a dirty page for vhost-user.
31 const VHOST_LOG_PAGE: u64 = 0x1000;
32 
33 #[derive(Debug, Clone)]
34 pub struct VhostUserConfig {
35     pub socket: String,
36     pub num_queues: usize,
37     pub queue_size: u16,
38 }
39 
40 #[derive(Clone)]
41 struct VringInfo {
42     config_data: VringConfigData,
43     used_guest_addr: u64,
44 }
45 
46 #[derive(Clone)]
47 pub struct VhostUserHandle {
48     vu: Master,
49     ready: bool,
50     supports_migration: bool,
51     shm_log: Option<Arc<MmapRegion>>,
52     acked_features: u64,
53     vrings_info: Option<Vec<VringInfo>>,
54 }
55 
56 impl VhostUserHandle {
57     pub fn update_mem_table(&mut self, mem: &GuestMemoryMmap) -> Result<()> {
58         let mut regions: Vec<VhostUserMemoryRegionInfo> = Vec::new();
59         for region in mem.iter() {
60             let (mmap_handle, mmap_offset) = match region.file_offset() {
61                 Some(_file_offset) => (_file_offset.file().as_raw_fd(), _file_offset.start()),
62                 None => return Err(Error::VhostUserMemoryRegion(MmapError::NoMemoryRegion)),
63             };
64 
65             let vhost_user_net_reg = VhostUserMemoryRegionInfo {
66                 guest_phys_addr: region.start_addr().raw_value(),
67                 memory_size: region.len() as u64,
68                 userspace_addr: region.as_ptr() as u64,
69                 mmap_offset,
70                 mmap_handle,
71             };
72 
73             regions.push(vhost_user_net_reg);
74         }
75 
76         self.vu
77             .set_mem_table(regions.as_slice())
78             .map_err(Error::VhostUserSetMemTable)?;
79 
80         Ok(())
81     }
82 
83     pub fn add_memory_region(&mut self, region: &Arc<GuestRegionMmap>) -> Result<()> {
84         let (mmap_handle, mmap_offset) = match region.file_offset() {
85             Some(file_offset) => (file_offset.file().as_raw_fd(), file_offset.start()),
86             None => return Err(Error::MissingRegionFd),
87         };
88 
89         let region = VhostUserMemoryRegionInfo {
90             guest_phys_addr: region.start_addr().raw_value(),
91             memory_size: region.len() as u64,
92             userspace_addr: region.as_ptr() as u64,
93             mmap_offset,
94             mmap_handle,
95         };
96 
97         self.vu
98             .add_mem_region(&region)
99             .map_err(Error::VhostUserAddMemReg)
100     }
101 
102     pub fn negotiate_features_vhost_user(
103         &mut self,
104         avail_features: u64,
105         avail_protocol_features: VhostUserProtocolFeatures,
106     ) -> Result<(u64, u64)> {
107         // Set vhost-user owner.
108         self.vu.set_owner().map_err(Error::VhostUserSetOwner)?;
109 
110         // Get features from backend, do negotiation to get a feature collection which
111         // both VMM and backend support.
112         let backend_features = self
113             .vu
114             .get_features()
115             .map_err(Error::VhostUserGetFeatures)?;
116         let acked_features = avail_features & backend_features;
117 
118         let acked_protocol_features =
119             if acked_features & VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits() != 0 {
120                 let backend_protocol_features = self
121                     .vu
122                     .get_protocol_features()
123                     .map_err(Error::VhostUserGetProtocolFeatures)?;
124 
125                 let acked_protocol_features = avail_protocol_features & backend_protocol_features;
126 
127                 self.vu
128                     .set_protocol_features(acked_protocol_features)
129                     .map_err(Error::VhostUserSetProtocolFeatures)?;
130 
131                 acked_protocol_features
132             } else {
133                 VhostUserProtocolFeatures::empty()
134             };
135 
136         if avail_protocol_features.contains(VhostUserProtocolFeatures::REPLY_ACK)
137             && acked_protocol_features.contains(VhostUserProtocolFeatures::REPLY_ACK)
138         {
139             self.vu.set_hdr_flags(VhostUserHeaderFlag::NEED_REPLY);
140         }
141 
142         self.update_supports_migration(acked_features, acked_protocol_features.bits());
143 
144         Ok((acked_features, acked_protocol_features.bits()))
145     }
146 
147     #[allow(clippy::too_many_arguments)]
148     pub fn setup_vhost_user<S: VhostUserMasterReqHandler>(
149         &mut self,
150         mem: &GuestMemoryMmap,
151         queues: Vec<Queue>,
152         queue_evts: Vec<EventFd>,
153         virtio_interrupt: &Arc<dyn VirtioInterrupt>,
154         acked_features: u64,
155         slave_req_handler: &Option<MasterReqHandler<S>>,
156         inflight: Option<&mut Inflight>,
157     ) -> Result<()> {
158         self.vu
159             .set_features(acked_features)
160             .map_err(Error::VhostUserSetFeatures)?;
161 
162         // Update internal value after it's been sent to the backend.
163         self.acked_features = acked_features;
164 
165         // Let's first provide the memory table to the backend.
166         self.update_mem_table(mem)?;
167 
168         // Send set_vring_num here, since it could tell backends, like SPDK,
169         // how many virt queues to be handled, which backend required to know
170         // at early stage.
171         for (queue_index, queue) in queues.iter().enumerate() {
172             self.vu
173                 .set_vring_num(queue_index, queue.actual_size())
174                 .map_err(Error::VhostUserSetVringNum)?;
175         }
176 
177         // Setup for inflight I/O tracking shared memory.
178         if let Some(inflight) = inflight {
179             if inflight.fd.is_none() {
180                 let inflight_req_info = VhostUserInflight {
181                     mmap_size: 0,
182                     mmap_offset: 0,
183                     num_queues: queues.len() as u16,
184                     queue_size: queues[0].actual_size(),
185                 };
186                 let (info, fd) = self
187                     .vu
188                     .get_inflight_fd(&inflight_req_info)
189                     .map_err(Error::VhostUserGetInflight)?;
190                 inflight.info = info;
191                 inflight.fd = Some(fd);
192             }
193             // Unwrapping the inflight fd is safe here since we know it can't be None.
194             self.vu
195                 .set_inflight_fd(&inflight.info, inflight.fd.as_ref().unwrap().as_raw_fd())
196                 .map_err(Error::VhostUserSetInflight)?;
197         }
198 
199         let num_queues = queues.len() as usize;
200 
201         let mut vrings_info = Vec::new();
202         for (queue_index, queue) in queues.into_iter().enumerate() {
203             let actual_size: usize = queue.actual_size().try_into().unwrap();
204 
205             let config_data = VringConfigData {
206                 queue_max_size: queue.get_max_size(),
207                 queue_size: queue.actual_size(),
208                 flags: 0u32,
209                 desc_table_addr: get_host_address_range(
210                     mem,
211                     queue.desc_table,
212                     actual_size * std::mem::size_of::<Descriptor>(),
213                 )
214                 .ok_or(Error::DescriptorTableAddress)? as u64,
215                 // The used ring is {flags: u16; idx: u16; virtq_used_elem [{id: u16, len: u16}; actual_size]},
216                 // i.e. 4 + (4 + 4) * actual_size.
217                 used_ring_addr: get_host_address_range(mem, queue.used_ring, 4 + actual_size * 8)
218                     .ok_or(Error::UsedAddress)? as u64,
219                 // The used ring is {flags: u16; idx: u16; elem [u16; actual_size]},
220                 // i.e. 4 + (2) * actual_size.
221                 avail_ring_addr: get_host_address_range(mem, queue.avail_ring, 4 + actual_size * 2)
222                     .ok_or(Error::AvailAddress)? as u64,
223                 log_addr: None,
224             };
225 
226             vrings_info.push(VringInfo {
227                 config_data,
228                 used_guest_addr: queue.used_ring.raw_value(),
229             });
230 
231             self.vu
232                 .set_vring_addr(queue_index, &config_data)
233                 .map_err(Error::VhostUserSetVringAddr)?;
234             self.vu
235                 .set_vring_base(
236                     queue_index,
237                     queue
238                         .used_index_from_memory(mem)
239                         .map_err(Error::GetAvailableIndex)?,
240                 )
241                 .map_err(Error::VhostUserSetVringBase)?;
242 
243             if let Some(eventfd) =
244                 virtio_interrupt.notifier(&VirtioInterruptType::Queue, Some(&queue))
245             {
246                 self.vu
247                     .set_vring_call(queue_index, &eventfd)
248                     .map_err(Error::VhostUserSetVringCall)?;
249             }
250 
251             self.vu
252                 .set_vring_kick(queue_index, &queue_evts[queue_index])
253                 .map_err(Error::VhostUserSetVringKick)?;
254         }
255 
256         self.enable_vhost_user_vrings(num_queues, true)?;
257 
258         if let Some(slave_req_handler) = slave_req_handler {
259             self.vu
260                 .set_slave_request_fd(&slave_req_handler.get_tx_raw_fd())
261                 .map_err(Error::VhostUserSetSlaveRequestFd)?;
262         }
263 
264         self.vrings_info = Some(vrings_info);
265         self.ready = true;
266 
267         Ok(())
268     }
269 
270     fn enable_vhost_user_vrings(&mut self, num_queues: usize, enable: bool) -> Result<()> {
271         for queue_index in 0..num_queues {
272             self.vu
273                 .set_vring_enable(queue_index, enable)
274                 .map_err(Error::VhostUserSetVringEnable)?;
275         }
276 
277         Ok(())
278     }
279 
280     pub fn reset_vhost_user(&mut self, num_queues: usize) -> Result<()> {
281         self.enable_vhost_user_vrings(num_queues, false)?;
282 
283         // Reset the owner.
284         self.vu.reset_owner().map_err(Error::VhostUserResetOwner)
285     }
286 
287     #[allow(clippy::too_many_arguments)]
288     pub fn reinitialize_vhost_user<S: VhostUserMasterReqHandler>(
289         &mut self,
290         mem: &GuestMemoryMmap,
291         queues: Vec<Queue>,
292         queue_evts: Vec<EventFd>,
293         virtio_interrupt: &Arc<dyn VirtioInterrupt>,
294         acked_features: u64,
295         acked_protocol_features: u64,
296         slave_req_handler: &Option<MasterReqHandler<S>>,
297         inflight: Option<&mut Inflight>,
298     ) -> Result<()> {
299         self.vu.set_owner().map_err(Error::VhostUserSetOwner)?;
300         self.vu
301             .get_features()
302             .map_err(Error::VhostUserGetFeatures)?;
303 
304         if acked_features & VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits() != 0 {
305             if let Some(acked_protocol_features) =
306                 VhostUserProtocolFeatures::from_bits(acked_protocol_features)
307             {
308                 self.vu
309                     .set_protocol_features(acked_protocol_features)
310                     .map_err(Error::VhostUserSetProtocolFeatures)?;
311 
312                 if acked_protocol_features.contains(VhostUserProtocolFeatures::REPLY_ACK) {
313                     self.vu.set_hdr_flags(VhostUserHeaderFlag::NEED_REPLY);
314                 }
315             }
316         }
317 
318         self.update_supports_migration(acked_features, acked_protocol_features);
319 
320         self.setup_vhost_user(
321             mem,
322             queues,
323             queue_evts,
324             virtio_interrupt,
325             acked_features,
326             slave_req_handler,
327             inflight,
328         )
329     }
330 
331     pub fn connect_vhost_user(
332         server: bool,
333         socket_path: &str,
334         num_queues: u64,
335         unlink_socket: bool,
336     ) -> Result<Self> {
337         if server {
338             if unlink_socket {
339                 std::fs::remove_file(socket_path).map_err(Error::RemoveSocketPath)?;
340             }
341 
342             info!("Binding vhost-user listener...");
343             let listener = UnixListener::bind(socket_path).map_err(Error::BindSocket)?;
344             info!("Waiting for incoming vhost-user connection...");
345             let (stream, _) = listener.accept().map_err(Error::AcceptConnection)?;
346 
347             Ok(VhostUserHandle {
348                 vu: Master::from_stream(stream, num_queues),
349                 ready: false,
350                 supports_migration: false,
351                 shm_log: None,
352                 acked_features: 0,
353                 vrings_info: None,
354             })
355         } else {
356             let now = Instant::now();
357 
358             // Retry connecting for a full minute
359             let err = loop {
360                 let err = match Master::connect(socket_path, num_queues) {
361                     Ok(m) => {
362                         return Ok(VhostUserHandle {
363                             vu: m,
364                             ready: false,
365                             supports_migration: false,
366                             shm_log: None,
367                             acked_features: 0,
368                             vrings_info: None,
369                         })
370                     }
371                     Err(e) => e,
372                 };
373                 sleep(Duration::from_millis(100));
374 
375                 if now.elapsed().as_secs() >= 60 {
376                     break err;
377                 }
378             };
379 
380             error!(
381                 "Failed connecting the backend after trying for 1 minute: {:?}",
382                 err
383             );
384             Err(Error::VhostUserConnect)
385         }
386     }
387 
388     pub fn socket_handle(&mut self) -> &mut Master {
389         &mut self.vu
390     }
391 
392     pub fn pause_vhost_user(&mut self, num_queues: usize) -> Result<()> {
393         if self.ready {
394             self.enable_vhost_user_vrings(num_queues, false)?;
395         }
396 
397         Ok(())
398     }
399 
400     pub fn resume_vhost_user(&mut self, num_queues: usize) -> Result<()> {
401         if self.ready {
402             self.enable_vhost_user_vrings(num_queues, true)?;
403         }
404 
405         Ok(())
406     }
407 
408     fn update_supports_migration(&mut self, acked_features: u64, acked_protocol_features: u64) {
409         if (acked_features & u64::from(vhost::vhost_kern::vhost_binding::VHOST_F_LOG_ALL) != 0)
410             && (acked_protocol_features & VhostUserProtocolFeatures::LOG_SHMFD.bits() != 0)
411         {
412             self.supports_migration = true;
413         }
414     }
415 
416     fn update_log_base(&mut self, last_ram_addr: u64) -> Result<Option<Arc<MmapRegion>>> {
417         // Create the memfd
418         let fd = memfd_create(
419             &ffi::CString::new("vhost_user_dirty_log").unwrap(),
420             libc::MFD_CLOEXEC | libc::MFD_ALLOW_SEALING,
421         )
422         .map_err(Error::MemfdCreate)?;
423 
424         // Safe because we checked the file descriptor is valid
425         let file = unsafe { File::from_raw_fd(fd) };
426         // The size of the memory mapping corresponds to the size of a bitmap
427         // covering all guest pages for addresses from 0 to the last physical
428         // address in guest RAM.
429         // A page is always 4kiB from a vhost-user perspective, and each bit is
430         // a page. That's how we can compute mmap_size from the last address.
431         let mmap_size = (last_ram_addr / (VHOST_LOG_PAGE * 8)) + 1;
432         let mmap_handle = file.as_raw_fd();
433 
434         // Set shm_log region size
435         file.set_len(mmap_size).map_err(Error::SetFileSize)?;
436 
437         // Set the seals
438         let res = unsafe {
439             libc::fcntl(
440                 file.as_raw_fd(),
441                 libc::F_ADD_SEALS,
442                 libc::F_SEAL_GROW | libc::F_SEAL_SHRINK | libc::F_SEAL_SEAL,
443             )
444         };
445         if res < 0 {
446             return Err(Error::SetSeals(std::io::Error::last_os_error()));
447         }
448 
449         // Mmap shm_log region
450         let region = MmapRegion::build(
451             Some(FileOffset::new(file, 0)),
452             mmap_size as usize,
453             libc::PROT_READ | libc::PROT_WRITE,
454             libc::MAP_SHARED,
455         )
456         .map_err(Error::NewMmapRegion)?;
457 
458         // Make sure we hold onto the region to prevent the mapping from being
459         // released.
460         let old_region = self.shm_log.take();
461         self.shm_log = Some(Arc::new(region));
462 
463         // Send the shm_log fd over to the backend
464         let log = VhostUserDirtyLogRegion {
465             mmap_size,
466             mmap_offset: 0,
467             mmap_handle,
468         };
469         self.vu
470             .set_log_base(0, Some(log))
471             .map_err(Error::VhostUserSetLogBase)?;
472 
473         Ok(old_region)
474     }
475 
476     fn set_vring_logging(&mut self, enable: bool) -> Result<()> {
477         if let Some(vrings_info) = &self.vrings_info {
478             for (i, vring_info) in vrings_info.iter().enumerate() {
479                 let mut config_data = vring_info.config_data;
480                 config_data.flags = if enable { 1 << VHOST_VRING_F_LOG } else { 0 };
481                 config_data.log_addr = if enable {
482                     Some(vring_info.used_guest_addr)
483                 } else {
484                     None
485                 };
486 
487                 self.vu
488                     .set_vring_addr(i, &config_data)
489                     .map_err(Error::VhostUserSetVringAddr)?;
490             }
491         }
492 
493         Ok(())
494     }
495 
496     pub fn start_dirty_log(&mut self, last_ram_addr: u64) -> Result<()> {
497         if !self.supports_migration {
498             return Err(Error::MigrationNotSupported);
499         }
500 
501         // Set the shm log region
502         self.update_log_base(last_ram_addr)?;
503 
504         // Enable VHOST_F_LOG_ALL feature
505         let features = self.acked_features | (1 << VHOST_F_LOG_ALL);
506         self.vu
507             .set_features(features)
508             .map_err(Error::VhostUserSetFeatures)?;
509 
510         // Enable dirty page logging of used ring for all queues
511         self.set_vring_logging(true)
512     }
513 
514     pub fn stop_dirty_log(&mut self) -> Result<()> {
515         if !self.supports_migration {
516             return Err(Error::MigrationNotSupported);
517         }
518 
519         // Disable dirty page logging of used ring for all queues
520         self.set_vring_logging(false)?;
521 
522         // Disable VHOST_F_LOG_ALL feature
523         self.vu
524             .set_features(self.acked_features)
525             .map_err(Error::VhostUserSetFeatures)?;
526 
527         // This is important here since the log region goes out of scope,
528         // invoking the Drop trait, hence unmapping the memory.
529         self.shm_log = None;
530 
531         Ok(())
532     }
533 
534     pub fn dirty_log(&mut self, last_ram_addr: u64) -> Result<MemoryRangeTable> {
535         // The log region is updated by creating a new region that is sent to
536         // the backend. This ensures the backend stops logging to the previous
537         // region. The previous region is returned and processed to create the
538         // bitmap representing the dirty pages.
539         if let Some(region) = self.update_log_base(last_ram_addr)? {
540             // Cast the pointer to u64
541             let ptr = region.as_ptr() as *mut u64;
542             // Be careful with the size, as it was based on u8, meaning we must
543             // divide it by 8.
544             let len = region.size() / 8;
545             let bitmap = unsafe { Vec::from_raw_parts(ptr, len, len) };
546             Ok(MemoryRangeTable::from_bitmap(bitmap, 0))
547         } else {
548             Err(Error::MissingShmLogRegion)
549         }
550     }
551 }
552 
553 fn memfd_create(name: &ffi::CStr, flags: u32) -> std::result::Result<RawFd, std::io::Error> {
554     let res = unsafe { libc::syscall(libc::SYS_memfd_create, name.as_ptr(), flags) };
555 
556     if res < 0 {
557         Err(std::io::Error::last_os_error())
558     } else {
559         Ok(res as RawFd)
560     }
561 }
562