Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 25 additions & 15 deletions uefi/src/proto/ata/pass_thru.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use crate::StatusExt;
use crate::mem::{AlignedBuffer, PoolAllocation};
use crate::proto::device_path::PoolDevicePathNode;
use core::alloc::LayoutError;
use core::cell::UnsafeCell;
use core::ptr::{self, NonNull};
use uefi_macros::unsafe_protocol;
use uefi_raw::Status;
Expand All @@ -33,7 +34,7 @@ pub type AtaPassThruMode = uefi_raw::protocol::ata::AtaPassThruMode;
#[derive(Debug)]
#[repr(transparent)]
#[unsafe_protocol(AtaPassThruProtocol::GUID)]
pub struct AtaPassThru(AtaPassThruProtocol);
pub struct AtaPassThru(UnsafeCell<AtaPassThruProtocol>);

impl AtaPassThru {
/// Retrieves the mode structure for the Extended SCSI Pass Thru protocol.
Expand All @@ -42,7 +43,7 @@ impl AtaPassThru {
/// The [`AtaPassThruMode`] structure containing configuration details of the protocol.
#[must_use]
pub fn mode(&self) -> AtaPassThruMode {
let mut mode = unsafe { (*self.0.mode).clone() };
let mut mode = unsafe { (*(*self.0.get()).mode).clone() };
mode.io_align = mode.io_align.max(1); // 0 and 1 is the same, says UEFI spec
mode
}
Expand Down Expand Up @@ -101,16 +102,12 @@ impl AtaPassThru {
/// available / connected device using [`AtaDevice::execute_command`] before doing anything meaningful.
#[derive(Debug)]
pub struct AtaDevice<'a> {
proto: &'a AtaPassThruProtocol,
proto: &'a UnsafeCell<AtaPassThruProtocol>,
port: u16,
pmp: u16,
}

impl AtaDevice<'_> {
const fn proto_mut(&mut self) -> *mut AtaPassThruProtocol {
ptr::from_ref(self.proto).cast_mut()
}

/// Returns the port number of the device.
///
/// # Details
Expand Down Expand Up @@ -142,7 +139,9 @@ impl AtaDevice<'_> {
/// - [`Status::DEVICE_ERROR`] A device error occurred while attempting to reset the specified ATA device.
/// - [`Status::TIMEOUT`] A timeout occurred while attempting to reset the specified ATA device.
pub fn reset(&mut self) -> crate::Result<()> {
unsafe { (self.proto.reset_device)(self.proto_mut(), self.port, self.pmp).to_result() }
unsafe {
((*self.proto.get()).reset_device)(self.proto.get(), self.port, self.pmp).to_result()
}
}

/// Get the final device path node for this device.
Expand All @@ -152,8 +151,13 @@ impl AtaDevice<'_> {
pub fn path_node(&self) -> crate::Result<PoolDevicePathNode> {
unsafe {
let mut path_ptr: *const DevicePathProtocol = ptr::null();
(self.proto.build_device_path)(self.proto, self.port, self.pmp, &mut path_ptr)
.to_result()?;
((*self.proto.get()).build_device_path)(
self.proto.get(),
self.port,
self.pmp,
&mut path_ptr,
)
.to_result()?;
NonNull::new(path_ptr.cast_mut())
.map(|p| PoolDevicePathNode(PoolAllocation::new(p.cast())))
.ok_or_else(|| Status::OUT_OF_RESOURCES.into())
Expand Down Expand Up @@ -184,8 +188,8 @@ impl AtaDevice<'_> {
) -> crate::Result<AtaResponse<'req>> {
req.packet.acb = &req.acb;
unsafe {
(self.proto.pass_thru)(
self.proto_mut(),
((*self.proto.get()).pass_thru)(
self.proto.get(),
self.port,
self.pmp,
&mut req.packet,
Expand All @@ -203,7 +207,7 @@ impl AtaDevice<'_> {
/// is actually available and connected!
#[derive(Debug)]
pub struct AtaDeviceIterator<'a> {
proto: &'a AtaPassThruProtocol,
proto: &'a UnsafeCell<AtaPassThruProtocol>,
// when there are no more devices on this port -> get next port
end_of_port: bool,
prev_port: u16,
Expand All @@ -216,7 +220,9 @@ impl<'a> Iterator for AtaDeviceIterator<'a> {
fn next(&mut self) -> Option<Self::Item> {
loop {
if self.end_of_port {
let result = unsafe { (self.proto.get_next_port)(self.proto, &mut self.prev_port) };
let result = unsafe {
((*self.proto.get()).get_next_port)(self.proto.get(), &mut self.prev_port)
};
match result {
Status::SUCCESS => self.end_of_port = false,
Status::NOT_FOUND => return None, // no more ports / devices. End of list
Expand All @@ -233,7 +239,11 @@ impl<'a> Iterator for AtaDeviceIterator<'a> {
// to the port! A port where the device is directly connected uses a pmp-value of 0xFFFF.
let was_first = self.prev_pmp == 0xFFFF;
let result = unsafe {
(self.proto.get_next_device)(self.proto, self.prev_port, &mut self.prev_pmp)
((*self.proto.get()).get_next_device)(
self.proto.get(),
self.prev_port,
&mut self.prev_pmp,
)
};
match result {
Status::SUCCESS => {
Expand Down
Loading