Skip to content

Commit f773edf

Browse files
l46kokcopybara-github
authored andcommitted
Allow constant folding to fold equals operator
PiperOrigin-RevId: 810518582
1 parent f3a1b2b commit f773edf

File tree

2 files changed

+72
-2
lines changed

2 files changed

+72
-2
lines changed

optimizer/src/main/java/dev/cel/optimizer/optimizers/ConstantFoldingOptimizer.java

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,16 @@ private boolean canFold(CelNavigableMutableExpr navigableExpr) {
167167
&& cond.constant().getKind().equals(CelConstant.Kind.BOOLEAN_VALUE);
168168
}
169169

170+
if (functionName.equals(Operator.EQUALS.getFunction())
171+
|| functionName.equals(Operator.NOT_EQUALS.getFunction())) {
172+
if (mutableCall.args().stream()
173+
.anyMatch(node -> isExprConstantOfKind(node, CelConstant.Kind.BOOLEAN_VALUE))
174+
|| mutableCall.args().stream()
175+
.allMatch(node -> node.getKind().equals(Kind.CONSTANT))) {
176+
return true;
177+
}
178+
}
179+
170180
if (functionName.equals(Operator.IN.getFunction())) {
171181
return canFoldInOperator(navigableExpr);
172182
}
@@ -393,6 +403,38 @@ private Optional<CelMutableAst> maybePruneBranches(
393403
}
394404
}
395405
}
406+
} else if (function.equals(Operator.EQUALS.getFunction())
407+
|| function.equals(Operator.NOT_EQUALS.getFunction())) {
408+
CelMutableExpr lhs = call.args().get(0);
409+
CelMutableExpr rhs = call.args().get(1);
410+
boolean lhsIsBoolean = isExprConstantOfKind(lhs, CelConstant.Kind.BOOLEAN_VALUE);
411+
boolean rhsIsBoolean = isExprConstantOfKind(rhs, CelConstant.Kind.BOOLEAN_VALUE);
412+
boolean invertCondition = function.equals(Operator.NOT_EQUALS.getFunction());
413+
Optional<CelMutableExpr> replacementExpr = Optional.empty();
414+
415+
if (lhs.getKind().equals(Kind.CONSTANT) && rhs.getKind().equals(Kind.CONSTANT)) {
416+
// If both args are const, don't prune any branches and let maybeFold method evaluate this
417+
// subExpr
418+
return Optional.empty();
419+
} else if (lhsIsBoolean) {
420+
boolean cond = invertCondition != lhs.constant().booleanValue();
421+
replacementExpr =
422+
Optional.of(
423+
cond
424+
? rhs
425+
: CelMutableExpr.ofCall(
426+
CelMutableCall.create(Operator.LOGICAL_NOT.getFunction(), rhs)));
427+
} else if (rhsIsBoolean) {
428+
boolean cond = invertCondition != rhs.constant().booleanValue();
429+
replacementExpr =
430+
Optional.of(
431+
cond
432+
? lhs
433+
: CelMutableExpr.ofCall(
434+
CelMutableCall.create(Operator.LOGICAL_NOT.getFunction(), lhs)));
435+
}
436+
437+
return replacementExpr.map(node -> astMutator.replaceSubtree(mutableAst, node, expr.id()));
396438
}
397439

398440
return Optional.empty();
@@ -663,6 +705,10 @@ public static Builder newBuilder() {
663705
ConstantFoldingOptions() {}
664706
}
665707

