Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions experiments/gemmini/SUPPORT.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Buddy Gemmini lowering coverage (Sparsh)

This table is a quick view of what we’ve stress-tested and what Buddy lowers into.

| Test | Input dialect/op | Layout | Proof of Gemmini match | Proof of Gemmini command expansion | Notes |
|---|---|---|---|---|---|
| matmul | linalg.matmul | (varies) | `gemmini.tile_matmul` | `gemmini.intr.loop_ws_config*` + `gemmini.intr.loop_ws` | matmul lowered end-to-end |
| batch_matmul | linalg.batch_matmul | (varies) | `gemmini.tile_*` | `gemmini.intr.*` | batched path works |
| conv (NHWC/HWCF) | linalg.conv_2d_nhwc_hwcf | NHWC x HWCF | `gemmini.tile_conv` | `gemmini.intr.loop_conv_ws_config*` + `gemmini.intr.loop_conv_ws` | conv lowered to WS loop |
| conv (NCHW/FCHW) | linalg.conv_2d_nchw_fchw | NCHW x FCHW | `gemmini.tile_conv` | `gemmini.intr.loop_conv_ws_config*` + `gemmini.intr.loop_conv_ws` | alternate layout works |
| mini CNN block | 2x conv + copy | NCHW/FCHW | 2x `gemmini.tile_conv` | `gemmini.intr.loop_conv_ws*` appears | multi-layer block lowers |
30 changes: 30 additions & 0 deletions experiments/gemmini/inputs/batch_matmul.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
// RUN: buddy-opt %s \
// RUN: --convert-linalg-to-gemmini | \
// RUN: FileCheck %s

func.func @main() -> i8 {
%0 = arith.constant 0 : i8
%1 = arith.constant 1 : i8
%2 = arith.constant 2 : i8
%input0 = memref.alloc() : memref<3x3x3xi8>
%input1 = memref.alloc() : memref<3x3x3xi8>
%output = memref.alloc() : memref<3x3x3xi8>
linalg.fill
ins(%1 : i8)
outs(%input0 : memref<3x3x3xi8>)
linalg.fill
ins(%2 : i8)
outs(%input1 : memref<3x3x3xi8>)
// CHECK: gemmini.tile_matmul %subview %subview_2 %subview_3 %alloc_4 :
// CHECK-SAME: memref<3x3xi8, strided<[3, 1]>> memref<3x3xi8, strided<[3, 1]>> memref<3x3xi8, strided<[3, 1]>> memref<3x3xi32>
// CHECK: gemmini.tile_matmul %subview_5 %subview_6 %subview_7 %alloc_8 :
// CHECK-SAME: memref<3x3xi8, strided<[3, 1], offset: 9>> memref<3x3xi8, strided<[3, 1], offset: 9>> memref<3x3xi8, strided<[3, 1], offset: 9>> memref<3x3xi32>
// CHECK: gemmini.tile_matmul %subview_10 %subview_11 %subview_12 %alloc_13 :
// CHECK-SAME: memref<3x3xi8, strided<[3, 1], offset: 18>> memref<3x3xi8, strided<[3, 1], offset: 18>> memref<3x3xi8, strided<[3, 1], offset: 18>> memref<3x3xi32>
linalg.batch_matmul
ins(%input0, %input1: memref<3x3x3xi8>, memref<3x3x3xi8>)
outs(%output : memref<3x3x3xi8>)
gemmini.print %output : memref<3x3x3xi8>
memref.dealloc %output : memref<3x3x3xi8>
return %0 : i8
}
51 changes: 51 additions & 0 deletions experiments/gemmini/inputs/conv_2d_nchw_fchw_f32.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
// RUN: buddy-opt %s \
// RUN: --convert-linalg-to-gemmini="acc_t=f32" | \
// RUN: FileCheck %s

