Skip to content

Commit 0ed9126

Browse files
author
Saurabh Singh
committed
fix
1 parent 22ba3f2 commit 0ed9126

File tree

1 file changed

+40
-54
lines changed

1 file changed

+40
-54
lines changed

keras/src/backend/openvino/numpy.py

+40-54
Original file line numberDiff line numberDiff line change
@@ -996,64 +996,50 @@ def maximum(x1, x2):
996996

997997

998998
def median(x, axis=None, keepdims=False):
999-
x = get_ov_output(x)
1000-
999+
x_node = get_ov_output(x)
1000+
orig_dtype = x_node.get_element_type()
1001+
10011002
if axis is None:
1002-
flatten_shape = ov_opset.constant([-1], Type.i32).output(0)
1003-
x = ov_opset.reshape(x, flatten_shape, False).output(0)
1004-
axis = 0
1005-
1006-
shape = ov_opset.shape_of(x).output(0)
1007-
shape = ov_opset.convert(shape, Type.i64).output(0)
1008-
1009-
if axis is not None:
1010-
indices = ov_opset.constant([axis], Type.i32).output(0)
1011-
length = ov_opset.gather(shape, indices, 0).output(0)
1012-
length = ov_opset.reshape(length, ov_opset.constant([], Type.i32).output(0), False).output(0)
1003+
flatten_shape = ov_opset.constant([-1], Type.i64).output(0)
1004+
x_node = ov_opset.reshape(x_node, flatten_shape, False).output(0)
1005+
axis_val = 0
10131006
else:
1014-
length = ov_opset.shape_of(shape).output(0)
1015-
length = ov_opset.convert(length, Type.i64).output(0)
1016-
length = ov_opset.reshape(length, ov_opset.constant([], Type.i32).output(0), False).output(0)
1017-
1018-
sorted_x = ov_opset.topk(x, length, axis, "min", "value").output(0)
1019-
1020-
const_2 = ov_opset.constant(2, Type.i64).output(0)
1021-
mid_index = ov_opset.floor_mod(length, const_2).output(0)
1022-
is_odd = ov_opset.equal(mid_index, ov_opset.constant(1, Type.i64).output(0)).output(0)
1023-
1024-
half_length = ov_opset.divide(length, const_2).output(0)
1025-
floor_half_length = ov_opset.floor(half_length).output(0)
1026-
floor_half_length = ov_opset.convert(floor_half_length, Type.i64).output(0)
1027-
1028-
mid_index_scalar = ov_opset.convert(floor_half_length, Type.i32).output(0)
1029-
middle_elem = ov_opset.gather(sorted_x, mid_index_scalar, axis).output(0)
1030-
1031-
prev_mid_index = ov_opset.subtract(floor_half_length, ov_opset.constant(1, Type.i64).output(0)).output(0)
1032-
prev_mid_index_scalar = ov_opset.convert(prev_mid_index, Type.i32).output(0)
1033-
prev_middle_elem = ov_opset.gather(sorted_x, prev_mid_index_scalar, axis).output(0)
1034-
1035-
median_value = ov_opset.select(
1036-
is_odd,
1037-
middle_elem,
1038-
ov_opset.divide(
1039-
ov_opset.add(middle_elem, prev_middle_elem).output(0),
1040-
ov_opset.constant(2.0, middle_elem.get_element_type()).output(0)
1041-
).output(0)
1042-
).output(0)
1043-
1007+
axis_val = axis
1008+
1009+
shape = ov_opset.shape_of(x_node).output(0)
1010+
shape = ov_opset.convert(shape, Type.i64).output(0)
1011+
1012+
axis_const = ov_opset.constant([axis_val], Type.i64).output(0)
1013+
length = ov_opset.gather(shape, axis_const, 0).output(0)
1014+
length = ov_opset.reshape(length, ov_opset.constant([], Type.i64).output(0), False).output(0)
1015+
1016+
sorted_x = ov_opset.topk(x_node, length, axis_val, "min", "value").output(0)
1017+
1018+
two = ov_opset.constant(2, Type.i64).output(0)
1019+
1020+
half = ov_opset.divide(length, two).output(0)
1021+
mid_idx = ov_opset.floor(half).output(0)
1022+
rem = ov_opset.floor_mod(length, two).output(0)
1023+
is_odd = ov_opset.equal(rem, ov_opset.constant(1, Type.i64).output(0)).output(0)
1024+
1025+
mid_idx_i32 = ov_opset.convert(mid_idx, Type.i32).output(0)
1026+
middle = ov_opset.gather(sorted_x, mid_idx_i32, axis_val).output(0)
1027+
1028+
prev_idx = ov_opset.subtract(mid_idx, ov_opset.constant(1, Type.i64).output(0)).output(0)
1029+
prev_idx_i32 = ov_opset.convert(prev_idx, Type.i32).output(0)
1030+
prev = ov_opset.gather(sorted_x, prev_idx_i32, axis_val).output(0)
1031+
sum_val = ov_opset.add(middle, prev).output(0)
1032+
avg = ov_opset.divide(sum_val, ov_opset.constant(2.0, sum_val.get_element_type()).output(0)).output(0)
1033+
1034+
median_val = ov_opset.select(is_odd, middle, avg).output(0)
1035+
10441036
if keepdims:
1045-
keep_shape = shape
1046-
if axis is not None:
1047-
one_tensor = ov_opset.constant(1, Type.i64).output(0)
1048-
indices = ov_opset.constant([axis], Type.i32).output(0)
1049-
keep_shape = ov_opset.scatter_elements_update(shape, indices, one_tensor, 0).output(0)
1050-
median_value = ov_opset.reshape(median_value, keep_shape, False).output(0)
1051-
elif axis is None and x.get_partial_shape().rank.get_length() > 1:
1052-
scalar_shape = ov_opset.constant([], Type.i32).output(0)
1053-
median_value = ov_opset.reshape(median_value, scalar_shape, False).output(0)
1037+
one = ov_opset.constant(1, Type.i64).output(0)
1038+
keep_shape = ov_opset.scatter_elements_update(shape, axis_const, one, 0).output(0)
1039+
median_val = ov_opset.reshape(median_val, keep_shape, False).output(0)
10541040

1055-
median_value = ov_opset.convert(median_value, x.get_element_type()).output(0)
1056-
return OpenVINOKerasTensor(median_value)
1041+
median_val = ov_opset.convert(median_val, orig_dtype).output(0)
1042+
return OpenVINOKerasTensor(median_val)
10571043

10581044

10591045
def meshgrid(*x, indexing="xy"):

0 commit comments

Comments
 (0)