@@ -996,7 +996,111 @@ def maximum(x1, x2):
996
996
997
997
998
998
def median (x , axis = None , keepdims = False ):
999
- raise NotImplementedError ("`median` is not supported with openvino backend" )
999
+ x = get_ov_output (x )
1000
+
1001
+ # Flatten the tensor if axis is None
1002
+ if axis is None :
1003
+ original_shape = ov_opset .shape_of (x , dtype = Type .i64 ).output (0 )
1004
+ flatten_shape = ov_opset .constant ([- 1 ], Type .i32 ).output (0 )
1005
+ x = ov_opset .reshape (x , flatten_shape , False ).output (0 )
1006
+ axis = 0
1007
+
1008
+ # Convert axis to constant
1009
+ axis_const = ov_opset .constant (axis , dtype = Type .i32 ).output (0 )
1010
+
1011
+ # Get the shape of the tensor
1012
+ shape = ov_opset .shape_of (x , dtype = Type .i64 ).output (0 )
1013
+
1014
+ # Compute the length of the axis
1015
+ if axis is not None :
1016
+ indices = ov_opset .constant ([axis ], dtype = Type .i32 ).output (0 )
1017
+ length = ov_opset .gather (shape , indices , 0 ).output (0 )
1018
+ else :
1019
+ length = ov_opset .shape_of (shape , dtype = Type .i64 ).output (0 )
1020
+ length = ov_opset .reshape (length , ov_opset .constant ([], dtype = Type .i32 ).output (0 ), False ).output (0 )
1021
+
1022
+ # Sort the tensor along the axis
1023
+ sorted_x = ov_opset .topk (x , length , axis , "value" , "ascending" , "f32" ).output (0 )
1024
+
1025
+ # Get the indices of the middle elements
1026
+ const_2 = ov_opset .constant (2 , dtype = Type .i64 ).output (0 )
1027
+ mid_index = ov_opset .floor_mod (length , const_2 ).output (0 )
1028
+ is_odd = ov_opset .equal (mid_index , ov_opset .constant (1 , dtype = Type .i64 ).output (0 )).output (0 )
1029
+
1030
+ # Calculate indices for middle elements
1031
+ half_length = ov_opset .divide (length , const_2 ).output (0 )
1032
+ floor_half_length = ov_opset .floor (half_length ).output (0 )
1033
+ floor_half_length = ov_opset .convert (floor_half_length , Type .i64 ).output (0 )
1034
+ ceil_half_length = ov_opset .ceiling (half_length ).output (0 )
1035
+ ceil_half_length = ov_opset .convert (ceil_half_length , Type .i64 ).output (0 )
1036
+
1037
+ # Create a slice to extract the median value(s)
1038
+ slice_begin = ov_opset .constant ([0 ], dtype = Type .i64 ).output (0 )
1039
+ slice_begin_with_axis = ov_opset .broadcast (slice_begin , ov_opset .shape_of (shape , dtype = Type .i64 ).output (0 )).output (0 )
1040
+
1041
+ # For odd length, take the middle element
1042
+ # For even length, take the average of two middle elements
1043
+ mid_elem_indices = ov_opset .select (is_odd , floor_half_length , floor_half_length ).output (0 )
1044
+
1045
+ # Get the middle element(s)
1046
+ if axis is not None :
1047
+ # Prepare indices for gather
1048
+ gather_indices = ov_opset .range (
1049
+ ov_opset .constant (0 , dtype = Type .i64 ).output (0 ),
1050
+ mid_elem_indices ,
1051
+ ov_opset .constant (1 , dtype = Type .i64 ).output (0 ),
1052
+ "i64"
1053
+ ).output (0 )
1054
+
1055
+ # Get the middle element
1056
+ middle_elem = ov_opset .gather (sorted_x , mid_elem_indices , axis ).output (0 )
1057
+
1058
+ # If even length, get the element before the middle and calculate average
1059
+ prev_mid_elem_indices = ov_opset .subtract (mid_elem_indices , ov_opset .constant (1 , dtype = Type .i64 ).output (0 )).output (0 )
1060
+ prev_middle_elem = ov_opset .gather (sorted_x , prev_mid_elem_indices , axis ).output (0 )
1061
+
1062
+ # Calculate the median: if odd use middle element, if even use average of two middle elements
1063
+ median_value = ov_opset .select (
1064
+ is_odd ,
1065
+ middle_elem ,
1066
+ ov_opset .divide (
1067
+ ov_opset .add (middle_elem , prev_middle_elem ).output (0 ),
1068
+ ov_opset .constant (2.0 , dtype = middle_elem .get_element_type ()).output (0 )
1069
+ ).output (0 )
1070
+ ).output (0 )
1071
+ else :
1072
+ # For flattened tensor
1073
+ mid_index_scalar = ov_opset .convert (mid_elem_indices , Type .i32 ).output (0 )
1074
+ middle_elem = ov_opset .gather (sorted_x , mid_index_scalar , 0 ).output (0 )
1075
+
1076
+ prev_mid_elem_indices = ov_opset .subtract (mid_elem_indices , ov_opset .constant (1 , dtype = Type .i64 ).output (0 )).output (0 )
1077
+ prev_mid_index_scalar = ov_opset .convert (prev_mid_elem_indices , Type .i32 ).output (0 )
1078
+ prev_middle_elem = ov_opset .gather (sorted_x , prev_mid_index_scalar , 0 ).output (0 )
1079
+
1080
+ median_value = ov_opset .select (
1081
+ is_odd ,
1082
+ middle_elem ,
1083
+ ov_opset .divide (
1084
+ ov_opset .add (middle_elem , prev_middle_elem ).output (0 ),
1085
+ ov_opset .constant (2.0 , dtype = middle_elem .get_element_type ()).output (0 )
1086
+ ).output (0 )
1087
+ ).output (0 )
1088
+
1089
+ # Reshape if needed
1090
+ if keepdims :
1091
+ # Create keepdims shape
1092
+ keep_shape = shape
1093
+ if axis is not None :
1094
+ one_tensor = ov_opset .constant (1 , dtype = Type .i64 ).output (0 )
1095
+ indices = ov_opset .constant ([axis ], dtype = Type .i32 ).output (0 )
1096
+ keep_shape = ov_opset .scatter_elements_update (shape , indices , one_tensor , 0 ).output (0 )
1097
+ median_value = ov_opset .reshape (median_value , keep_shape , False ).output (0 )
1098
+ elif axis is None and x .get_partial_shape ().rank .get_length () > 1 :
1099
+ # Reshape back to scalar for axis=None case if original input was not a scalar
1100
+ scalar_shape = ov_opset .constant ([], dtype = Type .i32 ).output (0 )
1101
+ median_value = ov_opset .reshape (median_value , scalar_shape , False ).output (0 )
1102
+
1103
+ return OpenVINOKerasTensor (median_value )
1000
1104
1001
1105
1002
1106
def meshgrid (* x , indexing = "xy" ):
0 commit comments