From abdd6c4ff87f6af463d4085ad7bfc7512bdf7865 Mon Sep 17 00:00:00 2001 From: Nicholas Bishop Date: Sun, 12 Oct 2025 14:57:20 -0400 Subject: [PATCH] uefi: Wrap AtaPassThruProtocol in UnsafeCell The existing code had some potential UB; it created a mutable pointer from a const reference and passed it across the FFI boundary. (Whether the pointee is actually mutated depends on the firmware implementation.) An `UnsafeCell` allows the interior data to be mutated through a const reference. `AtaPassThru` now contains an `UnsafeCell`, which allows a mutable pointer to be created with less risk of UB. (Note that it's still not allowed to create multiple mutable _references_ to the data, but as long as only raw pointers are used, it should be OK.) The `AtaDevice` and `AtaDeviceIterator` types have been adjusted to take a reference to the `UnsafeCell`. --- uefi/src/proto/ata/pass_thru.rs | 40 ++++++++++++++++++++------------- 1 file changed, 25 insertions(+), 15 deletions(-) diff --git a/uefi/src/proto/ata/pass_thru.rs b/uefi/src/proto/ata/pass_thru.rs index 76adcef17..241a8c9c9 100644 --- a/uefi/src/proto/ata/pass_thru.rs +++ b/uefi/src/proto/ata/pass_thru.rs @@ -7,6 +7,7 @@ use crate::StatusExt; use crate::mem::{AlignedBuffer, PoolAllocation}; use crate::proto::device_path::PoolDevicePathNode; use core::alloc::LayoutError; +use core::cell::UnsafeCell; use core::ptr::{self, NonNull}; use uefi_macros::unsafe_protocol; use uefi_raw::Status; @@ -33,7 +34,7 @@ pub type AtaPassThruMode = uefi_raw::protocol::ata::AtaPassThruMode; #[derive(Debug)] #[repr(transparent)] #[unsafe_protocol(AtaPassThruProtocol::GUID)] -pub struct AtaPassThru(AtaPassThruProtocol); +pub struct AtaPassThru(UnsafeCell); impl AtaPassThru { /// Retrieves the mode structure for the Extended SCSI Pass Thru protocol. @@ -42,7 +43,7 @@ impl AtaPassThru { /// The [`AtaPassThruMode`] structure containing configuration details of the protocol. #[must_use] pub fn mode(&self) -> AtaPassThruMode { - let mut mode = unsafe { (*self.0.mode).clone() }; + let mut mode = unsafe { (*(*self.0.get()).mode).clone() }; mode.io_align = mode.io_align.max(1); // 0 and 1 is the same, says UEFI spec mode } @@ -101,16 +102,12 @@ impl AtaPassThru { /// available / connected device using [`AtaDevice::execute_command`] before doing anything meaningful. #[derive(Debug)] pub struct AtaDevice<'a> { - proto: &'a AtaPassThruProtocol, + proto: &'a UnsafeCell, port: u16, pmp: u16, } impl AtaDevice<'_> { - const fn proto_mut(&mut self) -> *mut AtaPassThruProtocol { - ptr::from_ref(self.proto).cast_mut() - } - /// Returns the port number of the device. /// /// # Details @@ -142,7 +139,9 @@ impl AtaDevice<'_> { /// - [`Status::DEVICE_ERROR`] A device error occurred while attempting to reset the specified ATA device. /// - [`Status::TIMEOUT`] A timeout occurred while attempting to reset the specified ATA device. pub fn reset(&mut self) -> crate::Result<()> { - unsafe { (self.proto.reset_device)(self.proto_mut(), self.port, self.pmp).to_result() } + unsafe { + ((*self.proto.get()).reset_device)(self.proto.get(), self.port, self.pmp).to_result() + } } /// Get the final device path node for this device. @@ -152,8 +151,13 @@ impl AtaDevice<'_> { pub fn path_node(&self) -> crate::Result { unsafe { let mut path_ptr: *const DevicePathProtocol = ptr::null(); - (self.proto.build_device_path)(self.proto, self.port, self.pmp, &mut path_ptr) - .to_result()?; + ((*self.proto.get()).build_device_path)( + self.proto.get(), + self.port, + self.pmp, + &mut path_ptr, + ) + .to_result()?; NonNull::new(path_ptr.cast_mut()) .map(|p| PoolDevicePathNode(PoolAllocation::new(p.cast()))) .ok_or_else(|| Status::OUT_OF_RESOURCES.into()) @@ -184,8 +188,8 @@ impl AtaDevice<'_> { ) -> crate::Result> { req.packet.acb = &req.acb; unsafe { - (self.proto.pass_thru)( - self.proto_mut(), + ((*self.proto.get()).pass_thru)( + self.proto.get(), self.port, self.pmp, &mut req.packet, @@ -203,7 +207,7 @@ impl AtaDevice<'_> { /// is actually available and connected! #[derive(Debug)] pub struct AtaDeviceIterator<'a> { - proto: &'a AtaPassThruProtocol, + proto: &'a UnsafeCell, // when there are no more devices on this port -> get next port end_of_port: bool, prev_port: u16, @@ -216,7 +220,9 @@ impl<'a> Iterator for AtaDeviceIterator<'a> { fn next(&mut self) -> Option { loop { if self.end_of_port { - let result = unsafe { (self.proto.get_next_port)(self.proto, &mut self.prev_port) }; + let result = unsafe { + ((*self.proto.get()).get_next_port)(self.proto.get(), &mut self.prev_port) + }; match result { Status::SUCCESS => self.end_of_port = false, Status::NOT_FOUND => return None, // no more ports / devices. End of list @@ -233,7 +239,11 @@ impl<'a> Iterator for AtaDeviceIterator<'a> { // to the port! A port where the device is directly connected uses a pmp-value of 0xFFFF. let was_first = self.prev_pmp == 0xFFFF; let result = unsafe { - (self.proto.get_next_device)(self.proto, self.prev_port, &mut self.prev_pmp) + ((*self.proto.get()).get_next_device)( + self.proto.get(), + self.prev_port, + &mut self.prev_pmp, + ) }; match result { Status::SUCCESS => {