Skip to content

Commit c6b745b

Browse files
committed
minor fixups on the attention kernel
1 parent 42809d2 commit c6b745b

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

ir/linalg_ext.mlir

+2-2
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ func.func @custom_op_symbolic_dims(
3535
affine_map<(d0, d1, d2, d3, d4) -> (d0, d1)>],
3636
iterator_types = [#iree_linalg_ext.iterator_type<parallel>,
3737
#iree_linalg_ext.iterator_type<parallel>,
38-
#iree_linalg_ext.iterator_type<reduction>,
38+
#iree_linalg_ext.iterator_type<parallel>,
3939
#iree_linalg_ext.iterator_type<reduction>,
4040
#iree_linalg_ext.iterator_type<reduction>]}
4141
ins(%q, %k, %v : !q_type, !k_type, !v_type)
@@ -111,7 +111,7 @@ func.func @custom_op_symbolic_dims(
111111
linalg.yield %13 : f32
112112
} -> tensor<?x?x?xf32>
113113

114-
// Perform the second matmul for hte output tile
114+
// Perform the second matmul for the output tile
115115
%new_ot = linalg.batch_matmul ins(%new_qk, %vt : tensor<?x?x?xf32>, tensor<?x?x?xf32>) outs(%norm_ot : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
116116

117117
iree_linalg_ext.yield %new_ot, %new_max, %new_sum : tensor<?x?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>

0 commit comments

Comments
 (0)