Skip to content

Commit

Permalink
Allow more output storage types for some dot algorithms.
Browse files Browse the repository at this point in the history
As reported in #24794, there were some dot products that were resulting in an unnecessary conversion. This change makes the output storage type selection more flexible.

Fixes #24794

PiperOrigin-RevId: 694596944
  • Loading branch information
dfm authored and Google-ML-Automation committed Nov 9, 2024
1 parent 85dae9e commit b6e9e18
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 17 deletions.
69 changes: 52 additions & 17 deletions jax/_src/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -906,16 +906,25 @@ def rhs_precision_type(self) -> DTypeLike | tuple[DTypeLike, ...] | None:
return self.lhs_precision_type

@property
def accumulation_type(self) -> DTypeLike | None:
def accumulation_type(self) -> DTypeLike | tuple[DTypeLike, ...] | None:
match self:
case (
DotAlgorithmPreset.DEFAULT |
DotAlgorithmPreset.ANY_F8_ANY_F8_ANY |
DotAlgorithmPreset.ANY_F8_ANY_F8_ANY_FAST_ACCUM
):
return None
case (
DotAlgorithmPreset.ANY_F8_ANY_F8_F32 |
DotAlgorithmPreset.ANY_F8_ANY_F8_F32_FAST_ACCUM
):
return (np.float32, np.float16, dtypes.bfloat16, dtypes.float8_e4m3fn,
dtypes.float8_e5m2, dtypes.float8_e5m2fnuz,
dtypes.float8_e4m3fnuz, dtypes.float8_e4m3b11fnuz)
case DotAlgorithmPreset.F16_F16_F16:
return np.float16
case DotAlgorithmPreset.F16_F16_F32:
return (np.float32, np.float16)
case DotAlgorithmPreset.BF16_BF16_BF16:
return dtypes.bfloat16
case DotAlgorithmPreset.F64_F64_F64:
Expand Down Expand Up @@ -3619,6 +3628,37 @@ def dot_algorithm_attr(precision: CanonicalPrecision, lhs_dtype: DTypeLike,
return precision._convert_to_hlo_attr(lhs_dtype, rhs_dtype)


def get_algorithm_compute_types(
algorithm: DotAlgorithm | DotAlgorithmPreset,
lhs_dtype: DTypeLike,
rhs_dtype: DTypeLike,
out_dtype: DTypeLike | None = None,
) -> tuple[DTypeLike | None, DTypeLike | None, DTypeLike | None]:
def maybe_convert_dtype(input_dtype, target_dtype):
if target_dtype is None:
return input_dtype
if not isinstance(target_dtype, tuple):
target_dtype = (target_dtype,)
if any(input_dtype == d for d in target_dtype):
return input_dtype
return target_dtype[0]
if algorithm == DotAlgorithmPreset.BF16_BF16_F32:
lhs_dtype = maybe_convert_dtype(lhs_dtype, algorithm.lhs_precision_type)
rhs_dtype = maybe_convert_dtype(rhs_dtype, algorithm.rhs_precision_type)
if lhs_dtype == dtypes.bfloat16:
out_dtype = maybe_convert_dtype(out_dtype,
(np.float32, dtypes.bfloat16))
else:
out_dtype = maybe_convert_dtype(out_dtype, np.float32)
return lhs_dtype, rhs_dtype, out_dtype
else:
return (
maybe_convert_dtype(lhs_dtype, algorithm.lhs_precision_type),
maybe_convert_dtype(rhs_dtype, algorithm.rhs_precision_type),
maybe_convert_dtype(out_dtype, algorithm.accumulation_type),
)


def _dot_general_lower(ctx, lhs, rhs, *, dimension_numbers,
precision, preferred_element_type: np.dtype | None,
out_type, platform: str = "default"):
Expand Down Expand Up @@ -3656,20 +3696,17 @@ def _is_fp8_mixed_precision_matmul(_lhs_dtypes, _rhs_dtypes):
# If an explicit algorithm was specified, we always cast the input types to
# the correct types.
def maybe_convert_dtype(operand, operand_aval, target_dtype):
if target_dtype is None:
return operand, operand_aval.dtype
if not isinstance(target_dtype, tuple):
target_dtype = (target_dtype,)
if any(operand_aval.dtype == d for d in target_dtype):
return operand, operand_aval.dtype
aval = core.ShapedArray(operand_aval.shape, target_dtype[0])
return mlir.convert_hlo(ctx, operand, operand_aval, aval), target_dtype[0]

lhs, lhs_dtype = maybe_convert_dtype(lhs, lhs_aval, precision.lhs_precision_type)
rhs, rhs_dtype = maybe_convert_dtype(rhs, rhs_aval, precision.rhs_precision_type)
accumulation_type = precision.accumulation_type
if accumulation_type is not None:
accumulation_aval = core.ShapedArray(aval_out.shape, accumulation_type)
if target_dtype is None or operand_aval.dtype == target_dtype:
return operand
aval = core.ShapedArray(operand_aval.shape, target_dtype)
return mlir.convert_hlo(ctx, operand, operand_aval, aval)

lhs_dtype, rhs_dtype, accumulation_dtype = get_algorithm_compute_types(
precision, lhs_dtype, rhs_dtype, aval_out.dtype)
lhs = maybe_convert_dtype(lhs, lhs_aval, lhs_dtype)
rhs = maybe_convert_dtype(rhs, rhs_aval, rhs_dtype)
if accumulation_dtype is not None:
accumulation_aval = core.ShapedArray(aval_out.shape, accumulation_dtype)

if precision != DotAlgorithmPreset.DEFAULT:
algorithm_kwarg = {
Expand All @@ -3690,15 +3727,13 @@ def maybe_convert_dtype(operand, operand_aval, target_dtype):
core.ShapedArray(lhs_aval.shape, aval_out.dtype))
rhs = mlir.convert_hlo(ctx, rhs, rhs_aval,
core.ShapedArray(rhs_aval.shape, aval_out.dtype))
lhs_dtype = rhs_dtype = aval_out.dtype
else: # cpu and gpu
# Do not convert mixed fp8 types to output type.
if not _is_fp8_mixed_precision_matmul(lhs_dtype, rhs_dtype):
lhs = mlir.convert_hlo(ctx, lhs, lhs_aval,
core.ShapedArray(lhs_aval.shape, aval_out.dtype))
rhs = mlir.convert_hlo(ctx, rhs, rhs_aval,
core.ShapedArray(rhs_aval.shape, aval_out.dtype))
lhs_dtype = rhs_dtype = aval_out.dtype

result = hlo.dot_general(
mlir.aval_to_ir_type(accumulation_aval),
Expand Down
13 changes: 13 additions & 0 deletions tests/lax_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1146,6 +1146,19 @@ def fun(lhs, rhs):
lhs, rhs = rng(lhs_shape, np.float16), rng(rhs_shape, np.float16)
self.assertEqual(fun(lhs, rhs).dtype, np.float16)

def testDotAlgorithmAllowedOutputStorage(self):
# see https://github.com/jax-ml/jax/issues/24794
if not jtu.test_device_matches(["gpu"]):
self.skipTest("Only supported on GPU.")
def fun(lhs, rhs):
return lax.dot(lhs, rhs, precision="F16_F16_F32",
preferred_element_type=np.float16)
lhs_shape = (3, 4)
rhs_shape = (4, 3)
rng = jtu.rand_default(self.rng())
lhs, rhs = rng(lhs_shape, np.float16), rng(rhs_shape, np.float16)
self.assertNotIn("convert", jax.jit(fun).lower(lhs, rhs).as_text())

def testDotAlgorithmConfig(self):
lhs_shape = (3, 4)
rhs_shape = (4, 3)
Expand Down

0 comments on commit b6e9e18

Please sign in to comment.