Skip to content

Commit 076cf06

Browse files
Support Streamable return type in AOT repositories.
1 parent 4ae1c5d commit 076cf06

File tree

6 files changed

+91
-2
lines changed

6 files changed

+91
-2
lines changed

spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/AggregationBlocks.java

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
import org.springframework.data.mongodb.repository.query.MongoQueryMethod;
4242
import org.springframework.data.repository.aot.generate.AotQueryMethodGenerationContext;
4343
import org.springframework.data.util.ReflectionUtils;
44+
import org.springframework.data.util.Streamable;
4445
import org.springframework.javapoet.CodeBlock;
4546
import org.springframework.javapoet.CodeBlock.Builder;
4647
import org.springframework.util.ClassUtils;
@@ -145,8 +146,15 @@ CodeBlock build() {
145146
builder.addStatement("return $L.aggregateStream($L, $T.class)", mongoOpsRef, aggregationVariableName,
146147
outputType);
147148
} else {
148-
builder.addStatement("return $L.aggregate($L, $T.class).getMappedResults()", mongoOpsRef,
149+
150+
CodeBlock resultBlock = CodeBlock.of("$L.aggregate($L, $T.class).getMappedResults()", mongoOpsRef,
149151
aggregationVariableName, outputType);
152+
153+
if (queryMethod.getReturnType().getType().equals(Streamable.class)) {
154+
resultBlock = CodeBlock.of("$T.of($L)", Streamable.class, resultBlock);
155+
}
156+
157+
builder.addStatement("return $L", resultBlock);
150158
}
151159
}
152160
}

spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/QueryBlocks.java

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
import org.springframework.data.mongodb.repository.query.MongoQueryMethod;
3535
import org.springframework.data.repository.aot.generate.AotQueryMethodGenerationContext;
3636
import org.springframework.data.util.Lazy;
37+
import org.springframework.data.util.Streamable;
3738
import org.springframework.javapoet.CodeBlock;
3839
import org.springframework.javapoet.CodeBlock.Builder;
3940
import org.springframework.javapoet.TypeName;
@@ -145,8 +146,15 @@ CodeBlock build() {
145146
context.localVariable("finder"), query.name(), terminatingMethod, returnType);
146147

147148
} else {
148-
builder.addStatement("return $L.matching($L).$L", context.localVariable("finder"), query.name(),
149+
150+
CodeBlock resultBlock = CodeBlock.of("$L.matching($L).$L", context.localVariable("finder"), query.name(),
149151
terminatingMethod);
152+
153+
if (queryMethod.getReturnType().getType().equals(Streamable.class)) {
154+
resultBlock = CodeBlock.of("$T.of($L)", Streamable.class, resultBlock);
155+
}
156+
157+
builder.addStatement("return $L", resultBlock);
150158
}
151159
}
152160

spring-data-mongodb/src/test/java/example/aot/UserRepository.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
import org.springframework.data.mongodb.repository.VectorSearch;
4949
import org.springframework.data.repository.CrudRepository;
5050
import org.springframework.data.repository.query.Param;
51+
import org.springframework.data.util.Streamable;
5152

