Skip to content

dialects (arm): Add mixed vector/scalar fmul op #4053

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 14 commits into from
Closed
9 changes: 9 additions & 0 deletions tests/filecheck/dialects/arm_neon/test_ops.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
// RUN: XDSL_ROUNDTRIP
// RUN: xdsl-opt -t arm-asm %s | filecheck %s --check-prefix=CHECK-ASM
// CHECK: %v1 = arm_neon.get_register : !arm_neon.reg<v1>
%v1 = arm_neon.get_register : !arm_neon.reg<v1>
// CHECK: %v2 = arm_neon.get_register : !arm_neon.reg<v2>
%v2 = arm_neon.get_register : !arm_neon.reg<v2>
// CHECK: %dss_fmulvec = arm_neon.dss.fmulvec %v1, %v2 {arrangement = "4S", comment = "floating-point vector multiply v1 by v2", scalar_idx = 0 : i8} : (!arm_neon.reg<v1>, !arm_neon.reg<v2>) -> !arm_neon.reg<v3>
// CHECK-ASM: fmul v3.4S, v1.4S, v2.S[0] # floating-point vector multiply v1 by v2
%dss_fmulvec = arm_neon.dss.fmulvec %v1, %v2 {"arrangement" = "4S", "comment" = "floating-point vector multiply v1 by v2", "scalar_idx" = 0 : i8} : (!arm_neon.reg<v1>, !arm_neon.reg<v2>) -> !arm_neon.reg<v3>
8 changes: 8 additions & 0 deletions tests/filecheck/dialects/arm_neon/test_registers.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
// RUN: XDSL_ROUNDTRIP


// CHECK: "test.op"() {unallocated = !arm_neon.reg} : () -> ()
"test.op"() {unallocated = !arm_neon.reg} : () -> ()

// CHECK: "test.op"() {allocated = !arm_neon.reg<v1>} : () -> ()
"test.op"() {allocated = !arm_neon.reg<v1>} : () -> ()
6 changes: 6 additions & 0 deletions xdsl/dialects/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,11 @@ def get_arm_func():

return ARM_FUNC

def get_arm_neon():
from xdsl.dialects.arm_neon import ARM_NEON

return ARM_NEON

def get_bufferization():
from xdsl.dialects.bufferization import Bufferization

Expand Down Expand Up @@ -320,6 +325,7 @@ def get_transform():
"arith": get_arith,
"arm": get_arm,
"arm_func": get_arm_func,
"arm_neon": get_arm_neon,
"bufferization": get_bufferization,
"builtin": get_builtin,
"cf": get_cf,
Expand Down
203 changes: 203 additions & 0 deletions xdsl/dialects/arm_neon.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
from __future__ import annotations

from abc import ABC
from typing import Annotated

from xdsl.backend.assembly_printer import AssemblyPrinter
from xdsl.dialects.arm.assembly import assembly_arg_str
from xdsl.dialects.arm.ops import ARMInstruction, ARMOperation
from xdsl.dialects.arm.register import ARMRegisterType
from xdsl.dialects.builtin import IntegerAttr, IntegerType, StringAttr, i8
from xdsl.ir import Dialect, Operation, SSAValue
from xdsl.irdl import (
attr_def,
base,
irdl_attr_definition,
irdl_op_definition,
operand_def,
result_def,
)

ARM_NEON_INDEX_BY_NAME = {f"v{i}": i for i in range(0, 32)}

Imm8Attr = IntegerAttr[Annotated[IntegerType, i8]]


@irdl_attr_definition
class NEONRegisterType(ARMRegisterType):
"""
A 128-bit NEON ARM register type.
"""

name = "arm_neon.reg"

@classmethod
def instruction_set_name(cls) -> str:
return "arm_neon"

@classmethod
def index_by_name(cls) -> dict[str, int]:
return ARM_NEON_INDEX_BY_NAME

@classmethod
def infinite_register_prefix(cls):
return "inf_"


UNALLOCATED_NEON = NEONRegisterType.unallocated()
V0 = NEONRegisterType.from_name("v0")
V1 = NEONRegisterType.from_name("v1")
V2 = NEONRegisterType.from_name("v2")
V3 = NEONRegisterType.from_name("v3")
V4 = NEONRegisterType.from_name("v4")
V5 = NEONRegisterType.from_name("v5")
V6 = NEONRegisterType.from_name("v6")
V7 = NEONRegisterType.from_name("v7")
V8 = NEONRegisterType.from_name("v8")
V9 = NEONRegisterType.from_name("v9")
V10 = NEONRegisterType.from_name("v10")
V11 = NEONRegisterType.from_name("v11")
V12 = NEONRegisterType.from_name("v12")
V13 = NEONRegisterType.from_name("v13")
V14 = NEONRegisterType.from_name("v14")
V15 = NEONRegisterType.from_name("v15")
V16 = NEONRegisterType.from_name("v16")
V17 = NEONRegisterType.from_name("v17")
V18 = NEONRegisterType.from_name("v18")
V19 = NEONRegisterType.from_name("v19")
V20 = NEONRegisterType.from_name("v20")
V21 = NEONRegisterType.from_name("v21")
V22 = NEONRegisterType.from_name("v22")
V23 = NEONRegisterType.from_name("v23")
V24 = NEONRegisterType.from_name("v24")
V25 = NEONRegisterType.from_name("v25")
V26 = NEONRegisterType.from_name("v26")
V27 = NEONRegisterType.from_name("v27")
V28 = NEONRegisterType.from_name("v28")
V29 = NEONRegisterType.from_name("v29")
V30 = NEONRegisterType.from_name("v30")
V31 = NEONRegisterType.from_name("v31")


