Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 66 additions & 0 deletions lib/cublas/libcublas.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3048,6 +3048,7 @@ end
CUBLAS_PEDANTIC_MATH = 2
CUBLAS_TF32_TENSOR_OP_MATH = 3
CUBLAS_FP32_EMULATED_BF16X9_MATH = 4
CUBLAS_FP64_EMULATED_FIXEDPOINT_MATH = 8
CUBLAS_MATH_DISALLOW_REDUCED_PRECISION_REDUCTION = 16
end

Expand All @@ -3064,6 +3065,7 @@ const cublasDataType_t = cudaDataType
CUBLAS_COMPUTE_32F_EMULATED_16BFX9 = 78
CUBLAS_COMPUTE_64F = 70
CUBLAS_COMPUTE_64F_PEDANTIC = 71
CUBLAS_COMPUTE_64F_EMULATED_FIXEDPOINT = 79
CUBLAS_COMPUTE_32I = 72
CUBLAS_COMPUTE_32I_PEDANTIC = 73
end
Expand Down Expand Up @@ -3134,6 +3136,70 @@ end
emulationStrategy::cublasEmulationStrategy_t)::cublasStatus_t
end

@checked function cublasGetEmulationSpecialValuesSupport(handle, mask)
initialize_context()
@gcsafe_ccall libcublas.cublasGetEmulationSpecialValuesSupport(handle::cublasHandle_t,
mask::Ptr{cudaEmulationSpecialValuesSupport})::cublasStatus_t
end

@checked function cublasSetEmulationSpecialValuesSupport(handle, mask)
initialize_context()
@gcsafe_ccall libcublas.cublasSetEmulationSpecialValuesSupport(handle::cublasHandle_t,
mask::cudaEmulationSpecialValuesSupport)::cublasStatus_t
end

@checked function cublasGetFixedPointEmulationMantissaControl(handle, mantissaControl)
initialize_context()
@gcsafe_ccall libcublas.cublasGetFixedPointEmulationMantissaControl(handle::cublasHandle_t,
mantissaControl::Ptr{cudaEmulationMantissaControl})::cublasStatus_t
end

@checked function cublasSetFixedPointEmulationMantissaControl(handle, mantissaControl)
initialize_context()
@gcsafe_ccall libcublas.cublasSetFixedPointEmulationMantissaControl(handle::cublasHandle_t,
mantissaControl::cudaEmulationMantissaControl)::cublasStatus_t
end

@checked function cublasGetFixedPointEmulationMaxMantissaBitCount(handle,
maxMantissaBitCount)
initialize_context()
@gcsafe_ccall libcublas.cublasGetFixedPointEmulationMaxMantissaBitCount(handle::cublasHandle_t,
maxMantissaBitCount::Ptr{Cint})::cublasStatus_t
end

@checked function cublasSetFixedPointEmulationMaxMantissaBitCount(handle,
maxMantissaBitCount)
initialize_context()
@gcsafe_ccall libcublas.cublasSetFixedPointEmulationMaxMantissaBitCount(handle::cublasHandle_t,
maxMantissaBitCount::Cint)::cublasStatus_t
end

@checked function cublasGetFixedPointEmulationMantissaBitOffset(handle, mantissaBitOffset)
initialize_context()
@gcsafe_ccall libcublas.cublasGetFixedPointEmulationMantissaBitOffset(handle::cublasHandle_t,
mantissaBitOffset::Ptr{Cint})::cublasStatus_t
end

@checked function cublasSetFixedPointEmulationMantissaBitOffset(handle, mantissaBitOffset)
initialize_context()
@gcsafe_ccall libcublas.cublasSetFixedPointEmulationMantissaBitOffset(handle::cublasHandle_t,
mantissaBitOffset::Cint)::cublasStatus_t
end

@checked function cublasGetFixedPointEmulationMantissaBitCountPointer(handle,
mantissaBitCount)
initialize_context()
@gcsafe_ccall libcublas.cublasGetFixedPointEmulationMantissaBitCountPointer(handle::cublasHandle_t,
mantissaBitCount::Ptr{Ptr{Cint}})::cublasStatus_t
end

@checked function cublasSetFixedPointEmulationMantissaBitCountPointer(handle,
mantissaBitCount)
initialize_context()
@gcsafe_ccall libcublas.cublasSetFixedPointEmulationMantissaBitCountPointer(handle::cublasHandle_t,
mantissaBitCount::Ptr{Cint})::cublasStatus_t
end

function cublasGetStatusName(status)
initialize_context()
@gcsafe_ccall libcublas.cublasGetStatusName(status::cublasStatus_t)::Cstring
Expand Down
57 changes: 56 additions & 1 deletion lib/cublas/libcublasLt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,12 @@ end

