-
Notifications
You must be signed in to change notification settings - Fork 107
dialects (arm): Add mixed vector/scalar fmul op #4053
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 11 commits
9e6222c
a2c66e7
621b220
90317eb
71b9381
68c8055
59af0f2
bbff14a
7f3fbc7
2557d29
9dd596f
111916f
09d7e65
3c9394e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
// RUN: XDSL_ROUNDTRIP | ||
// RUN: xdsl-opt -t arm-asm %s | filecheck %s --check-prefix=CHECK-ASM | ||
// CHECK: %v1 = arm_neon.get_register : !arm_neon.reg<v1> | ||
%v1 = arm_neon.get_register : !arm_neon.reg<v1> | ||
// CHECK: %v2 = arm_neon.get_register : !arm_neon.reg<v2> | ||
%v2 = arm_neon.get_register : !arm_neon.reg<v2> | ||
// CHECK: %dss_fmulvec = arm_neon.dss.fmulvec %v1, %v2 {arrangement = "4S", comment = "floating-point vector multiply v1 by v2", scalar_idx = 0 : i8} : (!arm_neon.reg<v1>, !arm_neon.reg<v2>) -> !arm_neon.reg<v3> | ||
// CHECK-ASM: fmul v3.4S, v1.4S, v2.S[0] # floating-point vector multiply v1 by v2 | ||
%dss_fmulvec = arm_neon.dss.fmulvec %v1, %v2 {"arrangement" = "4S", "comment" = "floating-point vector multiply v1 by v2", "scalar_idx" = 0 : i8} : (!arm_neon.reg<v1>, !arm_neon.reg<v2>) -> !arm_neon.reg<v3> |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
// RUN: XDSL_ROUNDTRIP | ||
|
||
|
||
// CHECK: "test.op"() {unallocated = !arm_neon.reg} : () -> () | ||
"test.op"() {unallocated = !arm_neon.reg} : () -> () | ||
|
||
// CHECK: "test.op"() {allocated = !arm_neon.reg<v1>} : () -> () | ||
"test.op"() {allocated = !arm_neon.reg<v1>} : () -> () |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,203 @@ | ||
from __future__ import annotations | ||
|
||
from abc import ABC | ||
from typing import Annotated | ||
|
||
from xdsl.backend.assembly_printer import AssemblyPrinter | ||
from xdsl.dialects.arm.assembly import assembly_arg_str | ||
from xdsl.dialects.arm.ops import ARMInstruction, ARMOperation | ||
from xdsl.dialects.arm.register import ARMRegisterType | ||
from xdsl.dialects.builtin import IntegerAttr, IntegerType, StringAttr, i8 | ||
from xdsl.ir import Dialect, Operation, SSAValue | ||
from xdsl.irdl import ( | ||
attr_def, | ||
base, | ||
irdl_attr_definition, | ||
irdl_op_definition, | ||
operand_def, | ||
result_def, | ||
) | ||
|
||
ARM_NEON_INDEX_BY_NAME = {f"v{i}": i for i in range(0, 32)} | ||
|
||
Imm8Attr = IntegerAttr[Annotated[IntegerType, i8]] | ||
|
||
|
||
@irdl_attr_definition | ||
class NEONRegisterType(ARMRegisterType): | ||
""" | ||
A 128-bit NEON ARM register type. | ||
""" | ||
|
||
name = "arm_neon.reg" | ||
|
||
@classmethod | ||
def instruction_set_name(cls) -> str: | ||
return "arm_neon" | ||
|
||
@classmethod | ||
def index_by_name(cls) -> dict[str, int]: | ||
return ARM_NEON_INDEX_BY_NAME | ||
|
||
@classmethod | ||
def infinite_register_prefix(cls): | ||
return "inf_" | ||
|
||
|
||
UNALLOCATED_NEON = NEONRegisterType.unallocated() | ||
V0 = NEONRegisterType.from_name("v0") | ||
V1 = NEONRegisterType.from_name("v1") | ||
V2 = NEONRegisterType.from_name("v2") | ||
V3 = NEONRegisterType.from_name("v3") | ||
V4 = NEONRegisterType.from_name("v4") | ||
V5 = NEONRegisterType.from_name("v5") | ||
V6 = NEONRegisterType.from_name("v6") | ||
V7 = NEONRegisterType.from_name("v7") | ||
V8 = NEONRegisterType.from_name("v8") | ||
V9 = NEONRegisterType.from_name("v9") | ||
V10 = NEONRegisterType.from_name("v10") | ||
V11 = NEONRegisterType.from_name("v11") | ||
V12 = NEONRegisterType.from_name("v12") | ||
V13 = NEONRegisterType.from_name("v13") | ||
V14 = NEONRegisterType.from_name("v14") | ||
V15 = NEONRegisterType.from_name("v15") | ||
V16 = NEONRegisterType.from_name("v16") | ||
V17 = NEONRegisterType.from_name("v17") | ||
V18 = NEONRegisterType.from_name("v18") | ||
V19 = NEONRegisterType.from_name("v19") | ||
V20 = NEONRegisterType.from_name("v20") | ||
V21 = NEONRegisterType.from_name("v21") | ||
V22 = NEONRegisterType.from_name("v22") | ||
V23 = NEONRegisterType.from_name("v23") | ||
V24 = NEONRegisterType.from_name("v24") | ||
V25 = NEONRegisterType.from_name("v25") | ||
V26 = NEONRegisterType.from_name("v26") | ||
V27 = NEONRegisterType.from_name("v27") | ||
V28 = NEONRegisterType.from_name("v28") | ||
V29 = NEONRegisterType.from_name("v29") | ||
V30 = NEONRegisterType.from_name("v30") | ||
V31 = NEONRegisterType.from_name("v31") | ||
|
||
|
||
class ARMNEONInstruction(ARMInstruction, ABC): | ||
""" | ||
Base class for operations in the NEON instruction set. | ||
The name of the operation will be used as the NEON assembly instruction name. | ||
|
||
The arrangement specifier for NEON instructions determines element size and count: | ||
- "4H" → 4 half-precision floats | ||
- "8H" → 8 half-precision floats | ||
- "2S" → 2 single-precision floats | ||
- "4S" → 4 single-precision floats | ||
- "2D" → 2 double-precision floats | ||
""" | ||
|
||
arrangement = attr_def(StringAttr) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should be an There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Or rather before this one, and after the register type |
||
|
||
def assembly_line(self) -> str | None: | ||
# default assembly code generator | ||
instruction_name = self.assembly_instruction_name() | ||
arg_str = ", ".join( | ||
f"{assembly_arg_str(arg)}.{self.arrangement.data}" | ||
for arg in self.assembly_line_args() | ||
if arg is not None | ||
) | ||
return AssemblyPrinter.assembly_line(instruction_name, arg_str, self.comment) | ||
|
||
|
||
@irdl_op_definition | ||
class GetRegisterOp(ARMOperation): | ||
superlopuh marked this conversation as resolved.
Show resolved
Hide resolved
|
||
""" | ||
This instruction allows us to create an SSAValue for a given register name. | ||
""" | ||
|
||
name = "arm_neon.get_register" | ||
|
||
result = result_def(NEONRegisterType) | ||
assembly_format = "attr-dict `:` type($result)" | ||
|
||
def __init__(self, register_type: NEONRegisterType): | ||
super().__init__(result_types=[register_type]) | ||
|
||
def assembly_line(self): | ||
return None | ||
|
||
|
||
@irdl_op_definition | ||
class DSSFMulVecScalarOp(ARMNEONInstruction): | ||
""" | ||
Floating-point multiply (mixed: first source operand is a vector, second is a scalar. Destination is a vector) | ||
This instruction multiplies each of the floating-point values in the first source operand by the | ||
second source operand and writes the resulting values to the corresponding lanes of the destination. | ||
|
||
Encoding: FMUL <Vd>.<T>, <Vn>.<T>, <Vm>.<idx>. | ||
Vd, Vn, Vm specify the regs. The <T> specifier determines element arrangement (size and count). | ||
The <idx> specifier determines the index of Vm at which the second source operand (scalar) can be found, | ||
preceded by a size specifier. | ||
|
||
https://developer.arm.com/documentation/ddi0602/2024-12/SIMD-FP-Instructions/FMUL--vector---Floating-point-multiply--vector--?lang=en#T_option__4 | ||
""" | ||
|
||
name = "arm_neon.dss.fmulvec" | ||
d = result_def(NEONRegisterType) | ||
s1 = operand_def(NEONRegisterType) | ||
s2 = operand_def(NEONRegisterType) | ||
scalar_idx = attr_def(base(Imm8Attr)) | ||
|
||
assembly_format = ( | ||
"$s1 `,` $s2 attr-dict `:` `(` type($s1) `,` type($s2) `)` `->` type($d)" | ||
) | ||
|
||
def __init__( | ||
self, | ||
s1: Operation | SSAValue, | ||
s2: Operation | SSAValue, | ||
*, | ||
d: NEONRegisterType, | ||
arrangement: str | StringAttr, | ||
comment: str | StringAttr | None = None, | ||
): | ||
if isinstance(arrangement, str): | ||
valid_arrangements = {"4H", "8H", "2S", "4S", "2D"} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I thought the bitwidth specifiers were per register? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure if I understand this correctly - they are specified with the register yes, but the same register can be used with different arrangement specifiers, as I understand There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If I understand well, the arrangement specifier describes the registers, but is carried by the instruction (whereas in x86, the arrangement specifier is essentially carried by the registers' names) ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes that’s my understanding too There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What about There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes I'm aware of some of these instructions where the destination has a different specifier than the source. My initial approach had just been trying to get it to work for we wanted with the intention of adding handling for the different cases as we build it up. But maybe that's not how we want to go, in which case I'm happy to do some more digging around the docs to try to establish the rules for these specifiers:) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I now think it's worth not overthinking it, we can fix things later. It would be great to add more documentation around the place to explain the design for future readers. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah I think it's fine. What we basically need is (to generate) runnable code implementing tiled matrix multiplications. We can refine it incrementally. |
||
if arrangement in valid_arrangements: | ||
arrangement = StringAttr(arrangement) | ||
else: | ||
raise ValueError(f"Invalid FMUL arrangement: {arrangement}") | ||
if isinstance(comment, str): | ||
comment = StringAttr(comment) | ||
super().__init__( | ||
operands=(s1, s2), | ||
attributes={ | ||
"arrangement": arrangement, | ||
"comment": comment, | ||
}, | ||
result_types=(d,), | ||
) | ||
|
||
def assembly_instruction_name(self) -> str: | ||
return "fmul" | ||
|
||
def assembly_line_args(self): | ||
return (self.d, self.s1, self.s2) | ||
|
||
def assembly_line(self) -> str | None: | ||
instruction_name = self.assembly_instruction_name() | ||
arg_str = ", ".join( | ||
f"{assembly_arg_str(arg)}.{self.arrangement.data}" | ||
for arg in self.assembly_line_args()[:2] | ||
if arg is not None | ||
) | ||
arg_str += f", {assembly_arg_str(self.assembly_line_args()[2])}.{self.arrangement.data[1]}[{self.scalar_idx.value.data}]" | ||
return AssemblyPrinter.assembly_line(instruction_name, arg_str, self.comment) | ||
|
||
|
||
ARM_NEON = Dialect( | ||
"arm_neon", | ||
[ | ||
DSSFMulVecScalarOp, | ||
GetRegisterOp, | ||
], | ||
[ | ||
NEONRegisterType, | ||
], | ||
) |
Uh oh!
There was an error while loading. Please reload this page.