Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/Makefile.manual
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ stdlib_stats_mean.o: \
stdlib_stats_median.o: \
stdlib_optval.o \
stdlib_kinds.o \
stdlib_sorting.o \
stdlib_selection.o \
stdlib_stats.o
stdlib_stats_moment.o: \
stdlib_optval.o \
Expand Down
66 changes: 40 additions & 26 deletions src/stdlib_stats_median.fypp
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,7 @@ submodule (stdlib_stats) stdlib_stats_median
use, intrinsic:: ieee_arithmetic, only: ieee_value, ieee_quiet_nan, ieee_is_nan
use stdlib_error, only: error_stop
use stdlib_optval, only: optval
! Use "ord_sort" rather than "sort" because the former can be much faster for arrays
! that are already partly sorted. While it is slightly slower for random arrays,
! ord_sort seems a better overall choice.
use stdlib_sorting, only: sort => ord_sort
use stdlib_selection, only: select
implicit none

contains
Expand All @@ -24,6 +21,7 @@ contains
real(${o1}$) :: res

integer(kind = int64) :: c, n
${t1}$ :: val, val1
${t1}$, allocatable :: x_tmp(:)

if (.not.optval(mask, .true.) .or. size(x) == 0) then
Expand All @@ -43,16 +41,18 @@ contains

x_tmp = reshape(x, [n])

call sort(x_tmp)
call select(x_tmp, c, val)

if (mod(n, 2_int64) == 0) then
call select(x_tmp, c+1, val1, left = c)
#:if t1[0] == 'r'
res = sum(x_tmp(c:c+1)) / 2._${o1}$
res = (val + val1) / 2._${o1}$
#:else
res = sum( real(x_tmp(c:c+1), kind=${o1}$) ) / 2._${o1}$
res = (real(val, kind=${o1}$) + &
real(val1, kind=${o1}$)) / 2._${o1}$
#:endif
else
res = x_tmp(c)
res = val
end if

end function ${name}$
Expand All @@ -74,6 +74,7 @@ contains
integer :: j${fj}$
#:endfor
#:endif
${t1}$ :: val, val1
${t1}$, allocatable :: x_tmp(:)

if (.not.optval(mask, .true.) .or. size(x) == 0) then
Expand Down Expand Up @@ -107,17 +108,18 @@ contains
end if
#:endif

call sort(x_tmp)
call select(x_tmp, c, val)

if (mod(n, 2) == 0) then
call select(x_tmp, c+1, val1, left = c)
res${reduce_subvector('j', rank, fi)}$ = &
#:if t1[0] == 'r'
sum(x_tmp(c:c+1)) / 2._${o1}$
(val + val1) / 2._${o1}$
#:else
sum(real(x_tmp(c:c+1), kind=${o1}$) ) / 2._${o1}$
(real(val, kind=${o1}$) + real(val1, kind=${o1}$)) / 2._${o1}$
#:endif
else
res${reduce_subvector('j', rank, fi)}$ = x_tmp(c)
res${reduce_subvector('j', rank, fi)}$ = val
end if
#:for fj in range(1, rank)
end do
Expand All @@ -141,6 +143,7 @@ contains
real(${o1}$) :: res

integer(kind = int64) :: c, n
${t1}$ :: val, val1
${t1}$, allocatable :: x_tmp(:)

if (any(shape(x) .ne. shape(mask))) then
Expand All @@ -156,21 +159,26 @@ contains

x_tmp = pack(x, mask)

call sort(x_tmp)

n = size(x_tmp, kind=int64)
c = floor( (n + 1) / 2._${o1}$, kind=int64)

if (n == 0) then
res = ieee_value(1._${o1}$, ieee_quiet_nan)
else if (mod(n, 2_int64) == 0) then
return
end if

c = floor( (n + 1) / 2._${o1}$, kind=int64)

call select(x_tmp, c, val)

if (mod(n, 2_int64) == 0) then
call select(x_tmp, c+1, val1, left = c)
#:if t1[0] == 'r'
res = sum(x_tmp(c:c+1)) / 2._${o1}$
res = (val + val1) / 2._${o1}$
#:else
res = sum(real(x_tmp(c:c+1), kind=${o1}$)) / 2._${o1}$
res = (real(val, kind=${o1}$) + real(val1, kind=${o1}$)) / 2._${o1}$
#:endif
else if (mod(n, 2_int64) == 1) then
res = x_tmp(c)
res = val
end if

end function ${name}$
Expand All @@ -192,6 +200,7 @@ contains
integer :: j${fj}$
#:endfor
#:endif
${t1}$ :: val, val1
${t1}$, allocatable :: x_tmp(:)

if (any(shape(x) .ne. shape(mask))) then
Expand Down Expand Up @@ -220,23 +229,28 @@ contains
end if
#:endif

call sort(x_tmp)

n = size(x_tmp, kind=int64)
c = floor( (n + 1) / 2._${o1}$, kind=int64 )

if (n == 0) then
res${reduce_subvector('j', rank, fi)}$ = &
ieee_value(1._${o1}$, ieee_quiet_nan)
else if (mod(n, 2_int64) == 0) then
return
end if

c = floor( (n + 1) / 2._${o1}$, kind=int64 )

call select(x_tmp, c, val)

if (mod(n, 2_int64) == 0) then
call select(x_tmp, c+1, val1, left = c)
res${reduce_subvector('j', rank, fi)}$ = &
#:if t1[0] == 'r'
sum(x_tmp(c:c+1)) / 2._${o1}$
(val + val1) / 2._${o1}$
#:else
sum(real(x_tmp(c:c+1), kind=${o1}$)) / 2._${o1}$
(real(val, kind=${o1}$) + real(val1, kind=${o1}$)) / 2._${o1}$
#:endif
else if (mod(n, 2_int64) == 1) then
res${reduce_subvector('j', rank, fi)}$ = x_tmp(c)
res${reduce_subvector('j', rank, fi)}$ = val
end if

deallocate(x_tmp)
Expand Down
6 changes: 0 additions & 6 deletions src/tests/stats/test_median.fypp
Original file line number Diff line number Diff line change
Expand Up @@ -117,12 +117,6 @@ contains
call check(error, median(d2odd_${k1}$), 1._dp&
, 'median(d2odd_${k1}$): uncorrect answer'&
, thr = dptol)
if (allocated(error)) return

call check(error, median(d2odd_${k1}$), 1._dp&
, 'median(d2odd_${k1}$): uncorrect answer'&
, thr = dptol)
if (allocated(error)) return

end subroutine

Expand Down