Skip to content

Commit

Permalink
feat:添加一种向量存储实现pgvector。
Browse files Browse the repository at this point in the history
  • Loading branch information
hkh1012 committed Aug 16, 2024
1 parent 9ca4b33 commit 526361c
Show file tree
Hide file tree
Showing 10 changed files with 263 additions and 16 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
| stream-chat | SSE |
| LLMs | openaichatglm2文心一言智谱AIKimi |
| embeddings | openaitext2vec-transformers文心一言 |
| vector store | weaviatemilvus |
| vector store | weaviatemilvuspgvector |

## langchain rag原理
<img src="src/main/resources/assets/langchain+chatglm.png" alt="原理图"/>
Expand Down
28 changes: 23 additions & 5 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -222,11 +222,29 @@
<artifactId>freemarker-java8</artifactId>
<version>1.3.0</version>
</dependency>
<!-- <dependency>-->
<!-- <groupId>me.zhyd.oauth</groupId>-->
<!-- <artifactId>JustAuth</artifactId>-->
<!-- <version>1.16.6</version>-->
<!-- </dependency>-->

<dependency>
<groupId>io.github.amikos-tech</groupId>
<artifactId>chromadb-java-client</artifactId>
<version>0.1.5</version>
</dependency>

<dependency>
<groupId>org.postgresql</groupId>
<artifactId>postgresql</artifactId>
<version>42.7.3</version>
</dependency>
<dependency>
<groupId>com.pgvector</groupId>
<artifactId>pgvector</artifactId>
<version>0.1.6</version>
</dependency>
<dependency>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-databind</artifactId>
<version>2.16.0</version>
</dependency>

</dependencies>

