Skip to content

[MLIR][Transform] Introduce transform.tune.knob op #146732

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jul 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions mlir/include/mlir/Dialect/Transform/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ add_subdirectory(IRDLExtension)
add_subdirectory(LoopExtension)
add_subdirectory(PDLExtension)
add_subdirectory(Transforms)
add_subdirectory(TuneExtension)
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
set(LLVM_TARGET_DEFINITIONS TuneExtensionOps.td)
mlir_tablegen(TuneExtensionOps.h.inc -gen-op-decls)
mlir_tablegen(TuneExtensionOps.cpp.inc -gen-op-defs)
add_public_tablegen_target(MLIRTransformDialectTuneExtensionOpsIncGen)

add_mlir_doc(TuneExtensionOps TuneExtensionOps Dialects/ -gen-op-doc)
21 changes: 21 additions & 0 deletions mlir/include/mlir/Dialect/Transform/TuneExtension/TuneExtension.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
//===- TuneExtension.h - Tune extension for Transform dialect ---*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_TRANSFORM_TUNEEXTENSION_TUNEEXTENSION_H
#define MLIR_DIALECT_TRANSFORM_TUNEEXTENSION_TUNEEXTENSION_H

namespace mlir {
class DialectRegistry;

namespace transform {
/// Registers the tune extension of the Transform dialect in the given registry.
void registerTuneExtension(DialectRegistry &dialectRegistry);
} // namespace transform
} // namespace mlir

#endif // MLIR_DIALECT_TRANSFORM_TUNEEXTENSION_TUNEEXTENSION_H
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
//===- TuneExtensionOps.h - Tune ext. for Transform dialect -----*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_TRANSFORM_TUNEEXTENSION_TUNEEXTENSIONOPS_H
#define MLIR_DIALECT_TRANSFORM_TUNEEXTENSION_TUNEEXTENSIONOPS_H

#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/OpDefinition.h"

#define GET_OP_CLASSES
#include "mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.h.inc"

#endif // MLIR_DIALECT_TRANSFORM_TUNEEXTENSION_TUNEEXTENSIONOPS_H
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
//===- TuneExtensionOps.td - Transform dialect operations --*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_TRANSFORM_TUNEEXTENSION_TUNEEXTENSIONOPS
#define MLIR_DIALECT_TRANSFORM_TUNEEXTENSION_TUNEEXTENSIONOPS

include "mlir/Dialect/Transform/IR/TransformDialect.td"
include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/IR/BuiltinAttributes.td"
include "mlir/IR/CommonAttrConstraints.td"

def KnobOp : Op<Transform_Dialect, "tune.knob", [
DeclareOpInterfaceMethods<TransformOpInterface>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
]> {
let summary = "Represents a tunable parameter with a set of options";

let description = [{
Provides a representation for "tunables" within schedules.

Each op represents a single tunable, which has a `name` and a set
of valid `options` described by an attribute. Without a specified
`selected` option, this op represents a non-deterministic choice
that has yet to be resolved -- as such, the interpreter runtime
semantics is to raise a failure.

The non-deterministic choice is resolved through providing a
`selected` attribute. When provided, the interpreter runtime
semantics are to return the `selected` attribute as a param through
the op's result.

-----

In case the `options` attribute is an `ArrayAttr`, the verifier
checks that the provided `selected` attribute occurs in `options`.
}];
let cppNamespace = [{ mlir::transform::tune }];
let hasVerifier = 1;

let arguments = (ins Builtin_StringAttr:$name,
AnyAttr:$options,
OptionalAttr<AnyAttr>:$selected);
let results = (outs TransformParamTypeInterface:$result);

let assemblyFormat =
"`<` $name `>` (`=` $selected^ `from`)? `options` `=` $options attr-dict `->` type(results)";
}

