Skip to content

Commit 11b288c

Browse files
authored
[AMD] Use LLVM op for s8->bf16 conversion (triton-lang#6445)
- Use LLVM::SIToFpOp to replace s8->bf16 conversion inline assembly. - Enable bf16 conversions in test_core.py.
1 parent b93befc commit 11b288c

File tree

2 files changed

+8
-16
lines changed

2 files changed

+8
-16
lines changed

python/test/unit/language/test_core.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1905,8 +1905,12 @@ def test_cast(dtype_x, dtype_z, bitcast, size, num_ctas, device):
19051905
check_type_supported(dtype_x, device)
19061906
check_type_supported(dtype_z, device)
19071907

1908-
if is_hip() and (dtype_z in ("bfloat16", "float8_e4m3fn") or dtype_x == "float8_e4m3fn"):
1909-
pytest.skip(f'test_cast{(dtype_x, dtype_z)} cast to bfloat16 not supported on HIP.')
1908+
if is_hip():
1909+
if not is_hip_cdna3() and not is_hip_cdna4() and (dtype_x == 'float8_e4m3fn' or dtype_z == 'float8_e4m3fn'):
1910+
pytest.skip(f'test_cast{(dtype_x, dtype_z)} only supported on HIP CDNA3/CDNA4.')
1911+
if (not is_hip_cdna4()) and ((dtype_x == 'bfloat16' and dtype_z == "float8_e4m3fn") or
1912+
(dtype_x == "float8_e4m3fn" and dtype_z == 'bfloat16')):
1913+
pytest.skip(f'test_cast{(dtype_x, dtype_z)} only supported on HIP CDNA4.')
19101914

19111915
torch.manual_seed(0)
19121916
# This is tricky because numpy doesn't have bfloat, and torch doesn't have uints.

third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1494,24 +1494,12 @@ static SmallVector<Value> S8_to_Bf16(Location loc,
14941494
SmallVector<Value> inValues = {v[0], v[1], v[2], v[3]};
14951495
SmallVector<Value> outValues = {};
14961496
for (Value inVal : inValues) {
1497-
Value i32Val = b.sext(i32_ty, inVal);
1498-
1499-
GCNBuilder builder;
1500-
auto &cvt = *builder.create("v_cvt_f32_i32");
1501-
auto res = builder.newOperand("=v");
1502-
auto operand = builder.newOperand(i32Val, "v");
1503-
cvt(res, operand);
1504-
auto f32Val = builder.launch(rewriter, loc, f32_ty, false);
1505-
1506-
f32Val = b.bitcast(f32Val, i32_ty);
1507-
auto shifted = b.lshr(i32_ty, f32Val, b.i32_val(16));
1508-
auto truncated = b.trunc(i16_ty, shifted);
1509-
outValues.push_back(b.bitcast(truncated, bf16_ty));
1497+
Value bf16Val = rewriter.create<LLVM::SIToFPOp>(loc, bf16_ty, inVal);
1498+
outValues.push_back(bf16Val);
15101499
}
15111500
return outValues;
15121501
}
15131502

1514-
// Uses inline ptx to convert s8/u8 to bf16, since the
15151503
struct SIToFPOpConversion
15161504
: ElementwiseOpConversionBase<arith::SIToFPOp, SIToFPOpConversion> {
15171505
using ElementwiseOpConversionBase::ElementwiseOpConversionBase;

0 commit comments

Comments
 (0)