Skip to content

[MLIR][Transform] Introduce transform.tune.knob op #146732

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from

Conversation

rolfmorel
Copy link
Contributor

@rolfmorel rolfmorel commented Jul 2, 2025

A new transform op to represent that an attribute is to be chosen from a set of alternatives and that this choice is made available as a !transform.param. When a selected argument is provided, the op's apply() semantics is that of just making this selected attribute available as the result. When selected is not provided, apply() complains that nothing has resolved the non-determinism that the op is representing.

rolfmorel added 2 commits July 2, 2025 08:40
A new op to represent that an attribute is to be chosen from a set of
alternatives and that this choice is made available as a
`!transform.param`. When a `selected` argument is provided, the op's
`apply()` semantics is that of just making this selected attribute
available as the result. When `selected` is not provided, `apply()`
complains that nothing has resolved the non-determinism that the op is
representing.
@rolfmorel rolfmorel changed the title [MLIR][Transform] Introduce transform.tune.select op [MLIR][Transform] Introduce transform.tune.knob op Jul 3, 2025
@rolfmorel rolfmorel requested review from rengolin and fschlimb July 3, 2025 14:03
@rolfmorel rolfmorel marked this pull request as ready for review July 3, 2025 14:03
@llvmbot llvmbot added mlir:python MLIR Python bindings mlir labels Jul 3, 2025
@llvmbot
Copy link
Member

llvmbot commented Jul 3, 2025

@llvm/pr-subscribers-mlir

Author: Rolf Morel (rolfmorel)

Changes

A new transform op to represent that an attribute is to be chosen from a set of alternatives and that this choice is made available as a !transform.param. When a selected argument is provided, the op's apply() semantics is that of just making this selected attribute available as the result. When selected is not provided, apply() complains that nothing has resolved the non-determinism that the op is representing.


Patch is 24.52 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/146732.diff

16 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Transform/CMakeLists.txt (+1)
  • (added) mlir/include/mlir/Dialect/Transform/TuneExtension/CMakeLists.txt (+6)
  • (added) mlir/include/mlir/Dialect/Transform/TuneExtension/TuneExtension.h (+21)
  • (added) mlir/include/mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.h (+19)
  • (added) mlir/include/mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.td (+54)
  • (modified) mlir/include/mlir/InitAllExtensions.h (+2)
  • (modified) mlir/lib/Dialect/Transform/CMakeLists.txt (+1)
  • (added) mlir/lib/Dialect/Transform/TuneExtension/CMakeLists.txt (+12)
  • (added) mlir/lib/Dialect/Transform/TuneExtension/TuneExtension.cpp (+32)
  • (added) mlir/lib/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp (+62)
  • (modified) mlir/python/CMakeLists.txt (+9)
  • (added) mlir/python/mlir/dialects/TransformTuneExtensionOps.td (+18)
  • (added) mlir/python/mlir/dialects/transform/tune.py (+82)
  • (added) mlir/test/Dialect/Transform/test-tune-extension-invalid.mlir (+21)
  • (added) mlir/test/Dialect/Transform/test-tune-extension.mlir (+61)
  • (added) mlir/test/python/dialects/transform_tune_ext.py (+56)
diff --git a/mlir/include/mlir/Dialect/Transform/CMakeLists.txt b/mlir/include/mlir/Dialect/Transform/CMakeLists.txt
index b6155b5f573f1..e70479b2a39f2 100644
--- a/mlir/include/mlir/Dialect/Transform/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/Transform/CMakeLists.txt
@@ -5,3 +5,4 @@ add_subdirectory(IRDLExtension)
 add_subdirectory(LoopExtension)
 add_subdirectory(PDLExtension)
 add_subdirectory(Transforms)
