Skip to content

Commit 36442f4

Browse files
author
Saurabh Singh
committed
added numpy median for ov
1 parent 50dae30 commit 36442f4

File tree

3 files changed

+107
-24
lines changed

3 files changed

+107
-24
lines changed

.gitignore

+2-22
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,2 @@
1-
.DS_Store
2-
*.pyc
3-
.vscode-test
4-
__pycache__
5-
**/.vscode-test/**
6-
**/.vscode test/**
7-
**/.vscode-smoke/**
8-
**/.venv*/
9-
bin/**
10-
build/**
11-
obj/**
12-
.pytest_cache
13-
tmp/**
14-
.vs/
15-
dist/**
16-
*.egg-info/*
17-
.vscode
18-
examples/**/*.jpg
19-
.python-version
20-
.coverage
21-
*coverage.xml
22-
.ruff_cache
1+
# Created by venv; see https://docs.python.org/3/library/venv.html
2+
*

keras/src/backend/openvino/excluded_concrete_tests.txt

-1
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,6 @@ NumpyOneInputOpsCorrectnessTest::test_log1p
9292
NumpyOneInputOpsCorrectnessTest::test_logaddexp
9393
NumpyOneInputOpsCorrectnessTest::test_max
9494
NumpyOneInputOpsCorrectnessTest::test_mean
95-
NumpyOneInputOpsCorrectnessTest::test_median
9695
NumpyOneInputOpsCorrectnessTest::test_meshgrid
9796
NumpyOneInputOpsCorrectnessTest::test_min
9897
NumpyOneInputOpsCorrectnessTest::test_moveaxis

keras/src/backend/openvino/numpy.py

+105-1
Original file line numberDiff line numberDiff line change
@@ -996,7 +996,111 @@ def maximum(x1, x2):
996996

997997

