Skip to content

[mlir][SymbolOpInterface] Easier visibility overriding #151036

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

maerhart
Copy link
Member

When overriding 'getVisibility and/or 'setVisibility' the interface methods calling them do not pick up the overriden version. Instead it is necessary to override all the other methods as well. This adjusts these interface methods to use the overriden version when available.

When overriding 'getVisibility and/or 'setVisibility' the interface methods calling them do not pick up the overriden version. Instead it is necessary to override all the other methods as well. This adjusts these interface methods to use the overriden version when available.
@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir mlir:ods labels Jul 28, 2025
@llvmbot
Copy link
Member

llvmbot commented Jul 28, 2025

@llvm/pr-subscribers-mlir-core

@llvm/pr-subscribers-mlir-ods

Author: Martin Erhart (maerhart)

Changes

When overriding 'getVisibility and/or 'setVisibility' the interface methods calling them do not pick up the overriden version. Instead it is necessary to override all the other methods as well. This adjusts these interface methods to use the overriden version when available.


Full diff: https://github.com/llvm/llvm-project/pull/151036.diff

4 Files Affected:

  • (modified) mlir/include/mlir/IR/SymbolInterfaces.td (+7-7)
  • (modified) mlir/test/lib/Dialect/Test/TestOpDefs.cpp (+26)
  • (modified) mlir/test/lib/Dialect/Test/TestOps.td (+7)
  • (modified) mlir/unittests/IR/SymbolTableTest.cpp (+34)