#endif // MLIR_DIALECT_TRANSFORM_TUNEEXTENSION_TUNEEXTENSIONOPS
2 changes: 2 additions & 0 deletions mlir/include/mlir/InitAllExtensions.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
#include "mlir/Dialect/Transform/IRDLExtension/IRDLExtension.h"
#include "mlir/Dialect/Transform/LoopExtension/LoopExtension.h"
#include "mlir/Dialect/Transform/PDLExtension/PDLExtension.h"
#include "mlir/Dialect/Transform/TuneExtension/TuneExtension.h"
#include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.h"
#include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h"
#include "mlir/Target/LLVMIR/Dialect/GPU/GPUToLLVMIRTranslation.h"
Expand Down Expand Up @@ -107,6 +108,7 @@ inline void registerAllExtensions(DialectRegistry &registry) {
transform::registerIRDLExtension(registry);
transform::registerLoopExtension(registry);
transform::registerPDLExtension(registry);
transform::registerTuneExtension(registry);
vector::registerTransformDialectExtension(registry);
arm_neon::registerTransformDialectExtension(registry);
arm_sve::registerTransformDialectExtension(registry);
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Dialect/Transform/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@ add_subdirectory(IRDLExtension)
add_subdirectory(LoopExtension)
add_subdirectory(PDLExtension)
add_subdirectory(Transforms)
add_subdirectory(TuneExtension)
add_subdirectory(Utils)
11 changes: 11 additions & 0 deletions mlir/lib/Dialect/Transform/TuneExtension/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
add_mlir_dialect_library(MLIRTransformTuneExtension
TuneExtension.cpp
TuneExtensionOps.cpp

DEPENDS
MLIRTransformDialectTuneExtensionOpsIncGen

LINK_LIBS PUBLIC
MLIRIR
MLIRTransformDialect
)
32 changes: 32 additions & 0 deletions mlir/lib/Dialect/Transform/TuneExtension/TuneExtension.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
//===- TuneExtension.cpp - Tune extension for the Transform dialect -------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Transform/TuneExtension/TuneExtension.h"

#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.h"
#include "mlir/IR/DialectRegistry.h"

using namespace mlir;

class TuneExtension
: public transform::TransformDialectExtension<TuneExtension> {
public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TuneExtension)

void init() {
registerTransformOps<
#define GET_OP_LIST
#include "mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp.inc"
>();
}
};

void mlir::transform::registerTuneExtension(DialectRegistry &dialectRegistry) {
dialectRegistry.addExtensions<TuneExtension>();
}
62 changes: 62 additions & 0 deletions mlir/lib/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
//===- TuneExtensionOps.cpp - Tune extension for the Transform dialect ----===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Transform/IR/TransformOps.h"
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/PatternMatch.h"
#include "llvm/Support/Debug.h"

#include "mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.h"

using namespace mlir;

#define GET_OP_CLASSES
#include "mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp.inc"

#define DEBUG_TYPE "transform-tune"
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "] ")

//===----------------------------------------------------------------------===//
// KnobOp
//===----------------------------------------------------------------------===//

void transform::tune::KnobOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
producesHandle(getOperation()->getOpResults(), effects);
onlyReadsPayload(effects);
}

DiagnosedSilenceableFailure
transform::tune::KnobOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
transform::TransformState &state) {
if (getSelected()) {
results.setParams(llvm::cast<OpResult>(getResult()), *getSelected());
return DiagnosedSilenceableFailure::success();
}

return emitDefiniteFailure()
<< "non-deterministic choice " << getName()
<< " is only resolved through providing a `selected` attr";
}

LogicalResult transform::tune::KnobOp::verify() {
if (auto selected = getSelected()) {
if (auto optionsArray = dyn_cast<ArrayAttr>(getOptions())) {
if (!llvm::is_contained(optionsArray, selected))
return emitOpError("provided `selected` attribute is not an element of "
"`options` array of attributes");
} else
LLVM_DEBUG(DBGS() << "cannot verify `selected` attribute " << selected
<< " is an element of `options` attribute "
<< getOptions());
}

return success();
}
9 changes: 9 additions & 0 deletions mlir/python/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,15 @@ ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
DIALECT_NAME transform
EXTENSION_NAME transform_debug_extension)

declare_mlir_dialect_extension_python_bindings(
ADD_TO_PARENT MLIRPythonSources.Dialects
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
TD_FILE dialects/TransformTuneExtensionOps.td
SOURCES
dialects/transform/tune.py
DIALECT_NAME transform
EXTENSION_NAME transform_tune_extension)

declare_mlir_dialect_python_bindings(
ADD_TO_PARENT MLIRPythonSources.Dialects
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
Expand Down
19 changes: 19 additions & 0 deletions mlir/python/mlir/dialects/TransformTuneExtensionOps.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
//===-- TransformTuneExtensionOps.td - Binding entry point -*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Entry point of the generated Python bindings for the Tune extension of the
// Transform dialect.
//
//===----------------------------------------------------------------------===//

#ifndef PYTHON_BINDINGS_TRANSFORM_DEBUG_EXTENSION_OPS
#define PYTHON_BINDINGS_TRANSFORM_DEBUG_EXTENSION_OPS

include "mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.td"

#endif // PYTHON_BINDINGS_TRANSFORM_DEBUG_EXTENSION_OPS
82 changes: 82 additions & 0 deletions mlir/python/mlir/dialects/transform/tune.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from typing import Optional, Sequence

from ...ir import (
Type,
Attribute,
ArrayAttr,
StringAttr,
F64Type,
IntegerType,
IntegerAttr,
FloatAttr,
BoolAttr,
)
from .._transform_tune_extension_ops_gen import *
from .._transform_tune_extension_ops_gen import _Dialect

try:
from .._ods_common import _cext as _ods_cext
except ImportError as e:
raise RuntimeError("Error loading imports from extension module") from e

from typing import Union


@_ods_cext.register_operation(_Dialect, replace=True)
class KnobOp(KnobOp):
def __init__(
self,
result: Type, # !transform.any_param or !transform.param<Type>
name: Union[StringAttr, str],
options: Union[
ArrayAttr, Sequence[Union[Attribute, bool, int, float, str]], Attribute
],
*,
selected: Optional[Attribute] = None,
loc=None,
ip=None,
):
if isinstance(name, str):
name = StringAttr.get(name)

def map_to_attr(value):
if isinstance(value, bool):
return BoolAttr.get(value)
if isinstance(value, int):
return IntegerAttr.get(IntegerType.get_signless(64), value)
if isinstance(value, float):
return FloatAttr.get(F64Type.get(), value)
if isinstance(value, str):
return StringAttr.get(value)
assert isinstance(value, Attribute)
return value

if isinstance(options, Sequence) and not isinstance(options, ArrayAttr):
options = ArrayAttr.get([map_to_attr(opt) for opt in options])

super().__init__(
result,
name,
options,
selected=selected and map_to_attr(selected),
loc=loc,
ip=ip,
)


def knob(
result: Type, # !transform.any_param or !transform.param<Type>
name: Union[StringAttr, str],
options: Union[
ArrayAttr, Sequence[Union[Attribute, bool, int, float, str]], Attribute
],
*,
selected: Optional[Attribute] = None,
loc=None,
ip=None,
):
return KnobOp(result, name, options, selected=selected, loc=loc, ip=ip)
21 changes: 21 additions & 0 deletions mlir/test/Dialect/Transform/test-tune-extension-invalid.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
// RUN: mlir-opt %s --transform-interpreter --split-input-file --verify-diagnostics

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
// expected-error@below {{provided `selected` attribute is not an element of `options` array of attributes}}
%heads_or_tails = transform.tune.knob<"coin"> = 1 from options = [true, false] -> !transform.any_param
transform.yield
}
}

// -----

func.func private @f()

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
// expected-error@below {{non-deterministic choice "coin" is only resolved through providing a `selected` attr}}
%heads_or_tails = transform.tune.knob<"coin"> options = [true, false] -> !transform.any_param
transform.yield
}
}
Loading
Loading