memref.global "private" @input : memref<2x2x5x5xf32> = dense<[[[[1., 0., -1., 0., 1.],
[1., 0., -1., 0., 1.],
[1., 0., -1., 0., 1.],
[1., 0., -1., 0., 1.],
[-1., 0., 1., 0., -1.]],
[[-1., 0., 1., 0., -1.],
[-1., 0., 1., 0., -1.],
[-1., 0., 1., 0., -1.],
[-1., 0., 1., 0., -1.],
[-1., 0., 1., 0., -1.]]],
[[[1., 0., 2., 0., 1.],
[1., 0., 2., 0., 1.],
[1., 0., 2., 0., 1.],
[1., 0., 2., 0., 1.],
[-1., 0., 2., 0., -1.]],
[[-1., 0., 2., 0., -1.],
[-1., 0., 2., 0., -1.],
[-1., 0., 2., 0., -1.],
[-1., 0., 2., 0., -1.],
[-1., 0., 2., 0., -1.]]]]>

memref.global "private" @weight : memref<2x2x3x3xf32> = dense<[[[[1., 2., 3.],
[3., 2., 1.],
[1., 2., 3.]],
[[3., 2., 1.],
[1., 2., 3.],
[3., 2., 1.]]],
[[[1., 2., 3.],
[3., 2., 1.],
[1., 2., 3.]],
[[3., 2., 1.],
[1., 2., 3.],
[3., 2., 1.]]]]>

func.func @main() -> i8 {
%0 = arith.constant 0 : i8
%mem0 = memref.get_global @input : memref<2x2x5x5xf32>
%mem1 = memref.get_global @weight : memref<2x2x3x3xf32>
%mem2 = memref.alloc() : memref<2x2x3x3xf32>
// CHECK: gemmini.tile_conv %alloc_{{[0-9]+}} %alloc_{{[0-9]+}} %alloc_{{[0-9]+}} %alloc_{{[0-9]+}} %{{.+}} %{{.+}} :
// CHECK-SAME: memref<2x5x5x2xf32> memref<18x2xf32> memref<2xf32> memref<18x2xf32> i64 i64
linalg.conv_2d_nchw_fchw
ins (%mem0, %mem1 : memref<2x2x5x5xf32>, memref<2x2x3x3xf32>)
outs(%mem2 : memref<2x2x3x3xf32>)
gemmini.print %mem2 : memref<2x2x3x3xf32>
return %0 : i8
}
7 changes: 7 additions & 0 deletions experiments/gemmini/inputs/matmul.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
module {
func.func @matmul(%A: memref<64x64xf16>, %B: memref<64x64xf16>, %C: memref<64x64xf32>) {
linalg.matmul ins(%A, %B : memref<64x64xf16>, memref<64x64xf16>)
outs(%C : memref<64x64xf32>)
return
}
}
49 changes: 49 additions & 0 deletions experiments/gemmini/inputs/tile-matmul-ws-softmax.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
// RUN: buddy-opt %s \
// RUN: --lower-gemmini | \
// RUN: FileCheck %s

memref.global "private" @g1 : memref<5x5xi8> = dense<[[1, 0, 0, 1, 0], [1, -1, 1, 0, 0], [-1, 0, 1, -1, 1], [1, 0, 0, 1, 0], [-1, 0, 0, -1, 0]]>
memref.global "private" @g2 : memref<5x5xi8> = dense<[[1, -1, 0, 0, 1], [1, 0, -1, 0, -1], [-1, -1, 0, -1, 1], [-1, 0, 0, 1, 0], [1, 0, 0, -1, 0]]>