class ARMNEONInstruction(ARMInstruction, ABC):
"""
Base class for operations in the NEON instruction set.
The name of the operation will be used as the NEON assembly instruction name.

The arrangement specifier for NEON instructions determines element size and count:
- "4H" → 4 half-precision floats
- "8H" → 8 half-precision floats
- "2S" → 2 single-precision floats
- "4S" → 4 single-precision floats
- "2D" → 2 double-precision floats
"""

arrangement = attr_def(StringAttr)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be an EnumAttribute, maybe we can do this as a first PR?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or rather before this one, and after the register type


def assembly_line(self) -> str | None:
# default assembly code generator
instruction_name = self.assembly_instruction_name()
arg_str = ", ".join(
f"{assembly_arg_str(arg)}.{self.arrangement.data}"
for arg in self.assembly_line_args()
if arg is not None
)
return AssemblyPrinter.assembly_line(instruction_name, arg_str, self.comment)


@irdl_op_definition
class GetRegisterOp(ARMOperation):
"""
This instruction allows us to create an SSAValue for a given register name.
"""

name = "arm_neon.get_register"

result = result_def(NEONRegisterType)
assembly_format = "attr-dict `:` type($result)"

def __init__(self, register_type: NEONRegisterType):
super().__init__(result_types=[register_type])

def assembly_line(self):
return None


@irdl_op_definition
class DSSFMulVecScalarOp(ARMNEONInstruction):
"""
Floating-point multiply (mixed: first source operand is a vector, second is a scalar. Destination is a vector)
This instruction multiplies each of the floating-point values in the first source operand by the
second source operand and writes the resulting values to the corresponding lanes of the destination.

Encoding: FMUL <Vd>.<T>, <Vn>.<T>, <Vm>.<idx>.
Vd, Vn, Vm specify the regs. The <T> specifier determines element arrangement (size and count).
The <idx> specifier determines the index of Vm at which the second source operand (scalar) can be found,
preceded by a size specifier.

https://developer.arm.com/documentation/ddi0602/2024-12/SIMD-FP-Instructions/FMUL--vector---Floating-point-multiply--vector--?lang=en#T_option__4
"""

name = "arm_neon.dss.fmulvec"
d = result_def(NEONRegisterType)
s1 = operand_def(NEONRegisterType)
s2 = operand_def(NEONRegisterType)
scalar_idx = attr_def(base(Imm8Attr))

assembly_format = (
"$s1 `,` $s2 attr-dict `:` `(` type($s1) `,` type($s2) `)` `->` type($d)"
)

def __init__(
self,
s1: Operation | SSAValue,
s2: Operation | SSAValue,
*,
d: NEONRegisterType,
arrangement: str | StringAttr,
comment: str | StringAttr | None = None,
):
if isinstance(arrangement, str):
valid_arrangements = {"4H", "8H", "2S", "4S", "2D"}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought the bitwidth specifiers were per register?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if I understand this correctly - they are specified with the register yes, but the same register can be used with different arrangement specifiers, as I understand

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I understand well, the arrangement specifier describes the registers, but is carried by the instruction (whereas in x86, the arrangement specifier is essentially carried by the registers' names) ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes that’s my understanding too

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about SADDL2 V0.2D, V1.4S, V2.4S from page 102 of https://cs140e.sergio.bz/docs/ARMv8-A-Programmer-Guide.pdf ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes I'm aware of some of these instructions where the destination has a different specifier than the source. My initial approach had just been trying to get it to work for we wanted with the intention of adding handling for the different cases as we build it up. But maybe that's not how we want to go, in which case I'm happy to do some more digging around the docs to try to establish the rules for these specifiers:)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I now think it's worth not overthinking it, we can fix things later. It would be great to add more documentation around the place to explain the design for future readers.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I think it's fine. What we basically need is (to generate) runnable code implementing tiled matrix multiplications. We can refine it incrementally.

if arrangement in valid_arrangements:
arrangement = StringAttr(arrangement)
else:
raise ValueError(f"Invalid FMUL arrangement: {arrangement}")
if isinstance(comment, str):
comment = StringAttr(comment)
super().__init__(
operands=(s1, s2),
attributes={
"arrangement": arrangement,
"comment": comment,
},
result_types=(d,),
)

def assembly_instruction_name(self) -> str:
return "fmul"

def assembly_line_args(self):
return (self.d, self.s1, self.s2)

def assembly_line(self) -> str | None:
instruction_name = self.assembly_instruction_name()
arg_str = ", ".join(
f"{assembly_arg_str(arg)}.{self.arrangement.data}"
for arg in self.assembly_line_args()[:2]
if arg is not None
)
arg_str += f", {assembly_arg_str(self.assembly_line_args()[2])}.{self.arrangement.data[1]}[{self.scalar_idx.value.data}]"
return AssemblyPrinter.assembly_line(instruction_name, arg_str, self.comment)


ARM_NEON = Dialect(
"arm_neon",
[
DSSFMulVecScalarOp,
GetRegisterOp,
],
[
NEONRegisterType,
],
)