From 3386c6310aeb426c1936ffe5ff8107d84156f6c2 Mon Sep 17 00:00:00 2001
From: Yanming Zhou <zhouyanming@gmail.com>
Date: Thu, 3 Apr 2025 10:30:46 +0800
Subject: [PATCH] Discard further rows once maxRows has been reached

See https://github.com/spring-projects/spring-framework/issues/34666#issuecomment-2773151317

Signed-off-by: Yanming Zhou <zhouyanming@gmail.com>
---
 .../jdbc/core/JdbcTemplate.java               | 40 +++++++++-------
 .../core/RowMapperResultSetExtractor.java     | 17 ++++++-
 .../jdbc/core/JdbcTemplateTests.java          | 47 +++++++++++++++++++
 3 files changed, 87 insertions(+), 17 deletions(-)

diff --git a/spring-jdbc/src/main/java/org/springframework/jdbc/core/JdbcTemplate.java b/spring-jdbc/src/main/java/org/springframework/jdbc/core/JdbcTemplate.java
index 1b3e14d36866..11833ec090f0 100644
--- a/spring-jdbc/src/main/java/org/springframework/jdbc/core/JdbcTemplate.java
+++ b/spring-jdbc/src/main/java/org/springframework/jdbc/core/JdbcTemplate.java
@@ -102,6 +102,7 @@
  * @author Rod Johnson
  * @author Juergen Hoeller
  * @author Thomas Risberg
+ * @author Yanming Zhou
  * @since May 3, 2001
  * @see JdbcOperations
  * @see PreparedStatementCreator