diff --git a/mlir/include/mlir/IR/SymbolInterfaces.td b/mlir/include/mlir/IR/SymbolInterfaces.td
index a8b04d0453110..bbfa30815bd4a 100644
--- a/mlir/include/mlir/IR/SymbolInterfaces.td
+++ b/mlir/include/mlir/IR/SymbolInterfaces.td
@@ -55,19 +55,19 @@ def Symbol : OpInterface<"SymbolOpInterface"> {
     InterfaceMethod<"Returns true if this symbol has nested visibility.",
       "bool", "isNested", (ins),  [{}],
       /*defaultImplementation=*/[{
-        return getVisibility() == mlir::SymbolTable::Visibility::Nested;
+        return $_op.getVisibility() == mlir::SymbolTable::Visibility::Nested;
       }]
     >,
     InterfaceMethod<"Returns true if this symbol has private visibility.",
       "bool", "isPrivate", (ins),  [{}],
       /*defaultImplementation=*/[{
-        return getVisibility() == mlir::SymbolTable::Visibility::Private;
+        return $_op.getVisibility() == mlir::SymbolTable::Visibility::Private;
       }]
     >,
     InterfaceMethod<"Returns true if this symbol has public visibility.",
       "bool", "isPublic", (ins),  [{}],
       /*defaultImplementation=*/[{
-        return getVisibility() == mlir::SymbolTable::Visibility::Public;
+        return $_op.getVisibility() == mlir::SymbolTable::Visibility::Public;
       }]
     >,
     InterfaceMethod<"Sets the visibility of this symbol.",
@@ -79,19 +79,19 @@ def Symbol : OpInterface<"SymbolOpInterface"> {
     InterfaceMethod<"Sets the visibility of this symbol to be nested.",
       "void", "setNested", (ins),  [{}],
       /*defaultImplementation=*/[{
-        setVisibility(mlir::SymbolTable::Visibility::Nested);
+        $_op.setVisibility(mlir::SymbolTable::Visibility::Nested);
       }]
     >,
     InterfaceMethod<"Sets the visibility of this symbol to be private.",
       "void", "setPrivate", (ins),  [{}],
       /*defaultImplementation=*/[{
-        setVisibility(mlir::SymbolTable::Visibility::Private);
+        $_op.setVisibility(mlir::SymbolTable::Visibility::Private);
       }]
     >,
     InterfaceMethod<"Sets the visibility of this symbol to be public.",
       "void", "setPublic", (ins),  [{}],
       /*defaultImplementation=*/[{
-        setVisibility(mlir::SymbolTable::Visibility::Public);
+        $_op.setVisibility(mlir::SymbolTable::Visibility::Public);
       }]
     >,
     InterfaceMethod<[{
@@ -144,7 +144,7 @@ def Symbol : OpInterface<"SymbolOpInterface"> {
         // By default, base this on the visibility alone. A symbol can be
         // discarded as long as it is not public. Only public symbols may be
         // visible from outside of the IR.
-        return getVisibility() != ::mlir::SymbolTable::Visibility::Public;
+        return $_op.getVisibility() != ::mlir::SymbolTable::Visibility::Public;
       }]
     >,
     InterfaceMethod<[{
diff --git a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
index f79e2cfbcb259..53055fea215b7 100644
--- a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
+++ b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
@@ -17,6 +17,32 @@
 using namespace mlir;
 using namespace test;
 
+//===----------------------------------------------------------------------===//
+// OverridenSymbolVisibilityOp
+//===----------------------------------------------------------------------===//
+
+SymbolTable::Visibility OverriddenSymbolVisibilityOp::getVisibility() {
+  return SymbolTable::Visibility::Private;
+}
+
+static StringLiteral getVisibilityString(SymbolTable::Visibility visibility) {
+  switch (visibility) {
+  case SymbolTable::Visibility::Private:
+    return "private";
+  case SymbolTable::Visibility::Nested:
+    return "nested";
+  case SymbolTable::Visibility::Public:
+    return "public";
+  }
+}
+
+void OverriddenSymbolVisibilityOp::setVisibility(
+    SymbolTable::Visibility visibility) {
+
+  emitOpError("cannot change visibility of symbol to ")
+      << getVisibilityString(visibility);
+}
+
 //===----------------------------------------------------------------------===//
 // TestBranchOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index a7c6cd60a0ee4..927c98225bf2f 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -119,6 +119,13 @@ def SymbolOp : TEST_Op<"symbol", [NoMemoryEffect, Symbol]> {
                        OptionalAttr<StrAttr>:$sym_visibility);
 }
 
+def OverriddenSymbolVisibilityOp : TEST_Op<"overridden_symbol_visibility", [
+  DeclareOpInterfaceMethods<Symbol, ["getVisibility", "setVisibility"]>,
+]> {
+  let summary =  "operation overridden symbol visibility accessors";
+  let arguments = (ins StrAttr:$sym_name);
+}
+
 def SymbolScopeOp : TEST_Op<"symbol_scope",
     [SymbolTable, SingleBlockImplicitTerminator<"TerminatorOp">]> {
   let summary =  "operation which defines a new symbol table";
diff --git a/mlir/unittests/IR/SymbolTableTest.cpp b/mlir/unittests/IR/SymbolTableTest.cpp
index cfc3fe0cb1c5b..4b3545bce1952 100644
--- a/mlir/unittests/IR/SymbolTableTest.cpp
+++ b/mlir/unittests/IR/SymbolTableTest.cpp
@@ -132,4 +132,38 @@ TEST_F(ReplaceAllSymbolUsesTest, StringAttrInFuncOp) {
   });
 }
 
+TEST(SymbolOpInterface, Visibility) {
+  DialectRegistry registry;
+  ::test::registerTestDialect(registry);
+  MLIRContext context(registry);
+
+  constexpr static StringLiteral kInput = R"MLIR(
+    "test.overridden_symbol_visibility"() {sym_name = "symbol_name"} : () -> ()
+  )MLIR";
+  OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(kInput, &context);
+  auto symOp = cast<SymbolOpInterface>(module->getBody()->front());
+
+  ASSERT_TRUE(symOp.isPrivate());
+  ASSERT_FALSE(symOp.isPublic());
+  ASSERT_FALSE(symOp.isNested());
+  ASSERT_TRUE(symOp.canDiscardOnUseEmpty());
+
+  std::string diagStr;
+  context.getDiagEngine().registerHandler(
+      [&](Diagnostic &diag) { diagStr += diag.str(); });
+
+  std::string expectedDiag;
+  symOp.setPublic();
+  expectedDiag += "'test.overridden_symbol_visibility' op cannot change "
+                  "visibility of symbol to public";
+  symOp.setNested();
+  expectedDiag += "'test.overridden_symbol_visibility' op cannot change "
+                  "visibility of symbol to nested";
+  symOp.setPrivate();
+  expectedDiag += "'test.overridden_symbol_visibility' op cannot change "
+                  "visibility of symbol to private";
+
+  ASSERT_EQ(diagStr, expectedDiag);
+}
+
 } // namespace

