diff --git a/mlir/include/mlir/Dialect/Transform/CMakeLists.txt b/mlir/include/mlir/Dialect/Transform/CMakeLists.txt index b6155b5f573f1..e70479b2a39f2 100644 --- a/mlir/include/mlir/Dialect/Transform/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/Transform/CMakeLists.txt @@ -5,3 +5,4 @@ add_subdirectory(IRDLExtension) add_subdirectory(LoopExtension) add_subdirectory(PDLExtension) add_subdirectory(Transforms) +add_subdirectory(TuneExtension) diff --git a/mlir/include/mlir/Dialect/Transform/TuneExtension/CMakeLists.txt b/mlir/include/mlir/Dialect/Transform/TuneExtension/CMakeLists.txt new file mode 100644 index 0000000000000..9afca813afda6 --- /dev/null +++ b/mlir/include/mlir/Dialect/Transform/TuneExtension/CMakeLists.txt @@ -0,0 +1,6 @@ +set(LLVM_TARGET_DEFINITIONS TuneExtensionOps.td) +mlir_tablegen(TuneExtensionOps.h.inc -gen-op-decls) +mlir_tablegen(TuneExtensionOps.cpp.inc -gen-op-defs) +add_public_tablegen_target(MLIRTransformDialectTuneExtensionOpsIncGen) + +add_mlir_doc(TuneExtensionOps TuneExtensionOps Dialects/ -gen-op-doc) diff --git a/mlir/include/mlir/Dialect/Transform/TuneExtension/TuneExtension.h b/mlir/include/mlir/Dialect/Transform/TuneExtension/TuneExtension.h new file mode 100644 index 0000000000000..1453d1754297f --- /dev/null +++ b/mlir/include/mlir/Dialect/Transform/TuneExtension/TuneExtension.h @@ -0,0 +1,21 @@ +//===- TuneExtension.h - Tune extension for Transform dialect ---*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_TRANSFORM_TUNEEXTENSION_TUNEEXTENSION_H +#define MLIR_DIALECT_TRANSFORM_TUNEEXTENSION_TUNEEXTENSION_H + +namespace mlir { +class DialectRegistry; + +namespace transform { +/// Registers the tune extension of the Transform dialect in the given registry. +void registerTuneExtension(DialectRegistry &dialectRegistry); +} // namespace transform +} // namespace mlir + +#endif // MLIR_DIALECT_TRANSFORM_TUNEEXTENSION_TUNEEXTENSION_H diff --git a/mlir/include/mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.h b/mlir/include/mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.h new file mode 100644 index 0000000000000..74e1d28ffac82 --- /dev/null +++ b/mlir/include/mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.h @@ -0,0 +1,19 @@ +//===- TuneExtensionOps.h - Tune ext. for Transform dialect -----*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_TRANSFORM_TUNEEXTENSION_TUNEEXTENSIONOPS_H +#define MLIR_DIALECT_TRANSFORM_TUNEEXTENSION_TUNEEXTENSIONOPS_H + +#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/OpDefinition.h" + +#define GET_OP_CLASSES +#include "mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.h.inc" + +#endif // MLIR_DIALECT_TRANSFORM_TUNEEXTENSION_TUNEEXTENSIONOPS_H diff --git a/mlir/include/mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.td b/mlir/include/mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.td new file mode 100644 index 0000000000000..d68d451afac40 --- /dev/null +++ b/mlir/include/mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.td @@ -0,0 +1,55 @@ +//===- TuneExtensionOps.td - Transform dialect operations --*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_TRANSFORM_TUNEEXTENSION_TUNEEXTENSIONOPS +#define MLIR_DIALECT_TRANSFORM_TUNEEXTENSION_TUNEEXTENSIONOPS + +include "mlir/Dialect/Transform/IR/TransformDialect.td" +include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td" +include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/IR/BuiltinAttributes.td" +include "mlir/IR/CommonAttrConstraints.td" + +def KnobOp : Op, + DeclareOpInterfaceMethods, +]> { + let summary = "Represents a tunable parameter with a set of options"; + + let description = [{ + Provides a representation for "tunables" within schedules. + + Each op represents a single tunable, which has a `name` and a set + of valid `options` described by an attribute. Without a specified + `selected` option, this op represents a non-deterministic choice + that has yet to be resolved -- as such, the interpreter runtime + semantics is to raise a failure. + + The non-deterministic choice is resolved through providing a + `selected` attribute. When provided, the interpreter runtime + semantics are to return the `selected` attribute as a param through + the op's result. + + ----- + + In case the `options` attribute is an `ArrayAttr`, the verifier + checks that the provided `selected` attribute occurs in `options`. + }]; + let cppNamespace = [{ mlir::transform::tune }]; + let hasVerifier = 1; + + let arguments = (ins Builtin_StringAttr:$name, + AnyAttr:$options, + OptionalAttr:$selected); + let results = (outs TransformParamTypeInterface:$result); + + let assemblyFormat = + "`<` $name `>` (`=` $selected^ `from`)? `options` `=` $options attr-dict `->` type(results)"; +} + +#endif // MLIR_DIALECT_TRANSFORM_TUNEEXTENSION_TUNEEXTENSIONOPS diff --git a/mlir/include/mlir/InitAllExtensions.h b/mlir/include/mlir/InitAllExtensions.h index f356b91b1b6c0..0f2d0e45008cc 100644 --- a/mlir/include/mlir/InitAllExtensions.h +++ b/mlir/include/mlir/InitAllExtensions.h @@ -52,6 +52,7 @@ #include "mlir/Dialect/Transform/IRDLExtension/IRDLExtension.h" #include "mlir/Dialect/Transform/LoopExtension/LoopExtension.h" #include "mlir/Dialect/Transform/PDLExtension/PDLExtension.h" +#include "mlir/Dialect/Transform/TuneExtension/TuneExtension.h" #include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.h" #include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/Dialect/GPU/GPUToLLVMIRTranslation.h" @@ -107,6 +108,7 @@ inline void registerAllExtensions(DialectRegistry ®istry) { transform::registerIRDLExtension(registry); transform::registerLoopExtension(registry); transform::registerPDLExtension(registry); + transform::registerTuneExtension(registry); vector::registerTransformDialectExtension(registry); arm_neon::registerTransformDialectExtension(registry); arm_sve::registerTransformDialectExtension(registry); diff --git a/mlir/lib/Dialect/Transform/CMakeLists.txt b/mlir/lib/Dialect/Transform/CMakeLists.txt index 0c0d5ebe0c212..6e628353258d6 100644 --- a/mlir/lib/Dialect/Transform/CMakeLists.txt +++ b/mlir/lib/Dialect/Transform/CMakeLists.txt @@ -5,4 +5,5 @@ add_subdirectory(IRDLExtension) add_subdirectory(LoopExtension) add_subdirectory(PDLExtension) add_subdirectory(Transforms) +add_subdirectory(TuneExtension) add_subdirectory(Utils) diff --git a/mlir/lib/Dialect/Transform/TuneExtension/CMakeLists.txt b/mlir/lib/Dialect/Transform/TuneExtension/CMakeLists.txt new file mode 100644 index 0000000000000..56b90a9a04edf --- /dev/null +++ b/mlir/lib/Dialect/Transform/TuneExtension/CMakeLists.txt @@ -0,0 +1,11 @@ +add_mlir_dialect_library(MLIRTransformTuneExtension + TuneExtension.cpp + TuneExtensionOps.cpp + + DEPENDS + MLIRTransformDialectTuneExtensionOpsIncGen + + LINK_LIBS PUBLIC + MLIRIR + MLIRTransformDialect +) diff --git a/mlir/lib/Dialect/Transform/TuneExtension/TuneExtension.cpp b/mlir/lib/Dialect/Transform/TuneExtension/TuneExtension.cpp new file mode 100644 index 0000000000000..e18f1e2748540 --- /dev/null +++ b/mlir/lib/Dialect/Transform/TuneExtension/TuneExtension.cpp @@ -0,0 +1,32 @@ +//===- TuneExtension.cpp - Tune extension for the Transform dialect -------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Transform/TuneExtension/TuneExtension.h" + +#include "mlir/Dialect/Transform/IR/TransformDialect.h" +#include "mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.h" +#include "mlir/IR/DialectRegistry.h" + +using namespace mlir; + +class TuneExtension + : public transform::TransformDialectExtension { +public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TuneExtension) + + void init() { + registerTransformOps< +#define GET_OP_LIST +#include "mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp.inc" + >(); + } +}; + +void mlir::transform::registerTuneExtension(DialectRegistry &dialectRegistry) { + dialectRegistry.addExtensions(); +} diff --git a/mlir/lib/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp b/mlir/lib/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp new file mode 100644 index 0000000000000..75c1cc53e2606 --- /dev/null +++ b/mlir/lib/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp @@ -0,0 +1,62 @@ +//===- TuneExtensionOps.cpp - Tune extension for the Transform dialect ----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Transform/IR/TransformOps.h" +#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/PatternMatch.h" +#include "llvm/Support/Debug.h" + +#include "mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.h" + +using namespace mlir; + +#define GET_OP_CLASSES +#include "mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp.inc" + +#define DEBUG_TYPE "transform-tune" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "] ") + +//===----------------------------------------------------------------------===// +// KnobOp +//===----------------------------------------------------------------------===// + +void transform::tune::KnobOp::getEffects( + SmallVectorImpl &effects) { + producesHandle(getOperation()->getOpResults(), effects); + onlyReadsPayload(effects); +} + +DiagnosedSilenceableFailure +transform::tune::KnobOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, + transform::TransformState &state) { + if (getSelected()) { + results.setParams(llvm::cast(getResult()), *getSelected()); + return DiagnosedSilenceableFailure::success(); + } + + return emitDefiniteFailure() + << "non-deterministic choice " << getName() + << " is only resolved through providing a `selected` attr"; +} + +LogicalResult transform::tune::KnobOp::verify() { + if (auto selected = getSelected()) { + if (auto optionsArray = dyn_cast(getOptions())) { + if (!llvm::is_contained(optionsArray, selected)) + return emitOpError("provided `selected` attribute is not an element of " + "`options` array of attributes"); + } else + LLVM_DEBUG(DBGS() << "cannot verify `selected` attribute " << selected + << " is an element of `options` attribute " + << getOptions()); + } + + return success(); +} diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt index b2daabb2a5957..7a0c95ebb8200 100644 --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -180,6 +180,15 @@ ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" DIALECT_NAME transform EXTENSION_NAME transform_debug_extension) +declare_mlir_dialect_extension_python_bindings( +ADD_TO_PARENT MLIRPythonSources.Dialects +ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" + TD_FILE dialects/TransformTuneExtensionOps.td + SOURCES + dialects/transform/tune.py + DIALECT_NAME transform + EXTENSION_NAME transform_tune_extension) + declare_mlir_dialect_python_bindings( ADD_TO_PARENT MLIRPythonSources.Dialects ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" diff --git a/mlir/python/mlir/dialects/TransformTuneExtensionOps.td b/mlir/python/mlir/dialects/TransformTuneExtensionOps.td new file mode 100644 index 0000000000000..ff3047592ab12 --- /dev/null +++ b/mlir/python/mlir/dialects/TransformTuneExtensionOps.td @@ -0,0 +1,19 @@ +//===-- TransformTuneExtensionOps.td - Binding entry point -*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Entry point of the generated Python bindings for the Tune extension of the +// Transform dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef PYTHON_BINDINGS_TRANSFORM_DEBUG_EXTENSION_OPS +#define PYTHON_BINDINGS_TRANSFORM_DEBUG_EXTENSION_OPS + +include "mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.td" + +#endif // PYTHON_BINDINGS_TRANSFORM_DEBUG_EXTENSION_OPS diff --git a/mlir/python/mlir/dialects/transform/tune.py b/mlir/python/mlir/dialects/transform/tune.py new file mode 100644 index 0000000000000..f63f88a382422 --- /dev/null +++ b/mlir/python/mlir/dialects/transform/tune.py @@ -0,0 +1,82 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from typing import Optional, Sequence + +from ...ir import ( + Type, + Attribute, + ArrayAttr, + StringAttr, + F64Type, + IntegerType, + IntegerAttr, + FloatAttr, + BoolAttr, +) +from .._transform_tune_extension_ops_gen import * +from .._transform_tune_extension_ops_gen import _Dialect + +try: + from .._ods_common import _cext as _ods_cext +except ImportError as e: + raise RuntimeError("Error loading imports from extension module") from e + +from typing import Union + + +@_ods_cext.register_operation(_Dialect, replace=True) +class KnobOp(KnobOp): + def __init__( + self, + result: Type, # !transform.any_param or !transform.param + name: Union[StringAttr, str], + options: Union[ + ArrayAttr, Sequence[Union[Attribute, bool, int, float, str]], Attribute + ], + *, + selected: Optional[Attribute] = None, + loc=None, + ip=None, + ): + if isinstance(name, str): + name = StringAttr.get(name) + + def map_to_attr(value): + if isinstance(value, bool): + return BoolAttr.get(value) + if isinstance(value, int): + return IntegerAttr.get(IntegerType.get_signless(64), value) + if isinstance(value, float): + return FloatAttr.get(F64Type.get(), value) + if isinstance(value, str): + return StringAttr.get(value) + assert isinstance(value, Attribute) + return value + + if isinstance(options, Sequence) and not isinstance(options, ArrayAttr): + options = ArrayAttr.get([map_to_attr(opt) for opt in options]) + + super().__init__( + result, + name, + options, + selected=selected and map_to_attr(selected), + loc=loc, + ip=ip, + ) + + +def knob( + result: Type, # !transform.any_param or !transform.param + name: Union[StringAttr, str], + options: Union[ + ArrayAttr, Sequence[Union[Attribute, bool, int, float, str]], Attribute + ], + *, + selected: Optional[Attribute] = None, + loc=None, + ip=None, +): + return KnobOp(result, name, options, selected=selected, loc=loc, ip=ip) diff --git a/mlir/test/Dialect/Transform/test-tune-extension-invalid.mlir b/mlir/test/Dialect/Transform/test-tune-extension-invalid.mlir new file mode 100644 index 0000000000000..2e5f433abeb71 --- /dev/null +++ b/mlir/test/Dialect/Transform/test-tune-extension-invalid.mlir @@ -0,0 +1,21 @@ +// RUN: mlir-opt %s --transform-interpreter --split-input-file --verify-diagnostics + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + // expected-error@below {{provided `selected` attribute is not an element of `options` array of attributes}} + %heads_or_tails = transform.tune.knob<"coin"> = 1 from options = [true, false] -> !transform.any_param + transform.yield + } +} + +// ----- + +func.func private @f() + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + // expected-error@below {{non-deterministic choice "coin" is only resolved through providing a `selected` attr}} + %heads_or_tails = transform.tune.knob<"coin"> options = [true, false] -> !transform.any_param + transform.yield + } +} diff --git a/mlir/test/Dialect/Transform/test-tune-extension.mlir b/mlir/test/Dialect/Transform/test-tune-extension.mlir new file mode 100644 index 0000000000000..0a253c6d5f837 --- /dev/null +++ b/mlir/test/Dialect/Transform/test-tune-extension.mlir @@ -0,0 +1,61 @@ +// RUN: mlir-opt %s --transform-interpreter --split-input-file \ +// RUN: --verify-diagnostics | FileCheck %s + +// CHECK-LABEL: @schedule_with_nondet_knobs +module attributes {transform.with_named_sequence} { + transform.named_sequence @schedule_with_nondet_knobs(%arg0: !transform.any_op {transform.readonly}) { + // CHECK: %[[HEADS_OR_TAILS:.*]] = transform.tune.knob<"coin"> options = [true, false] -> !transform.any_param + %heads_or_tails = transform.tune.knob<"coin"> options = [true, false] -> !transform.any_param + // CHECK: transform.tune.knob<"animal"> options = ["cat", "dog", unit] -> !transform.any_param + %chosen_category = transform.tune.knob<"animal"> options = ["cat", "dog", unit] -> !transform.any_param + // CHECK: transform.tune.knob<"tile_size"> options = [2, 4, 8, 16, 24, 32] -> !transform.any_param + %chosen_tile_size = transform.tune.knob<"tile_size"> options = [2, 4, 8, 16, 24, 32] -> !transform.any_param + // CHECK: transform.tune.knob<"magic_value"> options = [2.000000e+00 : f32, 2.250000e+00 : f32, 2.500000e+00 : f32, 2.750000e+00 : f32, 3.000000e+00 : f32] -> !transform.any_param + %chosen_constant = transform.tune.knob<"magic_value"> options = [2.0 : f32, 2.25 : f32, 2.5 : f32, 2.75 : f32, 3.0 : f32] -> !transform.any_param + // CHECK: transform.debug.emit_param_as_remark %[[HEADS_OR_TAILS]] + transform.debug.emit_param_as_remark %heads_or_tails : !transform.any_param + transform.yield + } + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + // Dummy sequence to appease -transform-interpreter invocation + transform.yield + } +} + +// ----- + +// Schedule where non-determinism on knobs has been resolved by selecting a valid option. + +// CHECK-LABEL: payload_for_schedule_with_selected_knobs +func.func private @payload_for_schedule_with_selected_knobs() + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + // CHECK: %[[HEADS_OR_TAILS:.*]] = transform.tune.knob<"coin"> = true from options = [true, false] -> !transform.any_param + %heads_or_tails = transform.tune.knob<"coin"> = true from options = [true, false] -> !transform.any_param + // expected-remark@below {{true}} + transform.debug.emit_param_as_remark %heads_or_tails : !transform.any_param + + // CHECK: transform.tune.knob<"animal"> = "dog" from options = ["cat", "dog", unit] -> !transform.any_param + %chosen_category = transform.tune.knob<"animal"> = "dog" from options = ["cat", "dog", unit] -> !transform.any_param + // CHECK: transform.tune.knob<"tile_size"> = 8 : i64 from options = [2, 4, 8, 16, 24, 32] -> !transform.any_param + %chosen_tile_size = transform.tune.knob<"tile_size"> = 8 from options = [2, 4, 8, 16, 24, 32] -> !transform.any_param + // CHECK: transform.tune.knob<"magic_value"> = 2.500000e+00 : f32 from options = [2.000000e+00 : f32, 2.250000e+00 : f32, 2.500000e+00 : f32, 2.750000e+00 : f32, 3.000000e+00 : f32] -> !transform.any_param + %chosen_constant = transform.tune.knob<"magic_value"> = 2.5 : f32 from options = [2.0 : f32, 2.25 : f32, 2.5 : f32, 2.75 : f32, 3.0 : f32] -> !transform.any_param + transform.yield + } +} + +// ----- + +// CHECK: #[[AFFINE_SET:.*]] = affine_set<(d0) : (d0 - 2 >= 0)> +// CHECK: payload_for_schedule_where_selected_knob_being_a_member_of_options_is_unverified +func.func private @payload_for_schedule_where_selected_knob_being_a_member_of_options_is_unverified() + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + // CHECK: transform.tune.knob<"bounded"> = 4242 : i64 from options = #[[AFFINE_SET]] -> !transform.any_param + %value_in_half_range = transform.tune.knob<"bounded"> = 4242 from options = affine_set<(d0) : (d0 - 2 >= 0)> -> !transform.any_param + transform.yield + } +} diff --git a/mlir/test/python/dialects/transform_tune_ext.py b/mlir/test/python/dialects/transform_tune_ext.py new file mode 100644 index 0000000000000..dfb93594bca52 --- /dev/null +++ b/mlir/test/python/dialects/transform_tune_ext.py @@ -0,0 +1,72 @@ +# RUN: %PYTHON %s | FileCheck %s + +from mlir.ir import * +from mlir.dialects import transform +from mlir.dialects.transform import tune, debug + + +def run(f): + print("\nTEST:", f.__name__) + with Context(), Location.unknown(): + module = Module.create() + with InsertionPoint(module.body): + sequence = transform.SequenceOp( + transform.FailurePropagationMode.Propagate, + [], + transform.AnyOpType.get(), + ) + with InsertionPoint(sequence.body): + f(sequence.bodyTarget) + transform.YieldOp() + print(module) + return f + + +# CHECK-LABEL: TEST: testKnobOp +@run +def testKnobOp(target): + any_param = transform.AnyParamType.get() + + # CHECK: %[[HEADS_OR_TAILS:.*]] = transform.tune.knob<"coin"> options = [true, false] -> !transform.any_param + heads_or_tails = tune.KnobOp( + result=any_param, name=StringAttr.get("coin"), options=[True, False] + ) + # CHECK: transform.tune.knob<"animal"> options = ["cat", "dog", unit] -> !transform.any_param + tune.KnobOp(any_param, name="animal", options=["cat", "dog", UnitAttr.get()]) + # CHECK: transform.tune.knob<"tile_size"> options = [2, 4, 8, 16, 24, 32] -> !transform.any_param + tune.KnobOp(any_param, "tile_size", [2, 4, 8, 16, 24, 32]) + # CHECK: transform.tune.knob<"magic_value"> options = [2.000000e+00, 2.250000e+00, 2.500000e+00, 2.750000e+00, 3.000000e+00] -> !transform.any_param + tune.knob(any_param, "magic_value", [2.0, 2.25, 2.5, 2.75, 3.0]) + + # CHECK: transform.debug.emit_param_as_remark %[[HEADS_OR_TAILS]] + debug.emit_param_as_remark(heads_or_tails) + + # CHECK: %[[HEADS:.*]] = transform.tune.knob<"coin"> = true from options = [true, false] -> !transform.any_param + heads = tune.KnobOp(any_param, "coin", options=[True, False], selected=True) + # CHECK: transform.tune.knob<"animal"> = "dog" from options = ["cat", "dog", unit] -> !transform.any_param + tune.KnobOp( + any_param, name="animal", options=["cat", "dog", UnitAttr.get()], selected="dog" + ) + # CHECK: transform.tune.knob<"tile_size"> = 8 : i64 from options = [2, 4, 8, 16, 24, 32] -> !transform.any_param + tune.KnobOp(any_param, "tile_size", [2, 4, 8, 16, 24, 32], selected=8) + # CHECK: transform.tune.knob<"magic_value"> = 2.500000e+00 : f64 from options = [2.000000e+00, 2.250000e+00, 2.500000e+00, 2.750000e+00, 3.000000e+00] -> !transform.any_param + tune.knob(any_param, "magic_value", [2.0, 2.25, 2.5, 2.75, 3.0], selected=2.5) + + # CHECK: transform.debug.emit_param_as_remark %[[HEADS]] + debug.emit_param_as_remark(heads) + + # CHECK: transform.tune.knob<"range_as_a_dict"> = 4 : i64 from options = {start = 2 : i64, step = 2 : i64, stop = 16 : i64} -> !transform.any_param + # NB: Membership of `selected` in non-ArrayAttr `options` is _not_ verified. + i64 = IntegerType.get_signless(64) + tune.knob( + any_param, + "range_as_a_dict", + DictAttr.get( + { + "start": IntegerAttr.get(i64, 2), + "stop": IntegerAttr.get(i64, 16), + "step": IntegerAttr.get(i64, 2), + } + ), + selected=4, + )