+add_subdirectory(TuneExtension)
diff --git a/mlir/include/mlir/Dialect/Transform/TuneExtension/CMakeLists.txt b/mlir/include/mlir/Dialect/Transform/TuneExtension/CMakeLists.txt
new file mode 100644
index 0000000000000..9afca813afda6
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Transform/TuneExtension/CMakeLists.txt
@@ -0,0 +1,6 @@
+set(LLVM_TARGET_DEFINITIONS TuneExtensionOps.td)
+mlir_tablegen(TuneExtensionOps.h.inc -gen-op-decls)
+mlir_tablegen(TuneExtensionOps.cpp.inc -gen-op-defs)
+add_public_tablegen_target(MLIRTransformDialectTuneExtensionOpsIncGen)
+
+add_mlir_doc(TuneExtensionOps TuneExtensionOps Dialects/ -gen-op-doc)
diff --git a/mlir/include/mlir/Dialect/Transform/TuneExtension/TuneExtension.h b/mlir/include/mlir/Dialect/Transform/TuneExtension/TuneExtension.h
new file mode 100644
index 0000000000000..1453d1754297f
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Transform/TuneExtension/TuneExtension.h
@@ -0,0 +1,21 @@
+//===- TuneExtension.h - Tune extension for Transform dialect ---*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_TRANSFORM_TUNEEXTENSION_TUNEEXTENSION_H
+#define MLIR_DIALECT_TRANSFORM_TUNEEXTENSION_TUNEEXTENSION_H
+
+namespace mlir {
+class DialectRegistry;
+
+namespace transform {
+/// Registers the tune extension of the Transform dialect in the given registry.
+void registerTuneExtension(DialectRegistry &dialectRegistry);
+} // namespace transform
+} // namespace mlir
+
+#endif // MLIR_DIALECT_TRANSFORM_TUNEEXTENSION_TUNEEXTENSION_H
diff --git a/mlir/include/mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.h b/mlir/include/mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.h
new file mode 100644
index 0000000000000..74e1d28ffac82
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.h
@@ -0,0 +1,19 @@
+//===- TuneExtensionOps.h - Tune ext. for Transform dialect -----*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_TRANSFORM_TUNEEXTENSION_TUNEEXTENSIONOPS_H
+#define MLIR_DIALECT_TRANSFORM_TUNEEXTENSION_TUNEEXTENSIONOPS_H
+
+#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/OpDefinition.h"
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.h.inc"
+
+#endif // MLIR_DIALECT_TRANSFORM_TUNEEXTENSION_TUNEEXTENSIONOPS_H
diff --git a/mlir/include/mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.td b/mlir/include/mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.td
new file mode 100644
index 0000000000000..afb67e8fef250
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.td
@@ -0,0 +1,54 @@
+//===- TuneExtensionOps.td - Transform dialect operations --*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_TRANSFORM_TUNEEXTENSION_TUNEEXTENSIONOPS
+#define MLIR_DIALECT_TRANSFORM_TUNEEXTENSION_TUNEEXTENSIONOPS
+
+include "mlir/Dialect/Transform/IR/TransformDialect.td"
+include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td"
+include "mlir/Interfaces/SideEffectInterfaces.td"
+include "mlir/IR/BuiltinAttributes.td"
+include "mlir/IR/CommonAttrConstraints.td"
+
+def KnobOp : Op<Transform_Dialect, "tune.knob", [
+  DeclareOpInterfaceMethods<TransformOpInterface>,
+  DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+]> {
+  let summary = "Represents a tunable parameter with a set of options";
+
+  let description = [{
+    Provides a representation for "tunables" within schedules.
+
+    Each op represents a single tunable, which has a `name` and a set
+    of valid `options` described by an attribute. Without a specified
+    `selected` option, this op represents a non-deterministic choice
+    that has yet to be resolved -- as such, the interpreter runtime
+    semantics is to raise a failure.
+
+    The non-deterministic choice is resolved through providing a
+    `selected` attribute. When provided, the interpreter runtime
+    semantics are to return the `selected` attribute as a param through
+    the op's result.
+
+    -----
+
+    In case the `options` attribute is an `ArrayAttr`, the verifier checks that the provided `selected` attribute occurs in `options`.
+  }];
+  let cppNamespace = [{ mlir::transform::tune }];
+  let hasVerifier = 1;
+
+  let arguments = (ins Builtin_StringAttr:$name,
+                       AnyAttr:$options,
+                       OptionalAttr<AnyAttr>:$selected);
+  let results = (outs TransformParamTypeInterface:$result);
+
+  let assemblyFormat =
+      "`<` $name `>` (`=` $selected^ `from`)? `options` `=` $options attr-dict `->` type(results)";
+}
+
+#endif // MLIR_DIALECT_TRANSFORM_TUNEEXTENSION_TUNEEXTENSIONOPS
diff --git a/mlir/include/mlir/InitAllExtensions.h b/mlir/include/mlir/InitAllExtensions.h
index f356b91b1b6c0..0f2d0e45008cc 100644
--- a/mlir/include/mlir/InitAllExtensions.h
+++ b/mlir/include/mlir/InitAllExtensions.h
@@ -52,6 +52,7 @@
 #include "mlir/Dialect/Transform/IRDLExtension/IRDLExtension.h"
 #include "mlir/Dialect/Transform/LoopExtension/LoopExtension.h"
 #include "mlir/Dialect/Transform/PDLExtension/PDLExtension.h"
