-
Notifications
You must be signed in to change notification settings - Fork 14.4k
[MLIR][Transform] Introduce transform.tune.knob
op
#146732
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
25ab701
[MLIR][Transform] Introduce `transform.tune.select` op
rolfmorel 0205620
Move to new name, KnobOp, and new syntax and add test cases and docs
rolfmorel 8e0a4d4
Fix up header includes
rolfmorel 91cc2f3
\n
rolfmorel 2f24acf
Python formatting
rolfmorel de0d81e
Demonstrate options-spec as a dict attr from Python
rolfmorel 8afb10a
Minor fixes
rolfmorel File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
6 changes: 6 additions & 0 deletions
6
mlir/include/mlir/Dialect/Transform/TuneExtension/CMakeLists.txt
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
set(LLVM_TARGET_DEFINITIONS TuneExtensionOps.td) | ||
mlir_tablegen(TuneExtensionOps.h.inc -gen-op-decls) | ||
mlir_tablegen(TuneExtensionOps.cpp.inc -gen-op-defs) | ||
add_public_tablegen_target(MLIRTransformDialectTuneExtensionOpsIncGen) | ||
|
||
add_mlir_doc(TuneExtensionOps TuneExtensionOps Dialects/ -gen-op-doc) |
21 changes: 21 additions & 0 deletions
21
mlir/include/mlir/Dialect/Transform/TuneExtension/TuneExtension.h
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
//===- TuneExtension.h - Tune extension for Transform dialect ---*- C++ -*-===// | ||
// | ||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#ifndef MLIR_DIALECT_TRANSFORM_TUNEEXTENSION_TUNEEXTENSION_H | ||
#define MLIR_DIALECT_TRANSFORM_TUNEEXTENSION_TUNEEXTENSION_H | ||
|
||
namespace mlir { | ||
class DialectRegistry; | ||
|
||
namespace transform { | ||
/// Registers the tune extension of the Transform dialect in the given registry. | ||
void registerTuneExtension(DialectRegistry &dialectRegistry); | ||
} // namespace transform | ||
} // namespace mlir | ||
|
||
#endif // MLIR_DIALECT_TRANSFORM_TUNEEXTENSION_TUNEEXTENSION_H |
19 changes: 19 additions & 0 deletions
19
mlir/include/mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.h
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
//===- TuneExtensionOps.h - Tune ext. for Transform dialect -----*- C++ -*-===// | ||
// | ||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#ifndef MLIR_DIALECT_TRANSFORM_TUNEEXTENSION_TUNEEXTENSIONOPS_H | ||
#define MLIR_DIALECT_TRANSFORM_TUNEEXTENSION_TUNEEXTENSIONOPS_H | ||
|
||
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" | ||
#include "mlir/IR/BuiltinAttributes.h" | ||
#include "mlir/IR/OpDefinition.h" | ||
|
||
#define GET_OP_CLASSES | ||
#include "mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.h.inc" | ||
|
||
#endif // MLIR_DIALECT_TRANSFORM_TUNEEXTENSION_TUNEEXTENSIONOPS_H |
55 changes: 55 additions & 0 deletions
55
mlir/include/mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.td
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
//===- TuneExtensionOps.td - Transform dialect operations --*- tablegen -*-===// | ||
// | ||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#ifndef MLIR_DIALECT_TRANSFORM_TUNEEXTENSION_TUNEEXTENSIONOPS | ||
#define MLIR_DIALECT_TRANSFORM_TUNEEXTENSION_TUNEEXTENSIONOPS | ||
|
||
include "mlir/Dialect/Transform/IR/TransformDialect.td" | ||
include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td" | ||
include "mlir/Interfaces/SideEffectInterfaces.td" | ||
include "mlir/IR/BuiltinAttributes.td" | ||
include "mlir/IR/CommonAttrConstraints.td" | ||
|
||
def KnobOp : Op<Transform_Dialect, "tune.knob", [ | ||
DeclareOpInterfaceMethods<TransformOpInterface>, | ||
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>, | ||
]> { | ||
let summary = "Represents a tunable parameter with a set of options"; | ||
|
||
let description = [{ | ||
Provides a representation for "tunables" within schedules. | ||
|
||
Each op represents a single tunable, which has a `name` and a set | ||
of valid `options` described by an attribute. Without a specified | ||
`selected` option, this op represents a non-deterministic choice | ||
that has yet to be resolved -- as such, the interpreter runtime | ||
semantics is to raise a failure. | ||
|
||
The non-deterministic choice is resolved through providing a | ||
`selected` attribute. When provided, the interpreter runtime | ||
semantics are to return the `selected` attribute as a param through | ||
the op's result. | ||
|
||
----- | ||
|
||
In case the `options` attribute is an `ArrayAttr`, the verifier | ||
checks that the provided `selected` attribute occurs in `options`. | ||
}]; | ||
let cppNamespace = [{ mlir::transform::tune }]; | ||
let hasVerifier = 1; | ||
|
||
let arguments = (ins Builtin_StringAttr:$name, | ||
AnyAttr:$options, | ||
OptionalAttr<AnyAttr>:$selected); | ||
let results = (outs TransformParamTypeInterface:$result); | ||
|
||
let assemblyFormat = | ||
"`<` $name `>` (`=` $selected^ `from`)? `options` `=` $options attr-dict `->` type(results)"; | ||
} | ||
|
||
#endif // MLIR_DIALECT_TRANSFORM_TUNEEXTENSION_TUNEEXTENSIONOPS |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
add_mlir_dialect_library(MLIRTransformTuneExtension | ||
TuneExtension.cpp | ||
TuneExtensionOps.cpp | ||
|
||
DEPENDS | ||
MLIRTransformDialectTuneExtensionOpsIncGen | ||
|
||
LINK_LIBS PUBLIC | ||
MLIRIR | ||
MLIRTransformDialect | ||
) |
32 changes: 32 additions & 0 deletions
32
mlir/lib/Dialect/Transform/TuneExtension/TuneExtension.cpp
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
//===- TuneExtension.cpp - Tune extension for the Transform dialect -------===// | ||
// | ||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#include "mlir/Dialect/Transform/TuneExtension/TuneExtension.h" | ||
|
||
#include "mlir/Dialect/Transform/IR/TransformDialect.h" | ||
#include "mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.h" | ||
#include "mlir/IR/DialectRegistry.h" | ||
|
||
using namespace mlir; | ||
|
||
class TuneExtension | ||
: public transform::TransformDialectExtension<TuneExtension> { | ||
public: | ||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TuneExtension) | ||
|
||
void init() { | ||
registerTransformOps< | ||
#define GET_OP_LIST | ||
#include "mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp.inc" | ||
>(); | ||
} | ||
}; | ||
|
||
void mlir::transform::registerTuneExtension(DialectRegistry &dialectRegistry) { | ||
dialectRegistry.addExtensions<TuneExtension>(); | ||
} |
62 changes: 62 additions & 0 deletions
62
mlir/lib/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
//===- TuneExtensionOps.cpp - Tune extension for the Transform dialect ----===// | ||
// | ||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#include "mlir/Dialect/Transform/IR/TransformOps.h" | ||
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" | ||
#include "mlir/IR/OpImplementation.h" | ||
#include "mlir/IR/PatternMatch.h" | ||
#include "llvm/Support/Debug.h" | ||
|
||
#include "mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.h" | ||
|
||
using namespace mlir; | ||
|
||
#define GET_OP_CLASSES | ||
#include "mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp.inc" | ||
|
||
#define DEBUG_TYPE "transform-tune" | ||
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "] ") | ||
|
||
//===----------------------------------------------------------------------===// | ||
// KnobOp | ||
//===----------------------------------------------------------------------===// | ||
|
||
void transform::tune::KnobOp::getEffects( | ||
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { | ||
producesHandle(getOperation()->getOpResults(), effects); | ||
onlyReadsPayload(effects); | ||
} | ||
|
||
DiagnosedSilenceableFailure | ||
transform::tune::KnobOp::apply(transform::TransformRewriter &rewriter, | ||
transform::TransformResults &results, | ||
transform::TransformState &state) { | ||
if (getSelected()) { | ||
results.setParams(llvm::cast<OpResult>(getResult()), *getSelected()); | ||
return DiagnosedSilenceableFailure::success(); | ||
} | ||
|
||
return emitDefiniteFailure() | ||
<< "non-deterministic choice " << getName() | ||
<< " is only resolved through providing a `selected` attr"; | ||
} | ||
|
||
LogicalResult transform::tune::KnobOp::verify() { | ||
if (auto selected = getSelected()) { | ||
if (auto optionsArray = dyn_cast<ArrayAttr>(getOptions())) { | ||
if (!llvm::is_contained(optionsArray, selected)) | ||
return emitOpError("provided `selected` attribute is not an element of " | ||
"`options` array of attributes"); | ||
} else | ||
LLVM_DEBUG(DBGS() << "cannot verify `selected` attribute " << selected | ||
<< " is an element of `options` attribute " | ||
<< getOptions()); | ||
} | ||
|
||
return success(); | ||
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
//===-- TransformTuneExtensionOps.td - Binding entry point -*- tablegen -*-===// | ||
// | ||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// | ||
//===----------------------------------------------------------------------===// | ||
// | ||
// Entry point of the generated Python bindings for the Tune extension of the | ||
// Transform dialect. | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#ifndef PYTHON_BINDINGS_TRANSFORM_DEBUG_EXTENSION_OPS | ||
#define PYTHON_BINDINGS_TRANSFORM_DEBUG_EXTENSION_OPS | ||
|
||
include "mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.td" | ||
|
||
#endif // PYTHON_BINDINGS_TRANSFORM_DEBUG_EXTENSION_OPS |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
# See https://llvm.org/LICENSE.txt for license information. | ||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
from typing import Optional, Sequence | ||
|
||
from ...ir import ( | ||
Type, | ||
Attribute, | ||
ArrayAttr, | ||
StringAttr, | ||
F64Type, | ||
IntegerType, | ||
IntegerAttr, | ||
FloatAttr, | ||
BoolAttr, | ||
) | ||
from .._transform_tune_extension_ops_gen import * | ||
from .._transform_tune_extension_ops_gen import _Dialect | ||
|
||
try: | ||
from .._ods_common import _cext as _ods_cext | ||
except ImportError as e: | ||
raise RuntimeError("Error loading imports from extension module") from e | ||
|
||
from typing import Union | ||
|
||
|
||
@_ods_cext.register_operation(_Dialect, replace=True) | ||
class KnobOp(KnobOp): | ||
def __init__( | ||
self, | ||
result: Type, # !transform.any_param or !transform.param<Type> | ||
name: Union[StringAttr, str], | ||
options: Union[ | ||
ArrayAttr, Sequence[Union[Attribute, bool, int, float, str]], Attribute | ||
], | ||
*, | ||
selected: Optional[Attribute] = None, | ||
loc=None, | ||
ip=None, | ||
): | ||
if isinstance(name, str): | ||
name = StringAttr.get(name) | ||
|
||
def map_to_attr(value): | ||
if isinstance(value, bool): | ||
return BoolAttr.get(value) | ||
if isinstance(value, int): | ||
return IntegerAttr.get(IntegerType.get_signless(64), value) | ||
if isinstance(value, float): | ||
return FloatAttr.get(F64Type.get(), value) | ||
if isinstance(value, str): | ||
return StringAttr.get(value) | ||
assert isinstance(value, Attribute) | ||
return value | ||
|
||
if isinstance(options, Sequence) and not isinstance(options, ArrayAttr): | ||
options = ArrayAttr.get([map_to_attr(opt) for opt in options]) | ||
|
||
super().__init__( | ||
result, | ||
name, | ||
options, | ||
selected=selected and map_to_attr(selected), | ||
loc=loc, | ||
ip=ip, | ||
) | ||
|
||
|
||
def knob( | ||
result: Type, # !transform.any_param or !transform.param<Type> | ||
name: Union[StringAttr, str], | ||
options: Union[ | ||
ArrayAttr, Sequence[Union[Attribute, bool, int, float, str]], Attribute | ||
], | ||
*, | ||
selected: Optional[Attribute] = None, | ||
loc=None, | ||
ip=None, | ||
): | ||
return KnobOp(result, name, options, selected=selected, loc=loc, ip=ip) |
21 changes: 21 additions & 0 deletions
21
mlir/test/Dialect/Transform/test-tune-extension-invalid.mlir
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
// RUN: mlir-opt %s --transform-interpreter --split-input-file --verify-diagnostics | ||
|
||
module attributes {transform.with_named_sequence} { | ||
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { | ||
// expected-error@below {{provided `selected` attribute is not an element of `options` array of attributes}} | ||
%heads_or_tails = transform.tune.knob<"coin"> = 1 from options = [true, false] -> !transform.any_param | ||
transform.yield | ||
} | ||
} | ||
|
||
// ----- | ||
|
||
func.func private @f() | ||
|
||
module attributes {transform.with_named_sequence} { | ||
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { | ||
// expected-error@below {{non-deterministic choice "coin" is only resolved through providing a `selected` attr}} | ||
%heads_or_tails = transform.tune.knob<"coin"> options = [true, false] -> !transform.any_param | ||
transform.yield | ||
} | ||
} |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.