Skip to content

Commit

Permalink
Avoid calling log(0) when generating gaussian random variables (#662)
Browse files Browse the repository at this point in the history
* owl_stats_ziggurat: avoid log(0)

sfmt_f64_1 is documented to include 0, which would result in `log(0) =
neg_infinity`.
use sfmt_f64_3 instead which is documented to return `(0, 1)` instead.

This should also match the paper, which says "UNI floats it to (0,1)":
> Marsaglia, George, and Wai Wan Tsang.
> "The ziggurat method for generating random variables."
> Journal of statistical software 5 (2000): 1-7.

See #661

Signed-off-by: Edwin Török <[email protected]>

* Owl_base_stats_dist_uniform: add a function that returns a random float in (0,1)

Signed-off-by: Edwin Török <[email protected]>

* Owl_base_stats_dist_gaussian.{std_gaussian_rvs,gaussian_rvs}: avoid infinity on Random.float returning 0

#661

Signed-off-by: Edwin Török <[email protected]>

---------

Signed-off-by: Edwin Török <[email protected]>
  • Loading branch information
edwintorok authored Mar 25, 2024
1 parent c66069e commit 621cbff
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 10 deletions.
10 changes: 5 additions & 5 deletions src/base/stats/owl_base_stats_dist_gaussian.ml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
* OWL - OCaml Scientific Computing
* Copyright (c) 2016-2022 Liang Wang <[email protected]>
*)

open Owl_base_stats_dist_uniform
let _u1 = ref 0.

let _u2 = ref 0.
Expand All @@ -20,8 +20,8 @@ let std_gaussian_rvs () =
!_z1)
else (
_case := true;
_u1 := Random.float 1.;
_u2 := Random.float 1.;
_u1 := rand01_exclusive ();
_u2 := rand01_exclusive ();
_z0 := sqrt (~-.2. *. log !_u1) *. cos (2. *. Owl_const.pi *. !_u2);
_z1 := sqrt (~-.2. *. log !_u1) *. sin (2. *. Owl_const.pi *. !_u2);
!_z0)
Expand All @@ -35,8 +35,8 @@ let gaussian_rvs ~mu ~sigma =
mu +. (sigma *. !_z1))
else (
_case := true;
_u1 := Random.float 1.;
_u2 := Random.float 1.;
_u1 := rand01_exclusive ();
_u2 := rand01_exclusive ();
_z0 := sqrt (~-.2. *. log !_u1) *. cos (2. *. Owl_const.pi *. !_u2);
_z1 := sqrt (~-.2. *. log !_u1) *. sin (2. *. Owl_const.pi *. !_u2);
mu +. (sigma *. !_z0))
5 changes: 5 additions & 0 deletions src/base/stats/owl_base_stats_dist_uniform.ml
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,8 @@ let uniform_int_rvs n = Random.int n
let std_uniform_rvs () = Random.float 1.

let uniform_rvs ~a ~b = a +. ((b -. a) *. Random.float 1.)

(* The constants below are Printf.printf "%h,%h" (Float.succ 0.) (Float.pred 1.)
Also [Float.succ 0. +. Float.pred 1. < 1.]
*)
let rand01_exclusive () = 0x0.0000000000001p-1022 +. Random.float 0x1.fffffffffffffp-1
10 changes: 5 additions & 5 deletions src/owl/stats/owl_stats_ziggurat.c
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,13 @@ inline double std_exponential_rvs ( ) {
else {
for ( ; ; ) {
if ( iz == 0 ) {
value = 7.69711 - log ( sfmt_f64_1 );
value = 7.69711 - log ( sfmt_f64_3 );
break;
}

x = jz * we[iz];

if ( fe[iz] + sfmt_f64_1 * ( fe[iz-1] - fe[iz] ) < exp ( - x ) ) {
if ( fe[iz] + sfmt_f64_3 * ( fe[iz-1] - fe[iz] ) < exp ( - x ) ) {
value = x;
break;
}
Expand Down Expand Up @@ -92,8 +92,8 @@ inline double std_gaussian_rvs ( ) {
for ( ; ; ) {
if ( iz == 0 ) {
for ( ; ; ) {
x = - 0.2904764 * log ( sfmt_f64_1 );
y = - log ( sfmt_f64_1 );
x = - 0.2904764 * log ( sfmt_f64_3 );
y = - log ( sfmt_f64_3 );
if ( x * x <= y + y )
break;
}
Expand All @@ -103,7 +103,7 @@ inline double std_gaussian_rvs ( ) {

x = hz * wn[iz];

if ( fn[iz] + ( sfmt_f64_1 ) * ( fn[iz-1] - fn[iz] ) < exp ( - 0.5 * x * x ) ) {
if ( fn[iz] + ( sfmt_f64_3 ) * ( fn[iz-1] - fn[iz] ) < exp ( - 0.5 * x * x ) ) {
value = x;
break;
}
Expand Down

0 comments on commit 621cbff

Please sign in to comment.