@@ -493,12 +494,12 @@ public String getSql() {
 
 	@Override
 	public void query(String sql, RowCallbackHandler rch) throws DataAccessException {
-		query(sql, new RowCallbackHandlerResultSetExtractor(rch));
+		query(sql, new RowCallbackHandlerResultSetExtractor(rch, this.maxRows));
 	}
 
 	@Override
 	public <T> List<T> query(String sql, RowMapper<T> rowMapper) throws DataAccessException {
-		return result(query(sql, new RowMapperResultSetExtractor<>(rowMapper)));
+		return result(query(sql, new RowMapperResultSetExtractor<>(rowMapper, 0, this.maxRows)));
 	}
 
 	@Override
@@ -508,7 +509,7 @@ class StreamStatementCallback implements StatementCallback<Stream<T>>, SqlProvid
 			public Stream<T> doInStatement(Statement stmt) throws SQLException {
 				ResultSet rs = stmt.executeQuery(sql);
 				Connection con = stmt.getConnection();
-				return new ResultSetSpliterator<>(rs, rowMapper).stream().onClose(() -> {
+				return new ResultSetSpliterator<>(rs, rowMapper, JdbcTemplate.this.maxRows).stream().onClose(() -> {
 					JdbcUtils.closeResultSet(rs);
 					JdbcUtils.closeStatement(stmt);
 					DataSourceUtils.releaseConnection(con, getDataSource());
@@ -773,12 +774,12 @@ private String appendSql(@Nullable String sql, String statement) {
 
 	@Override
 	public void query(PreparedStatementCreator psc, RowCallbackHandler rch) throws DataAccessException {
-		query(psc, new RowCallbackHandlerResultSetExtractor(rch));
+		query(psc, new RowCallbackHandlerResultSetExtractor(rch, this.maxRows));
 	}
 
 	@Override
 	public void query(String sql, @Nullable PreparedStatementSetter pss, RowCallbackHandler rch) throws DataAccessException {
-		query(sql, pss, new RowCallbackHandlerResultSetExtractor(rch));
+		query(sql, pss, new RowCallbackHandlerResultSetExtractor(rch, this.maxRows));
 	}
 
 	@Override
@@ -799,28 +800,28 @@ public void query(String sql, RowCallbackHandler rch, @Nullable Object @Nullable
 
 	@Override
 	public <T> List<T> query(PreparedStatementCreator psc, RowMapper<T> rowMapper) throws DataAccessException {
-		return result(query(psc, new RowMapperResultSetExtractor<>(rowMapper)));
+		return result(query(psc, new RowMapperResultSetExtractor<>(rowMapper, 0, this.maxRows)));
 	}
 
 	@Override
 	public <T> List<T> query(String sql, @Nullable PreparedStatementSetter pss, RowMapper<T> rowMapper) throws DataAccessException {
-		return result(query(sql, pss, new RowMapperResultSetExtractor<>(rowMapper)));
+		return result(query(sql, pss, new RowMapperResultSetExtractor<>(rowMapper, 0, this.maxRows)));
 	}
 
 	@Override
 	public <T> List<T> query(String sql, @Nullable Object @Nullable [] args, int[] argTypes, RowMapper<T> rowMapper) throws DataAccessException {
-		return result(query(sql, args, argTypes, new RowMapperResultSetExtractor<>(rowMapper)));
+		return result(query(sql, args, argTypes, new RowMapperResultSetExtractor<>(rowMapper, 0, this.maxRows)));
 	}
 
 	@Deprecated(since = "5.3")
 	@Override
 	public <T> List<T> query(String sql, @Nullable Object @Nullable [] args, RowMapper<T> rowMapper) throws DataAccessException {
-		return result(query(sql, newArgPreparedStatementSetter(args), new RowMapperResultSetExtractor<>(rowMapper)));
+		return result(query(sql, newArgPreparedStatementSetter(args), new RowMapperResultSetExtractor<>(rowMapper, 0, this.maxRows)));
 	}
 
 	@Override
 	public <T> List<T> query(String sql, RowMapper<T> rowMapper, @Nullable Object @Nullable ... args) throws DataAccessException {
-		return result(query(sql, newArgPreparedStatementSetter(args), new RowMapperResultSetExtractor<>(rowMapper)));
+		return result(query(sql, newArgPreparedStatementSetter(args), new RowMapperResultSetExtractor<>(rowMapper, 0, this.maxRows)));
 	}
 
 	/**
@@ -845,7 +846,7 @@ public <T> Stream<T> queryForStream(PreparedStatementCreator psc, @Nullable Prep
 			}
 			ResultSet rs = ps.executeQuery();
 			Connection con = ps.getConnection();
-			return new ResultSetSpliterator<>(rs, rowMapper).stream().onClose(() -> {
+			return new ResultSetSpliterator<>(rs, rowMapper, this.maxRows).stream().onClose(() -> {
 				JdbcUtils.closeResultSet(rs);
 				if (pss instanceof ParameterDisposer parameterDisposer) {
 					parameterDisposer.cleanupParameters();
@@ -1364,7 +1365,7 @@ protected Map<String, Object> processResultSet(
 				}
 				else if (param.getRowCallbackHandler() != null) {
 					RowCallbackHandler rch = param.getRowCallbackHandler();
-					(new RowCallbackHandlerResultSetExtractor(rch)).extractData(rs);
+					(new RowCallbackHandlerResultSetExtractor(rch, -1)).extractData(rs);
 					return Collections.singletonMap(param.getName(),
 							"ResultSet returned from stored procedure was processed");
 				}
@@ -1747,13 +1748,17 @@ private static class RowCallbackHandlerResultSetExtractor implements ResultSetEx
 
 		private final RowCallbackHandler rch;
 
-		public RowCallbackHandlerResultSetExtractor(RowCallbackHandler rch) {
+		private final int maxRows;
+
+		public RowCallbackHandlerResultSetExtractor(RowCallbackHandler rch, int maxRows) {
 			this.rch = rch;
+			this.maxRows = maxRows;
 		}
 
 		@Override
 		public @Nullable Object extractData(ResultSet rs) throws SQLException {
-			while (rs.next()) {
+			int processed = 0;
+			while (rs.next() && (this.maxRows == -1 || (processed++) < this.maxRows)) {
 				this.rch.processRow(rs);
 			}
 			return null;
@@ -1771,17 +1776,20 @@ private static class ResultSetSpliterator<T> implements Spliterator<T> {
 
 		private final RowMapper<T> rowMapper;
 
+		private final int maxRows;
+
 		private int rowNum = 0;
 
-		public ResultSetSpliterator(ResultSet rs, RowMapper<T> rowMapper) {
+		public ResultSetSpliterator(ResultSet rs, RowMapper<T> rowMapper, int maxRows) {
 			this.rs = rs;
 			this.rowMapper = rowMapper;
+			this.maxRows = maxRows;
 		}
 
 		@Override
 		public boolean tryAdvance(Consumer<? super T> action) {
 			try {
-				if (this.rs.next()) {
+				if (this.rs.next() && (this.maxRows == -1 || this.rowNum < this.maxRows)) {
 					action.accept(this.rowMapper.mapRow(this.rs, this.rowNum++));
 					return true;
 				}
diff --git a/spring-jdbc/src/main/java/org/springframework/jdbc/core/RowMapperResultSetExtractor.java b/spring-jdbc/src/main/java/org/springframework/jdbc/core/RowMapperResultSetExtractor.java
index 66311a18c930..e353c850b09f 100644
--- a/spring-jdbc/src/main/java/org/springframework/jdbc/core/RowMapperResultSetExtractor.java
+++ b/spring-jdbc/src/main/java/org/springframework/jdbc/core/RowMapperResultSetExtractor.java
@@ -52,6 +52,7 @@
  * you can have executable query objects (containing row-mapping logic) there.
  *
  * @author Juergen Hoeller
+ * @author Yanming Zhou
  * @since 1.0.2
  * @param <T> the result element type
  * @see RowMapper
@@ -64,6 +65,8 @@ public class RowMapperResultSetExtractor<T> implements ResultSetExtractor<List<T
 
 	private final int rowsExpected;
 
+	private final int maxRows;
+
 
 	/**
 	 * Create a new RowMapperResultSetExtractor.
@@ -80,9 +83,21 @@ public RowMapperResultSetExtractor(RowMapper<T> rowMapper) {
 	 * (just used for optimized collection handling)
 	 */
 	public RowMapperResultSetExtractor(RowMapper<T> rowMapper, int rowsExpected) {
+		this(rowMapper, rowsExpected, -1);
+	}
+
+	/**
+	 * Create a new RowMapperResultSetExtractor.
+	 * @param rowMapper the RowMapper which creates an object for each row
+	 * @param rowsExpected the number of expected rows
+	 * (just used for optimized collection handling)
+	 * @param maxRows the number of max rows
+	 */
+	public RowMapperResultSetExtractor(RowMapper<T> rowMapper, int rowsExpected, int maxRows) {
 		Assert.notNull(rowMapper, "RowMapper must not be null");
 		this.rowMapper = rowMapper;
 		this.rowsExpected = rowsExpected;
+		this.maxRows = maxRows;
 	}
 
 
@@ -90,7 +105,7 @@ public RowMapperResultSetExtractor(RowMapper<T> rowMapper, int rowsExpected) {
 	public List<T> extractData(ResultSet rs) throws SQLException {
 		List<T> results = (this.rowsExpected > 0 ? new ArrayList<>(this.rowsExpected) : new ArrayList<>());
 		int rowNum = 0;
-		while (rs.next()) {
+		while (rs.next() && (this.maxRows == -1 || rowNum < this.maxRows)) {
 			results.add(this.rowMapper.mapRow(rs, rowNum++));
 		}
 		return results;
diff --git a/spring-jdbc/src/test/java/org/springframework/jdbc/core/JdbcTemplateTests.java b/spring-jdbc/src/test/java/org/springframework/jdbc/core/JdbcTemplateTests.java
index 6389af717356..a019470cbdff 100644
--- a/spring-jdbc/src/test/java/org/springframework/jdbc/core/JdbcTemplateTests.java
+++ b/spring-jdbc/src/test/java/org/springframework/jdbc/core/JdbcTemplateTests.java
@@ -32,7 +32,9 @@
 import java.util.Collections;
 import java.util.List;
 import java.util.Map;
+import java.util.function.BiFunction;
 import java.util.function.Consumer;
+import java.util.stream.Stream;
 
 import javax.sql.DataSource;
 
@@ -77,6 +79,7 @@
  * @author Thomas Risberg
  * @author Juergen Hoeller
  * @author Phillip Webb
+ * @author Yanming Zhou
  */
 class JdbcTemplateTests {
 
@@ -1236,6 +1239,50 @@ public int getBatchSize() {
 				Collections.singletonMap("someId", 456));
 	}
 
+	@Test
+	void testSkipFurtherRowsOnceMaxRowsHasBeenReachedForRowMapper() throws Exception {
+		testDiscardFurtherRowsOnceMaxRowsHasBeenReached((template, sql) ->
+				template.query(sql, (rs, rowNum) -> rs.getString(1)));
+	}
+
+	@Test
+	void testDiscardFurtherRowsOnceMaxRowsHasBeenReachedForRowCallbackHandler() throws Exception {
+		testDiscardFurtherRowsOnceMaxRowsHasBeenReached((template, sql) -> {
+			List<String> list = new ArrayList<>();
+			template.query(sql, (RowCallbackHandler) rs -> list.add(rs.getString(1)));
+			return list;
+		});
+	}
+
+	@Test
+	void testDiscardFurtherRowsOnceMaxRowsHasBeenReachedForStream() throws Exception {
+		testDiscardFurtherRowsOnceMaxRowsHasBeenReached((template, sql) -> {
+			try (Stream<String> stream = template.queryForStream(sql, (rs, rowNum) -> rs.getString(1))) {
+				return stream.toList();
+			}
+		});
+	}
+
+	private void testDiscardFurtherRowsOnceMaxRowsHasBeenReached(BiFunction<JdbcTemplate,String,List<String>> function) throws Exception {
+		String sql = "SELECT FORENAME FROM CUSTMR";
+		String[] results = {"rod", "gary", " portia"};
+		int maxRows = 2;
+
+		given(this.resultSet.next()).willReturn(true, true, true, false);
+		given(this.resultSet.getString(1)).willReturn(results[0], results[1], results[2]);
+		given(this.connection.createStatement()).willReturn(this.preparedStatement);
+
+		JdbcTemplate template = new JdbcTemplate();
+		template.setDataSource(this.dataSource);
+		template.setMaxRows(maxRows);
+
+		assertThat(function.apply(template, sql)).as("same length").hasSize(maxRows);
+
+		verify(this.resultSet).close();
+		verify(this.preparedStatement).close();
+		verify(this.connection).close();
+	}
+
 	private void mockDatabaseMetaData(boolean supportsBatchUpdates) throws SQLException {
 		DatabaseMetaData databaseMetaData = mock();
 		given(databaseMetaData.getDatabaseProductName()).willReturn("MySQL");