Skip to content

Commit 43aa893

Browse files
enricorampazzomarkpollack
authored andcommitted
Enhance Neo4jChatMemoryRepository, expand integration tests, improve configuration, and update documentation
- Enhanced Neo4jChatMemoryRepository to correctly restore custom metadata for SystemMessage using SystemMessage.Builder. - Refactored and clarified Neo4jChatMemoryRepository implementation code. - Added comprehensive integration tests for Neo4jChatMemoryConfig and Neo4jChatMemoryRepository, including: -- Index creation verification -- Custom label support -- Getter validation for all configuration properties -- Tests for saving and retrieving SystemMessage metadata -- Tests ensuring saveAll(conversationId, Collections.emptyList()) clears all messages and removes the conversation node -- Tests for handling of messages with empty content and empty metadata -- Improved overall test coverage for Neo4j persistence and configuration edge cases - Fixed resource management bugs in test classes (ensured proper driver/session closure). - Improved index creation logic in Neo4jChatMemoryConfig for reliability and logging. - Updated documentation to include Neo4jChatMemoryRepository usage and configuration Signed-off-by: enricorampazzo <[email protected]> Signed-off-by: Mark Pollack <[email protected]>
1 parent 4ec0676 commit 43aa893

File tree

11 files changed

+1113
-28
lines changed

11 files changed

+1113
-28
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
/*
2+
* Copyright 2024-2025 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.ai.model.chat.memory.jdbc.autoconfigure;
18+
19+
import javax.sql.DataSource;
20+
21+
import org.junit.jupiter.api.Test;
22+
import org.testcontainers.containers.PostgreSQLContainer;
23+
import org.testcontainers.junit.jupiter.Container;
24+
import org.testcontainers.junit.jupiter.Testcontainers;
25+
import org.testcontainers.utility.DockerImageName;
26+
27+
import org.springframework.boot.autoconfigure.AutoConfigurations;
28+
import org.springframework.boot.autoconfigure.jdbc.DataSourceAutoConfiguration;
29+
import org.springframework.boot.autoconfigure.jdbc.JdbcTemplateAutoConfiguration;
30+
import org.springframework.boot.test.context.runner.ApplicationContextRunner;
31+
32+
import static org.assertj.core.api.Assertions.assertThat;
33+
34+
/**
35+
* @author Jonathan Leijendekker
36+
*/
37+
@Testcontainers
38+
class JdbcChatMemoryDataSourceScriptDatabaseInitializerPostgresqlTests {
39+
40+
static final DockerImageName DEFAULT_IMAGE_NAME = DockerImageName.parse("postgres:17");
41+
42+
@Container
43+
@SuppressWarnings("resource")
44+
static PostgreSQLContainer<?> postgresContainer = new PostgreSQLContainer<>(DEFAULT_IMAGE_NAME)
45+
.withDatabaseName("chat_memory_initializer_test")
46+
.withUsername("postgres")
47+
.withPassword("postgres");
48+
49+
private final ApplicationContextRunner contextRunner = new ApplicationContextRunner()
50+
.withConfiguration(AutoConfigurations.of(JdbcChatMemoryAutoConfiguration.class,
51+
JdbcTemplateAutoConfiguration.class, DataSourceAutoConfiguration.class))
52+
.withPropertyValues(String.format("spring.datasource.url=%s", postgresContainer.getJdbcUrl()),
53+
String.format("spring.datasource.username=%s", postgresContainer.getUsername()),
54+
String.format("spring.datasource.password=%s", postgresContainer.getPassword()));
55+
56+
@Test
57+
void getSettings_shouldHaveSchemaLocations() {
58+
this.contextRunner.run(context -> {
59+
var dataSource = context.getBean(DataSource.class);
60+
var settings = JdbcChatMemoryDataSourceScriptDatabaseInitializer.getSettings(dataSource);
61+
62+
assertThat(settings.getSchemaLocations())
63+
.containsOnly("classpath:org/springframework/ai/chat/memory/jdbc/schema-postgresql.sql");
64+
});
65+
}
66+
67+
}

auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-neo4j/src/main/java/org/springframework/ai/model/chat/memory/neo4j/autoconfigure/Neo4jChatMemoryAutoConfiguration.java

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@
1818

1919
import org.neo4j.driver.Driver;
2020

21-
import org.springframework.ai.chat.memory.neo4j.Neo4jChatMemory;
2221
import org.springframework.ai.chat.memory.neo4j.Neo4jChatMemoryConfig;
22+
import org.springframework.ai.chat.memory.neo4j.Neo4jChatMemoryRepository;
2323
import org.springframework.ai.model.chat.memory.autoconfigure.ChatMemoryAutoConfiguration;
2424
import org.springframework.boot.autoconfigure.AutoConfiguration;
2525
import org.springframework.boot.autoconfigure.condition.ConditionalOnClass;
@@ -29,19 +29,19 @@
2929
import org.springframework.context.annotation.Bean;
3030

3131
/**
32-
* {@link AutoConfiguration Auto-configuration} for {@link Neo4jChatMemory}.
32+
* {@link AutoConfiguration Auto-configuration} for {@link Neo4jChatMemoryRepository}.
3333
*
3434
* @author Enrico Rampazzo
3535
* @since 1.0.0
3636
*/
3737
@AutoConfiguration(after = Neo4jAutoConfiguration.class, before = ChatMemoryAutoConfiguration.class)
38-
@ConditionalOnClass({ Neo4jChatMemory.class, Driver.class })
38+
@ConditionalOnClass({ Neo4jChatMemoryRepository.class, Driver.class })
3939
@EnableConfigurationProperties(Neo4jChatMemoryProperties.class)
4040
public class Neo4jChatMemoryAutoConfiguration {
4141

4242
@Bean
4343
@ConditionalOnMissingBean
44-
public Neo4jChatMemory chatMemory(Neo4jChatMemoryProperties properties, Driver driver) {
44+
public Neo4jChatMemoryRepository chatMemoryRepository(Neo4jChatMemoryProperties properties, Driver driver) {
4545

4646
var builder = Neo4jChatMemoryConfig.builder()
4747
.withMediaLabel(properties.getMediaLabel())
@@ -52,7 +52,7 @@ public Neo4jChatMemory chatMemory(Neo4jChatMemoryProperties properties, Driver d
5252
.withToolResponseLabel(properties.getToolResponseLabel())
5353
.withDriver(driver);
5454

55-
return Neo4jChatMemory.create(builder.build());
55+
return new Neo4jChatMemoryRepository(builder.build());
5656
}
5757

5858
}
Lines changed: 24 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,14 @@
2323
import java.util.UUID;
2424

2525
import org.junit.jupiter.api.Test;
26+
import org.springframework.ai.chat.memory.ChatMemory;
27+
import org.springframework.ai.chat.memory.ChatMemoryRepository;
28+
import org.springframework.ai.chat.memory.neo4j.Neo4jChatMemoryRepository;
2629
import org.testcontainers.containers.Neo4jContainer;
2730
import org.testcontainers.junit.jupiter.Container;
2831
import org.testcontainers.junit.jupiter.Testcontainers;
2932
import org.testcontainers.utility.DockerImageName;
3033

31-
import org.springframework.ai.chat.memory.neo4j.Neo4jChatMemory;
3234
import org.springframework.ai.chat.memory.neo4j.Neo4jChatMemoryConfig;
3335
import org.springframework.ai.chat.messages.AssistantMessage;
3436
import org.springframework.ai.chat.messages.Message;
@@ -51,7 +53,7 @@
5153
* @since 1.0.0
5254
*/
5355
@Testcontainers
54-
class Neo4jChatMemoryAutoConfigurationIT {
56+
class Neo4jChatMemoryRepositoryAutoConfigurationIT {
5557

5658
static final DockerImageName DEFAULT_IMAGE_NAME = DockerImageName.parse("neo4j");
5759

@@ -67,31 +69,31 @@ class Neo4jChatMemoryAutoConfigurationIT {
6769
@Test
6870
void addAndGet() {
6971
this.contextRunner.withPropertyValues("spring.neo4j.uri=" + neo4jContainer.getBoltUrl()).run(context -> {
70-
Neo4jChatMemory memory = context.getBean(Neo4jChatMemory.class);
72+
ChatMemoryRepository memory = context.getBean(ChatMemoryRepository.class);
7173

7274
String sessionId = UUID.randomUUID().toString();
73-
assertThat(memory.get(sessionId, Integer.MAX_VALUE)).isEmpty();
75+
assertThat(memory.findByConversationId(sessionId)).isEmpty();
7476

7577
UserMessage userMessage = new UserMessage("test question");
7678

77-
memory.add(sessionId, userMessage);
78-
List<Message> messages = memory.get(sessionId, Integer.MAX_VALUE);
79+
memory.saveAll(sessionId, List.of(userMessage));
80+
List<Message> messages = memory.findByConversationId(sessionId);
7981
assertThat(messages).hasSize(1);
8082
assertThat(messages.get(0)).usingRecursiveAssertion().isEqualTo(userMessage);
8183

82-
memory.clear(sessionId);
83-
assertThat(memory.get(sessionId, Integer.MAX_VALUE)).isEmpty();
84+
memory.deleteByConversationId(sessionId);
85+
assertThat(memory.findByConversationId(sessionId)).isEmpty();
8486

8587
AssistantMessage assistantMessage = new AssistantMessage("test answer", Map.of(),
8688
List.of(new AssistantMessage.ToolCall("id", "type", "name", "arguments")));
8789

88-
memory.add(sessionId, List.of(userMessage, assistantMessage));
89-
messages = memory.get(sessionId, Integer.MAX_VALUE);
90+
memory.saveAll(sessionId, List.of(userMessage, assistantMessage));
91+
messages = memory.findByConversationId(sessionId);
9092
assertThat(messages).hasSize(2);
91-
assertThat(messages.get(1)).isEqualTo(userMessage);
93+
assertThat(messages.get(0)).isEqualTo(userMessage);
9294

93-
assertThat(messages.get(0)).isEqualTo(assistantMessage);
94-
memory.clear(sessionId);
95+
assertThat(messages.get(1)).isEqualTo(assistantMessage);
96+
memory.deleteByConversationId(sessionId);
9597
MimeType textPlain = MimeType.valueOf("text/plain");
9698
List<Media> media = List.of(
9799
Media.builder()
@@ -102,28 +104,28 @@ void addAndGet() {
102104
.build(),
103105
Media.builder().data(URI.create("http://www.google.com")).mimeType(textPlain).build());
104106
UserMessage userMessageWithMedia = UserMessage.builder().text("Message with media").media(media).build();
105-
memory.add(sessionId, userMessageWithMedia);
107+
memory.saveAll(sessionId, List.of(userMessageWithMedia));
106108

107-
messages = memory.get(sessionId, Integer.MAX_VALUE);
109+
messages = memory.findByConversationId(sessionId);
108110
assertThat(messages.size()).isEqualTo(1);
109111
assertThat(messages.get(0)).isEqualTo(userMessageWithMedia);
110112
assertThat(((UserMessage) messages.get(0)).getMedia()).hasSize(2);
111113
assertThat(((UserMessage) messages.get(0)).getMedia()).usingRecursiveFieldByFieldElementComparator()
112114
.isEqualTo(media);
113-
memory.clear(sessionId);
115+
memory.deleteByConversationId(sessionId);
114116
ToolResponseMessage toolResponseMessage = new ToolResponseMessage(
115117
List.of(new ToolResponse("id", "name", "responseData"),
116118
new ToolResponse("id2", "name2", "responseData2")),
117119
Map.of("id", "id", "metadataKey", "metadata"));
118-
memory.add(sessionId, toolResponseMessage);
119-
messages = memory.get(sessionId, Integer.MAX_VALUE);
120+
memory.saveAll(sessionId, List.of(toolResponseMessage));
121+
messages = memory.findByConversationId(sessionId);
120122
assertThat(messages.size()).isEqualTo(1);
121123
assertThat(messages.get(0)).isEqualTo(toolResponseMessage);
122124

123-
memory.clear(sessionId);
125+
memory.deleteByConversationId(sessionId);
124126
SystemMessage sm = new SystemMessage("this is a System message");
125-
memory.add(sessionId, sm);
126-
messages = memory.get(sessionId, Integer.MAX_VALUE);
127+
memory.saveAll(sessionId, List.of(sm));
128+
messages = memory.findByConversationId(sessionId);
127129
assertThat(messages).hasSize(1);
128130
assertThat(messages.get(0)).usingRecursiveAssertion().isEqualTo(sm);
129131
});
@@ -148,7 +150,7 @@ void setCustomConfiguration() {
148150
propertyBase.formatted("toolresponselabel", toolResponseLabel),
149151
propertyBase.formatted("medialabel", mediaLabel))
150152
.run(context -> {
151-
Neo4jChatMemory chatMemory = context.getBean(Neo4jChatMemory.class);
153+
Neo4jChatMemoryRepository chatMemory = context.getBean(Neo4jChatMemoryRepository.class);
152154
Neo4jChatMemoryConfig config = chatMemory.getConfig();
153155
assertThat(config.getMessageLabel()).isEqualTo(messageLabel);
154156
assertThat(config.getMediaLabel()).isEqualTo(mediaLabel);
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
/*
2+
* Copyright 2023-2024 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.ai.model.embedding.observation.autoconfigure;
18+
19+
import io.micrometer.core.instrument.MeterRegistry;
20+
21+
import org.springframework.ai.embedding.EmbeddingModel;
22+
import org.springframework.ai.embedding.observation.EmbeddingModelMeterObservationHandler;
23+
import org.springframework.beans.factory.ObjectProvider;
24+
import org.springframework.boot.autoconfigure.AutoConfiguration;
25+
import org.springframework.boot.autoconfigure.condition.ConditionalOnBean;
26+
import org.springframework.boot.autoconfigure.condition.ConditionalOnClass;
27+
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
28+
import org.springframework.context.annotation.Bean;
29+
30+
/**
31+
* Auto-configuration for Spring AI embedding model observations.
32+
*
33+
* @author Thomas Vitale
34+
* @since 1.0.0
35+
*/
36+
@AutoConfiguration(
37+
afterName = "org.springframework.boot.actuate.autoconfigure.observation.ObservationAutoConfiguration")
38+
@ConditionalOnClass(EmbeddingModel.class)
39+
public class EmbeddingObservationAutoConfiguration {
40+
41+
@Bean
42+
@ConditionalOnMissingBean
43+
@ConditionalOnBean(MeterRegistry.class)
44+
EmbeddingModelMeterObservationHandler embeddingModelMeterObservationHandler(
45+
ObjectProvider<MeterRegistry> meterRegistry) {
46+
return new EmbeddingModelMeterObservationHandler(meterRegistry.getObject());
47+
}
48+
49+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
/*
2+
* Copyright 2023-2024 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.ai.model.embedding.observation.autoconfigure;
18+
19+
import io.micrometer.core.instrument.composite.CompositeMeterRegistry;
20+
import org.junit.jupiter.api.Test;
21+
22+
import org.springframework.ai.embedding.observation.EmbeddingModelMeterObservationHandler;
23+
import org.springframework.boot.autoconfigure.AutoConfigurations;
24+
import org.springframework.boot.test.context.runner.ApplicationContextRunner;
25+
26+
import static org.assertj.core.api.Assertions.assertThat;
27+
28+
/**
29+
* Unit tests for {@link EmbeddingObservationAutoConfiguration}.
30+
*
31+
* @author Thomas Vitale
32+
*/
33+
class EmbeddingObservationAutoConfigurationTests {
34+
35+
private final ApplicationContextRunner contextRunner = new ApplicationContextRunner()
36+
.withConfiguration(AutoConfigurations.of(EmbeddingObservationAutoConfiguration.class));
37+
38+
@Test
39+
void meterObservationHandlerEnabled() {
40+
this.contextRunner.withBean(CompositeMeterRegistry.class)
41+
.run(context -> assertThat(context).hasSingleBean(EmbeddingModelMeterObservationHandler.class));
42+
}
43+
44+
@Test
45+
void meterObservationHandlerDisabled() {
46+
this.contextRunner
47+
.run(context -> assertThat(context).doesNotHaveBean(EmbeddingModelMeterObservationHandler.class));
48+
}
49+
50+
}

memory/spring-ai-model-chat-memory-neo4j/pom.xml

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
<artifactId>spring-data-neo4j</artifactId>
5151
</dependency>
5252

53+
<!-- TESTING -->
5354
<dependency>
5455
<groupId>org.springframework.boot</groupId>
5556
<artifactId>spring-boot-starter-test</artifactId>
@@ -68,6 +69,12 @@
6869
<artifactId>spring-boot-testcontainers</artifactId>
6970
<scope>test</scope>
7071
</dependency>
72+
73+
<dependency>
74+
<groupId>org.testcontainers</groupId>
75+
<artifactId>testcontainers</artifactId>
76+
<scope>test</scope>
77+
</dependency>
7178

7279
<dependency>
7380
<groupId>org.neo4j.driver</groupId>
@@ -79,6 +86,12 @@
7986
<artifactId>neo4j</artifactId>
8087
<scope>test</scope>
8188
</dependency>
89+
90+
<dependency>
91+
<groupId>org.testcontainers</groupId>
92+
<artifactId>junit-jupiter</artifactId>
93+
<scope>test</scope>
94+
</dependency>
8295
</dependencies>
8396

8497
</project>

memory/spring-ai-model-chat-memory-neo4j/src/main/java/org/springframework/ai/chat/memory/neo4j/Neo4jChatMemoryConfig.java

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,33 @@ private Neo4jChatMemoryConfig(Builder builder) {
9393
this.toolCallLabel = builder.toolCallLabel;
9494
this.metadataLabel = builder.metadataLabel;
9595
this.toolResponseLabel = builder.toolResponseLabel;
96+
ensureIndexes();
97+
}
98+
99+
/**
100+
* Ensures that indexes exist on conversationId for Session nodes and index for
101+
* Message nodes. This improves query performance for lookups and ordering.
102+
*/
103+
private void ensureIndexes() {
104+
if (this.driver == null) {
105+
logger.warn("Neo4j Driver is null, cannot ensure indexes.");
106+
return;
107+
}
108+
try (var session = this.driver.session()) {
109+
// Index for conversationId on Session nodes
110+
String sessionIndexCypher = String.format(
111+
"CREATE INDEX session_conversation_id_index IF NOT EXISTS FOR (n:%s) ON (n.conversationId)",
112+
this.sessionLabel);
113+
// Index for index on Message nodes
114+
String messageIndexCypher = String
115+
.format("CREATE INDEX message_index_index IF NOT EXISTS FOR (n:%s) ON (n.index)", this.messageLabel);
116+
session.run(sessionIndexCypher);
117+
session.run(messageIndexCypher);
118+
logger.info("Ensured Neo4j indexes for conversationId and message index.");
119+
}
120+
catch (Exception e) {
121+
logger.warn("Failed to ensure Neo4j indexes for chat memory: {}", e.getMessage());
122+
}
96123
}
97124

98125
public static Builder builder() {

0 commit comments

Comments
 (0)