diff --git a/src/vmm/src/devices/virtio/device.rs b/src/vmm/src/devices/virtio/device.rs index 62131e775f5..ba1ca6b279e 100644 --- a/src/vmm/src/devices/virtio/device.rs +++ b/src/vmm/src/devices/virtio/device.rs @@ -182,9 +182,9 @@ pub trait VirtioDevice: AsAny + Send { } /// Mark pages used by queues as dirty. - fn mark_queue_memory_dirty(&self, mem: &GuestMemoryMmap) -> Result<(), QueueError> { - for queue in self.queues() { - queue.mark_memory_dirty(mem)? + fn mark_queue_memory_dirty(&mut self, mem: &GuestMemoryMmap) -> Result<(), QueueError> { + for queue in self.queues_mut() { + queue.initialize(mem)? } Ok(()) } diff --git a/src/vmm/src/devices/virtio/mmio.rs b/src/vmm/src/devices/virtio/mmio.rs index 30cf18c5efb..4114838bdd3 100644 --- a/src/vmm/src/devices/virtio/mmio.rs +++ b/src/vmm/src/devices/virtio/mmio.rs @@ -95,13 +95,6 @@ impl MmioTransport { self.device_status & (set | clr) == set } - fn are_queues_valid(&self) -> bool { - self.locked_device() - .queues() - .iter() - .all(|q| q.is_valid(&self.mem)) - } - fn with_queue(&self, d: U, f: F) -> U where F: FnOnce(&Queue) -> U, @@ -185,7 +178,7 @@ impl MmioTransport { DRIVER_OK if self.device_status == (ACKNOWLEDGE | DRIVER | FEATURES_OK) => { self.device_status = status; let device_activated = self.locked_device().is_activated(); - if !device_activated && self.are_queues_valid() { + if !device_activated { // temporary variable needed for borrow checker let activate_result = self.locked_device().activate(self.mem.clone()); if let Err(err) = activate_result { @@ -486,8 +479,6 @@ pub(crate) mod tests { assert_eq!(d.locked_device().queue_events().len(), 2); - assert!(!d.are_queues_valid()); - d.queue_select = 0; assert_eq!(d.with_queue(0, |q| q.max_size), 16); assert!(d.with_queue_mut(|q| q.size = 16)); @@ -501,8 +492,6 @@ pub(crate) mod tests { d.queue_select = 2; assert_eq!(d.with_queue(0, |q| q.max_size), 0); assert!(!d.with_queue_mut(|q| q.size = 16)); - - assert!(!d.are_queues_valid()); } #[test] @@ -761,7 +750,6 @@ pub(crate) mod tests { let m = single_region_mem(0x1000); let mut d = MmioTransport::new(m, Arc::new(Mutex::new(DummyDevice::new())), false); - assert!(!d.are_queues_valid()); assert!(!d.locked_device().is_activated()); assert_eq!(d.device_status, device_status::INIT); @@ -800,7 +788,6 @@ pub(crate) mod tests { write_le_u32(&mut buf[..], 1); d.bus_write(0x44, &buf[..]); } - assert!(d.are_queues_valid()); assert!(!d.locked_device().is_activated()); // Device should be ready for activation now. @@ -860,7 +847,6 @@ pub(crate) mod tests { write_le_u32(&mut buf[..], 1); d.bus_write(0x44, &buf[..]); } - assert!(d.are_queues_valid()); assert_eq!( d.locked_device().interrupt_status().load(Ordering::SeqCst), 0 @@ -910,7 +896,6 @@ pub(crate) mod tests { write_le_u32(&mut buf[..], 1); d.bus_write(0x44, &buf[..]); } - assert!(d.are_queues_valid()); assert!(!d.locked_device().is_activated()); // Device should be ready for activation now. @@ -937,7 +922,6 @@ pub(crate) mod tests { let mut d = MmioTransport::new(m, Arc::new(Mutex::new(DummyDevice::new())), false); let mut buf = [0; 4]; - assert!(!d.are_queues_valid()); assert!(!d.locked_device().is_activated()); assert_eq!(d.device_status, 0); activate_device(&mut d); diff --git a/src/vmm/src/devices/virtio/persist.rs b/src/vmm/src/devices/virtio/persist.rs index 38dd50e7c7f..dab1bb34104 100644 --- a/src/vmm/src/devices/virtio/persist.rs +++ b/src/vmm/src/devices/virtio/persist.rs @@ -184,14 +184,7 @@ impl VirtioDeviceState { for q in &queues { // Sanity check queue size and queue max size. - if q.max_size != expected_queue_max_size || q.size > expected_queue_max_size { - return Err(PersistError::InvalidInput); - } - // Snapshot can happen at any time, including during device configuration/activation - // when fields are only partially configured. - // - // Only if the device was activated, check `q.is_valid()`. - if self.activated && !q.is_valid(mem) { + if q.max_size != expected_queue_max_size { return Err(PersistError::InvalidInput); } } diff --git a/src/vmm/src/devices/virtio/queue.rs b/src/vmm/src/devices/virtio/queue.rs index 0660faf4689..20614762cd8 100644 --- a/src/vmm/src/devices/virtio/queue.rs +++ b/src/vmm/src/devices/virtio/queue.rs @@ -8,8 +8,10 @@ use std::num::Wrapping; use std::sync::atomic::{Ordering, fence}; +use crate::arch::host_page_size; use crate::logger::error; -use crate::vstate::memory::{Address, Bitmap, ByteValued, GuestAddress, GuestMemory}; +use crate::utils::u64_to_usize; +use crate::vstate::memory::{Bitmap, ByteValued, GuestAddress, GuestMemory}; pub const VIRTQ_DESC_F_NEXT: u16 = 0x1; pub const VIRTQ_DESC_F_WRITE: u16 = 0x2; @@ -32,7 +34,11 @@ pub enum QueueError { /// Failed to write value into the virtio queue used ring: {0} MemoryError(#[from] vm_memory::GuestMemoryError), /// Pointer is not aligned properly: {0:#x} not {1}-byte aligned. - PointerNotAligned(usize, u8), + PointerNotAligned(usize, usize), + /// Attempt to use virtio queue that is not marked ready + NotReady, + /// Virtio queue with invalid size: {0} + InvalidSize(u16), } /// Error type indicating the guest configured a virtio queue such that the avail_idx field would @@ -310,31 +316,42 @@ impl Queue { + std::mem::size_of::() } - fn get_slice_ptr( + fn get_aligned_slice_ptr( &self, mem: &M, addr: GuestAddress, len: usize, - ) -> Result<*mut u8, QueueError> { + alignment: usize, + ) -> Result<*mut T, QueueError> { + assert_eq!(host_page_size() % alignment, 0); + + // Guest memory base address is page aligned, so as long as alignment divides page size, + // It suffices to check that the GPA is properly aligned (e.g. we don't need to recheck + // the HVA). + if addr.0 & (alignment as u64 - 1) != 0 { + return Err(QueueError::PointerNotAligned( + u64_to_usize(addr.0), + alignment, + )); + } + let slice = mem.get_slice(addr, len).map_err(QueueError::MemoryError)?; slice.bitmap().mark_dirty(0, len); - Ok(slice.ptr_guard_mut().as_ptr()) + Ok(slice.ptr_guard_mut().as_ptr().cast()) } /// Set up pointers to the queue objects in the guest memory /// and mark memory dirty for those objects pub fn initialize(&mut self, mem: &M) -> Result<(), QueueError> { - self.desc_table_ptr = self - .get_slice_ptr(mem, self.desc_table_address, self.desc_table_size())? - .cast(); - self.avail_ring_ptr = self - .get_slice_ptr(mem, self.avail_ring_address, self.avail_ring_size())? - .cast(); - self.used_ring_ptr = self - .get_slice_ptr(mem, self.used_ring_address, self.used_ring_size())? - .cast(); - - // All the above pointers are expected to be aligned properly; otherwise some methods (e.g. + if !self.ready { + return Err(QueueError::NotReady); + } + + if self.size > self.max_size || self.size == 0 || (self.size & (self.size - 1)) != 0 { + return Err(QueueError::InvalidSize(self.size)); + } + + // All the below pointers are verified to be aligned properly; otherwise some methods (e.g. // `read_volatile()`) will panic. Such an unalignment is possible when restored from a // broken/fuzzed snapshot. // @@ -347,36 +364,16 @@ impl Queue { // > Available Ring 2 // > Used Ring 4 // > ================ ========== - if !self.desc_table_ptr.cast::().is_aligned() { - return Err(QueueError::PointerNotAligned( - self.desc_table_ptr as usize, - 16, - )); - } - if !self.avail_ring_ptr.is_aligned() { - return Err(QueueError::PointerNotAligned( - self.avail_ring_ptr as usize, - 2, - )); - } - if !self.used_ring_ptr.cast::().is_aligned() { - return Err(QueueError::PointerNotAligned( - self.used_ring_ptr as usize, - 4, - )); - } + self.desc_table_ptr = + self.get_aligned_slice_ptr(mem, self.desc_table_address, self.desc_table_size(), 16)?; + self.avail_ring_ptr = + self.get_aligned_slice_ptr(mem, self.avail_ring_address, self.avail_ring_size(), 2)?; + self.used_ring_ptr = + self.get_aligned_slice_ptr(mem, self.used_ring_address, self.used_ring_size(), 4)?; Ok(()) } - /// Mark memory used for queue objects as dirty. - pub fn mark_memory_dirty(&self, mem: &M) -> Result<(), QueueError> { - _ = self.get_slice_ptr(mem, self.desc_table_address, self.desc_table_size())?; - _ = self.get_slice_ptr(mem, self.avail_ring_address, self.avail_ring_size())?; - _ = self.get_slice_ptr(mem, self.used_ring_address, self.used_ring_size())?; - Ok(()) - } - /// Get AvailRing.idx #[inline(always)] pub fn avail_ring_idx_get(&self) -> u16 { @@ -461,58 +458,6 @@ impl Queue { } } - /// Validates the queue's in-memory layout is correct. - pub fn is_valid(&self, mem: &M) -> bool { - let desc_table = self.desc_table_address; - let desc_table_size = self.desc_table_size(); - let avail_ring = self.avail_ring_address; - let avail_ring_size = self.avail_ring_size(); - let used_ring = self.used_ring_address; - let used_ring_size = self.used_ring_size(); - - if !self.ready { - error!("attempt to use virtio queue that is not marked ready"); - false - } else if self.size > self.max_size || self.size == 0 || (self.size & (self.size - 1)) != 0 - { - error!("virtio queue with invalid size: {}", self.size); - false - } else if desc_table.raw_value() & 0xf != 0 { - error!("virtio queue descriptor table breaks alignment constraints"); - false - } else if avail_ring.raw_value() & 0x1 != 0 { - error!("virtio queue available ring breaks alignment constraints"); - false - } else if used_ring.raw_value() & 0x3 != 0 { - error!("virtio queue used ring breaks alignment constraints"); - false - // range check entire descriptor table to be assigned valid guest physical addresses - } else if mem.get_slice(desc_table, desc_table_size).is_err() { - error!( - "virtio queue descriptor table goes out of bounds: start:0x{:08x} size:0x{:08x}", - desc_table.raw_value(), - desc_table_size - ); - false - } else if mem.get_slice(avail_ring, avail_ring_size).is_err() { - error!( - "virtio queue available ring goes out of bounds: start:0x{:08x} size:0x{:08x}", - avail_ring.raw_value(), - avail_ring_size - ); - false - } else if mem.get_slice(used_ring, used_ring_size).is_err() { - error!( - "virtio queue used ring goes out of bounds: start:0x{:08x} size:0x{:08x}", - used_ring.raw_value(), - used_ring_size - ); - false - } else { - true - } - } - /// Returns the number of yet-to-be-popped descriptor chains in the avail ring. pub fn len(&self) -> u16 { (Wrapping(self.avail_ring_idx_get()) - self.next_avail).0 @@ -916,8 +861,6 @@ mod verification { let mut queue = less_arbitrary_queue(); queue.initialize(&mem).unwrap(); - assert!(queue.is_valid(&mem)); - ProofContext(queue, mem) } } @@ -927,8 +870,7 @@ mod verification { let mem = setup_kani_guest_memory(); let mut queue: Queue = kani::any(); - kani::assume(queue.is_valid(&mem)); - queue.initialize(&mem).unwrap(); + kani::assume(queue.initialize(&mem).is_ok()); ProofContext(queue, mem) } @@ -1095,10 +1037,10 @@ mod verification { #[kani::proof] #[kani::unwind(0)] #[kani::solver(cadical)] - fn verify_is_valid() { - let ProofContext(queue, mem) = kani::any(); + fn verify_initialize() { + let ProofContext(mut queue, mem) = kani::any(); - if queue.is_valid(&mem) { + if queue.initialize(&mem).is_ok() { // Section 2.6: Alignment of descriptor table, available ring and used ring; size of // queue fn alignment_of(val: u64) -> u64 { @@ -1115,15 +1057,6 @@ mod verification { } } - #[kani::proof] - #[kani::unwind(0)] - fn verify_size() { - let ProofContext(queue, _) = kani::any(); - - assert!(queue.size <= queue.max_size); - assert!(queue.size <= queue.size); - } - #[kani::proof] #[kani::unwind(0)] fn verify_avail_ring_idx_get() { @@ -1182,7 +1115,7 @@ mod verification { // This is an assertion in pop which we use to abort firecracker in a ddos scenario // This condition being false means that the guest is asking us to process every element - // in the queue multiple times. It cannot be checked by is_valid, as that function + // in the queue multiple times. It cannot be checked by initialize, as that function // is called when the queue is being initialized, e.g. empty. We compute it using // local variables here to make things easier on kani: One less roundtrip through vm-memory. let queue_len = queue.len(); @@ -1267,7 +1200,7 @@ mod verification { #[cfg(test)] mod tests { - use vm_memory::Bytes; + use vm_memory::{Address, Bytes}; pub use super::*; use crate::devices::virtio::queue::QueueError::DescIndexOutOfBounds; @@ -1327,26 +1260,35 @@ mod tests { let mut q = vq.create_queue(); // q is currently valid - assert!(q.is_valid(m)); + q.initialize(m).unwrap(); // shouldn't be valid when not marked as ready q.ready = false; - assert!(!q.is_valid(m)); + assert!(matches!(q.initialize(m).unwrap_err(), QueueError::NotReady)); q.ready = true; // or when size > max_size q.size = q.max_size << 1; - assert!(!q.is_valid(m)); + assert!(matches!( + q.initialize(m).unwrap_err(), + QueueError::InvalidSize(_) + )); q.size = q.max_size; // or when size is 0 q.size = 0; - assert!(!q.is_valid(m)); + assert!(matches!( + q.initialize(m).unwrap_err(), + QueueError::InvalidSize(_) + )); q.size = q.max_size; // or when size is not a power of 2 q.size = 11; - assert!(!q.is_valid(m)); + assert!(matches!( + q.initialize(m).unwrap_err(), + QueueError::InvalidSize(_) + )); q.size = q.max_size; // reset dirtied values @@ -1357,22 +1299,40 @@ mod tests { // or if the various addresses are off - q.desc_table_address = GuestAddress(0xffff_ffff); - assert!(!q.is_valid(m)); + q.desc_table_address = GuestAddress(0xffff_ff00); + assert!(matches!( + q.initialize(m).unwrap_err(), + QueueError::MemoryError(_) + )); q.desc_table_address = GuestAddress(0x1001); - assert!(!q.is_valid(m)); + assert!(matches!( + q.initialize(m).unwrap_err(), + QueueError::PointerNotAligned(_, _) + )); q.desc_table_address = vq.dtable_start(); - q.avail_ring_address = GuestAddress(0xffff_ffff); - assert!(!q.is_valid(m)); + q.avail_ring_address = GuestAddress(0xffff_ff00); + assert!(matches!( + q.initialize(m).unwrap_err(), + QueueError::MemoryError(_) + )); q.avail_ring_address = GuestAddress(0x1001); - assert!(!q.is_valid(m)); + assert!(matches!( + q.initialize(m).unwrap_err(), + QueueError::PointerNotAligned(_, _) + )); q.avail_ring_address = vq.avail_start(); - q.used_ring_address = GuestAddress(0xffff_ffff); - assert!(!q.is_valid(m)); + q.used_ring_address = GuestAddress(0xffff_ff00); + assert!(matches!( + q.initialize(m).unwrap_err(), + QueueError::MemoryError(_) + )); q.used_ring_address = GuestAddress(0x1001); - assert!(!q.is_valid(m)); + assert!(matches!( + q.initialize(m).unwrap_err(), + QueueError::PointerNotAligned(_, _) + )); q.used_ring_address = vq.used_start(); } @@ -1688,23 +1648,27 @@ mod tests { #[test] fn test_initialize_with_aligned_pointer() { - let mut q = Queue::new(0); + let mut q = Queue::new(FIRECRACKER_MAX_QUEUE_SIZE); + + q.ready = true; + q.size = q.max_size; - let random_addr = 0x321; // Descriptor table must be 16-byte aligned. - q.desc_table_address = GuestAddress(random_addr / 16 * 16); + q.desc_table_address = GuestAddress(16); // Available ring must be 2-byte aligned. - q.avail_ring_address = GuestAddress(random_addr / 2 * 2); + q.avail_ring_address = GuestAddress(2); // Used ring must be 4-byte aligned. - q.avail_ring_address = GuestAddress(random_addr / 4 * 4); + q.avail_ring_address = GuestAddress(4); - let mem = single_region_mem(0x1000); + let mem = single_region_mem(0x10000); q.initialize(&mem).unwrap(); } #[test] fn test_initialize_with_misaligned_pointer() { - let mut q = Queue::new(0); + let mut q = Queue::new(FIRECRACKER_MAX_QUEUE_SIZE); + q.ready = true; + q.size = q.max_size; let mem = single_region_mem(0x1000); // Descriptor table must be 16-byte aligned. diff --git a/src/vmm/src/persist.rs b/src/vmm/src/persist.rs index 4111d8d6c34..4699b80b185 100644 --- a/src/vmm/src/persist.rs +++ b/src/vmm/src/persist.rs @@ -166,14 +166,14 @@ pub fn create_snapshot( .snapshot_memory_to_file(¶ms.mem_file_path, params.snapshot_type)?; // We need to mark queues as dirty again for all activated devices. The reason we - // do it here is because we don't mark pages as dirty during runtime + // do it here is that we don't mark pages as dirty during runtime // for queue objects. // SAFETY: // This should never fail as we only mark pages only if device has already been activated, // and the address validation was already performed on device activation. vmm.mmio_device_manager .for_each_virtio_device(|_, _, _, dev| { - let d = dev.lock().unwrap(); + let mut d = dev.lock().unwrap(); if d.is_activated() { d.mark_queue_memory_dirty(vmm.vm.guest_memory()) } else {