998998
def median(x, axis=None, keepdims=False):
999-
raise NotImplementedError("`median` is not supported with openvino backend")
999+
x = get_ov_output(x)
1000+
1001+
# Flatten the tensor if axis is None
1002+
if axis is None:
1003+
original_shape = ov_opset.shape_of(x, dtype=Type.i64).output(0)
1004+
flatten_shape = ov_opset.constant([-1], Type.i32).output(0)
1005+
x = ov_opset.reshape(x, flatten_shape, False).output(0)
1006+
axis = 0
1007+
1008+
# Convert axis to constant
1009+
axis_const = ov_opset.constant(axis, dtype=Type.i32).output(0)
1010+
1011+
# Get the shape of the tensor
1012+
shape = ov_opset.shape_of(x, dtype=Type.i64).output(0)
1013+
1014+
# Compute the length of the axis
1015+
if axis is not None:
1016+
indices = ov_opset.constant([axis], dtype=Type.i32).output(0)
1017+
length = ov_opset.gather(shape, indices, 0).output(0)
1018+
else:
1019+
length = ov_opset.shape_of(shape, dtype=Type.i64).output(0)
1020+
length = ov_opset.reshape(length, ov_opset.constant([], dtype=Type.i32).output(0), False).output(0)
1021+
1022+
# Sort the tensor along the axis
1023+
sorted_x = ov_opset.topk(x, length, axis, "value", "ascending", "f32").output(0)
1024+
1025+
# Get the indices of the middle elements
1026+
const_2 = ov_opset.constant(2, dtype=Type.i64).output(0)
1027+
mid_index = ov_opset.floor_mod(length, const_2).output(0)
1028+
is_odd = ov_opset.equal(mid_index, ov_opset.constant(1, dtype=Type.i64).output(0)).output(0)
1029+
1030+
# Calculate indices for middle elements
1031+
half_length = ov_opset.divide(length, const_2).output(0)
1032+
floor_half_length = ov_opset.floor(half_length).output(0)
1033+
floor_half_length = ov_opset.convert(floor_half_length, Type.i64).output(0)
1034+
ceil_half_length = ov_opset.ceiling(half_length).output(0)
1035+
ceil_half_length = ov_opset.convert(ceil_half_length, Type.i64).output(0)
1036+
1037+
# Create a slice to extract the median value(s)
1038+
slice_begin = ov_opset.constant([0], dtype=Type.i64).output(0)
1039+
slice_begin_with_axis = ov_opset.broadcast(slice_begin, ov_opset.shape_of(shape, dtype=Type.i64).output(0)).output(0)
1040+
1041+
# For odd length, take the middle element
1042+
# For even length, take the average of two middle elements
1043+
mid_elem_indices = ov_opset.select(is_odd, floor_half_length, floor_half_length).output(0)
1044+
1045+
# Get the middle element(s)
1046+
if axis is not None:
1047+
# Prepare indices for gather
1048+
gather_indices = ov_opset.range(
1049+
ov_opset.constant(0, dtype=Type.i64).output(0),
1050+
mid_elem_indices,
1051+
ov_opset.constant(1, dtype=Type.i64).output(0),
1052+
"i64"
1053+
).output(0)
1054+
1055+
# Get the middle element
1056+
middle_elem = ov_opset.gather(sorted_x, mid_elem_indices, axis).output(0)
1057+
1058+
# If even length, get the element before the middle and calculate average
1059+
prev_mid_elem_indices = ov_opset.subtract(mid_elem_indices, ov_opset.constant(1, dtype=Type.i64).output(0)).output(0)
1060+
prev_middle_elem = ov_opset.gather(sorted_x, prev_mid_elem_indices, axis).output(0)
1061+
1062+
# Calculate the median: if odd use middle element, if even use average of two middle elements
1063+
median_value = ov_opset.select(
1064+
is_odd,
1065+
middle_elem,
1066+
ov_opset.divide(
1067+
ov_opset.add(middle_elem, prev_middle_elem).output(0),
1068+
ov_opset.constant(2.0, dtype=middle_elem.get_element_type()).output(0)
1069+
).output(0)
1070+
).output(0)
1071+
else:
1072+
# For flattened tensor
1073+
mid_index_scalar = ov_opset.convert(mid_elem_indices, Type.i32).output(0)
1074+
middle_elem = ov_opset.gather(sorted_x, mid_index_scalar, 0).output(0)
1075+
1076+
prev_mid_elem_indices = ov_opset.subtract(mid_elem_indices, ov_opset.constant(1, dtype=Type.i64).output(0)).output(0)
1077+
prev_mid_index_scalar = ov_opset.convert(prev_mid_elem_indices, Type.i32).output(0)
1078+
prev_middle_elem = ov_opset.gather(sorted_x, prev_mid_index_scalar, 0).output(0)
1079+
1080+
median_value = ov_opset.select(
1081+
is_odd,
1082+
middle_elem,
1083+
ov_opset.divide(
1084+
ov_opset.add(middle_elem, prev_middle_elem).output(0),
1085+
ov_opset.constant(2.0, dtype=middle_elem.get_element_type()).output(0)
1086+
).output(0)
1087+
).output(0)
1088+
1089+
# Reshape if needed
1090+
if keepdims:
1091+
# Create keepdims shape
1092+
keep_shape = shape
1093+
if axis is not None:
1094+
one_tensor = ov_opset.constant(1, dtype=Type.i64).output(0)
1095+
indices = ov_opset.constant([axis], dtype=Type.i32).output(0)
1096+
keep_shape = ov_opset.scatter_elements_update(shape, indices, one_tensor, 0).output(0)
1097+
median_value = ov_opset.reshape(median_value, keep_shape, False).output(0)
1098+
elif axis is None and x.get_partial_shape().rank.get_length() > 1:
1099+
# Reshape back to scalar for axis=None case if original input was not a scalar
1100+
scalar_shape = ov_opset.constant([], dtype=Type.i32).output(0)
1101+
median_value = ov_opset.reshape(median_value, scalar_shape, False).output(0)
1102+
1103+
return OpenVINOKerasTensor(median_value)
10001104

10011105

10021106
def meshgrid(*x, indexing="xy"):

0 commit comments

Comments
 (0)