Skip to content

Commit 10ce0e8

Browse files
committed
Fix np.linspace implementation using OpenVINO opset
1 parent ebcd971 commit 10ce0e8

File tree

1 file changed

+24
-20
lines changed

1 file changed

+24
-20
lines changed

keras/src/backend/openvino/numpy.py

+24-20
Original file line numberDiff line numberDiff line change
@@ -906,6 +906,7 @@ def less_equal(x1, x2):
906906
x1, x2 = _align_operand_types(x1, x2, "less_equal()")
907907
return OpenVINOKerasTensor(ov_opset.less_equal(x1, x2).output(0))
908908

909+
909910
def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis=0):
910911
if not isinstance(num, int):
911912
raise TypeError("num must be an integer")
@@ -916,38 +917,41 @@ def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis
916917
stop = get_ov_output(stop)
917918

918919
if dtype is None:
919-
dtype = OPENVINO_DTYPES[config.floatx()]
920+
ov_dtype = OPENVINO_DTYPES[config.floatx()]
921+
else:
922+
ov_dtype = OPENVINO_DTYPES[dtype]
920923

921-
# Convert inputs to specified dtype
922-
start = ov_opset.convert(start, dtype)
923-
stop = ov_opset.convert(stop, dtype)
924+
start = ov_opset.convert(start, ov_dtype).output(0)
925+
stop = ov_opset.convert(stop, ov_dtype).output(0)
924926

925927
div = num - 1 if endpoint else num
926-
div_const = ov_opset.constant(div, dtype=dtype)
927-
928-
delta = ov_opset.subtract(stop, start)
929-
step = ov_opset.divide(delta, div_const)
928+
div_const = ov_opset.constant(div, ov_dtype).output(0)
929+
delta = ov_opset.subtract(stop, start).output(0)
930+
step = ov_opset.divide(delta, div_const).output(0)
930931

932+
dtype_str = str(ov_dtype).split('.')[-1]
933+
931934
indices = ov_opset.range(
932-
ov_opset.constant(0, dtype=dtype),
933-
ov_opset.constant(num, dtype=dtype),
934-
ov_opset.constant(1, dtype=dtype)
935-
)
935+
ov_opset.constant(0, Type.i32).output(0),
936+
ov_opset.constant(num, Type.i32).output(0),
937+
ov_opset.constant(1, Type.i32).output(0),
938+
dtype_str
939+
).output(0)
936940

937-
scaled_indices = ov_opset.multiply(indices, step)
938-
result = ov_opset.add(start, scaled_indices)
941+
scaled_indices = ov_opset.multiply(indices, step).output(0)
942+
result = ov_opset.add(start, scaled_indices).output(0)
939943

940944
if endpoint and num > 1:
941-
last_idx = ov_opset.constant(num - 1, dtype=Type.i32)
942-
result = ov_opset.scatter_element_update(result, last_idx, stop)
945+
last_idx = ov_opset.constant(num - 1, Type.i32).output(0)
946+
result = ov_opset.scatter_element_update(result, last_idx, stop).output(0)
943947

944948
if axis != 0:
945-
axis_const = ov_opset.constant([axis], dtype=Type.i64)
946-
result = ov_opset.unsqueeze(result, axis_const)
949+
axis_const = ov_opset.constant([axis], Type.i64).output(0)
950+
result = ov_opset.unsqueeze(result, axis_const).output(0)
947951

948952
if retstep:
949-
return result, step
950-
return result
953+
return OpenVINOKerasTensor(result), OpenVINOKerasTensor(step)
954+
return OpenVINOKerasTensor(result)
951955

952956

953957
def log(x):

0 commit comments

Comments
 (0)