@@ -996,64 +996,50 @@ def maximum(x1, x2):
996
996
997
997
998
998
def median (x , axis = None , keepdims = False ):
999
- x = get_ov_output (x )
1000
-
999
+ x_node = get_ov_output (x )
1000
+ orig_dtype = x_node .get_element_type ()
1001
+
1001
1002
if axis is None :
1002
- flatten_shape = ov_opset .constant ([- 1 ], Type .i32 ).output (0 )
1003
- x = ov_opset .reshape (x , flatten_shape , False ).output (0 )
1004
- axis = 0
1005
-
1006
- shape = ov_opset .shape_of (x ).output (0 )
1007
- shape = ov_opset .convert (shape , Type .i64 ).output (0 )
1008
-
1009
- if axis is not None :
1010
- indices = ov_opset .constant ([axis ], Type .i32 ).output (0 )
1011
- length = ov_opset .gather (shape , indices , 0 ).output (0 )
1012
- length = ov_opset .reshape (length , ov_opset .constant ([], Type .i32 ).output (0 ), False ).output (0 )
1003
+ flatten_shape = ov_opset .constant ([- 1 ], Type .i64 ).output (0 )
1004
+ x_node = ov_opset .reshape (x_node , flatten_shape , False ).output (0 )
1005
+ axis_val = 0
1013
1006
else :
1014
- length = ov_opset .shape_of (shape ).output (0 )
1015
- length = ov_opset .convert (length , Type .i64 ).output (0 )
1016
- length = ov_opset .reshape (length , ov_opset .constant ([], Type .i32 ).output (0 ), False ).output (0 )
1017
-
1018
- sorted_x = ov_opset .topk (x , length , axis , "min" , "value" ).output (0 )
1019
-
1020
- const_2 = ov_opset .constant (2 , Type .i64 ).output (0 )
1021
- mid_index = ov_opset .floor_mod (length , const_2 ).output (0 )
1022
- is_odd = ov_opset .equal (mid_index , ov_opset .constant (1 , Type .i64 ).output (0 )).output (0 )
1023
-
1024
- half_length = ov_opset .divide (length , const_2 ).output (0 )
1025
- floor_half_length = ov_opset .floor (half_length ).output (0 )
1026
- floor_half_length = ov_opset .convert (floor_half_length , Type .i64 ).output (0 )
1027
-
1028
- mid_index_scalar = ov_opset .convert (floor_half_length , Type .i32 ).output (0 )
1029
- middle_elem = ov_opset .gather (sorted_x , mid_index_scalar , axis ).output (0 )
1030
-
1031
- prev_mid_index = ov_opset .subtract (floor_half_length , ov_opset .constant (1 , Type .i64 ).output (0 )).output (0 )
1032
- prev_mid_index_scalar = ov_opset .convert (prev_mid_index , Type .i32 ).output (0 )
1033
- prev_middle_elem = ov_opset .gather (sorted_x , prev_mid_index_scalar , axis ).output (0 )
1034
-
1035
- median_value = ov_opset .select (
1036
- is_odd ,
1037
- middle_elem ,
1038
- ov_opset .divide (
1039
- ov_opset .add (middle_elem , prev_middle_elem ).output (0 ),
1040
- ov_opset .constant (2.0 , middle_elem .get_element_type ()).output (0 )
1041
- ).output (0 )
1042
- ).output (0 )
1043
-
1007
+ axis_val = axis
1008
+
1009
+ shape = ov_opset .shape_of (x_node ).output (0 )
1010
+ shape = ov_opset .convert (shape , Type .i64 ).output (0 )
1011
+
1012
+ axis_const = ov_opset .constant ([axis_val ], Type .i64 ).output (0 )
1013
+ length = ov_opset .gather (shape , axis_const , 0 ).output (0 )
1014
+ length = ov_opset .reshape (length , ov_opset .constant ([], Type .i64 ).output (0 ), False ).output (0 )
1015
+
1016
+ sorted_x = ov_opset .topk (x_node , length , axis_val , "min" , "value" ).output (0 )
1017
+
1018
+ two = ov_opset .constant (2 , Type .i64 ).output (0 )
1019
+
1020
+ half = ov_opset .divide (length , two ).output (0 )
1021
+ mid_idx = ov_opset .floor (half ).output (0 )
1022
+ rem = ov_opset .floor_mod (length , two ).output (0 )
1023
+ is_odd = ov_opset .equal (rem , ov_opset .constant (1 , Type .i64 ).output (0 )).output (0 )
1024
+
1025
+ mid_idx_i32 = ov_opset .convert (mid_idx , Type .i32 ).output (0 )
1026
+ middle = ov_opset .gather (sorted_x , mid_idx_i32 , axis_val ).output (0 )
1027
+
1028
+ prev_idx = ov_opset .subtract (mid_idx , ov_opset .constant (1 , Type .i64 ).output (0 )).output (0 )
1029
+ prev_idx_i32 = ov_opset .convert (prev_idx , Type .i32 ).output (0 )
1030
+ prev = ov_opset .gather (sorted_x , prev_idx_i32 , axis_val ).output (0 )
1031
+ sum_val = ov_opset .add (middle , prev ).output (0 )
1032
+ avg = ov_opset .divide (sum_val , ov_opset .constant (2.0 , sum_val .get_element_type ()).output (0 )).output (0 )
1033
+
1034
+ median_val = ov_opset .select (is_odd , middle , avg ).output (0 )
1035
+
1044
1036
if keepdims :
1045
- keep_shape = shape
1046
- if axis is not None :
1047
- one_tensor = ov_opset .constant (1 , Type .i64 ).output (0 )
1048
- indices = ov_opset .constant ([axis ], Type .i32 ).output (0 )
1049
- keep_shape = ov_opset .scatter_elements_update (shape , indices , one_tensor , 0 ).output (0 )
1050
- median_value = ov_opset .reshape (median_value , keep_shape , False ).output (0 )
1051
- elif axis is None and x .get_partial_shape ().rank .get_length () > 1 :
1052
- scalar_shape = ov_opset .constant ([], Type .i32 ).output (0 )
1053
- median_value = ov_opset .reshape (median_value , scalar_shape , False ).output (0 )
1037
+ one = ov_opset .constant (1 , Type .i64 ).output (0 )
1038
+ keep_shape = ov_opset .scatter_elements_update (shape , axis_const , one , 0 ).output (0 )
1039
+ median_val = ov_opset .reshape (median_val , keep_shape , False ).output (0 )
1054
1040
1055
- median_value = ov_opset .convert (median_value , x . get_element_type () ).output (0 )
1056
- return OpenVINOKerasTensor (median_value )
1041
+ median_val = ov_opset .convert (median_val , orig_dtype ).output (0 )
1042
+ return OpenVINOKerasTensor (median_val )
1057
1043
1058
1044
1059
1045
def meshgrid (* x , indexing = "xy" ):
0 commit comments