+#include "mlir/Dialect/Transform/TuneExtension/TuneExtension.h"
 #include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.h"
 #include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h"
 #include "mlir/Target/LLVMIR/Dialect/GPU/GPUToLLVMIRTranslation.h"
@@ -107,6 +108,7 @@ inline void registerAllExtensions(DialectRegistry &registry) {
   transform::registerIRDLExtension(registry);
   transform::registerLoopExtension(registry);
   transform::registerPDLExtension(registry);
+  transform::registerTuneExtension(registry);
   vector::registerTransformDialectExtension(registry);
   arm_neon::registerTransformDialectExtension(registry);
   arm_sve::registerTransformDialectExtension(registry);
diff --git a/mlir/lib/Dialect/Transform/CMakeLists.txt b/mlir/lib/Dialect/Transform/CMakeLists.txt
index 0c0d5ebe0c212..6e628353258d6 100644
--- a/mlir/lib/Dialect/Transform/CMakeLists.txt
+++ b/mlir/lib/Dialect/Transform/CMakeLists.txt
@@ -5,4 +5,5 @@ add_subdirectory(IRDLExtension)
 add_subdirectory(LoopExtension)
 add_subdirectory(PDLExtension)
 add_subdirectory(Transforms)
+add_subdirectory(TuneExtension)
 add_subdirectory(Utils)
diff --git a/mlir/lib/Dialect/Transform/TuneExtension/CMakeLists.txt b/mlir/lib/Dialect/Transform/TuneExtension/CMakeLists.txt
new file mode 100644
index 0000000000000..ff01d25e57f68
--- /dev/null
+++ b/mlir/lib/Dialect/Transform/TuneExtension/CMakeLists.txt
@@ -0,0 +1,12 @@
+add_mlir_dialect_library(MLIRTransformTuneExtension
+  TuneExtension.cpp
+  TuneExtensionOps.cpp
+
+  DEPENDS
+  MLIRTransformDialectTuneExtensionOpsIncGen
+
+  LINK_LIBS PUBLIC
+  MLIRIR
+  MLIRTransformDialect
+  MLIRTransforms
+)
diff --git a/mlir/lib/Dialect/Transform/TuneExtension/TuneExtension.cpp b/mlir/lib/Dialect/Transform/TuneExtension/TuneExtension.cpp
new file mode 100644
index 0000000000000..e18f1e2748540
--- /dev/null
+++ b/mlir/lib/Dialect/Transform/TuneExtension/TuneExtension.cpp
@@ -0,0 +1,32 @@
+//===- TuneExtension.cpp - Tune extension for the Transform dialect -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Transform/TuneExtension/TuneExtension.h"
+
+#include "mlir/Dialect/Transform/IR/TransformDialect.h"
+#include "mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.h"
+#include "mlir/IR/DialectRegistry.h"
+
+using namespace mlir;
+
+class TuneExtension
+    : public transform::TransformDialectExtension<TuneExtension> {
+public:
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TuneExtension)
+
+  void init() {
+    registerTransformOps<
+#define GET_OP_LIST
+#include "mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp.inc"
+        >();
+  }
+};
+
+void mlir::transform::registerTuneExtension(DialectRegistry &dialectRegistry) {
+  dialectRegistry.addExtensions<TuneExtension>();
+}
diff --git a/mlir/lib/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp b/mlir/lib/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp
new file mode 100644
index 0000000000000..0c77dbb0f05dd
--- /dev/null
+++ b/mlir/lib/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp
@@ -0,0 +1,62 @@
+//===- TuneExtensionOps.cpp - Tune extension for the Transform dialect ----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Transform/IR/TransformOps.h"
+#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/PatternMatch.h"
+#include "llvm/Support/Debug.h"
+
+#include "mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.h"
+
+using namespace mlir;
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp.inc"
+
+#define DEBUG_TYPE "transform-tune"
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "] ")
+
+//===----------------------------------------------------------------------===//
+// KnobOp
+//===----------------------------------------------------------------------===//
+
+void transform::tune::KnobOp::getEffects(
+    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+  producesHandle(getOperation()->getOpResults(), effects);
+  onlyReadsPayload(effects);
+}
+
+DiagnosedSilenceableFailure
+transform::tune::KnobOp::apply(transform::TransformRewriter &rewriter,
+                               transform::TransformResults &results,
+                               transform::TransformState &state) {
+  if (getSelected()) {
+    results.setParams(getOperation()->getOpResults()[0], *getSelected());
+    return DiagnosedSilenceableFailure::success();
+  }
+
+  return emitDefiniteFailure()
+         << "non-deterministic choice " << getName()
+         << " is only resolved through providing a `selected` attr";
+}
+
+LogicalResult transform::tune::KnobOp::verify() {
+  if (auto selected = getSelected()) {
+    if (auto optionsArray = dyn_cast<ArrayAttr>(getOptions())) {
+      if (!llvm::is_contained(optionsArray, selected))
+        return emitOpError("provided `selected` attribute is not an element of "
+                           "`options` array of attributes");
+    } else
+      LLVM_DEBUG(DBGS() << "cannot verify `selected` attribute " << selected
+                        << " is an element of `options` attribute "
+                        << getOptions());
+  }
+
+  return success();
+}
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index b2daabb2a5957..7a0c95ebb8200 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -180,6 +180,15 @@ ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
   DIALECT_NAME transform
   EXTENSION_NAME transform_debug_extension)
 
