Skip to content

Commit 4a4808d

Browse files
sobychackoeeaters
authored andcommitted
refactor: simplify assistant message extraction using Optional chaining
Replace nested null checks and redundant branching with streamlined Optional-based approach. Since getResult() == getResults().get(0), processing all results handles both single and multiple result cases. Adding tests to verify. Fixes #4292 Co-Authored-By: eeaters <[email protected]> Signed-off-by: Soby Chacko <[email protected]>
1 parent 74e6417 commit 4a4808d

File tree

2 files changed

+140
-10
lines changed

2 files changed

+140
-10
lines changed

spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/PromptChatMemoryAdvisor.java

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import java.util.ArrayList;
2020
import java.util.List;
2121
import java.util.Map;
22+
import java.util.Optional;
2223
import java.util.stream.Collectors;
2324

2425
import org.slf4j.Logger;
@@ -141,18 +142,20 @@ public ChatClientRequest before(ChatClientRequest chatClientRequest, AdvisorChai
141142
@Override
142143
public ChatClientResponse after(ChatClientResponse chatClientResponse, AdvisorChain advisorChain) {
143144
List<Message> assistantMessages = new ArrayList<>();
144-
// Handle streaming case where we have a single result
145-
if (chatClientResponse.chatResponse() != null && chatClientResponse.chatResponse().getResult() != null
146-
&& chatClientResponse.chatResponse().getResult().getOutput() != null) {
147-
assistantMessages = List.of((Message) chatClientResponse.chatResponse().getResult().getOutput());
148-
}
149-
else if (chatClientResponse.chatResponse() != null) {
150-
assistantMessages = chatClientResponse.chatResponse()
151-
.getResults()
145+
// Extract assistant messages from chat client response.
146+
// Processes all results from getResults() which automatically handles both single
147+
// and multiple
148+
// result scenarios (since getResult() == getResults().get(0)). Uses Optional
149+
// chaining for
150+
// null safety and returns empty list if no results are available.
151+
assistantMessages = Optional.ofNullable(chatClientResponse)
152+
.map(ChatClientResponse::chatResponse)
153+
.filter(response -> response.getResults() != null && !response.getResults().isEmpty())
154+
.map(response -> response.getResults()
152155
.stream()
153156
.map(g -> (Message) g.getOutput())
154-
.toList();
155-
}
157+
.collect(Collectors.toList()))
158+
.orElse(List.of());
156159

157160
if (!assistantMessages.isEmpty()) {
158161
this.chatMemory.add(this.getConversationId(chatClientResponse.context(), this.defaultConversationId),

spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/advisor/PromptChatMemoryAdvisorTests.java

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,23 +16,34 @@
1616

1717
package org.springframework.ai.chat.client.advisor;
1818

19+
import java.util.List;
20+
1921
import org.junit.jupiter.api.Test;
2022
import reactor.core.scheduler.Schedulers;
2123

24+
import org.springframework.ai.chat.client.ChatClientResponse;
2225
import org.springframework.ai.chat.client.advisor.api.Advisor;
26+
import org.springframework.ai.chat.client.advisor.api.AdvisorChain;
2327
import org.springframework.ai.chat.memory.ChatMemory;
2428
import org.springframework.ai.chat.memory.InMemoryChatMemoryRepository;
2529
import org.springframework.ai.chat.memory.MessageWindowChatMemory;
30+
import org.springframework.ai.chat.messages.AssistantMessage;
31+
import org.springframework.ai.chat.messages.Message;
32+
import org.springframework.ai.chat.model.ChatResponse;
33+
import org.springframework.ai.chat.model.Generation;
2634
import org.springframework.ai.chat.prompt.PromptTemplate;
2735

2836
import static org.assertj.core.api.Assertions.assertThat;
2937
import static org.assertj.core.api.Assertions.assertThatThrownBy;
38+
import static org.mockito.Mockito.mock;
39+
import static org.mockito.Mockito.when;
3040

3141
/**
3242
* Unit tests for {@link PromptChatMemoryAdvisor}.
3343
*
3444
* @author Mark Pollack
3545
* @author Thomas Vitale
46+
* @author Soby Chacko
3647
*/
3748
public class PromptChatMemoryAdvisorTests {
3849

@@ -138,4 +149,120 @@ void testDefaultValues() {
138149
assertThat(advisor.getOrder()).isEqualTo(Advisor.DEFAULT_CHAT_MEMORY_PRECEDENCE_ORDER);
139150
}
140151

152+
@Test
153+
void testAfterMethodHandlesSingleGeneration() {
154+
ChatMemory chatMemory = MessageWindowChatMemory.builder()
155+
.chatMemoryRepository(new InMemoryChatMemoryRepository())
156+
.build();
157+
158+
PromptChatMemoryAdvisor advisor = PromptChatMemoryAdvisor.builder(chatMemory)
159+
.conversationId("test-conversation")
160+
.build();
161+
162+
ChatClientResponse mockResponse = mock(ChatClientResponse.class);
163+
ChatResponse mockChatResponse = mock(ChatResponse.class);
164+
Generation mockGeneration = mock(Generation.class);
165+
AdvisorChain mockChain = mock(AdvisorChain.class);
166+
167+
when(mockResponse.chatResponse()).thenReturn(mockChatResponse);
168+
when(mockChatResponse.getResults()).thenReturn(List.of(mockGeneration)); // Single
169+
// result
170+
when(mockGeneration.getOutput()).thenReturn(new AssistantMessage("Single response"));
171+
172+
ChatClientResponse result = advisor.after(mockResponse, mockChain);
173+
174+
assertThat(result).isEqualTo(mockResponse); // Should return the same response
175+
176+
// Verify single message stored in memory
177+
List<Message> messages = chatMemory.get("test-conversation");
178+
assertThat(messages).hasSize(1);
179+
assertThat(messages.get(0).getText()).isEqualTo("Single response");
180+
}
181+
182+
@Test
183+
void testAfterMethodHandlesMultipleGenerations() {
184+
ChatMemory chatMemory = MessageWindowChatMemory.builder()
185+
.chatMemoryRepository(new InMemoryChatMemoryRepository())
186+
.build();
187+
188+
PromptChatMemoryAdvisor advisor = PromptChatMemoryAdvisor.builder(chatMemory)
189+
.conversationId("test-conversation")
190+
.build();
191+
192+
ChatClientResponse mockResponse = mock(ChatClientResponse.class);
193+
ChatResponse mockChatResponse = mock(ChatResponse.class);
194+
Generation mockGen1 = mock(Generation.class);
195+
Generation mockGen2 = mock(Generation.class);
196+
Generation mockGen3 = mock(Generation.class);
197+
AdvisorChain mockChain = mock(AdvisorChain.class);
198+
199+
when(mockResponse.chatResponse()).thenReturn(mockChatResponse);
200+
when(mockChatResponse.getResults()).thenReturn(List.of(mockGen1, mockGen2, mockGen3)); // Multiple
201+
// results
202+
when(mockGen1.getOutput()).thenReturn(new AssistantMessage("Response 1"));
203+
when(mockGen2.getOutput()).thenReturn(new AssistantMessage("Response 2"));
204+
when(mockGen3.getOutput()).thenReturn(new AssistantMessage("Response 3"));
205+
206+
ChatClientResponse result = advisor.after(mockResponse, mockChain);
207+
208+
assertThat(result).isEqualTo(mockResponse); // Should return the same response
209+
210+
// Verify all messages were stored in memory
211+
List<Message> messages = chatMemory.get("test-conversation");
212+
assertThat(messages).hasSize(3);
213+
assertThat(messages.get(0).getText()).isEqualTo("Response 1");
214+
assertThat(messages.get(1).getText()).isEqualTo("Response 2");
215+
assertThat(messages.get(2).getText()).isEqualTo("Response 3");
216+
}
217+
218+
@Test
219+
void testAfterMethodHandlesEmptyResults() {
220+
ChatMemory chatMemory = MessageWindowChatMemory.builder()
221+
.chatMemoryRepository(new InMemoryChatMemoryRepository())
222+
.build();
223+
224+
PromptChatMemoryAdvisor advisor = PromptChatMemoryAdvisor.builder(chatMemory)
225+
.conversationId("test-conversation")
226+
.build();
227+
228+
ChatClientResponse mockResponse = mock(ChatClientResponse.class);
229+
ChatResponse mockChatResponse = mock(ChatResponse.class);
230+
AdvisorChain mockChain = mock(AdvisorChain.class);
231+
232+
when(mockResponse.chatResponse()).thenReturn(mockChatResponse);
233+
when(mockChatResponse.getResults()).thenReturn(List.of());
234+
235+
ChatClientResponse result = advisor.after(mockResponse, mockChain);
236+
237+
assertThat(result).isEqualTo(mockResponse);
238+
239+
// Verify no messages were stored in memory
240+
List<Message> messages = chatMemory.get("test-conversation");
241+
assertThat(messages).isEmpty();
242+
}
243+
244+
@Test
245+
void testAfterMethodHandlesNullChatResponse() {
246+
ChatMemory chatMemory = MessageWindowChatMemory.builder()
247+
.chatMemoryRepository(new InMemoryChatMemoryRepository())
248+
.build();
249+
250+
PromptChatMemoryAdvisor advisor = PromptChatMemoryAdvisor.builder(chatMemory)
251+
.conversationId("test-conversation")
252+
.build();
253+
254+
ChatClientResponse mockResponse = mock(ChatClientResponse.class);
255+
AdvisorChain mockChain = mock(AdvisorChain.class);
256+
257+
when(mockResponse.chatResponse()).thenReturn(null);
258+
259+
ChatClientResponse result = advisor.after(mockResponse, mockChain);
260+
261+
assertThat(result).isEqualTo(mockResponse);
262+
263+
// Verify no messages were stored in memory
264+
List<Message> messages = chatMemory.get("test-conversation");
265+
assertThat(messages).isEmpty();
266+
}
267+
141268
}

0 commit comments

Comments
 (0)