From 6b232d903407d91d111180e6069694705f789565 Mon Sep 17 00:00:00 2001 From: Ali Rahbar Date: Sun, 7 Sep 2025 20:12:28 -0400 Subject: [PATCH 1/6] added fixed function logic --- .../function_libs/torch_lib/ops/core.py | 35 ++++++++++++++----- 1 file changed, 26 insertions(+), 9 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index ab992e0580..e98e3c3f12 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -3101,27 +3101,44 @@ def aten_embedding_bag_padding_idx( sparse: bool = False, per_sample_weights: Optional[TFloat] = None, include_last_offset: bool = False, - padding_idx: int = -1, + padding_idx: Optional[int] = None, ) -> Tuple[TFloat, TFloat, TFloat, TFloat]: """embedding_bag.padding_idx(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq, int mode, bool sparse, Tensor? per_sample_weights, bool include_last_offset, int? padding_idx) -> (Tensor, Tensor, Tensor, Tensor) We add default values for the attributes to accommodate _embedding_bag as well: _embedding_bag(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq=False, int mode=0, bool sparse=False, Tensor? per_sample_weights=None, bool include_last_offset=False, int padding_idx=-1) """ - assert padding_idx is not None, ( - "padding_idx must not be None. This is likely a dispatcher error" - ) if per_sample_weights is None: per_sample_weights = op.Expand(op.Constant(value_floats=[1.0]), op.Shape(indices)) per_sample_weights = op.CastLike(per_sample_weights, weight) - # Change padding_idx to positive value, -1 means the last index - if padding_idx < 0: - padding_idx = weight.shape[0] + padding_idx + if padding_idx is not None: + # Call the existing function for handling padding_idx + result, offset2bag, bag_size, max_indices =_aten_embedding_bag_1d_padding_idx_onnx( + weight, + indices, + offsets, + scale_grad_by_freq, + mode, + sparse, + per_sample_weights, + include_last_offset, + padding_idx, + ) - result, offset2bag, bag_size, max_indices = _aten_embedding_bag_1d_padding_idx_onnx( - weight, indices, offsets, mode, per_sample_weights, include_last_offset, padding_idx + return result, offset2bag, bag_size, max_indices + + # When padding_idx is None, use the standard embedding_bag implementation + result, offset2bag, bag_size, max_indices = _aten_embedding_bag_onnx( + weight, + indices, + offsets, + scale_grad_by_freq, + mode, + sparse, + per_sample_weights, + include_last_offset, ) return result, offset2bag, bag_size, max_indices From 40f487bcb5ce0a7f158119623b1b64ea8729ec26 Mon Sep 17 00:00:00 2001 From: Ali Rahbar Date: Mon, 8 Sep 2025 14:12:28 -0400 Subject: [PATCH 2/6] added test cases for aten_embedding_bag_padding_idx --- tests/function_libs/torch_lib/extra_opinfo.py | 38 +++++++++++++++++ .../function_libs/torch_lib/ops_test_data.py | 41 +++++++++++++++++++ 2 files changed, 79 insertions(+) diff --git a/tests/function_libs/torch_lib/extra_opinfo.py b/tests/function_libs/torch_lib/extra_opinfo.py index ca80cf5172..4e607ff36c 100644 --- a/tests/function_libs/torch_lib/extra_opinfo.py +++ b/tests/function_libs/torch_lib/extra_opinfo.py @@ -2210,6 +2210,44 @@ def __init__(self): sample_inputs_func=sample_inputs_embedding_bag_padding_idx, supports_out=False, ), + opinfo_core.OpInfo( + "test_embedding_bag_with_padding_idx_none", + op=torch.nn.functional.embedding_bag, + dtypes=(torch.float32,), + sample_inputs_func=lambda op_info, device, dtype, requires_grad: [ + opinfo_core.SampleInput( + torch.tensor( + [[1.0, 1.0, 1.0], [2.0, 2.0, 2.0], [3.0, 3.0, 3.0], [4.0, 4.0, 4.0]], + dtype=dtype, + device=device, + ), + args=( + torch.tensor([0, 1, 2, 3], dtype=torch.int64, device=device), + torch.tensor([0, 2], dtype=torch.int64, device=device), + ), + kwargs={"padding_idx": None}, + ) + ], + ), + opinfo_core.OpInfo( + "test_embedding_bag_with_padding_idx_int", + op=torch.nn.functional.embedding_bag, + dtypes=(torch.float32,), + sample_inputs_func=lambda op_info, device, dtype, requires_grad: [ + opinfo_core.SampleInput( + torch.tensor( + [[1.0, 1.0, 1.0], [2.0, 2.0, 2.0], [3.0, 3.0, 3.0]], + dtype=dtype, + device=device, + ), + args=( + torch.tensor([0, 1, 2], dtype=torch.int64, device=device), + torch.tensor([0, 2], dtype=torch.int64, device=device), + ), + kwargs={"padding_idx": 0}, + ) + ], + ), opinfo_core.OpInfo( "ops.aten.embedding_renorm", aten_name="embedding_renorm", diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index 7af7413185..2b06003162 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -185,6 +185,24 @@ def xfail( # Modify this section ########################################################## +def _embedding_bag_input_wrangler( + args: list[Any], kwargs: dict[str, Any] +) -> tuple[list[Any], dict[str, Any]]: + # ONNX attributes cannot be None; omit padding_idx if it’s None. + padding_idx = kwargs.pop("padding_idx", "___MISSING___") + if padding_idx is not "___MISSING___": + if padding_idx is not None: + kwargs["padding_idx"] = int(padding_idx) + + # Ensure indices/offsets are int64 (positional: weight, indices, offsets, ...) + if len(args) >= 3: + if isinstance(args[1], torch.Tensor): + args[1] = args[1].to(torch.long) + if isinstance(args[2], torch.Tensor): + args[2] = args[2].to(torch.long) + + return args, kwargs + def _amin_amax_input_wrangler( args: list[Any], kwargs: dict[str, Any] ) -> tuple[list[Any], dict[str, Any]]: @@ -1035,15 +1053,38 @@ def _where_input_wrangler( core_ops.aten_embedding_bag, tolerance={torch.float32: (1e-4, 5e-4)}, compare_shape_only_for_output=(1, 2, 3), + input_wrangler=_embedding_bag_input_wrangler, ).skip( dtypes=(torch.float16,), reason="fixme: results mismatch in torch nightly.", ), + TorchLibOpInfo( + "test_embedding_bag_with_padding_idx_none", + core_ops.aten_embedding_bag, + input_wrangler=_embedding_bag_input_wrangler, + ), + TorchLibOpInfo( + "test_embedding_bag_with_padding_idx_int", + core_ops.aten_embedding_bag, + input_wrangler=_embedding_bag_input_wrangler, + ), + TorchLibOpInfo( "ops.aten.embedding_bag.padding_idx", core_ops.aten_embedding_bag_padding_idx, tolerance={torch.float16: (1e-2, 1e-2)}, compare_shape_only_for_output=(1, 2, 3), + input_wrangler=_embedding_bag_input_wrangler, + ), + TorchLibOpInfo( + "test_embedding_bag_with_padding_idx_none", + core_ops.aten_embedding_bag, + input_wrangler=_embedding_bag_input_wrangler, + ), + TorchLibOpInfo( + "test_embedding_bag_with_padding_idx_int", + core_ops.aten_embedding_bag, + input_wrangler=_embedding_bag_input_wrangler, ), TorchLibOpInfo( "ops.aten.embedding_renorm", From 294eca3e2d4691c552ff14c7dcabe812bdb8139d Mon Sep 17 00:00:00 2001 From: Ali Rahbar Date: Mon, 8 Sep 2025 23:28:50 -0400 Subject: [PATCH 3/6] fix: resolve lint warnings and comparison issue --- onnxscript/function_libs/torch_lib/ops/core.py | 4 ---- tests/function_libs/torch_lib/ops_test_data.py | 4 ++-- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index e98e3c3f12..589a5f4ba1 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -3119,9 +3119,7 @@ def aten_embedding_bag_padding_idx( weight, indices, offsets, - scale_grad_by_freq, mode, - sparse, per_sample_weights, include_last_offset, padding_idx, @@ -3134,9 +3132,7 @@ def aten_embedding_bag_padding_idx( weight, indices, offsets, - scale_grad_by_freq, mode, - sparse, per_sample_weights, include_last_offset, ) diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index 2b06003162..6612c7c72a 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -188,9 +188,9 @@ def xfail( def _embedding_bag_input_wrangler( args: list[Any], kwargs: dict[str, Any] ) -> tuple[list[Any], dict[str, Any]]: - # ONNX attributes cannot be None; omit padding_idx if it’s None. + # ONNX attributes cannot be None; omit padding_idx if it's None. padding_idx = kwargs.pop("padding_idx", "___MISSING___") - if padding_idx is not "___MISSING___": + if padding_idx != "___MISSING___": if padding_idx is not None: kwargs["padding_idx"] = int(padding_idx) From 71920357875e3904fe372452572a34bc2cde9201 Mon Sep 17 00:00:00 2001 From: Ali Rahbar Date: Tue, 9 Sep 2025 14:04:10 -0400 Subject: [PATCH 4/6] fix: clean up _embedding_bag_input_wrangler padding_idx check --- tests/function_libs/torch_lib/ops_test_data.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index 6612c7c72a..7ec639fad0 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -189,8 +189,8 @@ def _embedding_bag_input_wrangler( args: list[Any], kwargs: dict[str, Any] ) -> tuple[list[Any], dict[str, Any]]: # ONNX attributes cannot be None; omit padding_idx if it's None. - padding_idx = kwargs.pop("padding_idx", "___MISSING___") - if padding_idx != "___MISSING___": + if "padding_idx" in kwargs: + padding_idx = kwargs.pop("padding_idx") if padding_idx is not None: kwargs["padding_idx"] = int(padding_idx) From 6e41cfe796809770469b27c8269f287baadd3c9d Mon Sep 17 00:00:00 2001 From: Ali Rahbar Date: Sat, 13 Sep 2025 20:14:16 -0400 Subject: [PATCH 5/6] fix: fixed bugs and issues with test cases --- tests/function_libs/torch_lib/extra_opinfo.py | 8 ++++---- tests/function_libs/torch_lib/ops_test_data.py | 14 ++------------ 2 files changed, 6 insertions(+), 16 deletions(-) diff --git a/tests/function_libs/torch_lib/extra_opinfo.py b/tests/function_libs/torch_lib/extra_opinfo.py index 4e607ff36c..3d81896187 100644 --- a/tests/function_libs/torch_lib/extra_opinfo.py +++ b/tests/function_libs/torch_lib/extra_opinfo.py @@ -2211,9 +2211,9 @@ def __init__(self): supports_out=False, ), opinfo_core.OpInfo( - "test_embedding_bag_with_padding_idx_none", + "ops.aten.embedding_bag.padding_idx_none", op=torch.nn.functional.embedding_bag, - dtypes=(torch.float32,), + dtypes=common_dtype.floating_types_and_half(), sample_inputs_func=lambda op_info, device, dtype, requires_grad: [ opinfo_core.SampleInput( torch.tensor( @@ -2230,9 +2230,9 @@ def __init__(self): ], ), opinfo_core.OpInfo( - "test_embedding_bag_with_padding_idx_int", + "ops.aten.embedding_bag.padding_idx_int", op=torch.nn.functional.embedding_bag, - dtypes=(torch.float32,), + dtypes=common_dtype.floating_types_and_half(), sample_inputs_func=lambda op_info, device, dtype, requires_grad: [ opinfo_core.SampleInput( torch.tensor( diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index 7ec639fad0..6c560ee126 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -1059,12 +1059,12 @@ def _where_input_wrangler( reason="fixme: results mismatch in torch nightly.", ), TorchLibOpInfo( - "test_embedding_bag_with_padding_idx_none", + "ops.aten.embedding_bag.padding_idx_none", core_ops.aten_embedding_bag, input_wrangler=_embedding_bag_input_wrangler, ), TorchLibOpInfo( - "test_embedding_bag_with_padding_idx_int", + "ops.aten.embedding_bag.padding_idx_int", core_ops.aten_embedding_bag, input_wrangler=_embedding_bag_input_wrangler, ), @@ -1076,16 +1076,6 @@ def _where_input_wrangler( compare_shape_only_for_output=(1, 2, 3), input_wrangler=_embedding_bag_input_wrangler, ), - TorchLibOpInfo( - "test_embedding_bag_with_padding_idx_none", - core_ops.aten_embedding_bag, - input_wrangler=_embedding_bag_input_wrangler, - ), - TorchLibOpInfo( - "test_embedding_bag_with_padding_idx_int", - core_ops.aten_embedding_bag, - input_wrangler=_embedding_bag_input_wrangler, - ), TorchLibOpInfo( "ops.aten.embedding_renorm", core_ops.aten_embedding_renorm, From d2da96ecf5c7639d8162612d2f07c89d646084a4 Mon Sep 17 00:00:00 2001 From: Ali Rahbar Date: Sat, 13 Sep 2025 20:24:50 -0400 Subject: [PATCH 6/6] style: fixed linting issues in code --- onnxscript/function_libs/torch_lib/ops/core.py | 2 +- tests/function_libs/torch_lib/ops_test_data.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 589a5f4ba1..6eb9fb4cbb 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -3115,7 +3115,7 @@ def aten_embedding_bag_padding_idx( if padding_idx is not None: # Call the existing function for handling padding_idx - result, offset2bag, bag_size, max_indices =_aten_embedding_bag_1d_padding_idx_onnx( + result, offset2bag, bag_size, max_indices = _aten_embedding_bag_1d_padding_idx_onnx( weight, indices, offsets, diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index 6c560ee126..183b23cc4c 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -203,6 +203,7 @@ def _embedding_bag_input_wrangler( return args, kwargs + def _amin_amax_input_wrangler( args: list[Any], kwargs: dict[str, Any] ) -> tuple[list[Any], dict[str, Any]]: @@ -1068,7 +1069,6 @@ def _where_input_wrangler( core_ops.aten_embedding_bag, input_wrangler=_embedding_bag_input_wrangler, ), - TorchLibOpInfo( "ops.aten.embedding_bag.padding_idx", core_ops.aten_embedding_bag_padding_idx,