Skip to content

Commit

Permalink
[GR-18163] Reimplement Float#round in a way it is done in MRI
Browse files Browse the repository at this point in the history
PullRequest: truffleruby/4458
  • Loading branch information
andrykonchin authored and eregon committed Jan 31, 2025
2 parents 7722e65 + fed380d commit d94ffd1
Show file tree
Hide file tree
Showing 6 changed files with 152 additions and 267 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ Compatibility:
* Fix `Integer#ceil` when self is 0 (@andrykonchin).
* Fix `Module#remove_const` and emit warning when constant is deprecated (@andrykonchin).
* Add `Module#set_temporary_name` (#3681, @andrykonchin).
* Modify `Float#round` to match MRI behavior (#3676, @andrykonchin).

Performance:

Expand Down
9 changes: 9 additions & 0 deletions spec/ruby/core/float/round_spec.rb
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,11 @@
12.345678.round(3.999).should == 12.346
end

it "correctly rounds exact floats with a numerous digits in a fraction part" do
0.8241000000000004.round(10).should == 0.8241
0.8241000000000002.round(10).should == 0.8241
end

it "returns zero when passed a negative argument with magnitude greater than magnitude of the whole number portion of the Float" do
0.8346268.round(-1).should eql(0)
end
Expand Down Expand Up @@ -68,6 +73,10 @@
0.42.round(2.0**30).should == 0.42
end

it "returns rounded values for not so big argument" do
0.42.round(2.0**23).should == 0.42
end

it "returns big values rounded to nearest" do
+2.5e20.round(-20).should eql( +3 * 10 ** 20 )
-2.5e20.round(-20).should eql( -3 * 10 ** 20 )
Expand Down
1 change: 1 addition & 0 deletions spec/tags/core/float/round_tags.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
fails:Float#round returns rounded values for big argument
223 changes: 0 additions & 223 deletions src/main/java/org/truffleruby/core/numeric/FloatNodes.java
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
import com.oracle.truffle.api.profiles.InlinedConditionProfile;
import com.oracle.truffle.api.strings.TruffleString;
import org.truffleruby.annotations.Split;
import org.truffleruby.annotations.SuppressFBWarnings;
import org.truffleruby.annotations.CoreMethod;
import org.truffleruby.builtins.CoreMethodArrayArgumentsNode;
import org.truffleruby.annotations.CoreModule;
Expand Down Expand Up @@ -623,228 +622,6 @@ public abstract static class FloatFloorNDigitsPrimitiveNode extends PrimitiveArr

}

@ImportStatic(FloatRoundGuards.class)
@Primitive(name = "float_round_up")
public abstract static class FloatRoundUpPrimitiveNode extends PrimitiveArrayArgumentsNode {

@Specialization(guards = "fitsInInteger(n)")
int roundFittingInt(double n) {
int l = (int) n;
int signum = (int) Math.signum(n);
double d = Math.abs(n - l);
if (d >= 0.5) {
l += signum;
}
return l;
}

@Specialization(guards = "fitsInLong(n)", replaces = "roundFittingInt")
long roundFittingLong(double n) {
long l = (long) n;
long signum = (long) Math.signum(n);
double d = Math.abs(n - l);
if (d >= 0.5) {
l += signum;
}
return l;
}

@Specialization(replaces = "roundFittingLong")
Object round(double n,
@Cached FloatToIntegerNode floatToIntegerNode) {
double signum = Math.signum(n);
double f = Math.floor(Math.abs(n));
double d = Math.abs(n) - f;
if (d >= 0.5) {
f += 1;
}
return floatToIntegerNode.execute(this, f * signum);
}
}

@ImportStatic(FloatRoundGuards.class)
@Primitive(name = "float_round_up_decimal", lowerFixnum = 1)
public abstract static class FloatRoundUpDecimalPrimitiveNode extends PrimitiveArrayArgumentsNode {

@Specialization
double roundNDecimal(double n, int ndigits,
@Cached InlinedConditionProfile boundaryCase) {
long intPart = (long) n;
double s = Math.pow(10.0, ndigits) * Math.signum(n);
double f = (n % 1) * s;
long fInt = (long) f;
double d = f % 1;
int limit = Math.getExponent(n) + Math.getExponent(s) - 51;
if (boundaryCase.profile(this, (Math.getExponent(d) <= limit) ||
(Math.getExponent(1.0 - d) <= limit))) {
return findClosest(n, s, d);
} else if (d > 0.5 || Math.abs(n) - Math.abs((intPart + (fInt + 0.5) / s)) >= 0) {
fInt += 1;
}
return intPart + fInt / s;
}
}

/* If the rounding result is very near to an integer boundary then we need to find the number that is closest to the
* correct result. If we don't do this then it's possible to get errors in the least significant bit of the result.
* We'll test the adjacent double in the direction closest to the boundary and compare the fractional portions. If
* we're already at the minimum error we'll return the original number as it is already rounded as well as it can
* be. In the case of a tie we return the lower number, otherwise we check the go round again. */
private static double findClosest(double n, double s, double d) {
double n2;
while (true) {
if (d > 0.5) {
n2 = Math.nextAfter(n, n + s);
} else {
n2 = Math.nextAfter(n, n - s);
}
double f = (n2 % 1) * s;
double d2 = f % 1;
if (((d > 0.5) ? 1 - d : d) < ((d2 > 0.5) ? 1 - d2 : d2)) {
return n;
} else if (((d > 0.5) ? 1 - d : d) == ((d2 > 0.5) ? 1 - d2 : d2)) {
return Math.abs(n) < Math.abs(n2) ? n : n2;
} else {
d = d2;
n = n2;
}
}
}

@SuppressFBWarnings("FE_FLOATING_POINT_EQUALITY")
@ImportStatic(FloatRoundGuards.class)
@Primitive(name = "float_round_even")
public abstract static class FloatRoundEvenPrimitiveNode extends PrimitiveArrayArgumentsNode {

@Specialization(guards = { "fitsInInteger(n)" })
int roundFittingInt(double n) {
int l = (int) n;
int signum = (int) Math.signum(n);
double d = Math.abs(n - l);
if (d > 0.5) {
l += signum;
} else if (d == 0.5) {
l += l % 2;
}
return l;
}

@Specialization(guards = "fitsInLong(n)", replaces = "roundFittingInt")
long roundFittingLong(double n) {
long l = (long) n;
long signum = (long) Math.signum(n);
double d = Math.abs(n - l);
if (d > 0.5) {
l += signum;
} else if (d == 0.5) {
l += l % 2;
}
return l;
}

@Specialization(replaces = "roundFittingLong")
Object round(double n,
@Cached FloatToIntegerNode floatToIntegerNode) {
double signum = Math.signum(n);
double f = Math.floor(Math.abs(n));
double d = Math.abs(n) - f;
if (d > 0.5) {
f += signum;
} else if (d == 0.5) {
f += f % 2;
}
return floatToIntegerNode.execute(this, f * signum);
}
}

@SuppressFBWarnings("FE_FLOATING_POINT_EQUALITY")
@ImportStatic(FloatRoundGuards.class)
@Primitive(name = "float_round_even_decimal", lowerFixnum = 1)
public abstract static class FloatRoundEvenDecimalPrimitiveNode extends PrimitiveArrayArgumentsNode {

@Specialization
double roundNDecimal(double n, int ndigits,
@Cached InlinedConditionProfile boundaryCase) {
long intPart = (long) n;
double s = Math.pow(10.0, ndigits) * Math.signum(n);
double f = (n % 1) * s;
long fInt = (long) f;
double d = f % 1;
int limit = Math.getExponent(n) + Math.getExponent(s) - 51;
if (boundaryCase.profile(this, (Math.getExponent(d) <= limit) ||
(Math.getExponent(1.0 - d) <= limit))) {
return findClosest(n, s, d);
} else if (d > 0.5) {
fInt += 1;
} else if (d == 0.5 || Math.abs(n) - Math.abs((intPart + (fInt + 0.5) / s)) >= 0) {
fInt += fInt % 2;
}
return intPart + fInt / s;
}
}

@ImportStatic(FloatRoundGuards.class)
@Primitive(name = "float_round_down")
public abstract static class FloatRoundDownPrimitiveNode extends PrimitiveArrayArgumentsNode {

@Specialization(guards = "fitsInInteger(n)")
int roundFittingInt(double n) {
int l = (int) n;
int signum = (int) Math.signum(n);
double d = Math.abs(n - l);
if (d > 0.5) {
l += signum;
}
return l;
}

@Specialization(guards = "fitsInLong(n)", replaces = "roundFittingInt")
long roundFittingLong(double n) {
long l = (long) n;
long signum = (long) Math.signum(n);
double d = Math.abs(n - l);
if (d > 0.5) {
l += signum;
}
return l;
}

@Specialization(replaces = "roundFittingLong")
Object round(double n,
@Cached FloatToIntegerNode floatToIntegerNode) {
double signum = Math.signum(n);
double f = Math.floor(Math.abs(n));
double d = Math.abs(n) - f;
if (d > 0.5) {
f += 1;
}
return floatToIntegerNode.execute(this, f * signum);
}
}

@ImportStatic(FloatRoundGuards.class)
@Primitive(name = "float_round_down_decimal", lowerFixnum = 1)
public abstract static class FloatRoundDownDecimalPrimitiveNode extends PrimitiveArrayArgumentsNode {

@Specialization
double roundNDecimal(double n, int ndigits,
@Cached InlinedConditionProfile boundaryCase) {
long intPart = (long) n;
double s = Math.pow(10.0, ndigits) * Math.signum(n);
double f = (n % 1) * s;
long fInt = (long) f;
double d = f % 1;
int limit = Math.getExponent(n) + Math.getExponent(s) - 51;
if (boundaryCase.profile(this, (Math.getExponent(d) <= limit) ||
(Math.getExponent(1.0 - d) <= limit))) {
return findClosest(n, s, d);
} else if (d > 0.5 && Math.abs(n) - Math.abs((intPart + (fInt + 0.5) / s)) > 0) {
fInt += 1;
}
return intPart + fInt / s;
}
}

@Primitive(name = "float_exp")
public abstract static class FloatExpNode extends PrimitiveArrayArgumentsNode {

Expand Down
56 changes: 29 additions & 27 deletions src/main/ruby/truffleruby/core/float.rb
Original file line number Diff line number Diff line change
Expand Up @@ -193,42 +193,44 @@ def floor(ndigits = undefined)
end
end

def round(ndigits = undefined, half: nil)
ndigits = if Primitive.undefined?(ndigits)
nil
else
Truffle::Type.coerce_to(ndigits, Integer, :to_int)
end
def round(ndigits = 0, half: :up)
ndigits = Truffle::Type.coerce_to(ndigits, Integer, :to_int)

if self == 0.0
return ndigits && ndigits > 0 ? self : 0
return ndigits > 0 ? self : 0
end

half = :up if Primitive.nil?(half)
if half != :up && half != :down && half != :even
raise ArgumentError, "invalid rounding mode: #{half}"
end
if Primitive.nil?(ndigits)
if infinite?

if ndigits <= 0
if self.infinite?
raise FloatDomainError, 'Infinite'
elsif nan?
elsif self.nan?
raise FloatDomainError, 'NaN'
else
case half
when nil, :up
Primitive.float_round_up(self)
when :even
Primitive.float_round_even(self)
when :down
Primitive.float_round_down(self)
else
raise ArgumentError, "invalid rounding mode: #{half}"
end
end
else
if ndigits == 0
round(half: half)
elsif ndigits < 0
to_i.round(ndigits, :half => half)
elsif infinite? or nan?
end

if ndigits < 0
to_i.round(ndigits, half: half)
elsif ndigits == 0
Truffle::FloatOperations.round_to_n_place(self, ndigits, half)
elsif !infinite? && !nan?
exponent = Primitive.float_exp(self)

if Truffle::FloatOperations.round_overflow?(ndigits, exponent)
self
elsif Truffle::FloatOperations.round_overflow?(ndigits, exponent)
0.0
elsif ndigits > 14
to_r.round(ndigits, half: half).to_f
else
Truffle::FloatOperations.round_to_n_place(self, ndigits, half)
end
else
self # Infinity or NaN
end
end

Expand Down
Loading

0 comments on commit d94ffd1

Please sign in to comment.