<build>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ public String visionCompletion(String content, List<String> imageUrlList) {
// 构建请求体
JSONObject body = new JSONObject();
body.put("messages",messages);
body.put("model","gpt-4-vision-preview");
body.put("model","gpt-4o");
// body.put("request_id", UUID.fastUUID().toString(true));


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ public class MilvusVectorStore implements VectorStore{
@Value("${chain.vector.store.milvus.host}")
private String milvusHost;
@Value("${chain.vector.store.milvus.port}")
private Integer milvausPort;
private Integer milvusPort;

@Value("${chain.vector.store.milvus.dimension}")
private Integer dimension;
Expand All @@ -51,7 +51,7 @@ public void init(){
milvusServiceClient = new MilvusServiceClient(
ConnectParam.newBuilder()
.withHost(milvusHost)
.withPort(milvausPort)
.withPort(milvusPort)
.withDatabaseName("default")
.build()
);
Expand Down
187 changes: 187 additions & 0 deletions src/main/java/com/hkh/ai/chain/vectorstore/PgVectorStore.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
package com.hkh.ai.chain.vectorstore;

import com.hkh.ai.chain.retrieve.PromptRetrieverProperties;
import com.pgvector.PGvector;
import jakarta.annotation.PostConstruct;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service;

import java.sql.*;
import java.util.ArrayList;
import java.util.List;

@Service
@Slf4j
public class PgVectorStore implements VectorStore{


@Value("${chain.vector.store.pgvector.host}")
private String pgHost;
@Value("${chain.vector.store.pgvector.port}")
private Integer pgPort;

@Value("${chain.vector.store.pgvector.dimension}")
private Integer dimension;

@Value("${chain.vector.store.pgvector.collection}")
private String collectionName;

private Connection connection;

private final PromptRetrieverProperties promptRetrieverProperties;

public PgVectorStore(PromptRetrieverProperties promptRetrieverProperties) {
this.promptRetrieverProperties = promptRetrieverProperties;
}

@PostConstruct
public void init(){
try {
Class.forName("org.postgresql.Driver");
// replace user and password with the configuration of your pg database
connection = DriverManager.getConnection("jdbc:postgresql://" + pgHost + ":"+ pgPort +"/pg","postgres","pg123456");
} catch (ClassNotFoundException e) {
throw new RuntimeException(e);
} catch (SQLException e) {
throw new RuntimeException(e);
}

}

private void createSchema(String kid) {
Statement createStmt;
try {
createStmt = connection.createStatement();
createStmt.executeUpdate("CREATE TABLE " + collectionName + kid +" (id bigserial PRIMARY KEY, content text, kid varchar(20), docId varchar(20),fid varchar(20),embedding vector(" + dimension + "))");
} catch (SQLException e) {
throw new RuntimeException(e);
}finally {
try {
connection.close();
} catch (SQLException e) {
throw new RuntimeException(e);
}
}
}

@Override
public void newSchema(String kid) {
createSchema(kid);
}

@Override
public void removeByKidAndFid(String kid, String fid) {
PreparedStatement stmt;
try {
stmt = connection.prepareStatement("delete from " + collectionName + kid +" where fid = ?");
stmt.setString(1,fid);
int rowsDeleted = stmt.executeUpdate();
System.out.println("pg deleted rows: " + rowsDeleted);
} catch (SQLException e) {
throw new RuntimeException(e);
}finally {
try {
connection.close();
} catch (SQLException e) {
throw new RuntimeException(e);
}
}
}

@Override
public void storeEmbeddings(List<String> chunkList, List<List<Double>> vectorList, String kid, String docId, List<String> fidList) {
try {
for (int i = 0; i < chunkList.size(); i++) {
PreparedStatement stmt = connection.prepareStatement("insert into " + collectionName + kid + " (content,kid,docId,fid,embedding) values (?,?,?,?,?)");
stmt.setString(1,chunkList.get(i));
stmt.setString(2,kid);
stmt.setString(3,docId);
stmt.setString(4,fidList.get(i));
stmt.setObject(5,new PGvector(vectorList.get(i)));
stmt.executeUpdate();
}
} catch (SQLException e) {
throw new RuntimeException(e);
}finally {
try {
connection.close();
} catch (SQLException e) {
throw new RuntimeException(e);
}
}
}



@Override
public void removeByDocId(String kid, String docId) {
PreparedStatement stmt;
try {
stmt = connection.prepareStatement("delete from " + collectionName + kid +" where docId = ?");
stmt.setString(1,docId);
int rowsDeleted = stmt.executeUpdate();
System.out.println("pg deleted rows: " + rowsDeleted);
} catch (SQLException e) {
throw new RuntimeException(e);
}finally {
try {
connection.close();
} catch (SQLException e) {
throw new RuntimeException(e);
}
}
}

@Override
public void removeByKid(String kid) {
PreparedStatement stmt;
try {
stmt = connection.prepareStatement("drop table " + collectionName + kid);
stmt.executeUpdate();
} catch (SQLException e) {
throw new RuntimeException(e);
}finally {
try {
connection.close();
} catch (SQLException e) {
throw new RuntimeException(e);
}
}
}

@Override
public List<String> nearest(List<Double> queryVector, String kid) {
PreparedStatement stmt;
List<String> result = new ArrayList<>();
try {
stmt = connection.prepareStatement("SELECT content FROM " + collectionName + kid + " ORDER BY embedding <=> ? LIMIT 5");
stmt.setObject(1, new PGvector(queryVector));
ResultSet rs = stmt.executeQuery();
while (rs.next()) {
result.add(rs.getString("content"));
}
return result;
} catch (SQLException e) {
throw new RuntimeException(e);
}finally {
try {
connection.close();
} catch (SQLException e) {
throw new RuntimeException(e);
}
}
}

/**
* milvus 不支持通过文本检索相似性
* @param query
* @param kid
* @return
*/
@Override
public List<String> nearest(String query, String kid) {
return null;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,21 @@ public class VectorStoreFactory {

private final MilvusVectorStore milvusVectorStore;

public VectorStoreFactory(WeaviateVectorStore weaviateVectorStore, MilvusVectorStore milvusVectorStore) {
private final PgVectorStore pgVectorStore;

public VectorStoreFactory(WeaviateVectorStore weaviateVectorStore, MilvusVectorStore milvusVectorStore, PgVectorStore pgVectorStore) {
this.weaviateVectorStore = weaviateVectorStore;
this.milvusVectorStore = milvusVectorStore;
this.pgVectorStore = pgVectorStore;
}

public VectorStore getVectorStore(){
if ("weaviate".equals(type)){
return weaviateVectorStore;
}else if ("milvus".equals(type)){
return milvusVectorStore;
}else if ("pg".equals(type)){
return pgVectorStore;
}
return null;
}
Expand Down
3 changes: 3 additions & 0 deletions src/main/resources/application-dev.properties
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ google.search.cx=xxxxxxxxxxxxxxxxxxxxxx
zhipu.ai.token=xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
kimi.ai.token=xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx

doubao.ai.token=xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
claude.ai.token=xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx

spring.datasource.url=jdbc:mysql://127.0.0.1:3306/ai?serverTimezone=Asia/Shanghai&characterEncoding=utf8&useUnicode=true&useSSL=false&autoReconnect=true&zeroDateTimeBehavior=convertToNull&allowMultiQueries=true&rewriteBatchedStatements=true
spring.datasource.username=root
spring.datasource.password=123
Expand Down
14 changes: 10 additions & 4 deletions src/main/resources/application.properties
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ chain.split.chunk.size=200
chain.split.chunk.overlay=0
chain.split.chunk.qaspliter=######

chain.vectorization.type=zhipu
chain.vectorization.type=openai
chain.vectorization.openai.token=${openai.token}
chain.vectorization.openai.model=text-embedding-ada-002
chain.vectorization.baidu.model=bge-large-zh
Expand All @@ -49,9 +49,14 @@ chain.vector.store.milvus.host=192.168.40.229
chain.vector.store.milvus.port=19530
chain.vector.store.milvus.dimension=1536
chain.vector.store.milvus.collection=LocalKnowledge
chain.vector.store.pgvector.host=192.168.40.229
chain.vector.store.pgvector.port=5433
chain.vector.store.pgvector.dimension=1536
chain.vector.store.pgvector.collection=LocalKnowledge


chain.llm.openai.token=${openai.token}
chain.llm.openai.model=gpt-4-1106-preview
chain.llm.openai.model=gpt-4o
#chain.llm.openai.model=gpt-4
chain.llm.chatglm.baseurl=http://127.0.0.1:8000/
chain.llm.chatglm.model=chatglm2-6b
Expand All @@ -60,9 +65,10 @@ chain.llm.baidu.secretKey=${baidu.secretKey}
chain.llm.baidu.model=ernie_bot
chain.llm.zhipu.model=glm-4
chain.llm.kimi.model=moonshot-v1-32k

chain.llm.audio.type=openai
chain.llm.text.type=kimi
chain.llm.function.type=baidu
chain.llm.text.type=openai
chain.llm.function.type=openai
chain.llm.vision.type=openai
chain.llm.image.type=openai

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
services:
# Qdrant vector store.
pgvector:
image: pgvector/pgvector:pg16
restart: always
environment:
PGUSER: postgres
# The password for the default postgres user.
POSTGRES_PASSWORD: pg123456
# The name of the default postgres database.
POSTGRES_DB: pg
# postgres data directory
PGDATA: /var/lib/postgresql/data/pgdata
volumes:
- ./volumes/pgvector/data:/var/lib/postgresql/data
# uncomment to expose db(postgresql) port to host
ports:
- "5433:5432"
healthcheck:
test: [ "CMD", "pg_isready" ]
interval: 1s
timeout: 3s
retries: 30
9 changes: 7 additions & 2 deletions src/main/resources/static/css/chat.css
Original file line number Diff line number Diff line change
Expand Up @@ -414,8 +414,8 @@ a:hover {
}

.visionSelectedPreviewImage{
width: 100%;
max-width: 800px;
width: auto;
max-height: 80%;
}


Expand Down Expand Up @@ -651,6 +651,11 @@ a:hover {
padding: 2%;
}

.visionSelectedPreviewImage{
width: 100%;
max-width: 800px;
}

}

.knowledge-select-container {
Expand Down

0 comments on commit 526361c

Please sign in to comment.