+declare_mlir_dialect_extension_python_bindings(
+ADD_TO_PARENT MLIRPythonSources.Dialects
+ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
+  TD_FILE dialects/TransformTuneExtensionOps.td
+  SOURCES
+    dialects/transform/tune.py
+  DIALECT_NAME transform
+  EXTENSION_NAME transform_tune_extension)
+
 declare_mlir_dialect_python_bindings(
   ADD_TO_PARENT MLIRPythonSources.Dialects
   ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
diff --git a/mlir/python/mlir/dialects/TransformTuneExtensionOps.td b/mlir/python/mlir/dialects/TransformTuneExtensionOps.td
new file mode 100644
index 0000000000000..60ed95d110762
--- /dev/null
+++ b/mlir/python/mlir/dialects/TransformTuneExtensionOps.td
@@ -0,0 +1,18 @@
+//===-- TransformTuneExtensionOps.td - Binding entry point -*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Entry point of the generated Python bindings for the Tune extension of the
+// Transform dialect.
+//===----------------------------------------------------------------------===//
+
+#ifndef PYTHON_BINDINGS_TRANSFORM_DEBUG_EXTENSION_OPS
+#define PYTHON_BINDINGS_TRANSFORM_DEBUG_EXTENSION_OPS
+
+include "mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.td"
+
+#endif // PYTHON_BINDINGS_TRANSFORM_DEBUG_EXTENSION_OPS
diff --git a/mlir/python/mlir/dialects/transform/tune.py b/mlir/python/mlir/dialects/transform/tune.py
new file mode 100644
index 0000000000000..15c43aba795eb
--- /dev/null
+++ b/mlir/python/mlir/dialects/transform/tune.py
@@ -0,0 +1,82 @@
+#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+#  See https://llvm.org/LICENSE.txt for license information.
+#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+from typing import Optional, Sequence
+
+from ...ir import (
+    Type,
+    Attribute,
+    ArrayAttr,
+    StringAttr,
+    F64Type,
+    IntegerType,
+    IntegerAttr,
+    FloatAttr,
+    BoolAttr,
+)
+from .._transform_tune_extension_ops_gen import *
+from .._transform_tune_extension_ops_gen import _Dialect
+
+try:
+    from .._ods_common import _cext as _ods_cext
+except ImportError as e:
+    raise RuntimeError("Error loading imports from extension module") from e
+
+from typing import Union
+
+
+@_ods_cext.register_operation(_Dialect, replace=True)
+class KnobOp(KnobOp):
+    def __init__(
+        self,
+        result: Type, # !transform.any_param or !transform.param<Type>
+        name: Union[StringAttr, str],
+        options: Union[
+            ArrayAttr, Sequence[Union[Attribute, bool, int, float, str]], Attribute
+        ],
+        *,
+        selected: Optional[Attribute] = None,
+        loc=None,
+        ip=None,
+    ):
+        if isinstance(name, str):
+            name = StringAttr.get(name)
+
+        def map_to_attr(value):
+            if isinstance(value, bool):
+                return BoolAttr.get(value)
+            if isinstance(value, int):
+                return IntegerAttr.get(IntegerType.get_signless(64), value)
+            if isinstance(value, float):
+                return FloatAttr.get(F64Type.get(), value)
+            if isinstance(value, str):
+                return StringAttr.get(value)
+            assert isinstance(value, Attribute)
+            return value
+
+        if isinstance(options, Sequence) and not isinstance(options, ArrayAttr):
+            options = ArrayAttr.get([map_to_attr(opt) for opt in options])
+
+        super().__init__(
+            result,
+            name,
+            options,
+            selected=selected and map_to_attr(selected),
+            loc=loc,
+            ip=ip,
+        )
+
+
+def knob(
+    result: Type, # !transform.any_param or !transform.param<Type>
+    name: Union[StringAttr, str],
+    options: Union[
+        ArrayAttr, Sequence[Union[Attribute, bool, int, float, str]], Attribute
+    ],
+    *,
+    selected: Optional[Attribute] = None,
+    loc=None,
+    ip=None,
+):
+    return KnobOp(result, name, options, selected=selected, loc=loc, ip=ip)
diff --git a/mlir/test/Dialect/Transform/test-tune-extension-invalid.mlir b/mlir/test/Dialect/Transform/test-tune-extension-invalid.mlir
new file mode 100644
index 0000000000000..2e5f433abeb71
--- /dev/null
+++ b/mlir/test/Dialect/Transform/test-tune-extension-invalid.mlir
@@ -0,0 +1,21 @@
+// RUN: mlir-opt %s --transform-interpreter --split-input-file --verify-diagnostics
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    // expected-error@below {{provided `selected` attribute is not an element of `options` array of attributes}}
+    %heads_or_tails = transform.tune.knob<"coin"> = 1 from options = [true, false] -> !transform.any_param
+    transform.yield
+  }
+}
+
+// -----
+
+func.func private @f()
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    // expected-error@below {{non-deterministic choice "coin" is only resolved through providing a `selected` attr}}
+    %heads_or_tails = transform.tune.knob<"coin"> options = [true, false] -> !transform.any_param
+    transform.yield
+  }
+}
diff --git a/mlir/test/Dialect/Transform/test-tune-extension.mlir b/mlir/test/Dialect/Transform/test-tune-extension.mlir
new file mode 100644
index 0000000000000..0a253c6d5f837
--- /dev/null
+++ b/mlir/test/Dialect/Transform/test-tune-extension.mlir
@@ -0,0 +1,61 @@
+// RUN: mlir-opt %s --transform-interpreter --split-input-file \
+// RUN:     --verify-diagnostics | FileCheck %s
+
+// CHECK-LABEL: @schedule_with_nondet_knobs
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @schedule_with_nondet_knobs(%arg0: !transform.any_op {transform.readonly}) {
+    // CHECK: %[[HEADS_OR_TAILS:.*]] = transform.tune.knob<"coin"> options = [true, false] -> !transform.any_param
+    %heads_or_tails = transform.tune.knob<"coin"> options = [true, false] -> !transform.any_param
+    // CHECK: transform.tune.knob<"animal"> options = ["cat", "dog", unit] -> !transform.any_param
+    %chosen_category = transform.tune.knob<"animal"> options = ["cat", "dog", unit] -> !transform.any_param
+    // CHECK: transform.tune.knob<"tile_size"> options = [2, 4, 8, 16, 24, 32] -> !transform.any_param
+    %chosen_tile_size = transform.tune.knob<"tile_size"> options = [2, 4, 8, 16, 24, 32] -> !transform.any_param
+    // CHECK: transform.tune.knob<"magic_value"> options = [2.000000e+00 : f32, 2.250000e+00 : f32, 2.500000e+00 : f32, 2.750000e+00 : f32, 3.000000e+00 : f32] -> !transform.any_param
+    %chosen_constant = transform.tune.knob<"magic_value"> options = [2.0 : f32, 2.25 : f32, 2.5 : f32, 2.75 : f32, 3.0 : f32] -> !transform.any_param
+    // CHECK: transform.debug.emit_param_as_remark %[[HEADS_OR_TAILS]]
+    transform.debug.emit_param_as_remark %heads_or_tails : !transform.any_param
+    transform.yield
+  }
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    // Dummy sequence to appease -transform-interpreter invocation
+    transform.yield
+  }
+}
+
+// -----
+
+// Schedule where non-determinism on knobs has been resolved by selecting a valid option.
+
+// CHECK-LABEL: payload_for_schedule_with_selected_knobs
+func.func private @payload_for_schedule_with_selected_knobs()
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    // CHECK: %[[HEADS_OR_TAILS:.*]] = transform.tune.knob<"coin"> = true from options = [true, false] -> !transform.any_param
+    %heads_or_tails = transform.tune.knob<"coin"> = true from options = [true, false] -> !transform.any_param
+    // expected-remark@below {{true}...
[truncated]

