@@ -924,6 +924,11 @@ def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis
924
924
start = ov_opset .convert (start , ov_dtype ).output (0 )
925
925
stop = ov_opset .convert (stop , ov_dtype ).output (0 )
926
926
927
+ # Check if we're dealing with arrays
928
+ start_shape = start .get_shape ()
929
+ stop_shape = stop .get_shape ()
930
+ is_array_input = len (start_shape ) > 0 or len (stop_shape ) > 0
931
+
927
932
div = num - 1 if endpoint else num
928
933
div_const = ov_opset .constant (div , ov_dtype ).output (0 )
929
934
delta = ov_opset .subtract (stop , start ).output (0 )
@@ -944,31 +949,39 @@ def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis
944
949
Type .u64 : "u64"
945
950
}
946
951
947
- type_str = type_to_str .get (ov_dtype , "f32" )
952
+ type_str = type_to_str .get (ov_dtype , "f32" )
948
953
949
954
indices = ov_opset .range (
950
955
ov_opset .constant (0 , Type .i32 ).output (0 ),
951
956
ov_opset .constant (num , Type .i32 ).output (0 ),
952
957
ov_opset .constant (1 , Type .i32 ).output (0 ),
953
- type_str
958
+ type_str
954
959
).output (0 )
955
960
961
+ if is_array_input :
962
+ # For array inputs, we need to reshape the indices
963
+ # to properly broadcast with multidimensional steps
964
+ new_shape = []
965
+ new_shape .append (num )
966
+ for _ in range (len (start_shape )):
967
+ new_shape .append (1 )
968
+ indices = ov_opset .reshape (indices , ov_opset .constant (new_shape , Type .i64 ).output (0 )).output (0 )
969
+
956
970
scaled_indices = ov_opset .multiply (indices , step ).output (0 )
957
971
result = ov_opset .add (start , scaled_indices ).output (0 )
958
972
959
973
if endpoint and num > 1 :
960
974
last_idx = ov_opset .constant (num - 1 , Type .i32 ).output (0 )
961
- result = ov_opset .scatter_element_update (result , last_idx , stop ).output (0 )
975
+ # Fix 1: Use the correct function name
976
+ result = ov_opset .scatter_elements_update (result , last_idx , stop ).output (0 )
962
977
963
978
if axis != 0 :
964
979
axis_const = ov_opset .constant ([axis ], Type .i64 ).output (0 )
965
980
result = ov_opset .unsqueeze (result , axis_const ).output (0 )
966
981
967
982
if retstep :
968
983
return OpenVINOKerasTensor (result ), OpenVINOKerasTensor (step )
969
- return OpenVINOKerasTensor (result )
970
-
971
-
984
+ return OpenVINOKerasTensor (result )
972
985
def log (x ):
973
986
x = get_ov_output (x )
974
987
x_type = x .get_element_type ()
0 commit comments