func.func @main() -> i8 {
%i0 = arith.constant 0 : i8
%i1I8 = arith.constant 1 : i8
%minus1 = arith.constant -2 : i8
%i2I8 = arith.constant 2 : i8
%i2I32 = arith.constant 2 : i32
%dI32 = arith.constant 0 : i32
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%aArray = memref.get_global @g1 : memref<5x5xi8>
%bArray = memref.get_global @g2 : memref<5x5xi8>
%cArray = memref.alloc() : memref<5x5xi8>
%dArray = memref.alloc() : memref<5x5xi32>
%dim_I = memref.dim %aArray, %c0 : memref<5x5xi8>
%dim_J = memref.dim %bArray, %c1 : memref<5x5xi8>
%dim_K = memref.dim %aArray, %c1 : memref<5x5xi8>

scf.for %i3 = %c0 to %dim_I step %c1 {
scf.for %j3 = %c0 to %dim_J step %c1 {
memref.store %dI32, %dArray[%i3, %j3] : memref<5x5xi32>
}
}

gemmini.tile_matmul %aArray %bArray %cArray %dArray {dataflow=1}: memref<5x5xi8> memref<5x5xi8> memref<5x5xi8> memref<5x5xi32>
gemmini.print %cArray : memref<5x5xi8>

// CHECK: "gemmini.intr.config_ex"
// CHECK: "gemmini.intr.config_st"
// CHECK: "gemmini.intr.config_ld"
// CHECK: "gemmini.intr.config_norm"
// CHECK: "gemmini.intr.loop_ws_config_bounds"
// CHECK: "gemmini.intr.loop_ws_config_addrs_ab"
// CHECK: "gemmini.intr.loop_ws_config_addrs_dc"
// CHECK: "gemmini.intr.loop_ws_config_strides_ab"
// CHECK: "gemmini.intr.loop_ws_config_strides_dc"
// CHECK: "gemmini.intr.loop_ws"
// CHECk: "gemmini.intr.flush"
gemmini.tile_matmul %aArray %bArray %cArray %dArray {dataflow=1, act=4, bertScale=0.05:f32}: memref<5x5xi8> memref<5x5xi8> memref<5x5xi8> memref<5x5xi32>
gemmini.print %cArray : memref<5x5xi8>
return %i0 : i8
}
191 changes: 191 additions & 0 deletions experiments/gemmini/iree_inputs/mini_cnn_block.gemmini_tile.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
module {
func.func @mini_cnn_block(%arg0: memref<1x3x32x32xf32>, %arg1: memref<16x3x3x3xf32>, %arg2: memref<32x16x3x3xf32>, %arg3: memref<1x32x26x26xf32>) {
%alloc = memref.alloc() : memref<1x16x30x30xf32>
%alloc_0 = memref.alloc() : memref<1x32x26x26xf32>
%alloc_1 = memref.alloc() : memref<1x32x32x3xf32>
%alloc_2 = memref.alloc() : memref<27x16xf32>
%alloc_3 = memref.alloc() : memref<16xi32>
%alloc_4 = memref.alloc() : memref<900x16xf32>
%c30_i64 = arith.constant 30 : i64
%c3 = arith.constant 3 : index
%c3_5 = arith.constant 3 : index
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c1_6 = arith.constant 1 : index
scf.for %arg4 = %c0 to %c1 step %c1_6 {
%c0_27 = arith.constant 0 : index
%c3_28 = arith.constant 3 : index
%c1_29 = arith.constant 1 : index
scf.for %arg5 = %c0_27 to %c3_28 step %c1_29 {
%c0_30 = arith.constant 0 : index
%c32_31 = arith.constant 32 : index
%c1_32 = arith.constant 1 : index
scf.for %arg6 = %c0_30 to %c32_31 step %c1_32 {
%c0_33 = arith.constant 0 : index
%c32_34 = arith.constant 32 : index
%c1_35 = arith.constant 1 : index
scf.for %arg7 = %c0_33 to %c32_34 step %c1_35 {
%0 = memref.load %arg0[%arg4, %arg5, %arg6, %arg7] : memref<1x3x32x32xf32>
memref.store %0, %alloc_1[%arg4, %arg6, %arg7, %arg5] : memref<1x32x32x3xf32>
}
}
}
}
%c0_7 = arith.constant 0 : index
%c16 = arith.constant 16 : index
%c1_8 = arith.constant 1 : index
scf.for %arg4 = %c0_7 to %c16 step %c1_8 {
%c0_27 = arith.constant 0 : index
%c3_28 = arith.constant 3 : index
%c1_29 = arith.constant 1 : index
scf.for %arg5 = %c0_27 to %c3_28 step %c1_29 {
%c0_30 = arith.constant 0 : index
%c3_31 = arith.constant 3 : index
%c1_32 = arith.constant 1 : index
scf.for %arg6 = %c0_30 to %c3_31 step %c1_32 {
%c0_33 = arith.constant 0 : index
%c3_34 = arith.constant 3 : index
%c1_35 = arith.constant 1 : index
scf.for %arg7 = %c0_33 to %c3_34 step %c1_35 {
%0 = arith.muli %arg6, %c3 : index
%1 = arith.muli %0, %c3_5 : index
%2 = arith.muli %arg7, %c3_5 : index
%3 = arith.addi %1, %2 : index
%4 = arith.addi %3, %arg5 : index
%5 = memref.load %arg1[%arg4, %arg5, %arg6, %arg7] : memref<16x3x3x3xf32>
memref.store %5, %alloc_2[%4, %arg4] : memref<27x16xf32>
}
}
}
}
%c3_i64 = arith.constant 3 : i64
gemmini.tile_conv %alloc_1 %alloc_2 %alloc_3 %alloc_4 %c30_i64 %c30_i64 %c3_i64 : memref<1x32x32x3xf32> memref<27x16xf32> memref<16xi32> memref<900x16xf32> i64 i64 i64
%c0_9 = arith.constant 0 : index
%c1_10 = arith.constant 1 : index
%c1_11 = arith.constant 1 : index
scf.for %arg4 = %c0_9 to %c1_10 step %c1_11 {
%c0_27 = arith.constant 0 : index
%c16_28 = arith.constant 16 : index
%c1_29 = arith.constant 1 : index
scf.for %arg5 = %c0_27 to %c16_28 step %c1_29 {
%c0_30 = arith.constant 0 : index
%c30 = arith.constant 30 : index
%c1_31 = arith.constant 1 : index
scf.for %arg6 = %c0_30 to %c30 step %c1_31 {
%c0_32 = arith.constant 0 : index
%c30_33 = arith.constant 30 : index
%c1_34 = arith.constant 1 : index
scf.for %arg7 = %c0_32 to %c30_33 step %c1_34 {
%c30_35 = arith.constant 30 : index
%0 = arith.muli %arg4, %c30_35 : index
%1 = arith.muli %0, %c30_35 : index
%2 = arith.muli %arg6, %c30_35 : index
%3 = arith.addi %1, %2 : index
%4 = arith.addi %3, %arg7 : index
%5 = memref.load %alloc_4[%4, %arg5] : memref<900x16xf32>
memref.store %5, %alloc[%arg4, %arg5, %arg6, %arg7] : memref<1x16x30x30xf32>
}
}
}
}
memref.dealloc %alloc_1 : memref<1x32x32x3xf32>
memref.dealloc %alloc_2 : memref<27x16xf32>
memref.dealloc %alloc_4 : memref<900x16xf32>
memref.dealloc %alloc_3 : memref<16xi32>
%alloc_12 = memref.alloc() : memref<1x30x30x16xf32>
%alloc_13 = memref.alloc() : memref<144x32xf32>
%alloc_14 = memref.alloc() : memref<32xi32>
%alloc_15 = memref.alloc() : memref<676x32xf32>
%c26_i64 = arith.constant 26 : i64
%c3_16 = arith.constant 3 : index
%c16_17 = arith.constant 16 : index
%c0_18 = arith.constant 0 : index
%c1_19 = arith.constant 1 : index
%c1_20 = arith.constant 1 : index
scf.for %arg4 = %c0_18 to %c1_19 step %c1_20 {
%c0_27 = arith.constant 0 : index
%c16_28 = arith.constant 16 : index
%c1_29 = arith.constant 1 : index
scf.for %arg5 = %c0_27 to %c16_28 step %c1_29 {
%c0_30 = arith.constant 0 : index
%c30 = arith.constant 30 : index
%c1_31 = arith.constant 1 : index
scf.for %arg6 = %c0_30 to %c30 step %c1_31 {
%c0_32 = arith.constant 0 : index
%c30_33 = arith.constant 30 : index
%c1_34 = arith.constant 1 : index
scf.for %arg7 = %c0_32 to %c30_33 step %c1_34 {
%0 = memref.load %alloc[%arg4, %arg5, %arg6, %arg7] : memref<1x16x30x30xf32>
memref.store %0, %alloc_12[%arg4, %arg6, %arg7, %arg5] : memref<1x30x30x16xf32>
}
}
}
}
%c0_21 = arith.constant 0 : index
%c32 = arith.constant 32 : index
%c1_22 = arith.constant 1 : index
scf.for %arg4 = %c0_21 to %c32 step %c1_22 {
%c0_27 = arith.constant 0 : index
%c16_28 = arith.constant 16 : index
%c1_29 = arith.constant 1 : index
scf.for %arg5 = %c0_27 to %c16_28 step %c1_29 {
%c0_30 = arith.constant 0 : index
%c3_31 = arith.constant 3 : index
%c1_32 = arith.constant 1 : index
scf.for %arg6 = %c0_30 to %c3_31 step %c1_32 {
%c0_33 = arith.constant 0 : index
%c3_34 = arith.constant 3 : index
%c1_35 = arith.constant 1 : index
scf.for %arg7 = %c0_33 to %c3_34 step %c1_35 {
%0 = arith.muli %arg6, %c3_16 : index
%1 = arith.muli %0, %c16_17 : index
%2 = arith.muli %arg7, %c16_17 : index
%3 = arith.addi %1, %2 : index
%4 = arith.addi %3, %arg5 : index
%5 = memref.load %arg2[%arg4, %arg5, %arg6, %arg7] : memref<32x16x3x3xf32>
memref.store %5, %alloc_13[%4, %arg4] : memref<144x32xf32>
}
}
}
}
%c3_i64_23 = arith.constant 3 : i64
gemmini.tile_conv %alloc_12 %alloc_13 %alloc_14 %alloc_15 %c26_i64 %c26_i64 %c3_i64_23 : memref<1x30x30x16xf32> memref<144x32xf32> memref<32xi32> memref<676x32xf32> i64 i64 i64
%c0_24 = arith.constant 0 : index
%c1_25 = arith.constant 1 : index
%c1_26 = arith.constant 1 : index
scf.for %arg4 = %c0_24 to %c1_25 step %c1_26 {
%c0_27 = arith.constant 0 : index
%c32_28 = arith.constant 32 : index
%c1_29 = arith.constant 1 : index
scf.for %arg5 = %c0_27 to %c32_28 step %c1_29 {
%c0_30 = arith.constant 0 : index
%c26 = arith.constant 26 : index
%c1_31 = arith.constant 1 : index
scf.for %arg6 = %c0_30 to %c26 step %c1_31 {
%c0_32 = arith.constant 0 : index
%c26_33 = arith.constant 26 : index
%c1_34 = arith.constant 1 : index
scf.for %arg7 = %c0_32 to %c26_33 step %c1_34 {
%c26_35 = arith.constant 26 : index
%0 = arith.muli %arg4, %c26_35 : index
%1 = arith.muli %0, %c26_35 : index
%2 = arith.muli %arg6, %c26_35 : index
%3 = arith.addi %1, %2 : index
%4 = arith.addi %3, %arg7 : index
%5 = memref.load %alloc_15[%4, %arg5] : memref<676x32xf32>
memref.store %5, %alloc_0[%arg4, %arg5, %arg6, %arg7] : memref<1x32x26x26xf32>
}
}
}
}
memref.dealloc %alloc_12 : memref<1x30x30x16xf32>
memref.dealloc %alloc_13 : memref<144x32xf32>
memref.dealloc %alloc_15 : memref<676x32xf32>
memref.dealloc %alloc_14 : memref<32xi32>
linalg.copy ins(%alloc_0 : memref<1x32x26x26xf32>) outs(%arg3 : memref<1x32x26x26xf32>)
memref.dealloc %alloc : memref<1x16x30x30xf32>
memref.dealloc %alloc_0 : memref<1x32x26x26xf32>
return
}
}

13 changes: 13 additions & 0 deletions experiments/gemmini/libgemmini_status.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
Mac setup notes:

- buddy-mlir Gemmini lowering works (matmul, batch_matmul, conv, matmul+softmax).
- Generated LLVM IR (log.ll) and RISC-V asm (log.s) via Makefile targets, asm has Gemmini ops (config_ex, config_st, loop_ws, etc.).
- Spike + pk + riscv64-unknown-elf-gcc work for a simple "hello" test.

Blocked on:
- Installing libgemmini (Spike extension) from https://github.com/ucb-bar/libgemmini.
- `make libgemmini.so` fails on macOS with `ld: symbol(s) not found for architecture arm64` and RISCV-dependent paths that assume a full Chipyard/Gemmini tree.

Plan:
- Use Mac primarily for IR/pipeline experiments.
- Do Spike+Gemmini execution on a SLICE Linux machine with Chipyard/Gemmini installed.
Loading