diff --git a/src/EFCore.Relational/Query/SqlNullabilityProcessor.cs b/src/EFCore.Relational/Query/SqlNullabilityProcessor.cs index e9dc788484d..0c5f9cf0c38 100644 --- a/src/EFCore.Relational/Query/SqlNullabilityProcessor.cs +++ b/src/EFCore.Relational/Query/SqlNullabilityProcessor.cs @@ -1289,6 +1289,7 @@ protected virtual SqlExpression VisitSqlBinary( right, leftNullable, rightNullable, + optimize, out nullable); if (optimized is SqlUnaryExpression { Operand: ColumnExpression optimizedUnaryColumnOperand } optimizedUnary) @@ -1623,6 +1624,7 @@ private SqlExpression ProcessJoinPredicate(SqlExpression predicate) right, leftNullable, rightNullable, + optimize: true, out _); return result; @@ -1649,6 +1651,7 @@ private SqlExpression OptimizeComparison( SqlExpression right, bool leftNullable, bool rightNullable, + bool optimize, out bool nullable) { var leftNullValue = leftNullable && left is SqlConstantExpression or SqlParameterExpression; @@ -1729,32 +1732,8 @@ private SqlExpression OptimizeComparison( && !rightNullable && sqlBinaryExpression.OperatorType is ExpressionType.Equal or ExpressionType.NotEqual) { - var leftUnary = left as SqlUnaryExpression; - var rightUnary = right as SqlUnaryExpression; - - var leftNegated = IsLogicalNot(leftUnary); - var rightNegated = IsLogicalNot(rightUnary); - - if (leftNegated) - { - left = leftUnary!.Operand; - } - - if (rightNegated) - { - right = rightUnary!.Operand; - } - - // a == b <=> !a == !b -> a == b - // !a == b <=> a == !b -> a != b - // a != b <=> !a != !b -> a != b - // !a != b <=> a != !b -> a == b - nullable = false; - - return sqlBinaryExpression.OperatorType == ExpressionType.Equal ^ leftNegated == rightNegated - ? _sqlExpressionFactory.NotEqual(left, right) - : _sqlExpressionFactory.Equal(left, right); + return OptimizeBooleanComparison(sqlBinaryExpression, left, right, optimize); } nullable = false; @@ -1762,14 +1741,11 @@ private SqlExpression OptimizeComparison( return sqlBinaryExpression.Update(left, right); } - private SqlExpression RewriteNullSemantics( + private SqlExpression OptimizeBooleanComparison( SqlBinaryExpression sqlBinaryExpression, SqlExpression left, SqlExpression right, - bool leftNullable, - bool rightNullable, - bool optimize, - out bool nullable) + bool optimize) { var leftUnary = left as SqlUnaryExpression; var rightUnary = right as SqlUnaryExpression; @@ -1787,22 +1763,49 @@ private SqlExpression RewriteNullSemantics( right = rightUnary!.Operand; } + var notEqual = sqlBinaryExpression.OperatorType == ExpressionType.Equal ^ leftNegated == rightNegated; + + // prefer equality in predicates + if (optimize && notEqual && left.Type == typeof(bool)) + { + if (right is ColumnExpression && (left is not ColumnExpression || leftNegated)) + { + left = _sqlExpressionFactory.Not(left); + } + else + { + right = _sqlExpressionFactory.Not(right); + } + + return _sqlExpressionFactory.Equal(left, right); + } + + // a == b <=> !a == !b -> a == b + // !a == b <=> a == !b -> a != b + // a != b <=> !a != !b -> a != b + // !a != b <=> a != !b -> a == b + + return notEqual + ? _sqlExpressionFactory.NotEqual(left, right) + : _sqlExpressionFactory.Equal(left, right); + } + + private SqlExpression RewriteNullSemantics( + SqlBinaryExpression sqlBinaryExpression, + SqlExpression left, + SqlExpression right, + bool leftNullable, + bool rightNullable, + bool optimize, + out bool nullable) + { var leftIsNull = ProcessNullNotNull(_sqlExpressionFactory.IsNull(left), leftNullable); var leftIsNotNull = _sqlExpressionFactory.Not(leftIsNull); var rightIsNull = ProcessNullNotNull(_sqlExpressionFactory.IsNull(right), rightNullable); var rightIsNotNull = _sqlExpressionFactory.Not(rightIsNull); - SqlExpression body; - if (leftNegated == rightNegated) - { - body = _sqlExpressionFactory.Equal(left, right); - } - else - { - // a == !b and !a == b in SQL evaluate the same as a != b - body = _sqlExpressionFactory.NotEqual(left, right); - } + var body = OptimizeBooleanComparison(sqlBinaryExpression, left, right, optimize); // optimized expansion which doesn't distinguish between null and false if (optimize && sqlBinaryExpression.OperatorType == ExpressionType.Equal) @@ -1815,6 +1818,12 @@ private SqlExpression RewriteNullSemantics( // doing a full null semantics rewrite - removing all nulls from truth table nullable = false; + if (sqlBinaryExpression.OperatorType == ExpressionType.NotEqual) + { + // the factory takes care of simplifying equal <-> not-equal + body = _sqlExpressionFactory.Not(body); + } + // (a == b && (a != null && b != null)) || (a == null && b == null) body = _sqlExpressionFactory.OrElse( _sqlExpressionFactory.AndAlso(body, _sqlExpressionFactory.AndAlso(leftIsNotNull, rightIsNotNull)),