708+
private static boolean isExprConstantOfKind(CelMutableExpr expr, CelConstant.Kind constantKind) {
709+
return expr.getKind().equals(Kind.CONSTANT) && expr.constant().getKind().equals(constantKind);
710+
}
711+
666712
private ConstantFoldingOptimizer(ConstantFoldingOptions constantFoldingOptions) {
667713
this.constantFoldingOptions = constantFoldingOptions;
668714
this.astMutator = AstMutator.newInstance(constantFoldingOptions.maxIterationLimit());

optimizer/src/test/java/dev/cel/optimizer/optimizers/ConstantFoldingOptimizerTest.java

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,13 +66,13 @@ public class ConstantFoldingOptimizerTest {
6666
CelExtensions.math(CelOptions.DEFAULT),
6767
CelExtensions.strings(),
6868
CelExtensions.sets(CelOptions.DEFAULT),
69-
CelExtensions.encoders())
69+
CelExtensions.encoders(CelOptions.DEFAULT))
7070
.addRuntimeLibraries(
7171
CelOptionalLibrary.INSTANCE,
7272
CelExtensions.math(CelOptions.DEFAULT),
7373
CelExtensions.strings(),
7474
CelExtensions.sets(CelOptions.DEFAULT),
75-
CelExtensions.encoders())
75+
CelExtensions.encoders(CelOptions.DEFAULT))
7676
.build();
7777

7878
private static final CelOptimizer CEL_OPTIMIZER =
@@ -189,6 +189,28 @@ public class ConstantFoldingOptimizerTest {
189189
@TestParameters("{source: 'sets.contains([1], [1])', expected: 'true'}")
190190
@TestParameters(
191191
"{source: 'cel.bind(r0, [1, 2, 3], cel.bind(r1, 1 in r0, r1))', expected: 'true'}")
192+
@TestParameters("{source: 'x == true', expected: 'x'}")
193+
@TestParameters("{source: 'true == x', expected: 'x'}")
194+
@TestParameters("{source: 'x == false', expected: '!x'}")
195+
@TestParameters("{source: 'false == x', expected: '!x'}")
196+
@TestParameters("{source: 'true == false', expected: 'false'}")
197+
@TestParameters("{source: 'true == true', expected: 'true'}")
198+
@TestParameters("{source: 'false == true', expected: 'false'}")
199+
@TestParameters("{source: 'false == false', expected: 'true'}")
200+
@TestParameters("{source: '10 == 42', expected: 'false'}")
201+
@TestParameters("{source: '42 == 42', expected: 'true'}")
202+
@TestParameters("{source: 'x != true', expected: '!x'}")
203+
@TestParameters("{source: 'true != x', expected: '!x'}")
204+
@TestParameters("{source: 'x != false', expected: 'x'}")
205+
@TestParameters("{source: 'false != x', expected: 'x'}")
206+
@TestParameters("{source: 'true != false', expected: 'true'}")
207+
@TestParameters("{source: 'true != true', expected: 'false'}")
208+
@TestParameters("{source: 'false != true', expected: 'true'}")
209+
@TestParameters("{source: 'false != false', expected: 'false'}")
210+
@TestParameters("{source: '10 != 42', expected: 'true'}")
211+
@TestParameters("{source: '42 != 42', expected: 'false'}")
212+
@TestParameters("{source: '[\"foo\",\"bar\"] == [\"foo\",\"bar\"]', expected: 'true'}")
213+
@TestParameters("{source: '[\"bar\",\"foo\"] == [\"foo\",\"bar\"]', expected: 'false'}")
192214
// TODO: Support folding lists with mixed types. This requires mutable lists.
193215
// @TestParameters("{source: 'dyn([1]) + [1.0]'}")
194216
public void constantFold_success(String source, String expected) throws Exception {
@@ -324,6 +346,8 @@ public void constantFold_macros_withoutMacroCallMetadata(String source) throws E
324346
@TestParameters("{source: 'TestAllTypes{single_int32: x, repeated_int32: [1, 2, 3]}'}")
325347
@TestParameters("{source: 'get_true() == get_true()'}")
326348
@TestParameters("{source: 'get_true() == true'}")
349+
@TestParameters("{source: 'x == x'}")
350+
@TestParameters("{source: 'x == 42'}")
327351
public void constantFold_noOp(String source) throws Exception {
328352
CelAbstractSyntaxTree ast = CEL.compile(source).getAst();
329353

0 commit comments

Comments
 (0)