const cublasLtMatmulPreference_t = Ptr{cublasLtMatmulPreferenceOpaque_t}

struct cublasLtEmulationDescOpaque_t
data::NTuple{8,UInt64}
end

const cublasLtEmulationDesc_t = Ptr{cublasLtEmulationDescOpaque_t}

@cenum cublasLtMatmulTile_t::UInt32 begin
CUBLASLT_MATMUL_TILE_UNDEFINED = 0
CUBLASLT_MATMUL_TILE_8x8 = 1
Expand Down Expand Up @@ -998,6 +1004,7 @@ end
CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_SCALE_MODE = 35
CUBLASLT_MATMUL_DESC_D_OUT_SCALE_POINTER = 36
CUBLASLT_MATMUL_DESC_D_OUT_SCALE_MODE = 37
CUBLASLT_MATMUL_DESC_EMULATION_DESCRIPTOR = 38
end

@checked function cublasLtMatmulDescInit_internal(matmulDesc, size, computeType, scaleType)
Expand Down Expand Up @@ -1095,6 +1102,54 @@ end
sizeWritten::Ptr{Csize_t})::cublasStatus_t
end

@cenum cublasLtEmulationDescAttributes_t::UInt32 begin
CUBLASLT_EMULATION_DESC_STRATEGY = 0
CUBLASLT_EMULATION_DESC_SPECIAL_VALUES_SUPPORT = 1
CUBLASLT_EMULATION_DESC_FIXEDPOINT_MANTISSA_CONTROL = 2
CUBLASLT_EMULATION_DESC_FIXEDPOINT_MAX_MANTISSA_BIT_COUNT = 3
CUBLASLT_EMULATION_DESC_FIXEDPOINT_MANTISSA_BIT_OFFSET = 4
CUBLASLT_EMULATION_DESC_FIXEDPOINT_MANTISSA_BIT_COUNT_POINTER = 5
end

@checked function cublasLtEmulationDescInit_internal(emulationDesc, size)
initialize_context()
@gcsafe_ccall libcublasLt.cublasLtEmulationDescInit_internal(emulationDesc::cublasLtEmulationDesc_t,
size::Csize_t)::cublasStatus_t
end

@checked function cublasLtEmulationDescInit(emulationDesc)
initialize_context()
@gcsafe_ccall libcublasLt.cublasLtEmulationDescInit(emulationDesc::cublasLtEmulationDesc_t)::cublasStatus_t
end

@checked function cublasLtEmulationDescCreate(emulationDesc)
initialize_context()
@gcsafe_ccall libcublasLt.cublasLtEmulationDescCreate(emulationDesc::Ptr{cublasLtEmulationDesc_t})::cublasStatus_t
end

@checked function cublasLtEmulationDescDestroy(emulationDesc)
initialize_context()
@gcsafe_ccall libcublasLt.cublasLtEmulationDescDestroy(emulationDesc::cublasLtEmulationDesc_t)::cublasStatus_t
end

@checked function cublasLtEmulationDescSetAttribute(emulationDesc, attr, buf, sizeInBytes)
initialize_context()
@gcsafe_ccall libcublasLt.cublasLtEmulationDescSetAttribute(emulationDesc::cublasLtEmulationDesc_t,
attr::cublasLtEmulationDescAttributes_t,
buf::Ptr{Cvoid},
sizeInBytes::Csize_t)::cublasStatus_t
end

@checked function cublasLtEmulationDescGetAttribute(emulationDesc, attr, buf, sizeInBytes,
sizeWritten)
initialize_context()
@gcsafe_ccall libcublasLt.cublasLtEmulationDescGetAttribute(emulationDesc::cublasLtEmulationDesc_t,
attr::cublasLtEmulationDescAttributes_t,
buf::Ptr{Cvoid},
sizeInBytes::Csize_t,
sizeWritten::Ptr{Csize_t})::cublasStatus_t
end

@cenum cublasLtReductionScheme_t::UInt32 begin
CUBLASLT_REDUCTION_SCHEME_NONE = 0
CUBLASLT_REDUCTION_SCHEME_INPLACE = 1
Expand Down Expand Up @@ -1344,7 +1399,7 @@ end
@gcsafe_ccall libcublasLt.cublasLtLoggerSetMask(mask::Cint)::cublasStatus_t
end

# no prototype is found for this function at cublasLt.h:2521:29, please use with caution
# no prototype is found for this function at cublasLt.h:2670:29, please use with caution
@checked function cublasLtLoggerForceDisable()
initialize_context()
@gcsafe_ccall libcublasLt.cublasLtLoggerForceDisable()::cublasStatus_t
Expand Down
Loading
Loading