Skip to content

Commit 7abc003

Browse files
committed
Fix match type usage for Gather op
1 parent 896e437 commit 7abc003

File tree

2 files changed

+23
-22
lines changed

2 files changed

+23
-22
lines changed

core/src/main/scala/ONNX.scala

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -756,30 +756,28 @@ package object onnx {
756756
Float | Double | String | Boolean |
757757
Complex[
758758
Float
759-
] | Complex[Double],
760-
@sp Tind <: Int: Numeric, // Spec also supports long
759+
] | Complex[Double],
761760
Tt <: TensorTypeDenotation,
762761
Td <: TensorShapeDenotation,
763762
S <: Shape,
764-
Tt1 <: TensorTypeDenotation,
765-
Td1 <: TensorShapeDenotation,
766-
S1 <: Shape,
767763
Tt2 <: TensorTypeDenotation,
768764
Td2 <: TensorShapeDenotation,
769765
AxisIndex <: Index ::: INil,
770-
AxisIndices <: Indices
766+
AxisIndices <: Indices,
767+
IndicesSize <: Index
771768
](
772769
name: String,
773770
axis: AxisIndex = 0 ::: INil,
774771
data: Tensor[T, Tuple3[Tt, Td, S]],
775-
indices: AxisIndices
772+
indices: AxisIndices,
773+
// indicesSize: IndicesSize[AxisIndices]
776774
)(using
777775
tt: ValueOf[Tt2],
778776
td: TensorShapeDenotationOf[Td2],
779-
s: ShapeOf[GatheredShape[S, AxisIndex, AxisIndices]],
777+
s: ShapeOf[GatheredShape[S, AxisIndex, AxisIndices, IndicesSize]],
780778
i: IndicesOf[AxisIndex],
781-
i2: IndicesOf[AxisIndices]
782-
): Tensor[T, Tuple3[Tt2, Td2, GatheredShape[S, AxisIndex, AxisIndices]]] = {
779+
i2: IndicesOf[AxisIndices],
780+
): Tensor[T, Tuple3[Tt2, Td2, GatheredShape[S, AxisIndex, AxisIndices, IndicesSize]]] = {
783781
val map: Map[String, Any] = Map("axis" -> indicesOf[AxisIndex].indices.toArray.head)
784782
val allInputs = Tuple2(
785783
data,

core/src/main/scala/Tensors.scala

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -144,29 +144,32 @@ object Tensors {
144144
type GatheredShape[
145145
S <: Shape,
146146
AxisIndex <: None.type | Indices,
147-
AxisIndices <: Indices
147+
AxisIndices <: Indices,
148+
IndicesSize <: Index
148149
] <: Shape = AxisIndex match {
149150
case None.type => SNil
150-
case Indices => GatheredShapeLoop[S, AxisIndex, 0, AxisIndices]
151+
case Indices => GatheredShapeLoop[S, AxisIndex, 0, AxisIndices, IndicesSize]
151152
}
152153

153154
protected type GatheredShapeLoop[
154155
ToGather <: Shape,
155156
AxisIndex <: Indices,
156157
I <: Index,
157-
AxisIndices <: Indices
158+
AxisIndices <: Indices,
159+
IndicesSize <: Index
158160
] <: Shape = ToGather match {
159161
case head #: tail =>
160162
Indices.Contains[AxisIndex, I] match {
161-
case true =>
162-
IndicesSize[AxisIndices] #:
163+
case true =>
164+
IndicesSize #:
163165
GatheredShapeLoop[
164166
tail,
165167
Indices.RemoveValue[AxisIndex, I],
166168
S[I],
167-
AxisIndices
169+
AxisIndices,
170+
IndicesSize
168171
]
169-
case false => head #: GatheredShapeLoop[tail, AxisIndex, S[I], AxisIndices]
172+
case false => head #: GatheredShapeLoop[tail, AxisIndex, S[I], AxisIndices, IndicesSize]
170173
}
171174
case SNil =>
172175
AxisIndex match {
@@ -176,12 +179,12 @@ object Tensors {
176179
}
177180
}
178181

179-
type IndicesSize[AxisIndices <: Indices] = IndicesSizeLoop[AxisIndices, 0]
182+
type IndicesSizeOf[AxisIndices <: Indices] = IndicesSizeLoop[AxisIndices, 0]
180183

181-
type IndicesSizeLoop[AxisIndices <: Indices, Acc <: Dimension] = AxisIndices match {
182-
case head ::: tail => IndicesSizeLoop[tail, S[Acc]]
183-
case INil => Acc
184-
}
184+
protected type IndicesSizeLoop[AxisIndices <: Indices, Acc <: Dimension] <: Index = AxisIndices match {
185+
case head ::: tail => IndicesSizeLoop[tail, S[Acc]]
186+
case INil => Acc
187+
}
185188

186189
type FlattenedShape[S <: Shape, AxisIndex <: None.type | Indices] <: Shape = AxisIndex match {
187190
case None.type => SNil

0 commit comments

Comments
 (0)