xref: /cloud-hypervisor/virtio-devices/src/block.rs (revision 21f05ebb4fb0ddf1f148d9b5329c9259297ed3c7)
1 // Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 //
3 // Portions Copyright 2017 The Chromium OS Authors. All rights reserved.
4 // Use of this source code is governed by a BSD-style license that can be
5 // found in the LICENSE-BSD-3-Clause file.
6 //
7 // Copyright © 2020 Intel Corporation
8 //
9 // SPDX-License-Identifier: Apache-2.0 AND BSD-3-Clause
10 
11 use std::collections::{BTreeMap, HashMap, VecDeque};
12 use std::num::Wrapping;
13 use std::ops::Deref;
14 use std::os::unix::io::AsRawFd;
15 use std::path::PathBuf;
16 use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
17 use std::sync::{Arc, Barrier};
18 use std::{io, result};
19 
20 use anyhow::anyhow;
21 use block::async_io::{AsyncIo, AsyncIoError, DiskFile};
22 use block::{build_serial, Request, RequestType, VirtioBlockConfig};
23 use rate_limiter::group::{RateLimiterGroup, RateLimiterGroupHandle};
24 use rate_limiter::TokenType;
25 use seccompiler::SeccompAction;
26 use serde::{Deserialize, Serialize};
27 use thiserror::Error;
28 use virtio_bindings::virtio_blk::*;
29 use virtio_bindings::virtio_config::*;
30 use virtio_bindings::virtio_ring::{VIRTIO_RING_F_EVENT_IDX, VIRTIO_RING_F_INDIRECT_DESC};
31 use virtio_queue::{Queue, QueueOwnedT, QueueT};
32 use vm_memory::{ByteValued, Bytes, GuestAddressSpace, GuestMemoryAtomic, GuestMemoryError};
33 use vm_migration::{Migratable, MigratableError, Pausable, Snapshot, Snapshottable, Transportable};
34 use vm_virtio::AccessPlatform;
35 use vmm_sys_util::eventfd::EventFd;
36 
37 use super::{
38     ActivateError, ActivateResult, EpollHelper, EpollHelperError, EpollHelperHandler,
39     Error as DeviceError, VirtioCommon, VirtioDevice, VirtioDeviceType, VirtioInterruptType,
40     EPOLL_HELPER_EVENT_LAST,
41 };
42 use crate::seccomp_filters::Thread;
43 use crate::thread_helper::spawn_virtio_thread;
44 use crate::{GuestMemoryMmap, VirtioInterrupt};
45 
46 const SECTOR_SHIFT: u8 = 9;
47 pub const SECTOR_SIZE: u64 = 0x01 << SECTOR_SHIFT;
48 
49 // New descriptors are pending on the virtio queue.
50 const QUEUE_AVAIL_EVENT: u16 = EPOLL_HELPER_EVENT_LAST + 1;
51 // New completed tasks are pending on the completion ring.
52 const COMPLETION_EVENT: u16 = EPOLL_HELPER_EVENT_LAST + 2;
53 // New 'wake up' event from the rate limiter
54 const RATE_LIMITER_EVENT: u16 = EPOLL_HELPER_EVENT_LAST + 3;
55 
56 // latency scale, for reduce precision loss in calculate.
57 const LATENCY_SCALE: u64 = 10000;
58 
59 pub const MINIMUM_BLOCK_QUEUE_SIZE: u16 = 2;
60 
61 #[derive(Error, Debug)]
62 pub enum Error {
63     #[error("Failed to parse the request: {0}")]
64     RequestParsing(block::Error),
65     #[error("Failed to execute the request: {0}")]
66     RequestExecuting(block::ExecuteError),
67     #[error("Failed to complete the request: {0}")]
68     RequestCompleting(block::Error),
69     #[error("Missing the expected entry in the list of requests")]
70     MissingEntryRequestList,
71     #[error("The asynchronous request returned with failure")]
72     AsyncRequestFailure,
73     #[error("Failed synchronizing the file: {0}")]
74     Fsync(AsyncIoError),
75     #[error("Failed adding used index: {0}")]
76     QueueAddUsed(virtio_queue::Error),
77     #[error("Failed creating an iterator over the queue: {0}")]
78     QueueIterator(virtio_queue::Error),
79     #[error("Failed to update request status: {0}")]
80     RequestStatus(GuestMemoryError),
81     #[error("Failed to enable notification: {0}")]
82     QueueEnableNotification(virtio_queue::Error),
83 }
84 
85 pub type Result<T> = result::Result<T, Error>;
86 
87 // latency will be records as microseconds, average latency
88 // will be save as scaled value.
89 #[derive(Clone)]
90 pub struct BlockCounters {
91     read_bytes: Arc<AtomicU64>,
92     read_ops: Arc<AtomicU64>,
93     read_latency_min: Arc<AtomicU64>,
94     read_latency_max: Arc<AtomicU64>,
95     read_latency_avg: Arc<AtomicU64>,
96     write_bytes: Arc<AtomicU64>,
97     write_ops: Arc<AtomicU64>,
98     write_latency_min: Arc<AtomicU64>,
99     write_latency_max: Arc<AtomicU64>,
100     write_latency_avg: Arc<AtomicU64>,
101 }
102 
103 impl Default for BlockCounters {
104     fn default() -> Self {
105         BlockCounters {
106             read_bytes: Arc::new(AtomicU64::new(0)),
107             read_ops: Arc::new(AtomicU64::new(0)),
108             read_latency_min: Arc::new(AtomicU64::new(u64::MAX)),
109             read_latency_max: Arc::new(AtomicU64::new(u64::MAX)),
110             read_latency_avg: Arc::new(AtomicU64::new(u64::MAX)),
111             write_bytes: Arc::new(AtomicU64::new(0)),
112             write_ops: Arc::new(AtomicU64::new(0)),
113             write_latency_min: Arc::new(AtomicU64::new(u64::MAX)),
114             write_latency_max: Arc::new(AtomicU64::new(u64::MAX)),
115             write_latency_avg: Arc::new(AtomicU64::new(u64::MAX)),
116         }
117     }
118 }
119 
120 struct BlockEpollHandler {
121     queue_index: u16,
122     queue: Queue,
123     mem: GuestMemoryAtomic<GuestMemoryMmap>,
124     disk_image: Box<dyn AsyncIo>,
125     disk_nsectors: u64,
126     interrupt_cb: Arc<dyn VirtioInterrupt>,
127     serial: Vec<u8>,
128     kill_evt: EventFd,
129     pause_evt: EventFd,
130     writeback: Arc<AtomicBool>,
131     counters: BlockCounters,
132     queue_evt: EventFd,
133     inflight_requests: VecDeque<(u16, Request)>,
134     rate_limiter: Option<RateLimiterGroupHandle>,
135     access_platform: Option<Arc<dyn AccessPlatform>>,
136     read_only: bool,
137     host_cpus: Option<Vec<usize>>,
138 }
139 
140 impl BlockEpollHandler {
141     fn process_queue_submit(&mut self) -> Result<()> {
142         let queue = &mut self.queue;
143 
144         while let Some(mut desc_chain) = queue.pop_descriptor_chain(self.mem.memory()) {
145             let mut request = Request::parse(&mut desc_chain, self.access_platform.as_ref())
146                 .map_err(Error::RequestParsing)?;
147 
148             // For virtio spec compliance
149             // "A device MUST set the status byte to VIRTIO_BLK_S_IOERR for a write request
150             // if the VIRTIO_BLK_F_RO feature if offered, and MUST NOT write any data."
151             if self.read_only
152                 && (request.request_type == RequestType::Out
153                     || request.request_type == RequestType::Flush)
154             {
155                 desc_chain
156                     .memory()
157                     .write_obj(VIRTIO_BLK_S_IOERR, request.status_addr)
158                     .map_err(Error::RequestStatus)?;
159 
160                 // If no asynchronous operation has been submitted, we can
161                 // simply return the used descriptor.
162                 queue
163                     .add_used(desc_chain.memory(), desc_chain.head_index(), 0)
164                     .map_err(Error::QueueAddUsed)?;
165                 queue
166                     .enable_notification(self.mem.memory().deref())
167                     .map_err(Error::QueueEnableNotification)?;
168                 continue;
169             }
170 
171             if let Some(rate_limiter) = &mut self.rate_limiter {
172                 // If limiter.consume() fails it means there is no more TokenType::Ops
173                 // budget and rate limiting is in effect.
174                 if !rate_limiter.consume(1, TokenType::Ops) {
175                     // Stop processing the queue and return this descriptor chain to the
176                     // avail ring, for later processing.
177                     queue.go_to_previous_position();
178                     break;
179                 }
180                 // Exercise the rate limiter only if this request is of data transfer type.
181                 if request.request_type == RequestType::In
182                     || request.request_type == RequestType::Out
183                 {
184                     let mut bytes = Wrapping(0);
185                     for (_, data_len) in &request.data_descriptors {
186                         bytes += Wrapping(*data_len as u64);
187                     }
188 
189                     // If limiter.consume() fails it means there is no more TokenType::Bytes
190                     // budget and rate limiting is in effect.
191                     if !rate_limiter.consume(bytes.0, TokenType::Bytes) {
192                         // Revert the OPS consume().
193                         rate_limiter.manual_replenish(1, TokenType::Ops);
194                         // Stop processing the queue and return this descriptor chain to the
195                         // avail ring, for later processing.
196                         queue.go_to_previous_position();
197                         break;
198                     }
199                 };
200             }
201 
202             request.set_writeback(self.writeback.load(Ordering::Acquire));
203 
204             if request
205                 .execute_async(
206                     desc_chain.memory(),
207                     self.disk_nsectors,
208                     self.disk_image.as_mut(),
209                     &self.serial,
210                     desc_chain.head_index() as u64,
211                 )
212                 .map_err(Error::RequestExecuting)?
213             {
214                 self.inflight_requests
215                     .push_back((desc_chain.head_index(), request));
216             } else {
217                 desc_chain
218                     .memory()
219                     .write_obj(VIRTIO_BLK_S_OK as u8, request.status_addr)
220                     .map_err(Error::RequestStatus)?;
221 
222                 // If no asynchronous operation has been submitted, we can
223                 // simply return the used descriptor.
224                 queue
225                     .add_used(desc_chain.memory(), desc_chain.head_index(), 0)
226                     .map_err(Error::QueueAddUsed)?;
227                 queue
228                     .enable_notification(self.mem.memory().deref())
229                     .map_err(Error::QueueEnableNotification)?;
230             }
231         }
232 
233         Ok(())
234     }
235 
236     fn try_signal_used_queue(&mut self) -> result::Result<(), EpollHelperError> {
237         if self
238             .queue
239             .needs_notification(self.mem.memory().deref())
240             .map_err(|e| {
241                 EpollHelperError::HandleEvent(anyhow!(
242                     "Failed to check needs_notification: {:?}",
243                     e
244                 ))
245             })?
246         {
247             self.signal_used_queue().map_err(|e| {
248                 EpollHelperError::HandleEvent(anyhow!("Failed to signal used queue: {:?}", e))
249             })?;
250         }
251 
252         Ok(())
253     }
254 
255     fn process_queue_submit_and_signal(&mut self) -> result::Result<(), EpollHelperError> {
256         self.process_queue_submit().map_err(|e| {
257             EpollHelperError::HandleEvent(anyhow!("Failed to process queue (submit): {:?}", e))
258         })?;
259 
260         self.try_signal_used_queue()
261     }
262 
263     #[inline]
264     fn find_inflight_request(&mut self, completed_head: u16) -> Result<Request> {
265         // This loop neatly handles the fast path where the completions are
266         // in order (it turns into just a pop_front()) and the 1% of the time
267         // (analysis during boot) where slight out of ordering has been
268         // observed e.g.
269         // Submissions: 1 2 3 4 5 6 7
270         // Completions: 2 1 3 5 4 7 6
271         // In this case find the corresponding item and swap it with the front
272         // This is a O(1) operation and is prepared for the future as it it likely
273         // the next completion would be for the one that was skipped which will
274         // now be the new front.
275         for (i, (head, _)) in self.inflight_requests.iter().enumerate() {
276             if head == &completed_head {
277                 return Ok(self.inflight_requests.swap_remove_front(i).unwrap().1);
278             }
279         }
280 
281         Err(Error::MissingEntryRequestList)
282     }
283 
284     fn process_queue_complete(&mut self) -> Result<()> {
285         let mem = self.mem.memory();
286         let mut read_bytes = Wrapping(0);
287         let mut write_bytes = Wrapping(0);
288         let mut read_ops = Wrapping(0);
289         let mut write_ops = Wrapping(0);
290 
291         while let Some((user_data, result)) = self.disk_image.next_completed_request() {
292             let desc_index = user_data as u16;
293 
294             let mut request = self.find_inflight_request(desc_index)?;
295 
296             request.complete_async().map_err(Error::RequestCompleting)?;
297 
298             let latency = request.start.elapsed().as_micros() as u64;
299             let read_ops_last = self.counters.read_ops.load(Ordering::Relaxed);
300             let write_ops_last = self.counters.write_ops.load(Ordering::Relaxed);
301             let read_max = self.counters.read_latency_max.load(Ordering::Relaxed);
302             let write_max = self.counters.write_latency_max.load(Ordering::Relaxed);
303             let mut read_avg = self.counters.read_latency_avg.load(Ordering::Relaxed);
304             let mut write_avg = self.counters.write_latency_avg.load(Ordering::Relaxed);
305             let (status, len) = if result >= 0 {
306                 match request.request_type {
307                     RequestType::In => {
308                         for (_, data_len) in &request.data_descriptors {
309                             read_bytes += Wrapping(*data_len as u64);
310                         }
311                         read_ops += Wrapping(1);
312                         if latency < self.counters.read_latency_min.load(Ordering::Relaxed) {
313                             self.counters
314                                 .read_latency_min
315                                 .store(latency, Ordering::Relaxed);
316                         }
317                         if latency > read_max || read_max == u64::MAX {
318                             self.counters
319                                 .read_latency_max
320                                 .store(latency, Ordering::Relaxed);
321                         }
322 
323                         // Special case the first real latency report
324                         read_avg = if read_avg == u64::MAX {
325                             latency * LATENCY_SCALE
326                         } else {
327                             // Cumulative average is guaranteed to be
328                             // positive if being calculated properly
329                             (read_avg as i64
330                                 + ((latency * LATENCY_SCALE) as i64 - read_avg as i64)
331                                     / (read_ops_last + read_ops.0) as i64)
332                                 .try_into()
333                                 .unwrap()
334                         };
335                     }
336                     RequestType::Out => {
337                         if !request.writeback {
338                             self.disk_image.fsync(None).map_err(Error::Fsync)?;
339                         }
340                         for (_, data_len) in &request.data_descriptors {
341                             write_bytes += Wrapping(*data_len as u64);
342                         }
343                         write_ops += Wrapping(1);
344                         if latency < self.counters.write_latency_min.load(Ordering::Relaxed) {
345                             self.counters
346                                 .write_latency_min
347                                 .store(latency, Ordering::Relaxed);
348                         }
349                         if latency > write_max || write_max == u64::MAX {
350                             self.counters
351                                 .write_latency_max
352                                 .store(latency, Ordering::Relaxed);
353                         }
354 
355                         // Special case the first real latency report
356                         write_avg = if write_avg == u64::MAX {
357                             latency * LATENCY_SCALE
358                         } else {
359                             // Cumulative average is guaranteed to be
360                             // positive if being calculated properly
361                             (write_avg as i64
362                                 + ((latency * LATENCY_SCALE) as i64 - write_avg as i64)
363                                     / (write_ops_last + write_ops.0) as i64)
364                                 .try_into()
365                                 .unwrap()
366                         }
367                     }
368                     _ => {}
369                 }
370 
371                 self.counters
372                     .read_latency_avg
373                     .store(read_avg, Ordering::Relaxed);
374 
375                 self.counters
376                     .write_latency_avg
377                     .store(write_avg, Ordering::Relaxed);
378 
379                 (VIRTIO_BLK_S_OK as u8, result as u32)
380             } else {
381                 error!(
382                     "Request failed: {:x?} {:?}",
383                     request,
384                     io::Error::from_raw_os_error(-result)
385                 );
386                 return Err(Error::AsyncRequestFailure);
387             };
388 
389             mem.write_obj(status, request.status_addr)
390                 .map_err(Error::RequestStatus)?;
391 
392             let queue = &mut self.queue;
393 
394             queue
395                 .add_used(mem.deref(), desc_index, len)
396                 .map_err(Error::QueueAddUsed)?;
397             queue
398                 .enable_notification(mem.deref())
399                 .map_err(Error::QueueEnableNotification)?;
400         }
401 
402         self.counters
403             .write_bytes
404             .fetch_add(write_bytes.0, Ordering::AcqRel);
405         self.counters
406             .write_ops
407             .fetch_add(write_ops.0, Ordering::AcqRel);
408 
409         self.counters
410             .read_bytes
411             .fetch_add(read_bytes.0, Ordering::AcqRel);
412         self.counters
413             .read_ops
414             .fetch_add(read_ops.0, Ordering::AcqRel);
415 
416         Ok(())
417     }
418 
419     fn signal_used_queue(&self) -> result::Result<(), DeviceError> {
420         self.interrupt_cb
421             .trigger(VirtioInterruptType::Queue(self.queue_index))
422             .map_err(|e| {
423                 error!("Failed to signal used queue: {:?}", e);
424                 DeviceError::FailedSignalingUsedQueue(e)
425             })
426     }
427 
428     fn set_queue_thread_affinity(&self) {
429         // Prepare the CPU set the current queue thread is expected to run onto.
430         let cpuset = self.host_cpus.as_ref().map(|host_cpus| {
431             // SAFETY: all zeros is a valid pattern
432             let mut cpuset: libc::cpu_set_t = unsafe { std::mem::zeroed() };
433             // SAFETY: FFI call, trivially safe
434             unsafe { libc::CPU_ZERO(&mut cpuset) };
435             for host_cpu in host_cpus {
436                 // SAFETY: FFI call, trivially safe
437                 unsafe { libc::CPU_SET(*host_cpu, &mut cpuset) };
438             }
439             cpuset
440         });
441 
442         // Schedule the thread to run on the expected CPU set
443         if let Some(cpuset) = cpuset.as_ref() {
444             // SAFETY: FFI call with correct arguments
445             let ret = unsafe {
446                 libc::sched_setaffinity(
447                     0,
448                     std::mem::size_of::<libc::cpu_set_t>(),
449                     cpuset as *const libc::cpu_set_t,
450                 )
451             };
452 
453             if ret != 0 {
454                 error!(
455                     "Failed scheduling the virtqueue thread {} on the expected CPU set: {}",
456                     self.queue_index,
457                     io::Error::last_os_error()
458                 )
459             }
460         }
461     }
462 
463     fn run(
464         &mut self,
465         paused: Arc<AtomicBool>,
466         paused_sync: Arc<Barrier>,
467     ) -> result::Result<(), EpollHelperError> {
468         let mut helper = EpollHelper::new(&self.kill_evt, &self.pause_evt)?;
469         helper.add_event(self.queue_evt.as_raw_fd(), QUEUE_AVAIL_EVENT)?;
470         helper.add_event(self.disk_image.notifier().as_raw_fd(), COMPLETION_EVENT)?;
471         if let Some(rate_limiter) = &self.rate_limiter {
472             helper.add_event(rate_limiter.as_raw_fd(), RATE_LIMITER_EVENT)?;
473         }
474         self.set_queue_thread_affinity();
475         helper.run(paused, paused_sync, self)?;
476 
477         Ok(())
478     }
479 }
480 
481 impl EpollHelperHandler for BlockEpollHandler {
482     fn handle_event(
483         &mut self,
484         _helper: &mut EpollHelper,
485         event: &epoll::Event,
486     ) -> result::Result<(), EpollHelperError> {
487         let ev_type = event.data as u16;
488         match ev_type {
489             QUEUE_AVAIL_EVENT => {
490                 self.queue_evt.read().map_err(|e| {
491                     EpollHelperError::HandleEvent(anyhow!("Failed to get queue event: {:?}", e))
492                 })?;
493 
494                 let rate_limit_reached = self.rate_limiter.as_ref().is_some_and(|r| r.is_blocked());
495 
496                 // Process the queue only when the rate limit is not reached
497                 if !rate_limit_reached {
498                     self.process_queue_submit_and_signal()?
499                 }
500             }
501             COMPLETION_EVENT => {
502                 self.disk_image.notifier().read().map_err(|e| {
503                     EpollHelperError::HandleEvent(anyhow!("Failed to get queue event: {:?}", e))
504                 })?;
505 
506                 self.process_queue_complete().map_err(|e| {
507                     EpollHelperError::HandleEvent(anyhow!(
508                         "Failed to process queue (complete): {:?}",
509                         e
510                     ))
511                 })?;
512 
513                 let rate_limit_reached = self.rate_limiter.as_ref().is_some_and(|r| r.is_blocked());
514 
515                 // Process the queue only when the rate limit is not reached
516                 if !rate_limit_reached {
517                     self.process_queue_submit().map_err(|e| {
518                         EpollHelperError::HandleEvent(anyhow!(
519                             "Failed to process queue (submit): {:?}",
520                             e
521                         ))
522                     })?;
523                 }
524                 self.try_signal_used_queue()?;
525             }
526             RATE_LIMITER_EVENT => {
527                 if let Some(rate_limiter) = &mut self.rate_limiter {
528                     // Upon rate limiter event, call the rate limiter handler
529                     // and restart processing the queue.
530                     rate_limiter.event_handler().map_err(|e| {
531                         EpollHelperError::HandleEvent(anyhow!(
532                             "Failed to process rate limiter event: {:?}",
533                             e
534                         ))
535                     })?;
536 
537                     self.process_queue_submit_and_signal()?
538                 } else {
539                     return Err(EpollHelperError::HandleEvent(anyhow!(
540                         "Unexpected 'RATE_LIMITER_EVENT' when rate_limiter is not enabled."
541                     )));
542                 }
543             }
544             _ => {
545                 return Err(EpollHelperError::HandleEvent(anyhow!(
546                     "Unexpected event: {}",
547                     ev_type
548                 )));
549             }
550         }
551         Ok(())
552     }
553 }
554 
555 /// Virtio device for exposing block level read/write operations on a host file.
556 pub struct Block {
557     common: VirtioCommon,
558     id: String,
559     disk_image: Box<dyn DiskFile>,
560     disk_path: PathBuf,
561     disk_nsectors: u64,
562     config: VirtioBlockConfig,
563     writeback: Arc<AtomicBool>,
564     counters: BlockCounters,
565     seccomp_action: SeccompAction,
566     rate_limiter: Option<Arc<RateLimiterGroup>>,
567     exit_evt: EventFd,
568     read_only: bool,
569     serial: Vec<u8>,
570     queue_affinity: BTreeMap<u16, Vec<usize>>,
571 }
572 
573 #[derive(Serialize, Deserialize)]
574 pub struct BlockState {
575     pub disk_path: String,
576     pub disk_nsectors: u64,
577     pub avail_features: u64,
578     pub acked_features: u64,
579     pub config: VirtioBlockConfig,
580 }
581 
582 impl Block {
583     /// Create a new virtio block device that operates on the given file.
584     #[allow(clippy::too_many_arguments)]
585     pub fn new(
586         id: String,
587         mut disk_image: Box<dyn DiskFile>,
588         disk_path: PathBuf,
589         read_only: bool,
590         iommu: bool,
591         num_queues: usize,
592         queue_size: u16,
593         serial: Option<String>,
594         seccomp_action: SeccompAction,
595         rate_limiter: Option<Arc<RateLimiterGroup>>,
596         exit_evt: EventFd,
597         state: Option<BlockState>,
598         queue_affinity: BTreeMap<u16, Vec<usize>>,
599     ) -> io::Result<Self> {
600         let (disk_nsectors, avail_features, acked_features, config, paused) =
601             if let Some(state) = state {
602                 info!("Restoring virtio-block {}", id);
603                 (
604                     state.disk_nsectors,
605                     state.avail_features,
606                     state.acked_features,
607                     state.config,
608                     true,
609                 )
610             } else {
611                 let disk_size = disk_image.size().map_err(|e| {
612                     io::Error::new(
613                         io::ErrorKind::Other,
614                         format!("Failed getting disk size: {e}"),
615                     )
616                 })?;
617                 if disk_size % SECTOR_SIZE != 0 {
618                     warn!(
619                         "Disk size {} is not a multiple of sector size {}; \
620                  the remainder will not be visible to the guest.",
621                         disk_size, SECTOR_SIZE
622                     );
623                 }
624 
625                 let mut avail_features = (1u64 << VIRTIO_F_VERSION_1)
626                     | (1u64 << VIRTIO_BLK_F_FLUSH)
627                     | (1u64 << VIRTIO_BLK_F_CONFIG_WCE)
628                     | (1u64 << VIRTIO_BLK_F_BLK_SIZE)
629                     | (1u64 << VIRTIO_BLK_F_TOPOLOGY)
630                     | (1u64 << VIRTIO_BLK_F_SEG_MAX)
631                     | (1u64 << VIRTIO_RING_F_EVENT_IDX)
632                     | (1u64 << VIRTIO_RING_F_INDIRECT_DESC);
633                 if iommu {
634                     avail_features |= 1u64 << VIRTIO_F_IOMMU_PLATFORM;
635                 }
636 
637                 if read_only {
638                     avail_features |= 1u64 << VIRTIO_BLK_F_RO;
639                 }
640 
641                 let topology = disk_image.topology();
642                 info!("Disk topology: {:?}", topology);
643 
644                 let logical_block_size = if topology.logical_block_size > 512 {
645                     topology.logical_block_size
646                 } else {
647                     512
648                 };
649 
650                 // Calculate the exponent that maps physical block to logical block
651                 let mut physical_block_exp = 0;
652                 let mut size = logical_block_size;
653                 while size < topology.physical_block_size {
654                     physical_block_exp += 1;
655                     size <<= 1;
656                 }
657 
658                 let disk_nsectors = disk_size / SECTOR_SIZE;
659                 let mut config = VirtioBlockConfig {
660                     capacity: disk_nsectors,
661                     writeback: 1,
662                     blk_size: topology.logical_block_size as u32,
663                     physical_block_exp,
664                     min_io_size: (topology.minimum_io_size / logical_block_size) as u16,
665                     opt_io_size: (topology.optimal_io_size / logical_block_size) as u32,
666                     seg_max: (queue_size - MINIMUM_BLOCK_QUEUE_SIZE) as u32,
667                     ..Default::default()
668                 };
669 
670                 if num_queues > 1 {
671                     avail_features |= 1u64 << VIRTIO_BLK_F_MQ;
672                     config.num_queues = num_queues as u16;
673                 }
674 
675                 (disk_nsectors, avail_features, 0, config, false)
676             };
677 
678         let serial = serial
679             .map(Vec::from)
680             .unwrap_or_else(|| build_serial(&disk_path));
681 
682         Ok(Block {
683             common: VirtioCommon {
684                 device_type: VirtioDeviceType::Block as u32,
685                 avail_features,
686                 acked_features,
687                 paused_sync: Some(Arc::new(Barrier::new(num_queues + 1))),
688                 queue_sizes: vec![queue_size; num_queues],
689                 min_queues: 1,
690                 paused: Arc::new(AtomicBool::new(paused)),
691                 ..Default::default()
692             },
693             id,
694             disk_image,
695             disk_path,
696             disk_nsectors,
697             config,
698             writeback: Arc::new(AtomicBool::new(true)),
699             counters: BlockCounters::default(),
700             seccomp_action,
701             rate_limiter,
702             exit_evt,
703             read_only,
704             serial,
705             queue_affinity,
706         })
707     }
708 
709     fn state(&self) -> BlockState {
710         BlockState {
711             disk_path: self.disk_path.to_str().unwrap().to_owned(),
712             disk_nsectors: self.disk_nsectors,
713             avail_features: self.common.avail_features,
714             acked_features: self.common.acked_features,
715             config: self.config,
716         }
717     }
718 
719     fn update_writeback(&mut self) {
720         // Use writeback from config if VIRTIO_BLK_F_CONFIG_WCE
721         let writeback = if self.common.feature_acked(VIRTIO_BLK_F_CONFIG_WCE.into()) {
722             self.config.writeback == 1
723         } else {
724             // Else check if VIRTIO_BLK_F_FLUSH negotiated
725             self.common.feature_acked(VIRTIO_BLK_F_FLUSH.into())
726         };
727 
728         info!(
729             "Changing cache mode to {}",
730             if writeback {
731                 "writeback"
732             } else {
733                 "writethrough"
734             }
735         );
736         self.writeback.store(writeback, Ordering::Release);
737     }
738 
739     #[cfg(fuzzing)]
740     pub fn wait_for_epoll_threads(&mut self) {
741         self.common.wait_for_epoll_threads();
742     }
743 }
744 
745 impl Drop for Block {
746     fn drop(&mut self) {
747         if let Some(kill_evt) = self.common.kill_evt.take() {
748             // Ignore the result because there is nothing we can do about it.
749             let _ = kill_evt.write(1);
750         }
751         self.common.wait_for_epoll_threads();
752     }
753 }
754 
755 impl VirtioDevice for Block {
756     fn device_type(&self) -> u32 {
757         self.common.device_type
758     }
759 
760     fn queue_max_sizes(&self) -> &[u16] {
761         &self.common.queue_sizes
762     }
763 
764     fn features(&self) -> u64 {
765         self.common.avail_features
766     }
767 
768     fn ack_features(&mut self, value: u64) {
769         self.common.ack_features(value)
770     }
771 
772     fn read_config(&self, offset: u64, data: &mut [u8]) {
773         self.read_config_from_slice(self.config.as_slice(), offset, data);
774     }
775 
776     fn write_config(&mut self, offset: u64, data: &[u8]) {
777         // The "writeback" field is the only mutable field
778         let writeback_offset =
779             (&self.config.writeback as *const _ as u64) - (&self.config as *const _ as u64);
780         if offset != writeback_offset || data.len() != std::mem::size_of_val(&self.config.writeback)
781         {
782             error!(
783                 "Attempt to write to read-only field: offset {:x} length {}",
784                 offset,
785                 data.len()
786             );
787             return;
788         }
789 
790         self.config.writeback = data[0];
791         self.update_writeback();
792     }
793 
794     fn activate(
795         &mut self,
796         mem: GuestMemoryAtomic<GuestMemoryMmap>,
797         interrupt_cb: Arc<dyn VirtioInterrupt>,
798         mut queues: Vec<(usize, Queue, EventFd)>,
799     ) -> ActivateResult {
800         self.common.activate(&queues, &interrupt_cb)?;
801 
802         self.update_writeback();
803 
804         let mut epoll_threads = Vec::new();
805         let event_idx = self.common.feature_acked(VIRTIO_RING_F_EVENT_IDX.into());
806 
807         for i in 0..queues.len() {
808             let (_, mut queue, queue_evt) = queues.remove(0);
809             queue.set_event_idx(event_idx);
810 
811             let queue_size = queue.size();
812             let (kill_evt, pause_evt) = self.common.dup_eventfds();
813             let queue_idx = i as u16;
814 
815             let mut handler = BlockEpollHandler {
816                 queue_index: queue_idx,
817                 queue,
818                 mem: mem.clone(),
819                 disk_image: self
820                     .disk_image
821                     .new_async_io(queue_size as u32)
822                     .map_err(|e| {
823                         error!("failed to create new AsyncIo: {}", e);
824                         ActivateError::BadActivate
825                     })?,
826                 disk_nsectors: self.disk_nsectors,
827                 interrupt_cb: interrupt_cb.clone(),
828                 serial: self.serial.clone(),
829                 kill_evt,
830                 pause_evt,
831                 writeback: self.writeback.clone(),
832                 counters: self.counters.clone(),
833                 queue_evt,
834                 // Analysis during boot shows around ~40 maximum requests
835                 // This gives head room for systems with slower I/O without
836                 // compromising the cost of the reallocation or memory overhead
837                 inflight_requests: VecDeque::with_capacity(64),
838                 rate_limiter: self
839                     .rate_limiter
840                     .as_ref()
841                     .map(|r| r.new_handle())
842                     .transpose()
843                     .unwrap(),
844                 access_platform: self.common.access_platform.clone(),
845                 read_only: self.read_only,
846                 host_cpus: self.queue_affinity.get(&queue_idx).cloned(),
847             };
848 
849             let paused = self.common.paused.clone();
850             let paused_sync = self.common.paused_sync.clone();
851 
852             spawn_virtio_thread(
853                 &format!("{}_q{}", self.id.clone(), i),
854                 &self.seccomp_action,
855                 Thread::VirtioBlock,
856                 &mut epoll_threads,
857                 &self.exit_evt,
858                 move || handler.run(paused, paused_sync.unwrap()),
859             )?;
860         }
861 
862         self.common.epoll_threads = Some(epoll_threads);
863         event!("virtio-device", "activated", "id", &self.id);
864 
865         Ok(())
866     }
867 
868     fn reset(&mut self) -> Option<Arc<dyn VirtioInterrupt>> {
869         let result = self.common.reset();
870         event!("virtio-device", "reset", "id", &self.id);
871         result
872     }
873 
874     fn counters(&self) -> Option<HashMap<&'static str, Wrapping<u64>>> {
875         let mut counters = HashMap::new();
876 
877         counters.insert(
878             "read_bytes",
879             Wrapping(self.counters.read_bytes.load(Ordering::Acquire)),
880         );
881         counters.insert(
882             "write_bytes",
883             Wrapping(self.counters.write_bytes.load(Ordering::Acquire)),
884         );
885         counters.insert(
886             "read_ops",
887             Wrapping(self.counters.read_ops.load(Ordering::Acquire)),
888         );
889         counters.insert(
890             "write_ops",
891             Wrapping(self.counters.write_ops.load(Ordering::Acquire)),
892         );
893         counters.insert(
894             "write_latency_min",
895             Wrapping(self.counters.write_latency_min.load(Ordering::Acquire)),
896         );
897         counters.insert(
898             "write_latency_max",
899             Wrapping(self.counters.write_latency_max.load(Ordering::Acquire)),
900         );
901         counters.insert(
902             "write_latency_avg",
903             Wrapping(self.counters.write_latency_avg.load(Ordering::Acquire) / LATENCY_SCALE),
904         );
905         counters.insert(
906             "read_latency_min",
907             Wrapping(self.counters.read_latency_min.load(Ordering::Acquire)),
908         );
909         counters.insert(
910             "read_latency_max",
911             Wrapping(self.counters.read_latency_max.load(Ordering::Acquire)),
912         );
913         counters.insert(
914             "read_latency_avg",
915             Wrapping(self.counters.read_latency_avg.load(Ordering::Acquire) / LATENCY_SCALE),
916         );
917 
918         Some(counters)
919     }
920 
921     fn set_access_platform(&mut self, access_platform: Arc<dyn AccessPlatform>) {
922         self.common.set_access_platform(access_platform)
923     }
924 }
925 
926 impl Pausable for Block {
927     fn pause(&mut self) -> result::Result<(), MigratableError> {
928         self.common.pause()
929     }
930 
931     fn resume(&mut self) -> result::Result<(), MigratableError> {
932         self.common.resume()
933     }
934 }
935 
936 impl Snapshottable for Block {
937     fn id(&self) -> String {
938         self.id.clone()
939     }
940 
941     fn snapshot(&mut self) -> std::result::Result<Snapshot, MigratableError> {
942         Snapshot::new_from_state(&self.state())
943     }
944 }
945 impl Transportable for Block {}
946 impl Migratable for Block {}
947