xref: /cloud-hypervisor/virtio-devices/src/vhost_user/vu_common_ctrl.rs (revision 7d7bfb2034001d4cb15df2ddc56d2d350c8da30f)
1 // Copyright 2019 Intel Corporation. All Rights Reserved.
2 // SPDX-License-Identifier: Apache-2.0
3 
4 use super::{Error, Result};
5 use crate::vhost_user::Inflight;
6 use crate::{
7     get_host_address_range, GuestMemoryMmap, GuestRegionMmap, MmapRegion, VirtioInterrupt,
8     VirtioInterruptType,
9 };
10 use std::convert::TryInto;
11 use std::ffi;
12 use std::fs::File;
13 use std::os::unix::io::{AsRawFd, FromRawFd, RawFd};
14 use std::os::unix::net::UnixListener;
15 use std::sync::atomic::Ordering;
16 use std::sync::Arc;
17 use std::thread::sleep;
18 use std::time::{Duration, Instant};
19 use std::vec::Vec;
20 use vhost::vhost_kern::vhost_binding::{VHOST_F_LOG_ALL, VHOST_VRING_F_LOG};
21 use vhost::vhost_user::message::{
22     VhostUserHeaderFlag, VhostUserInflight, VhostUserProtocolFeatures, VhostUserVirtioFeatures,
23 };
24 use vhost::vhost_user::{Master, MasterReqHandler, VhostUserMaster, VhostUserMasterReqHandler};
25 use vhost::{VhostBackend, VhostUserDirtyLogRegion, VhostUserMemoryRegionInfo, VringConfigData};
26 use virtio_queue::{Descriptor, Queue};
27 use vm_memory::{
28     Address, Error as MmapError, FileOffset, GuestMemory, GuestMemoryAtomic, GuestMemoryRegion,
29 };
30 use vm_migration::protocol::MemoryRangeTable;
31 use vmm_sys_util::eventfd::EventFd;
32 
33 // Size of a dirty page for vhost-user.
34 const VHOST_LOG_PAGE: u64 = 0x1000;
35 
36 #[derive(Debug, Clone)]
37 pub struct VhostUserConfig {
38     pub socket: String,
39     pub num_queues: usize,
40     pub queue_size: u16,
41 }
42 
43 #[derive(Clone)]
44 struct VringInfo {
45     config_data: VringConfigData,
46     used_guest_addr: u64,
47 }
48 
49 #[derive(Clone)]
50 pub struct VhostUserHandle {
51     vu: Master,
52     ready: bool,
53     supports_migration: bool,
54     shm_log: Option<Arc<MmapRegion>>,
55     acked_features: u64,
56     vrings_info: Option<Vec<VringInfo>>,
57 }
58 
59 impl VhostUserHandle {
60     pub fn update_mem_table(&mut self, mem: &GuestMemoryMmap) -> Result<()> {
61         let mut regions: Vec<VhostUserMemoryRegionInfo> = Vec::new();
62         for region in mem.iter() {
63             let (mmap_handle, mmap_offset) = match region.file_offset() {
64                 Some(_file_offset) => (_file_offset.file().as_raw_fd(), _file_offset.start()),
65                 None => return Err(Error::VhostUserMemoryRegion(MmapError::NoMemoryRegion)),
66             };
67 
68             let vhost_user_net_reg = VhostUserMemoryRegionInfo {
69                 guest_phys_addr: region.start_addr().raw_value(),
70                 memory_size: region.len() as u64,
71                 userspace_addr: region.as_ptr() as u64,
72                 mmap_offset,
73                 mmap_handle,
74             };
75 
76             regions.push(vhost_user_net_reg);
77         }
78 
79         self.vu
80             .set_mem_table(regions.as_slice())
81             .map_err(Error::VhostUserSetMemTable)?;
82 
83         Ok(())
84     }
85 
86     pub fn add_memory_region(&mut self, region: &Arc<GuestRegionMmap>) -> Result<()> {
87         let (mmap_handle, mmap_offset) = match region.file_offset() {
88             Some(file_offset) => (file_offset.file().as_raw_fd(), file_offset.start()),
89             None => return Err(Error::MissingRegionFd),
90         };
91 
92         let region = VhostUserMemoryRegionInfo {
93             guest_phys_addr: region.start_addr().raw_value(),
94             memory_size: region.len() as u64,
95             userspace_addr: region.as_ptr() as u64,
96             mmap_offset,
97             mmap_handle,
98         };
99 
100         self.vu
101             .add_mem_region(&region)
102             .map_err(Error::VhostUserAddMemReg)
103     }
104 
105     pub fn negotiate_features_vhost_user(
106         &mut self,
107         avail_features: u64,
108         avail_protocol_features: VhostUserProtocolFeatures,
109     ) -> Result<(u64, u64)> {
110         // Set vhost-user owner.
111         self.vu.set_owner().map_err(Error::VhostUserSetOwner)?;
112 
113         // Get features from backend, do negotiation to get a feature collection which
114         // both VMM and backend support.
115         let backend_features = self
116             .vu
117             .get_features()
118             .map_err(Error::VhostUserGetFeatures)?;
119         let acked_features = avail_features & backend_features;
120 
121         let acked_protocol_features =
122             if acked_features & VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits() != 0 {
123                 let backend_protocol_features = self
124                     .vu
125                     .get_protocol_features()
126                     .map_err(Error::VhostUserGetProtocolFeatures)?;
127 
128                 let acked_protocol_features = avail_protocol_features & backend_protocol_features;
129 
130                 self.vu
131                     .set_protocol_features(acked_protocol_features)
132                     .map_err(Error::VhostUserSetProtocolFeatures)?;
133 
134                 acked_protocol_features
135             } else {
136                 VhostUserProtocolFeatures::empty()
137             };
138 
139         if avail_protocol_features.contains(VhostUserProtocolFeatures::REPLY_ACK)
140             && acked_protocol_features.contains(VhostUserProtocolFeatures::REPLY_ACK)
141         {
142             self.vu.set_hdr_flags(VhostUserHeaderFlag::NEED_REPLY);
143         }
144 
145         self.update_supports_migration(acked_features, acked_protocol_features.bits());
146 
147         Ok((acked_features, acked_protocol_features.bits()))
148     }
149 
150     #[allow(clippy::too_many_arguments)]
151     pub fn setup_vhost_user<S: VhostUserMasterReqHandler>(
152         &mut self,
153         mem: &GuestMemoryMmap,
154         queues: Vec<Queue<GuestMemoryAtomic<GuestMemoryMmap>>>,
155         queue_evts: Vec<EventFd>,
156         virtio_interrupt: &Arc<dyn VirtioInterrupt>,
157         acked_features: u64,
158         slave_req_handler: &Option<MasterReqHandler<S>>,
159         inflight: Option<&mut Inflight>,
160     ) -> Result<()> {
161         self.vu
162             .set_features(acked_features)
163             .map_err(Error::VhostUserSetFeatures)?;
164 
165         // Update internal value after it's been sent to the backend.
166         self.acked_features = acked_features;
167 
168         // Let's first provide the memory table to the backend.
169         self.update_mem_table(mem)?;
170 
171         // Send set_vring_num here, since it could tell backends, like SPDK,
172         // how many virt queues to be handled, which backend required to know
173         // at early stage.
174         for (queue_index, queue) in queues.iter().enumerate() {
175             self.vu
176                 .set_vring_num(queue_index, queue.max_size())
177                 .map_err(Error::VhostUserSetVringNum)?;
178         }
179 
180         // Setup for inflight I/O tracking shared memory.
181         if let Some(inflight) = inflight {
182             if inflight.fd.is_none() {
183                 let inflight_req_info = VhostUserInflight {
184                     mmap_size: 0,
185                     mmap_offset: 0,
186                     num_queues: queues.len() as u16,
187                     queue_size: queues[0].max_size(),
188                 };
189                 let (info, fd) = self
190                     .vu
191                     .get_inflight_fd(&inflight_req_info)
192                     .map_err(Error::VhostUserGetInflight)?;
193                 inflight.info = info;
194                 inflight.fd = Some(fd);
195             }
196             // Unwrapping the inflight fd is safe here since we know it can't be None.
197             self.vu
198                 .set_inflight_fd(&inflight.info, inflight.fd.as_ref().unwrap().as_raw_fd())
199                 .map_err(Error::VhostUserSetInflight)?;
200         }
201 
202         let num_queues = queues.len() as usize;
203 
204         let mut vrings_info = Vec::new();
205         for (queue_index, queue) in queues.into_iter().enumerate() {
206             let actual_size: usize = queue.max_size().try_into().unwrap();
207 
208             let config_data = VringConfigData {
209                 queue_max_size: queue.max_size(),
210                 queue_size: queue.max_size(),
211                 flags: 0u32,
212                 desc_table_addr: get_host_address_range(
213                     mem,
214                     queue.state.desc_table,
215                     actual_size * std::mem::size_of::<Descriptor>(),
216                 )
217                 .ok_or(Error::DescriptorTableAddress)? as u64,
218                 // The used ring is {flags: u16; idx: u16; virtq_used_elem [{id: u16, len: u16}; actual_size]},
219                 // i.e. 4 + (4 + 4) * actual_size.
220                 used_ring_addr: get_host_address_range(
221                     mem,
222                     queue.state.used_ring,
223                     4 + actual_size * 8,
224                 )
225                 .ok_or(Error::UsedAddress)? as u64,
226                 // The used ring is {flags: u16; idx: u16; elem [u16; actual_size]},
227                 // i.e. 4 + (2) * actual_size.
228                 avail_ring_addr: get_host_address_range(
229                     mem,
230                     queue.state.avail_ring,
231                     4 + actual_size * 2,
232                 )
233                 .ok_or(Error::AvailAddress)? as u64,
234                 log_addr: None,
235             };
236 
237             vrings_info.push(VringInfo {
238                 config_data,
239                 used_guest_addr: queue.state.used_ring.raw_value(),
240             });
241 
242             self.vu
243                 .set_vring_addr(queue_index, &config_data)
244                 .map_err(Error::VhostUserSetVringAddr)?;
245             self.vu
246                 .set_vring_base(
247                     queue_index,
248                     queue
249                         .avail_idx(Ordering::Acquire)
250                         .map_err(Error::GetAvailableIndex)?
251                         .0,
252                 )
253                 .map_err(Error::VhostUserSetVringBase)?;
254 
255             if let Some(eventfd) =
256                 virtio_interrupt.notifier(VirtioInterruptType::Queue(queue_index as u16))
257             {
258                 self.vu
259                     .set_vring_call(queue_index, &eventfd)
260                     .map_err(Error::VhostUserSetVringCall)?;
261             }
262 
263             self.vu
264                 .set_vring_kick(queue_index, &queue_evts[queue_index])
265                 .map_err(Error::VhostUserSetVringKick)?;
266         }
267 
268         self.enable_vhost_user_vrings(num_queues, true)?;
269 
270         if let Some(slave_req_handler) = slave_req_handler {
271             self.vu
272                 .set_slave_request_fd(&slave_req_handler.get_tx_raw_fd())
273                 .map_err(Error::VhostUserSetSlaveRequestFd)?;
274         }
275 
276         self.vrings_info = Some(vrings_info);
277         self.ready = true;
278 
279         Ok(())
280     }
281 
282     fn enable_vhost_user_vrings(&mut self, num_queues: usize, enable: bool) -> Result<()> {
283         for queue_index in 0..num_queues {
284             self.vu
285                 .set_vring_enable(queue_index, enable)
286                 .map_err(Error::VhostUserSetVringEnable)?;
287         }
288 
289         Ok(())
290     }
291 
292     pub fn reset_vhost_user(&mut self, num_queues: usize) -> Result<()> {
293         self.enable_vhost_user_vrings(num_queues, false)?;
294 
295         // Reset the owner.
296         self.vu.reset_owner().map_err(Error::VhostUserResetOwner)
297     }
298 
299     pub fn set_protocol_features_vhost_user(
300         &mut self,
301         acked_features: u64,
302         acked_protocol_features: u64,
303     ) -> Result<()> {
304         self.vu.set_owner().map_err(Error::VhostUserSetOwner)?;
305         self.vu
306             .get_features()
307             .map_err(Error::VhostUserGetFeatures)?;
308 
309         if acked_features & VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits() != 0 {
310             if let Some(acked_protocol_features) =
311                 VhostUserProtocolFeatures::from_bits(acked_protocol_features)
312             {
313                 self.vu
314                     .set_protocol_features(acked_protocol_features)
315                     .map_err(Error::VhostUserSetProtocolFeatures)?;
316 
317                 if acked_protocol_features.contains(VhostUserProtocolFeatures::REPLY_ACK) {
318                     self.vu.set_hdr_flags(VhostUserHeaderFlag::NEED_REPLY);
319                 }
320             }
321         }
322 
323         self.update_supports_migration(acked_features, acked_protocol_features);
324 
325         Ok(())
326     }
327 
328     #[allow(clippy::too_many_arguments)]
329     pub fn reinitialize_vhost_user<S: VhostUserMasterReqHandler>(
330         &mut self,
331         mem: &GuestMemoryMmap,
332         queues: Vec<Queue<GuestMemoryAtomic<GuestMemoryMmap>>>,
333         queue_evts: Vec<EventFd>,
334         virtio_interrupt: &Arc<dyn VirtioInterrupt>,
335         acked_features: u64,
336         acked_protocol_features: u64,
337         slave_req_handler: &Option<MasterReqHandler<S>>,
338         inflight: Option<&mut Inflight>,
339     ) -> Result<()> {
340         self.set_protocol_features_vhost_user(acked_features, acked_protocol_features)?;
341 
342         self.setup_vhost_user(
343             mem,
344             queues,
345             queue_evts,
346             virtio_interrupt,
347             acked_features,
348             slave_req_handler,
349             inflight,
350         )
351     }
352 
353     pub fn connect_vhost_user(
354         server: bool,
355         socket_path: &str,
356         num_queues: u64,
357         unlink_socket: bool,
358     ) -> Result<Self> {
359         if server {
360             if unlink_socket {
361                 std::fs::remove_file(socket_path).map_err(Error::RemoveSocketPath)?;
362             }
363 
364             info!("Binding vhost-user listener...");
365             let listener = UnixListener::bind(socket_path).map_err(Error::BindSocket)?;
366             info!("Waiting for incoming vhost-user connection...");
367             let (stream, _) = listener.accept().map_err(Error::AcceptConnection)?;
368 
369             Ok(VhostUserHandle {
370                 vu: Master::from_stream(stream, num_queues),
371                 ready: false,
372                 supports_migration: false,
373                 shm_log: None,
374                 acked_features: 0,
375                 vrings_info: None,
376             })
377         } else {
378             let now = Instant::now();
379 
380             // Retry connecting for a full minute
381             let err = loop {
382                 let err = match Master::connect(socket_path, num_queues) {
383                     Ok(m) => {
384                         return Ok(VhostUserHandle {
385                             vu: m,
386                             ready: false,
387                             supports_migration: false,
388                             shm_log: None,
389                             acked_features: 0,
390                             vrings_info: None,
391                         })
392                     }
393                     Err(e) => e,
394                 };
395                 sleep(Duration::from_millis(100));
396 
397                 if now.elapsed().as_secs() >= 60 {
398                     break err;
399                 }
400             };
401 
402             error!(
403                 "Failed connecting the backend after trying for 1 minute: {:?}",
404                 err
405             );
406             Err(Error::VhostUserConnect)
407         }
408     }
409 
410     pub fn socket_handle(&mut self) -> &mut Master {
411         &mut self.vu
412     }
413 
414     pub fn pause_vhost_user(&mut self, num_queues: usize) -> Result<()> {
415         if self.ready {
416             self.enable_vhost_user_vrings(num_queues, false)?;
417         }
418 
419         Ok(())
420     }
421 
422     pub fn resume_vhost_user(&mut self, num_queues: usize) -> Result<()> {
423         if self.ready {
424             self.enable_vhost_user_vrings(num_queues, true)?;
425         }
426 
427         Ok(())
428     }
429 
430     fn update_supports_migration(&mut self, acked_features: u64, acked_protocol_features: u64) {
431         if (acked_features & u64::from(vhost::vhost_kern::vhost_binding::VHOST_F_LOG_ALL) != 0)
432             && (acked_protocol_features & VhostUserProtocolFeatures::LOG_SHMFD.bits() != 0)
433         {
434             self.supports_migration = true;
435         }
436     }
437 
438     fn update_log_base(&mut self, last_ram_addr: u64) -> Result<Option<Arc<MmapRegion>>> {
439         // Create the memfd
440         let fd = memfd_create(
441             &ffi::CString::new("vhost_user_dirty_log").unwrap(),
442             libc::MFD_CLOEXEC | libc::MFD_ALLOW_SEALING,
443         )
444         .map_err(Error::MemfdCreate)?;
445 
446         // Safe because we checked the file descriptor is valid
447         let file = unsafe { File::from_raw_fd(fd) };
448         // The size of the memory mapping corresponds to the size of a bitmap
449         // covering all guest pages for addresses from 0 to the last physical
450         // address in guest RAM.
451         // A page is always 4kiB from a vhost-user perspective, and each bit is
452         // a page. That's how we can compute mmap_size from the last address.
453         let mmap_size = (last_ram_addr / (VHOST_LOG_PAGE * 8)) + 1;
454         let mmap_handle = file.as_raw_fd();
455 
456         // Set shm_log region size
457         file.set_len(mmap_size).map_err(Error::SetFileSize)?;
458 
459         // Set the seals
460         let res = unsafe {
461             libc::fcntl(
462                 file.as_raw_fd(),
463                 libc::F_ADD_SEALS,
464                 libc::F_SEAL_GROW | libc::F_SEAL_SHRINK | libc::F_SEAL_SEAL,
465             )
466         };
467         if res < 0 {
468             return Err(Error::SetSeals(std::io::Error::last_os_error()));
469         }
470 
471         // Mmap shm_log region
472         let region = MmapRegion::build(
473             Some(FileOffset::new(file, 0)),
474             mmap_size as usize,
475             libc::PROT_READ | libc::PROT_WRITE,
476             libc::MAP_SHARED,
477         )
478         .map_err(Error::NewMmapRegion)?;
479 
480         // Make sure we hold onto the region to prevent the mapping from being
481         // released.
482         let old_region = self.shm_log.replace(Arc::new(region));
483 
484         // Send the shm_log fd over to the backend
485         let log = VhostUserDirtyLogRegion {
486             mmap_size,
487             mmap_offset: 0,
488             mmap_handle,
489         };
490         self.vu
491             .set_log_base(0, Some(log))
492             .map_err(Error::VhostUserSetLogBase)?;
493 
494         Ok(old_region)
495     }
496 
497     fn set_vring_logging(&mut self, enable: bool) -> Result<()> {
498         if let Some(vrings_info) = &self.vrings_info {
499             for (i, vring_info) in vrings_info.iter().enumerate() {
500                 let mut config_data = vring_info.config_data;
501                 config_data.flags = if enable { 1 << VHOST_VRING_F_LOG } else { 0 };
502                 config_data.log_addr = if enable {
503                     Some(vring_info.used_guest_addr)
504                 } else {
505                     None
506                 };
507 
508                 self.vu
509                     .set_vring_addr(i, &config_data)
510                     .map_err(Error::VhostUserSetVringAddr)?;
511             }
512         }
513 
514         Ok(())
515     }
516 
517     pub fn start_dirty_log(&mut self, last_ram_addr: u64) -> Result<()> {
518         if !self.supports_migration {
519             return Err(Error::MigrationNotSupported);
520         }
521 
522         // Set the shm log region
523         self.update_log_base(last_ram_addr)?;
524 
525         // Enable VHOST_F_LOG_ALL feature
526         let features = self.acked_features | (1 << VHOST_F_LOG_ALL);
527         self.vu
528             .set_features(features)
529             .map_err(Error::VhostUserSetFeatures)?;
530 
531         // Enable dirty page logging of used ring for all queues
532         self.set_vring_logging(true)
533     }
534 
535     pub fn stop_dirty_log(&mut self) -> Result<()> {
536         if !self.supports_migration {
537             return Err(Error::MigrationNotSupported);
538         }
539 
540         // Disable dirty page logging of used ring for all queues
541         self.set_vring_logging(false)?;
542 
543         // Disable VHOST_F_LOG_ALL feature
544         self.vu
545             .set_features(self.acked_features)
546             .map_err(Error::VhostUserSetFeatures)?;
547 
548         // This is important here since the log region goes out of scope,
549         // invoking the Drop trait, hence unmapping the memory.
550         self.shm_log = None;
551 
552         Ok(())
553     }
554 
555     pub fn dirty_log(&mut self, last_ram_addr: u64) -> Result<MemoryRangeTable> {
556         // The log region is updated by creating a new region that is sent to
557         // the backend. This ensures the backend stops logging to the previous
558         // region. The previous region is returned and processed to create the
559         // bitmap representing the dirty pages.
560         if let Some(region) = self.update_log_base(last_ram_addr)? {
561             // Be careful with the size, as it was based on u8, meaning we must
562             // divide it by 8.
563             let len = region.size() / 8;
564             let bitmap = unsafe {
565                 // Cast the pointer to u64
566                 let ptr = region.as_ptr() as *const u64;
567                 std::slice::from_raw_parts(ptr, len).to_vec()
568             };
569             Ok(MemoryRangeTable::from_bitmap(bitmap, 0, 4096))
570         } else {
571             Err(Error::MissingShmLogRegion)
572         }
573     }
574 }
575 
576 fn memfd_create(name: &ffi::CStr, flags: u32) -> std::result::Result<RawFd, std::io::Error> {
577     let res = unsafe { libc::syscall(libc::SYS_memfd_create, name.as_ptr(), flags) };
578 
579     if res < 0 {
580         Err(std::io::Error::last_os_error())
581     } else {
582         Ok(res as RawFd)
583     }
584 }
585