Copy link

github-actions bot commented Jul 3, 2025

✅ With the latest revision this PR passed the Python code formatter.

@rengolin rengolin requested a review from silvasean July 3, 2025 15:48
@rengolin
Copy link
Member

rengolin commented Jul 3, 2025

@tkarna

Copy link
Contributor

@fschlimb fschlimb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, thanks!
I would like to follow up on ways to use such knobs to alter attributes/constants and even types in the payload, not only within the transforms.

@rolfmorel
Copy link
Contributor Author

Nice, thanks! I would like to follow up on ways to use such knobs to alter attributes/constants and even types in the payload, not only within the transforms.

I expect the nicest way to do that - given your knobs are in schedule IR - is to compose this op with other transform ops. For example, we can feed the param produced by transform.tune.knob into transform.annotate and thereby modify the attributes in the payload.

Comment on lines +46 to +59
def map_to_attr(value):
if isinstance(value, bool):
return BoolAttr.get(value)
if isinstance(value, int):
return IntegerAttr.get(IntegerType.get_signless(64), value)
if isinstance(value, float):
return FloatAttr.get(F64Type.get(), value)
if isinstance(value, str):
return StringAttr.get(value)
assert isinstance(value, Attribute)
return value

if isinstance(options, Sequence) and not isinstance(options, ArrayAttr):
options = ArrayAttr.get([map_to_attr(opt) for opt in options])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI we have a bunch of these auto-mapped via @register_attribute_builder (that's how the auto-generated classes handle this kind of thing).

Copy link
Contributor Author

@rolfmorel rolfmorel Jul 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm. I don't think there's any general facility at the moment* to recursively convert nested Python values to their wrapped-up equivalents. That is, the registered builder for ArrayAttr does nothing to the provided (list) argument:

@register_attribute_builder("ArrayAttr")
def _arrayAttr(x, context):
return ArrayAttr.get(x, context=context)

There are some wrappers for things like StrArrayAttr that do convert the elements, but for this op the array can be heterogeneous and hence none of those builders would be suitable. I (and I think @nicolasvasilache as well) would be in favour of more of these general converter utilities to do automagical conversions, also when that doesn't map to a pre-specified nested type (like StrArrayAttr) on the C++ side. Preferably these would live in a central place and it would be amazing if these were called as part of the auto-generated __init__ methods of the bindings. This actually seems pretty feasible to do.

*: something like the following, which does the recursive thing, but this one also deals with a special attribute:

def option_value_to_attr(value):
nonlocal cur_param_operand_idx
if isinstance(value, (Value, Operation, OpView)):
dynamic_options.append(_get_op_result_or_value(value))
cur_param_operand_idx += 1
return ParamOperandAttr(cur_param_operand_idx - 1, context)
elif isinstance(value, Attribute):
return value
# The following cases auto-convert Python values to attributes.
elif isinstance(value, bool):
return BoolAttr.get(value)
elif isinstance(value, int):
default_int_type = IntegerType.get_signless(64, context)
return IntegerAttr.get(default_int_type, value)
elif isinstance(value, str):
return StringAttr.get(value)
elif isinstance(value, Sequence):
return ArrayAttr.get([option_value_to_attr(elt) for elt in value])
else:
raise TypeError(f"Unsupported option type: {type(value)}")

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

automagical conversions

In general you would need something like pytree, which would let you "map/transform" a python value to an Attribute.

But I don't think calling such a thing from the autogenerated __init__s is a good idea. The reason being is that those classes/initializers are already a bottleneck for some people (eg Jax people frequently send PRs to speed up certain aspects of op construction).

Preferably these would live in a central place

You could promote map_to_attr to ir.py and/or combine it with option_value_to_attr?

Copy link
Contributor Author

@rolfmorel rolfmorel Jul 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason being is that those classes/initializers are already a bottleneck for some people (eg Jax people frequently send PRs to speed up certain aspects of op construction).

This is a fair take. Would still be nice to have it both ways, e.g. the snake_case wrappers have the bells and whistles while we expect the auto-generated CamelCaseOps to be as vanilla and low-overhead as can be.

You could promote map_to_attr to ir.py and/or combine it with option_value_to_attr?

👍 but will keep it as a TODO for another PR

@rolfmorel
Copy link
Contributor Author

Just a heads-up: unless there are more comments, I am looking to merge this in the next day or so.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:python MLIR Python bindings mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants