|
16 | 16 |
|
17 | 17 | package org.springframework.ai.chat.client.advisor;
|
18 | 18 |
|
| 19 | +import java.util.List; |
| 20 | + |
19 | 21 | import org.junit.jupiter.api.Test;
|
20 | 22 | import reactor.core.scheduler.Schedulers;
|
21 | 23 |
|
| 24 | +import org.springframework.ai.chat.client.ChatClientResponse; |
22 | 25 | import org.springframework.ai.chat.client.advisor.api.Advisor;
|
| 26 | +import org.springframework.ai.chat.client.advisor.api.AdvisorChain; |
23 | 27 | import org.springframework.ai.chat.memory.ChatMemory;
|
24 | 28 | import org.springframework.ai.chat.memory.InMemoryChatMemoryRepository;
|
25 | 29 | import org.springframework.ai.chat.memory.MessageWindowChatMemory;
|
| 30 | +import org.springframework.ai.chat.messages.AssistantMessage; |
| 31 | +import org.springframework.ai.chat.messages.Message; |
| 32 | +import org.springframework.ai.chat.model.ChatResponse; |
| 33 | +import org.springframework.ai.chat.model.Generation; |
26 | 34 | import org.springframework.ai.chat.prompt.PromptTemplate;
|
27 | 35 |
|
28 | 36 | import static org.assertj.core.api.Assertions.assertThat;
|
29 | 37 | import static org.assertj.core.api.Assertions.assertThatThrownBy;
|
| 38 | +import static org.mockito.Mockito.mock; |
| 39 | +import static org.mockito.Mockito.when; |
30 | 40 |
|
31 | 41 | /**
|
32 | 42 | * Unit tests for {@link PromptChatMemoryAdvisor}.
|
33 | 43 | *
|
34 | 44 | * @author Mark Pollack
|
35 | 45 | * @author Thomas Vitale
|
| 46 | + * @author Soby Chacko |
36 | 47 | */
|
37 | 48 | public class PromptChatMemoryAdvisorTests {
|
38 | 49 |
|
@@ -138,4 +149,120 @@ void testDefaultValues() {
|
138 | 149 | assertThat(advisor.getOrder()).isEqualTo(Advisor.DEFAULT_CHAT_MEMORY_PRECEDENCE_ORDER);
|
139 | 150 | }
|
140 | 151 |
|
| 152 | + @Test |
| 153 | + void testAfterMethodHandlesSingleGeneration() { |
| 154 | + ChatMemory chatMemory = MessageWindowChatMemory.builder() |
| 155 | + .chatMemoryRepository(new InMemoryChatMemoryRepository()) |
| 156 | + .build(); |
| 157 | + |
| 158 | + PromptChatMemoryAdvisor advisor = PromptChatMemoryAdvisor.builder(chatMemory) |
| 159 | + .conversationId("test-conversation") |
| 160 | + .build(); |
| 161 | + |
| 162 | + ChatClientResponse mockResponse = mock(ChatClientResponse.class); |
| 163 | + ChatResponse mockChatResponse = mock(ChatResponse.class); |
| 164 | + Generation mockGeneration = mock(Generation.class); |
| 165 | + AdvisorChain mockChain = mock(AdvisorChain.class); |
| 166 | + |
| 167 | + when(mockResponse.chatResponse()).thenReturn(mockChatResponse); |
| 168 | + when(mockChatResponse.getResults()).thenReturn(List.of(mockGeneration)); // Single |
| 169 | + // result |
| 170 | + when(mockGeneration.getOutput()).thenReturn(new AssistantMessage("Single response")); |
| 171 | + |
| 172 | + ChatClientResponse result = advisor.after(mockResponse, mockChain); |
| 173 | + |
| 174 | + assertThat(result).isEqualTo(mockResponse); // Should return the same response |
| 175 | + |
| 176 | + // Verify single message stored in memory |
| 177 | + List<Message> messages = chatMemory.get("test-conversation"); |
| 178 | + assertThat(messages).hasSize(1); |
| 179 | + assertThat(messages.get(0).getText()).isEqualTo("Single response"); |
| 180 | + } |
| 181 | + |
| 182 | + @Test |
| 183 | + void testAfterMethodHandlesMultipleGenerations() { |
| 184 | + ChatMemory chatMemory = MessageWindowChatMemory.builder() |
| 185 | + .chatMemoryRepository(new InMemoryChatMemoryRepository()) |
| 186 | + .build(); |
| 187 | + |
| 188 | + PromptChatMemoryAdvisor advisor = PromptChatMemoryAdvisor.builder(chatMemory) |
| 189 | + .conversationId("test-conversation") |
| 190 | + .build(); |
| 191 | + |
| 192 | + ChatClientResponse mockResponse = mock(ChatClientResponse.class); |
| 193 | + ChatResponse mockChatResponse = mock(ChatResponse.class); |
| 194 | + Generation mockGen1 = mock(Generation.class); |
| 195 | + Generation mockGen2 = mock(Generation.class); |
| 196 | + Generation mockGen3 = mock(Generation.class); |
| 197 | + AdvisorChain mockChain = mock(AdvisorChain.class); |
| 198 | + |
| 199 | + when(mockResponse.chatResponse()).thenReturn(mockChatResponse); |
| 200 | + when(mockChatResponse.getResults()).thenReturn(List.of(mockGen1, mockGen2, mockGen3)); // Multiple |
| 201 | + // results |
| 202 | + when(mockGen1.getOutput()).thenReturn(new AssistantMessage("Response 1")); |
| 203 | + when(mockGen2.getOutput()).thenReturn(new AssistantMessage("Response 2")); |
| 204 | + when(mockGen3.getOutput()).thenReturn(new AssistantMessage("Response 3")); |
| 205 | + |
| 206 | + ChatClientResponse result = advisor.after(mockResponse, mockChain); |
| 207 | + |
| 208 | + assertThat(result).isEqualTo(mockResponse); // Should return the same response |
| 209 | + |
| 210 | + // Verify all messages were stored in memory |
| 211 | + List<Message> messages = chatMemory.get("test-conversation"); |
| 212 | + assertThat(messages).hasSize(3); |
| 213 | + assertThat(messages.get(0).getText()).isEqualTo("Response 1"); |
| 214 | + assertThat(messages.get(1).getText()).isEqualTo("Response 2"); |
| 215 | + assertThat(messages.get(2).getText()).isEqualTo("Response 3"); |
| 216 | + } |
| 217 | + |
| 218 | + @Test |
| 219 | + void testAfterMethodHandlesEmptyResults() { |
| 220 | + ChatMemory chatMemory = MessageWindowChatMemory.builder() |
| 221 | + .chatMemoryRepository(new InMemoryChatMemoryRepository()) |
| 222 | + .build(); |
| 223 | + |
| 224 | + PromptChatMemoryAdvisor advisor = PromptChatMemoryAdvisor.builder(chatMemory) |
| 225 | + .conversationId("test-conversation") |
| 226 | + .build(); |
| 227 | + |
| 228 | + ChatClientResponse mockResponse = mock(ChatClientResponse.class); |
| 229 | + ChatResponse mockChatResponse = mock(ChatResponse.class); |
| 230 | + AdvisorChain mockChain = mock(AdvisorChain.class); |
| 231 | + |
| 232 | + when(mockResponse.chatResponse()).thenReturn(mockChatResponse); |
| 233 | + when(mockChatResponse.getResults()).thenReturn(List.of()); |
| 234 | + |
| 235 | + ChatClientResponse result = advisor.after(mockResponse, mockChain); |
| 236 | + |
| 237 | + assertThat(result).isEqualTo(mockResponse); |
| 238 | + |
| 239 | + // Verify no messages were stored in memory |
| 240 | + List<Message> messages = chatMemory.get("test-conversation"); |
| 241 | + assertThat(messages).isEmpty(); |
| 242 | + } |
| 243 | + |
| 244 | + @Test |
| 245 | + void testAfterMethodHandlesNullChatResponse() { |
| 246 | + ChatMemory chatMemory = MessageWindowChatMemory.builder() |
| 247 | + .chatMemoryRepository(new InMemoryChatMemoryRepository()) |
| 248 | + .build(); |
| 249 | + |
| 250 | + PromptChatMemoryAdvisor advisor = PromptChatMemoryAdvisor.builder(chatMemory) |
| 251 | + .conversationId("test-conversation") |
| 252 | + .build(); |
| 253 | + |
| 254 | + ChatClientResponse mockResponse = mock(ChatClientResponse.class); |
| 255 | + AdvisorChain mockChain = mock(AdvisorChain.class); |
| 256 | + |
| 257 | + when(mockResponse.chatResponse()).thenReturn(null); |
| 258 | + |
| 259 | + ChatClientResponse result = advisor.after(mockResponse, mockChain); |
| 260 | + |
| 261 | + assertThat(result).isEqualTo(mockResponse); |
| 262 | + |
| 263 | + // Verify no messages were stored in memory |
| 264 | + List<Message> messages = chatMemory.get("test-conversation"); |
| 265 | + assertThat(messages).isEmpty(); |
| 266 | + } |
| 267 | + |
141 | 268 | }
|
0 commit comments