diff --git a/spring-aop/src/main/java/org/springframework/aop/support/AopUtils.java b/spring-aop/src/main/java/org/springframework/aop/support/AopUtils.java index d27f4b66c32e..daec4be97621 100644 --- a/spring-aop/src/main/java/org/springframework/aop/support/AopUtils.java +++ b/spring-aop/src/main/java/org/springframework/aop/support/AopUtils.java @@ -381,7 +381,7 @@ public static Object invokeSuspendingFunction(Method method, @Nullable Object ta Continuation<?> continuation = (Continuation<?>) args[args.length -1]; Assert.state(continuation != null, "No Continuation available"); CoroutineContext context = continuation.getContext().minusKey(Job.Key); - return CoroutinesUtils.invokeSuspendingFunction(context, method, target, args); + return CoroutinesUtils.invokeSuspendingFunctionPreserveNulls(context, method, target, args); } } diff --git a/spring-aop/src/test/kotlin/org/springframework/aop/support/AopUtilsKotlinTests.kt b/spring-aop/src/test/kotlin/org/springframework/aop/support/AopUtilsKotlinTests.kt index f2051dbcb8cc..a5f8de3eeb5b 100644 --- a/spring-aop/src/test/kotlin/org/springframework/aop/support/AopUtilsKotlinTests.kt +++ b/spring-aop/src/test/kotlin/org/springframework/aop/support/AopUtilsKotlinTests.kt @@ -43,6 +43,17 @@ class AopUtilsKotlinTests { } } + @Test + fun `Invoking suspending function with null argument should not return default value`() { + val method = ReflectionUtils.findMethod(WithoutInterface::class.java, "handleWithDefaultParam", + String::class. java, Continuation::class.java)!! + val continuation = Continuation<Any>(CoroutineName("test")) { } + val result = AopUtils.invokeJoinpointUsingReflection(WithoutInterface(), method, arrayOf(null, continuation)) + assertThat(result).isInstanceOfSatisfying(Mono::class.java) { + assertThat(it.block()).isEqualTo(null) + } + } + @Test fun `Invoking suspending function on bridged method should return Mono`() { val value = "foo" @@ -65,6 +76,11 @@ class AopUtilsKotlinTests { delay(1) return value } + + suspend fun handleWithDefaultParam(value: String? = "defaultVal") : String? { + delay(1) + return value + } } interface ProxyInterface<T> { diff --git a/spring-core/src/main/java/org/springframework/core/CoroutinesUtils.java b/spring-core/src/main/java/org/springframework/core/CoroutinesUtils.java index f53847044d8c..8222780a8b07 100644 --- a/spring-core/src/main/java/org/springframework/core/CoroutinesUtils.java +++ b/spring-core/src/main/java/org/springframework/core/CoroutinesUtils.java @@ -112,6 +112,35 @@ public static Publisher<?> invokeSuspendingFunction(Method method, Object target @SuppressWarnings({"DataFlowIssue", "NullAway"}) public static Publisher<?> invokeSuspendingFunction( CoroutineContext context, Method method, @Nullable Object target, @Nullable Object... args) { + return invokeSuspendingFunctionCore(context, method, target, args, false); + } + + /** + * Invoke a suspending function and convert it to {@link Mono} or + * {@link Flux}. + * @param context the coroutine context to use + * @param method the suspending function to invoke + * @param target the target to invoke {@code method} on + * @param args the function arguments. If the {@code Continuation} argument is specified as the last argument + * (typically {@code null}), it is ignored. + * @return the method invocation result as reactive stream + * @throws IllegalArgumentException if {@code method} is not a suspending function + * @since 6.0 + * This function preservers the null parameter passed in argument + */ + @SuppressWarnings({"DataFlowIssue", "NullAway"}) + public static Publisher<?> invokeSuspendingFunctionPreserveNulls( + CoroutineContext context, Method method, @Nullable Object target, @Nullable Object... args) { + return invokeSuspendingFunctionCore(context, method, target, args, true); + } + + private static Publisher<?> invokeSuspendingFunctionCore( + CoroutineContext context, + Method method, + @Nullable Object target, + @Nullable Object[] args, + boolean preserveNulls) + { Assert.isTrue(KotlinDetector.isSuspendingFunction(method), "Method must be a suspending function"); KFunction<?> function = ReflectJvmMapping.getKotlinFunction(method); @@ -120,26 +149,7 @@ public static Publisher<?> invokeSuspendingFunction( KCallablesJvm.setAccessible(function, true); } Mono<Object> mono = MonoKt.mono(context, (scope, continuation) -> { - Map<KParameter, Object> argMap = CollectionUtils.newHashMap(args.length + 1); - int index = 0; - for (KParameter parameter : function.getParameters()) { - switch (parameter.getKind()) { - case INSTANCE -> argMap.put(parameter, target); - case VALUE, EXTENSION_RECEIVER -> { - Object arg = args[index]; - if (!(parameter.isOptional() && arg == null)) { - KType type = parameter.getType(); - if (!(type.isMarkedNullable() && arg == null) && - type.getClassifier() instanceof KClass<?> kClass && - KotlinDetector.isInlineClass(JvmClassMappingKt.getJavaClass(kClass))) { - arg = box(kClass, arg); - } - argMap.put(parameter, arg); - } - index++; - } - } - } + Map<KParameter, Object> argMap = buildArgMap(function, target, args, preserveNulls); return KCallables.callSuspendBy(function, argMap, continuation); }) .filter(result -> result != Unit.INSTANCE) @@ -158,6 +168,40 @@ public static Publisher<?> invokeSuspendingFunction( return mono; } + private static Map<KParameter, Object> buildArgMap( + KFunction<?> function, + @Nullable Object target, + @Nullable Object[] args, + boolean preserveNulls) { + + Map<KParameter, Object> argMap = CollectionUtils.newHashMap(args.length + 1); + int index = 0; + + for (KParameter parameter : function.getParameters()) { + switch (parameter.getKind()) { + case INSTANCE -> argMap.put(parameter, target); + case VALUE, EXTENSION_RECEIVER -> { + Object arg = args[index]; + + if (!(parameter.isOptional() && arg == null)) { + KType type = parameter.getType(); + if (!(type.isMarkedNullable() && arg == null) && + type.getClassifier() instanceof KClass<?> kClass && + KotlinDetector.isInlineClass(JvmClassMappingKt.getJavaClass(kClass))) { + arg = box(kClass, arg); + } + argMap.put(parameter, arg); + } else if(preserveNulls) { + argMap.put(parameter, arg); + } + index++; + } + } + } + return argMap; + } + + private static Object box(KClass<?> kClass, @Nullable Object arg) { KFunction<?> constructor = Objects.requireNonNull(KClasses.getPrimaryConstructor(kClass)); KType type = constructor.getParameters().get(0).getType(); diff --git a/spring-core/src/test/kotlin/org/springframework/core/CoroutinesUtilsTests.kt b/spring-core/src/test/kotlin/org/springframework/core/CoroutinesUtilsTests.kt index 24a98d61ccb6..a8270f738ba7 100644 --- a/spring-core/src/test/kotlin/org/springframework/core/CoroutinesUtilsTests.kt +++ b/spring-core/src/test/kotlin/org/springframework/core/CoroutinesUtilsTests.kt @@ -93,6 +93,16 @@ class CoroutinesUtilsTests { } } + @Test + fun invokeSuspendingFunctionWithNullableParameterPreservesNull() { + val method = CoroutinesUtilsTests::class.java.getDeclaredMethod("suspendingFunctionWithOptionalParameterAndDefaultValue", String::class.java, Continuation::class.java) + val mono = CoroutinesUtils.invokeSuspendingFunctionPreserveNulls(Dispatchers.Unconfined, method, this, null, null) as Mono + runBlocking { + Assertions.assertThat(mono.awaitSingleOrNull()).isNull() + } + } + + @Test fun invokePrivateSuspendingFunction() { val method = CoroutinesUtilsTests::class.java.getDeclaredMethod("privateSuspendingFunction", String::class.java, Continuation::class.java) @@ -300,6 +310,12 @@ class CoroutinesUtilsTests { return value } + suspend fun suspendingFunctionWithOptionalParameterAndDefaultValue(value: String? = "foo"): String? { + delay(1) + return value + } + + suspend fun suspendingFunctionWithMono(): Mono<String> { delay(1) return Mono.just("foo")