Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bring back the minimal implementations of runtime API. #45

Merged
merged 9 commits into from
Sep 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ members = [
"zluda_redirect",
"zluda_rt",
"zluda_rtc",
"zluda_runtime",
"zluda_sparse",
]

Expand Down
21 changes: 1 addition & 20 deletions hip_runtime-sys/src/hip_runtime_api_v6.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7495,17 +7495,9 @@ extern "C" {
extern "C" {
#[must_use]
pub fn __hipRegisterFatBinary(
data: *mut ::std::os::raw::c_void,
data: *const ::std::os::raw::c_void,
) -> *mut *mut ::std::os::raw::c_void;
}
/*
extern "C" {
#[must_use]
pub fn __hipRegisterFatBinaryEnd(
fatCubinHandle: *mut *mut ::std::os::raw::c_void,
) -> ::std::os::raw::c_void;
}
*/
extern "C" {
#[must_use]
pub fn __hipRegisterFunction(
Expand All @@ -7521,17 +7513,6 @@ extern "C" {
wSize: *mut ::std::os::raw::c_int,
) -> ::std::os::raw::c_void;
}
/*
extern "C" {
#[must_use]
pub fn __hipRegisterHostVar(
fatCubinHandle: *mut *mut ::std::os::raw::c_void,
deviceName: *const ::std::os::raw::c_char,
hostVar: *mut ::std::os::raw::c_char,
size: usize,
) -> ::std::os::raw::c_void;
}
*/
extern "C" {
#[must_use]
pub fn __hipRegisterManagedVar(
Expand Down
22 changes: 10 additions & 12 deletions zluda_inject/src/bin.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,16 +81,15 @@ struct ProgramArguments {
pub fn main_impl() -> Result<(), Box<dyn Error>> {
for argument in env::args_os() {
match argument.to_str() {
Some(argument) =>
match argument {
"--version" => {
println!("ZLUDA 3.8.3");
process::exit(0);
},
"--" => break,
_ => {},
},
None => {},
Some(argument) => match argument {
"--version" => {
println!("ZLUDA 3.8.3");
process::exit(0);
}
"--" => break,
_ => {}
},
None => {}
}
}

Expand Down Expand Up @@ -202,8 +201,7 @@ struct NormalizedArguments {
impl NormalizedArguments {
fn new(prog_args: ProgramArguments) -> Result<Self, Box<dyn Error>> {
let current_exe = env::current_exe()?;
let nccl_path =
Self::get_absolute_path_or_default(&current_exe, prog_args.nccl, NCCL_DLL)?;
let nccl_path = Self::get_absolute_path_or_default(&current_exe, prog_args.nccl, NCCL_DLL)?;
let nvrtc_path = prog_args.nvrtc.map(Self::get_absolute_path).transpose()?;
let nvcuda_path =
Self::get_absolute_path_or_default(&current_exe, prog_args.nvcuda, NVCUDA_DLL)?;
Expand Down
22 changes: 22 additions & 0 deletions zluda_runtime/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
[package]
name = "zluda_runtime"
version = "0.0.0"
authors = ["Seunghoon Lee <[email protected]>"]
edition = "2018"

[lib]
name = "cudart"
crate-type = ["cdylib"]

[features]
rocm5 = ["hip_common/rocm5", "hip_runtime-sys/rocm5", "zluda_dark_api/rocm5"]

[dependencies]
cuda_types = { path = "../cuda_types" }
hip_common = { path = "../hip_common" }
hip_runtime-sys = { path = "../hip_runtime-sys" }
zluda_dark_api = { path = "../zluda_dark_api" }

[package.metadata.zluda]
linux_names = ["libcudart.so.10", "libcudart.so.11"]
dump_names = ["libcudart.so"]
4 changes: 4 additions & 0 deletions zluda_runtime/README
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
bindgen include/cuda_runtime.h -o src/cudart.rs --allowlist-function="^cuda.*" --default-enum-style=newtype --no-layout-tests --no-derive-debug -- -I""
sed -i -e 's/extern "C" {//g' -e 's/-> cudaError_t;/-> cudaError_t { crate::unsupported()/g' -e 's/pub fn /#[no_mangle] pub extern "system" fn /g' src/cudart.rs
bindgen include/cuda_profiler_api.h -o src/profiler.rs --allowlist-function="^cuda.*" --default-enum-style=newtype --no-layout-tests --no-derive-debug -- -I""
sed -i -e 's/extern "C" {//g' -e 's/-> cudaError_t;/-> cudaError_t { crate::unsupported()/g' -e 's/pub fn /#[no_mangle] pub extern "system" fn /g' src/profiler.rs
5,976 changes: 5,976 additions & 0 deletions zluda_runtime/src/cudart.rs

Large diffs are not rendered by default.

208 changes: 208 additions & 0 deletions zluda_runtime/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,208 @@
#![allow(warnings)]
mod cudart;
pub use cudart::*;

use hip_runtime_sys::*;
use std::mem;

#[cfg(debug_assertions)]
fn unsupported() -> cudaError_t {
unimplemented!()
}

#[cfg(not(debug_assertions))]
fn unsupported() -> cudaError_t {
cudaError_t::cudaErrorNotSupported
}

fn to_cuda(status: hipError_t) -> cudaError_t {
match status {
hipError_t::hipSuccess => cudaError_t::cudaSuccess,
hipError_t::hipErrorInvalidValue => cudaError_t::cudaErrorInvalidValue,
hipError_t::hipErrorOutOfMemory => cudaError_t::cudaErrorMemoryAllocation,
hipError_t::hipErrorInvalidContext => cudaError_t::cudaErrorDeviceUninitialized,
hipError_t::hipErrorInvalidResourceHandle => cudaError_t::cudaErrorInvalidResourceHandle,
hipError_t::hipErrorNotSupported => cudaError_t::cudaErrorNotSupported,
err => panic!("[ZLUDA] HIP Runtime failed: {}", err.0),
}
}

unsafe fn to_stream(stream: cudaStream_t) -> hipStream_t {
let lib = hip_common::zluda_ext::get_cuda_library().unwrap();
let cu_get_export_table = lib
.get::<unsafe extern "C" fn(
ppExportTable: *mut *const ::std::os::raw::c_void,
pExportTableId: *const cuda_types::CUuuid,
) -> cuda_types::CUresult>(b"cuGetExportTable\0")
.unwrap();
let mut export_table = std::ptr::null();
let error = (cu_get_export_table)(&mut export_table, &zluda_dark_api::ZludaExt::GUID);
assert_eq!(error, cuda_types::CUresult::CUDA_SUCCESS);
let zluda_ext = zluda_dark_api::ZludaExt::new(export_table);
let maybe_hip_stream: Result<_, _> = zluda_ext.get_hip_stream(stream as _).into();
maybe_hip_stream.unwrap() as _
}

fn memcpy_kind(kind: cudaMemcpyKind) -> hipMemcpyKind {
match kind {
cudaMemcpyKind::cudaMemcpyHostToHost => hipMemcpyKind::hipMemcpyHostToHost,
cudaMemcpyKind::cudaMemcpyHostToDevice => hipMemcpyKind::hipMemcpyHostToDevice,
cudaMemcpyKind::cudaMemcpyDeviceToHost => hipMemcpyKind::hipMemcpyDeviceToHost,
cudaMemcpyKind::cudaMemcpyDeviceToDevice => hipMemcpyKind::hipMemcpyDeviceToDevice,
cudaMemcpyKind::cudaMemcpyDefault => hipMemcpyKind::hipMemcpyDefault,
_ => panic!(),
}
}

fn to_cuda_stream_capture_status(status: hipStreamCaptureStatus) -> cudaStreamCaptureStatus {
match status {
hipStreamCaptureStatus::hipStreamCaptureStatusNone => {
cudaStreamCaptureStatus::cudaStreamCaptureStatusNone
}
hipStreamCaptureStatus::hipStreamCaptureStatusActive => {
cudaStreamCaptureStatus::cudaStreamCaptureStatusActive
}
hipStreamCaptureStatus::hipStreamCaptureStatusInvalidated => {
cudaStreamCaptureStatus::cudaStreamCaptureStatusInvalidated
}
_ => panic!(),
}
}

unsafe fn register_fat_binary(
fat_cubin: *mut ::std::os::raw::c_void,
) -> *mut *mut ::std::os::raw::c_void {
__hipRegisterFatBinary(fat_cubin)
}

unsafe fn register_function(
fat_cubin_handle: *mut *mut ::std::os::raw::c_void,
host_fun: *const ::std::os::raw::c_char,
device_fun: *mut ::std::os::raw::c_char,
device_name: *const ::std::os::raw::c_char,
thread_limit: i32,
tid: *mut uint3,
bid: *mut uint3,
b_dim: *mut cudart::dim3,
g_dim: *mut cudart::dim3,
w_size: *mut i32,
) -> ::std::os::raw::c_void {
__hipRegisterFunction(
fat_cubin_handle,
host_fun.cast(),
device_fun,
device_name,
thread_limit as _,
tid.cast(),
bid.cast(),
b_dim.cast(),
g_dim.cast(),
w_size,
)
}

unsafe fn register_var(
fat_cubin_handle: *mut *mut ::std::os::raw::c_void,
host_var: *mut ::std::os::raw::c_char,
device_address: *mut ::std::os::raw::c_char,
device_name: *const ::std::os::raw::c_char,
ext: i32,
size: usize,
constant: i32,
global: i32,
) -> ::std::os::raw::c_void {
__hipRegisterVar(
fat_cubin_handle,
device_address.cast(),
host_var,
device_name.cast_mut(),
ext,
size,
constant,
global,
)
}

unsafe fn unregister_fat_binary(
fat_cubin_handle: *mut *mut ::std::os::raw::c_void,
) -> ::std::os::raw::c_void {
__hipUnregisterFatBinary(fat_cubin_handle)
}

unsafe fn device_get_stream_priority_range(
least_priority: *mut i32,
greatest_priority: *mut i32,
) -> cudaError_t {
to_cuda(hipDeviceGetStreamPriorityRange(
least_priority,
greatest_priority,
))
}

unsafe fn get_last_error() -> cudaError_t {
to_cuda(hipGetLastError())
}

unsafe fn get_device_count(count: *mut i32) -> cudaError_t {
to_cuda(hipGetDeviceCount(count))
}

unsafe fn get_device(device: *mut i32) -> cudaError_t {
to_cuda(hipGetDevice(device))
}

unsafe fn stream_create_with_priority(
p_stream: *mut cudaStream_t,
flags: u32,
priority: i32,
) -> cudaError_t {
let lib = hip_common::zluda_ext::get_cuda_library().unwrap();
let cu_stream_create_with_priority = lib
.get::<unsafe extern "C" fn(
phStream: *mut cuda_types::CUstream,
flags: ::std::os::raw::c_uint,
priority: ::std::os::raw::c_int,
) -> cuda_types::CUresult>(b"cuStreamCreateWithPriority\0")
.unwrap();
cudaError_t((cu_stream_create_with_priority)(p_stream.cast(), flags, priority).0)
}

unsafe fn stream_synchronize(stream: cudaStream_t) -> cudaError_t {
let stream = to_stream(stream);
to_cuda(hipStreamSynchronize(stream))
}

unsafe fn stream_is_capturing(
stream: cudaStream_t,
p_capture_status: *mut cudaStreamCaptureStatus,
) -> cudaError_t {
let stream = to_stream(stream);
let mut capture_status = mem::zeroed();
let status = to_cuda(hipStreamIsCapturing(stream, &mut capture_status));
*p_capture_status = to_cuda_stream_capture_status(capture_status);
status
}

unsafe fn malloc(dev_ptr: *mut *mut ::std::os::raw::c_void, size: usize) -> cudaError_t {
to_cuda(hipMalloc(dev_ptr, size))
}

unsafe fn free(dev_ptr: *mut ::std::os::raw::c_void) -> cudaError_t {
to_cuda(hipFree(dev_ptr))
}

unsafe fn mem_get_info(free: *mut usize, total: *mut usize) -> cudaError_t {
to_cuda(hipMemGetInfo(free, total))
}

unsafe fn memcpy_async(
dst: *mut ::std::os::raw::c_void,
src: *const ::std::os::raw::c_void,
count: usize,
kind: cudaMemcpyKind,
stream: cudaStream_t,
) -> cudaError_t {
let kind = memcpy_kind(kind);
let stream = to_stream(stream);
to_cuda(hipMemcpyAsync(dst, src, count, kind, stream))
}