Skip to content

Commit 5012e69

Browse files
committed
Fix np.linspace implementation using OpenVINO opset
1 parent 10ce0e8 commit 5012e69

File tree

1 file changed

+17
-2
lines changed

1 file changed

+17
-2
lines changed

keras/src/backend/openvino/numpy.py

+17-2
Original file line numberDiff line numberDiff line change
@@ -929,13 +929,28 @@ def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis
929929
delta = ov_opset.subtract(stop, start).output(0)
930930
step = ov_opset.divide(delta, div_const).output(0)
931931

932-
dtype_str = str(ov_dtype).split('.')[-1]
932+
type_to_str = {
933+
Type.f16: "f16",
934+
Type.f32: "f32",
935+
Type.f64: "f64",
936+
Type.bf16: "bf16",
937+
Type.i8: "i8",
938+
Type.i16: "i16",
939+
Type.i32: "i32",
940+
Type.i64: "i64",
941+
Type.u8: "u8",
942+
Type.u16: "u16",
943+
Type.u32: "u32",
944+
Type.u64: "u64"
945+
}
946+
947+
type_str = type_to_str.get(ov_dtype, "f32")
933948

934949
indices = ov_opset.range(
935950
ov_opset.constant(0, Type.i32).output(0),
936951
ov_opset.constant(num, Type.i32).output(0),
937952
ov_opset.constant(1, Type.i32).output(0),
938-
dtype_str
953+
type_str
939954
).output(0)
940955

941956
scaled_indices = ov_opset.multiply(indices, step).output(0)

0 commit comments

Comments
 (0)