Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow more output storage types for some dot algorithms. #24800

Merged
merged 1 commit into from
Nov 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 52 additions & 17 deletions jax/_src/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -906,16 +906,25 @@ def rhs_precision_type(self) -> DTypeLike | tuple[DTypeLike, ...] | None:
return self.lhs_precision_type

@property
def accumulation_type(self) -> DTypeLike | None:
def accumulation_type(self) -> DTypeLike | tuple[DTypeLike, ...] | None:
match self:
case (
DotAlgorithmPreset.DEFAULT |
DotAlgorithmPreset.ANY_F8_ANY_F8_ANY |
DotAlgorithmPreset.ANY_F8_ANY_F8_ANY_FAST_ACCUM
):
return None
case (
DotAlgorithmPreset.ANY_F8_ANY_F8_F32 |
DotAlgorithmPreset.ANY_F8_ANY_F8_F32_FAST_ACCUM
):
return (np.float32, np.float16, dtypes.bfloat16, dtypes.float8_e4m3fn,
dtypes.float8_e5m2, dtypes.float8_e5m2fnuz,
dtypes.float8_e4m3fnuz, dtypes.float8_e4m3b11fnuz)
case DotAlgorithmPreset.F16_F16_F16:
return np.float16
case DotAlgorithmPreset.F16_F16_F32:
return (np.float32, np.float16)
case DotAlgorithmPreset.BF16_BF16_BF16:
return dtypes.bfloat16
case DotAlgorithmPreset.F64_F64_F64:
Expand Down Expand Up @@ -3619,6 +3628,37 @@ def dot_algorithm_attr(precision: CanonicalPrecision, lhs_dtype: DTypeLike,
return precision._convert_to_hlo_attr(lhs_dtype, rhs_dtype)


def get_algorithm_compute_types(
algorithm: DotAlgorithm | DotAlgorithmPreset,
lhs_dtype: DTypeLike,
rhs_dtype: DTypeLike,
out_dtype: DTypeLike | None = None,
) -> tuple[DTypeLike | None, DTypeLike | None, DTypeLike | None]:
def maybe_convert_dtype(input_dtype, target_dtype):
if target_dtype is None:
return input_dtype
if not isinstance(target_dtype, tuple):
target_dtype = (target_dtype,)
if any(input_dtype == d for d in target_dtype):
return input_dtype
return target_dtype[0]
if algorithm == DotAlgorithmPreset.BF16_BF16_F32:
lhs_dtype = maybe_convert_dtype(lhs_dtype, algorithm.lhs_precision_type)
rhs_dtype = maybe_convert_dtype(rhs_dtype, algorithm.rhs_precision_type)
if lhs_dtype == dtypes.bfloat16:
out_dtype = maybe_convert_dtype(out_dtype,
(np.float32, dtypes.bfloat16))
else:
out_dtype = maybe_convert_dtype(out_dtype, np.float32)
return lhs_dtype, rhs_dtype, out_dtype
else:
return (
maybe_convert_dtype(lhs_dtype, algorithm.lhs_precision_type),
maybe_convert_dtype(rhs_dtype, algorithm.rhs_precision_type),
maybe_convert_dtype(out_dtype, algorithm.accumulation_type),
)


def _dot_general_lower(ctx, lhs, rhs, *, dimension_numbers,
precision, preferred_element_type: np.dtype | None,
out_type, platform: str = "default"):
Expand Down Expand Up @@ -3656,20 +3696,17 @@ def _is_fp8_mixed_precision_matmul(_lhs_dtypes, _rhs_dtypes):
# If an explicit algorithm was specified, we always cast the input types to
# the correct types.
def maybe_convert_dtype(operand, operand_aval, target_dtype):
if target_dtype is None:
return operand, operand_aval.dtype
if not isinstance(target_dtype, tuple):
target_dtype = (target_dtype,)
if any(operand_aval.dtype == d for d in target_dtype):
return operand, operand_aval.dtype
aval = core.ShapedArray(operand_aval.shape, target_dtype[0])
return mlir.convert_hlo(ctx, operand, operand_aval, aval), target_dtype[0]

lhs, lhs_dtype = maybe_convert_dtype(lhs, lhs_aval, precision.lhs_precision_type)
rhs, rhs_dtype = maybe_convert_dtype(rhs, rhs_aval, precision.rhs_precision_type)
accumulation_type = precision.accumulation_type
if accumulation_type is not None:
accumulation_aval = core.ShapedArray(aval_out.shape, accumulation_type)
if target_dtype is None or operand_aval.dtype == target_dtype:
return operand
aval = core.ShapedArray(operand_aval.shape, target_dtype)
return mlir.convert_hlo(ctx, operand, operand_aval, aval)

lhs_dtype, rhs_dtype, accumulation_dtype = get_algorithm_compute_types(
precision, lhs_dtype, rhs_dtype, aval_out.dtype)
lhs = maybe_convert_dtype(lhs, lhs_aval, lhs_dtype)
rhs = maybe_convert_dtype(rhs, rhs_aval, rhs_dtype)
if accumulation_dtype is not None:
accumulation_aval = core.ShapedArray(aval_out.shape, accumulation_dtype)

if precision != DotAlgorithmPreset.DEFAULT:
algorithm_kwarg = {
Expand All @@ -3690,15 +3727,13 @@ def maybe_convert_dtype(operand, operand_aval, target_dtype):
core.ShapedArray(lhs_aval.shape, aval_out.dtype))
rhs = mlir.convert_hlo(ctx, rhs, rhs_aval,
core.ShapedArray(rhs_aval.shape, aval_out.dtype))
lhs_dtype = rhs_dtype = aval_out.dtype
else: # cpu and gpu
# Do not convert mixed fp8 types to output type.
if not _is_fp8_mixed_precision_matmul(lhs_dtype, rhs_dtype):
lhs = mlir.convert_hlo(ctx, lhs, lhs_aval,
core.ShapedArray(lhs_aval.shape, aval_out.dtype))
rhs = mlir.convert_hlo(ctx, rhs, rhs_aval,
core.ShapedArray(rhs_aval.shape, aval_out.dtype))
lhs_dtype = rhs_dtype = aval_out.dtype

result = hlo.dot_general(
mlir.aval_to_ir_type(accumulation_aval),
Expand Down
13 changes: 13 additions & 0 deletions tests/lax_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1146,6 +1146,19 @@ def fun(lhs, rhs):
lhs, rhs = rng(lhs_shape, np.float16), rng(rhs_shape, np.float16)
self.assertEqual(fun(lhs, rhs).dtype, np.float16)

def testDotAlgorithmAllowedOutputStorage(self):
# see https://github.com/jax-ml/jax/issues/24794
if not jtu.test_device_matches(["gpu"]):
self.skipTest("Only supported on GPU.")
def fun(lhs, rhs):
return lax.dot(lhs, rhs, precision="F16_F16_F32",
preferred_element_type=np.float16)
lhs_shape = (3, 4)
rhs_shape = (4, 3)
rng = jtu.rand_default(self.rng())
lhs, rhs = rng(lhs_shape, np.float16), rng(rhs_shape, np.float16)
self.assertNotIn("convert", jax.jit(fun).lower(lhs, rhs).as_text())

def testDotAlgorithmConfig(self):
lhs_shape = (3, 4)
rhs_shape = (4, 3)
Expand Down
Loading