diff --git a/keras/src/backend/openvino/excluded_concrete_tests.txt b/keras/src/backend/openvino/excluded_concrete_tests.txt index f1ae053e623..33f07cf84b9 100644 --- a/keras/src/backend/openvino/excluded_concrete_tests.txt +++ b/keras/src/backend/openvino/excluded_concrete_tests.txt @@ -33,7 +33,6 @@ NumpyDtypeTest::test_logspace NumpyDtypeTest::test_matmul_ NumpyDtypeTest::test_max NumpyDtypeTest::test_mean -NumpyDtypeTest::test_median NumpyDtypeTest::test_meshgrid NumpyDtypeTest::test_min NumpyDtypeTest::test_moveaxis @@ -92,7 +91,6 @@ NumpyOneInputOpsCorrectnessTest::test_log1p NumpyOneInputOpsCorrectnessTest::test_logaddexp NumpyOneInputOpsCorrectnessTest::test_max NumpyOneInputOpsCorrectnessTest::test_mean -NumpyOneInputOpsCorrectnessTest::test_median NumpyOneInputOpsCorrectnessTest::test_meshgrid NumpyOneInputOpsCorrectnessTest::test_min NumpyOneInputOpsCorrectnessTest::test_moveaxis diff --git a/keras/src/backend/openvino/numpy.py b/keras/src/backend/openvino/numpy.py index a242a3c8d56..45d7ebc657d 100644 --- a/keras/src/backend/openvino/numpy.py +++ b/keras/src/backend/openvino/numpy.py @@ -996,7 +996,75 @@ def maximum(x1, x2): def median(x, axis=None, keepdims=False): - raise NotImplementedError("`median` is not supported with openvino backend") + x_node = get_ov_output(x) + orig_dtype = x_node.get_element_type() + orig_shape = ov_opset.shape_of(x_node).output(0) + + comp_dtype = orig_dtype + if orig_dtype.is_integral() or orig_dtype == Type.boolean: + comp_dtype = OPENVINO_DTYPES[config.floatx()] + x_node = ov_opset.convert(x_node, comp_dtype).output(0) + + rank_dim = x_node.get_partial_shape().rank + if not rank_dim.is_static: + raise RuntimeError("median: dynamic rank not supported for keepdims") + orig_rank = rank_dim.get_length() + + if axis is None: + reshape_shape = ov_opset.constant([-1], Type.i64).output(0) + x_proc = ov_opset.reshape(x_node, reshape_shape, False).output(0) + reduction_axis = 0 + keep_axes = list(range(orig_rank)) + else: + if isinstance(axis, (tuple, list)): + if len(axis) == 1: + reduction_axis = int(axis[0]) + else: + raise ValueError("median: multi-axis reduction not supported") + else: + reduction_axis = int(axis) + x_proc = x_node + keep_axes = [reduction_axis] + + shape_i64 = ov_opset.convert(ov_opset.shape_of(x_proc).output(0), Type.i64).output(0) + length = ov_opset.gather(shape_i64, ov_opset.constant([reduction_axis], Type.i64).output(0), 0).output(0) + length = ov_opset.reshape(length, ov_opset.constant([], Type.i64).output(0), False).output(0) + + sorted_x = ov_opset.topk(x_proc, length, reduction_axis, "min", "value").output(0) + + two = ov_opset.constant(2, Type.i64).output(0) + half = ov_opset.floor(ov_opset.divide(length, two).output(0)).output(0) + mid_idx = ov_opset.convert(half, Type.i32).output(0) + prev_idx = ov_opset.convert( + ov_opset.subtract(half, ov_opset.constant(1, Type.i64).output(0)).output(0), + Type.i32 + ).output(0) + + mid_val = ov_opset.gather(sorted_x, mid_idx, reduction_axis).output(0) + prev_val = ov_opset.gather(sorted_x, prev_idx, reduction_axis).output(0) + + avg = ov_opset.divide( + ov_opset.add(mid_val, prev_val).output(0), + ov_opset.constant(2.0, mid_val.get_element_type()).output(0) + ).output(0) + + rem = ov_opset.floor_mod(length, two).output(0) + is_odd = ov_opset.equal(rem, ov_opset.constant(1, Type.i64).output(0)).output(0) + median_val = ov_opset.select(is_odd, mid_val, avg).output(0) + + if keepdims: + idx_const = ov_opset.constant(keep_axes, Type.i64).output(0) + ones = ov_opset.constant([1] * len(keep_axes), Type.i64).output(0) + new_shape = ov_opset.scatter_elements_update( + ov_opset.convert(orig_shape, Type.i64).output(0), + idx_const, + ones, + 0 + ).output(0) + median_val = ov_opset.reshape(median_val, new_shape, False).output(0) + + result = ov_opset.convert(median_val, comp_dtype).output(0) + return OpenVINOKerasTensor(result) def meshgrid(*x, indexing="xy"):