Skip to content

Commit d083981

Browse files
committed
fixed shape and incorrect function name
1 parent c95b832 commit d083981

File tree

1 file changed

+19
-6
lines changed

1 file changed

+19
-6
lines changed

keras/src/backend/openvino/numpy.py

+19-6
Original file line numberDiff line numberDiff line change
@@ -924,6 +924,11 @@ def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis
924924
start = ov_opset.convert(start, ov_dtype).output(0)
925925
stop = ov_opset.convert(stop, ov_dtype).output(0)
926926

927+
# Check if we're dealing with arrays
928+
start_shape = start.get_shape()
929+
stop_shape = stop.get_shape()
930+
is_array_input = len(start_shape) > 0 or len(stop_shape) > 0
931+
927932
div = num - 1 if endpoint else num
928933
div_const = ov_opset.constant(div, ov_dtype).output(0)
929934
delta = ov_opset.subtract(stop, start).output(0)
@@ -944,31 +949,39 @@ def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis
944949
Type.u64: "u64"
945950
}
946951

947-
type_str = type_to_str.get(ov_dtype, "f32")
952+
type_str = type_to_str.get(ov_dtype, "f32")
948953

949954
indices = ov_opset.range(
950955
ov_opset.constant(0, Type.i32).output(0),
951956
ov_opset.constant(num, Type.i32).output(0),
952957
ov_opset.constant(1, Type.i32).output(0),
953-
type_str
958+
type_str
954959
).output(0)
955960

961+
if is_array_input:
962+
# For array inputs, we need to reshape the indices
963+
# to properly broadcast with multidimensional steps
964+
new_shape = []
965+
new_shape.append(num)
966+
for _ in range(len(start_shape)):
967+
new_shape.append(1)
968+
indices = ov_opset.reshape(indices, ov_opset.constant(new_shape, Type.i64).output(0)).output(0)
969+
956970
scaled_indices = ov_opset.multiply(indices, step).output(0)
957971
result = ov_opset.add(start, scaled_indices).output(0)
958972

959973
if endpoint and num > 1:
960974
last_idx = ov_opset.constant(num - 1, Type.i32).output(0)
961-
result = ov_opset.scatter_element_update(result, last_idx, stop).output(0)
975+
# Fix 1: Use the correct function name
976+
result = ov_opset.scatter_elements_update(result, last_idx, stop).output(0)
962977

963978
if axis != 0:
964979
axis_const = ov_opset.constant([axis], Type.i64).output(0)
965980
result = ov_opset.unsqueeze(result, axis_const).output(0)
966981

967982
if retstep:
968983
return OpenVINOKerasTensor(result), OpenVINOKerasTensor(step)
969-
return OpenVINOKerasTensor(result)
970-
971-
984+
return OpenVINOKerasTensor(result)
972985
def log(x):
973986
x = get_ov_output(x)
974987
x_type = x.get_element_type()

0 commit comments

Comments
 (0)