5253
/**
5354
* @author Christoph Strobl
@@ -58,6 +59,8 @@ public interface UserRepository extends CrudRepository<User, String> {
5859

5960
List<User> findUserNoArgumentsBy();
6061

62+
Streamable<User> streamUserNoArgumentsBy();
63+
6164
User findOneByUsername(String username);
6265

6366
Optional<User> findOptionalOneByUsername(String username);
@@ -267,6 +270,11 @@ public interface UserRepository extends CrudRepository<User, String> {
267270
"{ '$group': { '_id' : '$last_name', names : { $addToSet : '$?0' } } }" })
268271
Stream<UserAggregate> streamGroupByLastnameAndAsAggregationResults(String property);
269272

273+
@Aggregation(pipeline = { //
274+
"{ '$match' : { 'last_name' : { '$ne' : null } } }", //
275+
"{ '$group': { '_id' : '$last_name', names : { $addToSet : '$?0' } } }" })
276+
Streamable<UserAggregate> streamAsStreamableGroupByLastnameAndAsAggregationResults(String property);
277+
270278
@Aggregation(pipeline = { //
271279
"{ '$match' : { 'posts' : { '$ne' : null } } }", //
272280
"{ '$project': { 'nrPosts' : {'$size': '$posts' } } }", //

spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/AbstractPersonRepositoryIntegrationTests.java

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@
7979
import org.springframework.data.mongodb.test.util.DirtiesStateExtension.ProvidesState;
8080
import org.springframework.data.mongodb.test.util.EnableIfMongoServerVersion;
8181
import org.springframework.data.querydsl.QSort;
82+
import org.springframework.data.util.Streamable;
8283
import org.springframework.test.context.junit.jupiter.SpringExtension;
8384
import org.springframework.test.util.ReflectionTestUtils;
8485

@@ -312,6 +313,33 @@ void findsPersonByAddressCorrectly() {
312313
assertThat(result).hasSize(1).contains(dave);
313314
}
314315

316+
@Test // GH-5089
317+
void streamPersonByAddressCorrectly() {
318+
319+
Address address = new Address("Foo Street 1", "C0123", "Bar");
320+
dave.setAddress(address);
321+
repository.save(dave);
322+
323+
Streamable<Person> result = repository.streamByAddress(address);
324+
assertThat(result).hasSize(1).contains(dave);
325+
}
326+
327+
@Test // GH-5089
328+
void streamPersonByAddressCorrectlyWhenPaged() {
329+
330+
Address address = new Address("Foo Street 1", "C0123", "Bar");
331+
dave.setAddress(address);
332+
oliver.setAddress(address);
333+
repository.saveAll(List.of(dave, oliver));
334+
335+
Streamable<Person> result = repository.streamByAddress(address,
336+
PageRequest.of(0, 1, Sort.by(Direction.DESC, "firstname")));
337+
assertThat(result).containsExactly(oliver);
338+
339+
result = repository.streamByAddress(address, PageRequest.of(1, 1, Sort.by(Direction.DESC, "firstname")));
340+
assertThat(result).containsExactly(dave);
341+
}
342+
315343
@Test
316344
void findsPeopleByZipCode() {
317345

@@ -1516,6 +1544,16 @@ void annotatedAggregationWithPageable() {
15161544
new PersonAggregate("Matthews", Arrays.asList("Dave", "Oliver August")));
15171545
}
15181546

1547+
@Test // GH-5089
1548+
void annotatedAggregationReturningStreamable() {
1549+
1550+
assertThat(repository.streamGroupByLastnameAnd("firstname", PageRequest.of(1, 2, Sort.by("lastname")))) //
1551+
.isInstanceOf(Streamable.class) //
1552+
.containsExactly( //
1553+
new PersonAggregate("Lessard", Collections.singletonList("Stefan")), //
1554+
new PersonAggregate("Matthews", Arrays.asList("Dave", "Oliver August")));
1555+
}
1556+
15191557
@Test // DATAMONGO-2153
15201558
void annotatedAggregationWithSingleSimpleResult() {
15211559
assertThat(repository.sumAge()).isEqualTo(245);

spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/PersonRepository.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
import org.springframework.data.mongodb.repository.Person.Sex;
4646
import org.springframework.data.querydsl.QuerydslPredicateExecutor;
4747
import org.springframework.data.repository.query.Param;
48+
import org.springframework.data.util.Streamable;
4849

4950
/**
5051
* Sample repository managing {@link Person} entities.
@@ -211,6 +212,10 @@ Window<Person> findByLastnameLikeOrderByLastnameAscFirstnameAsc(Pattern lastname
211212
*/
212213
List<Person> findByAddress(Address address);
213214

215+
Streamable<Person> streamByAddress(Address address);
216+
217+
Streamable<Person> streamByAddress(Address address, Pageable pageable);
218+
214219
List<Person> findByAddressZipCode(String zipCode);
215220

216221
List<Person> findByLastnameLikeAndAgeBetween(String lastname, int from, int to);
@@ -442,6 +447,9 @@ Page<Person> findByCustomQueryLastnameAndAddressStreetInList(String lastname, Li
442447
@Aggregation("{ '$group': { '_id' : '$lastname', names : { $addToSet : '$?0' } } }")
443448
List<PersonAggregate> groupByLastnameAnd(String property, Pageable page);
444449

450+
@Aggregation("{ '$group': { '_id' : '$lastname', names : { $addToSet : '$?0' } } }")
451+
Streamable<PersonAggregate> streamGroupByLastnameAnd(String property, Pageable page);
452+
445453
@Aggregation(pipeline = "{ '$group' : { '_id' : null, 'total' : { $sum: '$age' } } }")
446454
int sumAge();
447455

spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/QueryMethodContributionUnitTests.java

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -387,6 +387,25 @@ void rendersVectorSearchOrderByWithScoreLast() throws NoSuchMethodException {
387387
"Document(\"$sort\", mappedSort.append(\"__score__\", -1))");
388388
}
389389

390+
@Test // GH-5089
391+
void rendersStreamableReturnType() throws NoSuchMethodException {
392+
393+
MethodSpec methodSpec = codeOf(UserRepository.class, "streamUserNoArgumentsBy");
394+
395+
assertThat(methodSpec.toString()) //
396+
.containsSubsequence("return", "Streamable.of(", "all())");
397+
}
398+
399+
@Test // GH-5089
400+
void rendersStreamableReturnTypeForAggregation() throws NoSuchMethodException {
401+
402+
MethodSpec methodSpec = codeOf(UserRepository.class, "streamAsStreamableGroupByLastnameAndAsAggregationResults",
403+
String.class);
404+
405+
assertThat(methodSpec.toString()) //
406+
.containsSubsequence("return", "Streamable.of(", "getMappedResults())");
407+
}
408+
390409
private static MethodSpec codeOf(Class<?> repository, String methodName, Class<?>... args)
391410
throws NoSuchMethodException {
392411

0 commit comments

Comments
 (0)