@llvmbot
Copy link
Member

llvmbot commented Jul 28, 2025

@llvm/pr-subscribers-mlir

Author: Martin Erhart (maerhart)

Changes

When overriding 'getVisibility and/or 'setVisibility' the interface methods calling them do not pick up the overriden version. Instead it is necessary to override all the other methods as well. This adjusts these interface methods to use the overriden version when available.


Full diff: https://github.com/llvm/llvm-project/pull/151036.diff

4 Files Affected:

  • (modified) mlir/include/mlir/IR/SymbolInterfaces.td (+7-7)
  • (modified) mlir/test/lib/Dialect/Test/TestOpDefs.cpp (+26)
  • (modified) mlir/test/lib/Dialect/Test/TestOps.td (+7)
  • (modified) mlir/unittests/IR/SymbolTableTest.cpp (+34)
diff --git a/mlir/include/mlir/IR/SymbolInterfaces.td b/mlir/include/mlir/IR/SymbolInterfaces.td
index a8b04d0453110..bbfa30815bd4a 100644
--- a/mlir/include/mlir/IR/SymbolInterfaces.td
+++ b/mlir/include/mlir/IR/SymbolInterfaces.td
@@ -55,19 +55,19 @@ def Symbol : OpInterface<"SymbolOpInterface"> {
     InterfaceMethod<"Returns true if this symbol has nested visibility.",
       "bool", "isNested", (ins),  [{}],
       /*defaultImplementation=*/[{
-        return getVisibility() == mlir::SymbolTable::Visibility::Nested;
+        return $_op.getVisibility() == mlir::SymbolTable::Visibility::Nested;
       }]
     >,
     InterfaceMethod<"Returns true if this symbol has private visibility.",
       "bool", "isPrivate", (ins),  [{}],
       /*defaultImplementation=*/[{
-        return getVisibility() == mlir::SymbolTable::Visibility::Private;
+        return $_op.getVisibility() == mlir::SymbolTable::Visibility::Private;
       }]
     >,
     InterfaceMethod<"Returns true if this symbol has public visibility.",
       "bool", "isPublic", (ins),  [{}],
       /*defaultImplementation=*/[{
-        return getVisibility() == mlir::SymbolTable::Visibility::Public;
+        return $_op.getVisibility() == mlir::SymbolTable::Visibility::Public;
       }]
     >,
     InterfaceMethod<"Sets the visibility of this symbol.",
@@ -79,19 +79,19 @@ def Symbol : OpInterface<"SymbolOpInterface"> {
     InterfaceMethod<"Sets the visibility of this symbol to be nested.",
       "void", "setNested", (ins),  [{}],
       /*defaultImplementation=*/[{
-        setVisibility(mlir::SymbolTable::Visibility::Nested);
+        $_op.setVisibility(mlir::SymbolTable::Visibility::Nested);
       }]
     >,
     InterfaceMethod<"Sets the visibility of this symbol to be private.",
       "void", "setPrivate", (ins),  [{}],
       /*defaultImplementation=*/[{
-        setVisibility(mlir::SymbolTable::Visibility::Private);
+        $_op.setVisibility(mlir::SymbolTable::Visibility::Private);
       }]
     >,
     InterfaceMethod<"Sets the visibility of this symbol to be public.",
       "void", "setPublic", (ins),  [{}],
       /*defaultImplementation=*/[{
-        setVisibility(mlir::SymbolTable::Visibility::Public);
+        $_op.setVisibility(mlir::SymbolTable::Visibility::Public);
       }]
     >,
     InterfaceMethod<[{
@@ -144,7 +144,7 @@ def Symbol : OpInterface<"SymbolOpInterface"> {
         // By default, base this on the visibility alone. A symbol can be
         // discarded as long as it is not public. Only public symbols may be
         // visible from outside of the IR.
-        return getVisibility() != ::mlir::SymbolTable::Visibility::Public;
+        return $_op.getVisibility() != ::mlir::SymbolTable::Visibility::Public;
       }]
     >,
     InterfaceMethod<[{
diff --git a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
index f79e2cfbcb259..53055fea215b7 100644
--- a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
+++ b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
@@ -17,6 +17,32 @@
 using namespace mlir;
 using namespace test;
 
+//===----------------------------------------------------------------------===//
+// OverridenSymbolVisibilityOp
+//===----------------------------------------------------------------------===//
+
+SymbolTable::Visibility OverriddenSymbolVisibilityOp::getVisibility() {
+  return SymbolTable::Visibility::Private;
+}
+
+static StringLiteral getVisibilityString(SymbolTable::Visibility visibility) {
+  switch (visibility) {
+  case SymbolTable::Visibility::Private:
+    return "private";
+  case SymbolTable::Visibility::Nested:
+    return "nested";
+  case SymbolTable::Visibility::Public:
+    return "public";
+  }
+}
+
+void OverriddenSymbolVisibilityOp::setVisibility(
+    SymbolTable::Visibility visibility) {
+
+  emitOpError("cannot change visibility of symbol to ")
+      << getVisibilityString(visibility);
+}
+
 //===----------------------------------------------------------------------===//
 // TestBranchOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index a7c6cd60a0ee4..927c98225bf2f 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -119,6 +119,13 @@ def SymbolOp : TEST_Op<"symbol", [NoMemoryEffect, Symbol]> {
                        OptionalAttr<StrAttr>:$sym_visibility);
 }
 
+def OverriddenSymbolVisibilityOp : TEST_Op<"overridden_symbol_visibility", [
+  DeclareOpInterfaceMethods<Symbol, ["getVisibility", "setVisibility"]>,
+]> {
+  let summary =  "operation overridden symbol visibility accessors";
+  let arguments = (ins StrAttr:$sym_name);
+}
+
 def SymbolScopeOp : TEST_Op<"symbol_scope",
     [SymbolTable, SingleBlockImplicitTerminator<"TerminatorOp">]> {
   let summary =  "operation which defines a new symbol table";
diff --git a/mlir/unittests/IR/SymbolTableTest.cpp b/mlir/unittests/IR/SymbolTableTest.cpp
index cfc3fe0cb1c5b..4b3545bce1952 100644
--- a/mlir/unittests/IR/SymbolTableTest.cpp
+++ b/mlir/unittests/IR/SymbolTableTest.cpp
@@ -132,4 +132,38 @@ TEST_F(ReplaceAllSymbolUsesTest, StringAttrInFuncOp) {
   });
 }
 
+TEST(SymbolOpInterface, Visibility) {
+  DialectRegistry registry;
+  ::test::registerTestDialect(registry);
+  MLIRContext context(registry);
+
+  constexpr static StringLiteral kInput = R"MLIR(
+    "test.overridden_symbol_visibility"() {sym_name = "symbol_name"} : () -> ()
+  )MLIR";
+  OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(kInput, &context);
+  auto symOp = cast<SymbolOpInterface>(module->getBody()->front());
+
+  ASSERT_TRUE(symOp.isPrivate());
+  ASSERT_FALSE(symOp.isPublic());
+  ASSERT_FALSE(symOp.isNested());
+  ASSERT_TRUE(symOp.canDiscardOnUseEmpty());
+
+  std::string diagStr;
+  context.getDiagEngine().registerHandler(
+      [&](Diagnostic &diag) { diagStr += diag.str(); });
+
+  std::string expectedDiag;
+  symOp.setPublic();
+  expectedDiag += "'test.overridden_symbol_visibility' op cannot change "
+                  "visibility of symbol to public";
+  symOp.setNested();
+  expectedDiag += "'test.overridden_symbol_visibility' op cannot change "
+                  "visibility of symbol to nested";
+  symOp.setPrivate();
+  expectedDiag += "'test.overridden_symbol_visibility' op cannot change "
+                  "visibility of symbol to private";
+
+  ASSERT_EQ(diagStr, expectedDiag);
+}
+
 } // namespace

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:core MLIR Core Infrastructure mlir:ods mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants