Skip to content

Commit

Permalink
feat: 图文多模态 vision chat 功能。
Browse files Browse the repository at this point in the history
  • Loading branch information
hkh1012 committed Feb 19, 2024
1 parent 11a9078 commit f62ed50
Show file tree
Hide file tree
Showing 26 changed files with 539 additions and 106 deletions.
26 changes: 12 additions & 14 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,31 +17,29 @@
| openai库 | openai-java |
| 前端 | freemarkerbootstrapjqueryrecorder.js |
| stream-chat | SSE |
| LLMs | openaichatglm2文心一言 |
| LLMs | openaichatglm2文心一言 智普AI |
| embeddings | openaitext2vec-transformers文心一言 |
| vector store | weaviatemilvus |

## langchain 原理
<img src="src/main/resources/assets/langchain+chatglm.png" alt="原理图"/>

## 路线图
已完成本地知识库上传及完成openaichatglm2两个LLMs模型流式聊天功能未来计划会接入更多大语言模型以满足更多需求场景
## 模型能力矩阵
| 模型/能力 | 文本生成 | 流式输出 | 语音 | 函数调用 | 图片生成 | 多模态(VISION) | 嵌入EMBEDDING |
|-----------|------|------|-----|------|---------|-------------|-------------|
| openai | 支持 | 支持 | 支持 | 支持 |-| 支持 | 支持 |
| 百度(文心) | 支持 | 支持 | - | 支持 |-| - | 支持 |
| 智普(GLM-4) | 支持 | 支持 | - | 支持 |-| 支持 | 支持 |
| chatglm2 | 支持 | 支持 | - | - |-| - | - |
| ... ... | - | - | - | - |-| - |- |

## 功能路线图
已完成本地知识库上传及完成openaichatglm2百度智普GLM-4四个LLMs模型流式聊天功能未来计划会接入更多大语言模型以满足更多需求场景
- [ ] Langchain 知识库
- [x] 接入非结构化文档已支持 mdpdfdocxtxtcsv 等文件格式
- [ ] 搜索引擎接入
- [ ] 结构化数据接入如ExcelSQL
- [ ] 知识图谱/图数据库接入
- [ ] 增加更多 LLM 模型支持
- [x] [OPENAI](https://platform.openai.com/docs/api-reference)
- [x] [THUDM/chatglm2-6b](https://huggingface.co/THUDM/chatglm2-6b)
- [x] [百度/文心一言](https://cloud.baidu.com/doc/WENXINWORKSHOP/s/clntwmv7t)
- [ ] [THUDM/chatglm-6b](https://huggingface.co/THUDM/chatglm-6b)
-
- [ ] 增加更多 Embedding 模型支持
- [x] [OPENAI/embedding](https://platform.openai.com/docs/api-reference/embeddings)
- [x] [weaviate/text2vec-transformers](https://weaviate.io/developers/weaviate/modules/retriever-vectorizer-modules/text2vec-transformers)
- [x] [百度/文心一言](https://cloud.baidu.com/doc/WENXINWORKSHOP/s/dllz04sro)
- [ ] [shibing624/text2vec-base-chinese](https://huggingface.co/shibing624/text2vec-base-chinese)
- [ ] 系统功能
- [ ] 用户
- [x] 用户登录
Expand Down
10 changes: 10 additions & 0 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,16 @@
<artifactId>mysql-connector-j</artifactId>
<scope>runtime</scope>
</dependency>
<dependency>
<groupId>io.projectreactor.netty</groupId>
<artifactId>reactor-netty-core</artifactId>
<version>1.1.16</version>
</dependency>
<dependency>
<groupId>io.projectreactor.netty</groupId>
<artifactId>reactor-netty-http</artifactId>
<version>1.1.16</version>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-configuration-processor</artifactId>
Expand Down
6 changes: 3 additions & 3 deletions src/main/java/com/hkh/ai/chain/callback/FunctionCaller.java
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,15 @@
public class FunctionCaller {

/**
* 查询天气API
* 查询天气API(模拟)
* @param jsonNode
* @return
*/
public String get_location_weather(JSONObject jsonNode){
String location = jsonNode.getString("location");
String datePeriod = jsonNode.getString("datePeriod");
String valueByName = DatePeriod.getValueByName(datePeriod);
String result = "(模拟)通过查询天气API得到:" +location + valueByName + "天气为晴天";
// String valueByName = DatePeriod.getValueByName(datePeriod);
String result = "(模拟)通过查询天气API得到:" +location + datePeriod + "天气为晴天";
log.info(result);
return result;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ public class ZhipuAiUtil {
private String completionModel;

/**
* 百度千帆开饭的embedding api model(默认:bge-large-zh)
* 可选模型:Embedding-V1、bge-large-zh、bge-large-en
* 可选模型:embedding-2
*
*/
@Value("${chain.vectorization.zhipu.model}")
private String embeddingModel;
Expand All @@ -60,12 +60,14 @@ public String getAccessToken(){
queryWrapper.ge("expired_time",LocalDateTime.now().plusSeconds(60L));
AccessToken accessToken = accessTokenService.getOne(queryWrapper,false);
if (accessToken == null){
String[] keys = appKey.split(".");
String[] keys = appKey.split("\\.");
LocalDateTime now = LocalDateTime.now();
LocalDateTime expiredTime = now.plusSeconds(60 * 60 * 24);

Map<String,String> jwtHeader = new HashMap<>();
jwtHeader.put("alg","HS256");
jwtHeader.put("sign_type","SIGN");

Map<String,Object> payloads = new HashMap<>();
payloads.put("api_key",keys[0]);
payloads.put("exp",expiredTime.atZone(ZoneId.systemDefault()).toInstant().toEpochMilli());
Expand All @@ -74,7 +76,7 @@ public String getAccessToken(){
JWT jwt = JWT.create()
.addHeaders(jwtHeader)
.addPayloads(payloads)
.setSigner(JWTSignerUtil.createSigner("HS256", DatatypeConverter.parseBase64Binary(keys[1])));
.setSigner(JWTSignerUtil.createSigner("HS256", keys[1].getBytes()));
String jwtToken = jwt.sign();

AccessToken newAccessToken = new AccessToken();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,8 @@ public List<FunctionCompletionResult> functionCompletion(String content, List<Ch
functionCompletionResult.setType("function");
functionCompletionResult.setName(choice.getMessage().getFunctionCall().getName());
JsonNode arguments = choice.getMessage().getFunctionCall().getArguments();
JSONObject jsonObject = JSONObject.parseObject(arguments.asText());
functionCompletionResult.setArguments(jsonObject);
String argumentsStr = JSONObject.toJSONString(arguments);
functionCompletionResult.setArguments(JSONObject.parseObject(argumentsStr));
functionResultList.add(functionCompletionResult);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import cn.hutool.http.HttpRequest;
import com.alibaba.fastjson2.JSONArray;
import com.alibaba.fastjson2.JSONObject;
import com.alibaba.fastjson2.JSONReader;
import com.hkh.ai.chain.llm.capabilities.generation.ZhipuAiUtil;
import com.hkh.ai.chain.llm.capabilities.generation.ZhipuChatApis;
import com.hkh.ai.chain.llm.capabilities.generation.function.ChatFunctionObject;
Expand Down Expand Up @@ -34,6 +35,14 @@ public class ZhipuAiFunctionChatService implements FunctionChatService {
public List<FunctionCompletionResult> functionCompletion(String content, List<ChatFunctionObject> functionObjectList) {
String accessToken = zhipuAiUtil.getAccessToken();

JSONArray toolArray = new JSONArray();
for (ChatFunctionObject functionObject : functionObjectList){
JSONObject tool = new JSONObject();
tool.put("type","function");
tool.put("function",functionObject);
toolArray.add(tool);
}

// 构建 message
JSONArray messages = new JSONArray();
JSONObject jsonObject = new JSONObject();
Expand All @@ -45,16 +54,24 @@ public List<FunctionCompletionResult> functionCompletion(String content, List<Ch
JSONObject body = new JSONObject();
body.put("messages",messages);
body.put("model",zhipuAiUtil.getCompletionModel());
body.put("tools", functionObjectList);
body.put("tools", toolArray);
body.put("tool_choice","auto");

HttpRequest httpRequest = new HttpRequest(UrlBuilder.of(ZhipuChatApis.COMPLETION_TEXT));
httpRequest.header("Authorization",accessToken);
httpRequest.header("content-type","application/json");
httpRequest.body(body.toJSONString());
String resultStr = httpRequest.execute().body();
String resultStr = HttpRequest.post(ZhipuChatApis.COMPLETION_TEXT)
.header("Authorization",accessToken)
.header("content-type","application/json")
.body(body.toJSONString())
.execute().body();

// 返回的是非标json,需要特殊处理。。。
resultStr = resultStr
.replaceAll("\\\\\\{","{")
.replaceAll("\\\\}\"","}")
.replaceAll("\\\\\"", "\"")
.replaceAll("\"\\{","{")
.replaceAll("}\"","}");

BlockCompletionResult result = JSONObject.parseObject(resultStr, BlockCompletionResult.class);
BlockCompletionResult result = JSONObject.parseObject(resultStr, BlockCompletionResult.class, JSONReader.Feature.AllowUnQuotedFieldNames);
List<BlockCompletionResult.BlockCompletionResultChoiceMessageToolCall> tool_calls = result.getChoices().get(0).getMessage().getTool_calls();
List<FunctionCompletionResult> functionResultList = new ArrayList<>();
for (int i = 0; i < tool_calls.size(); i++) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ public class BaiduQianFanCompletionWebClient {

@PostConstruct
public void init(){
log.info("baidu api web client init...");
log.info("baidu ai web client init...");
this.webClient = WebClient.builder()
.defaultHeader(HttpHeaders.CONTENT_TYPE, "application/json")
.build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,30 +41,32 @@ public class ZhipuCompletionBizProcessor {
private final List<Integer> promptTokens;

public void bizProcess(String item){
log.info(item);
StreamCompletionResult resultObj = JSONObject.parseObject(item, StreamCompletionResult.class);
String content = resultObj.getChoices().get(0).getDelta().getContent();
try {
if (StringUtils.isNotBlank(resultObj.getChoices().get(0).getFinish_reason())) {
this.getSseEmitter().send("[END]");
String fullContent = this.getSb().toString();
List<Integer> completionToken = this.getEnc().encode(fullContent);
System.out.println("total token costs: " + (this.getPromptTokens().size() + completionToken.size()));
this.getConversationService().saveConversation(this.getSysUser().getId(), this.getRequest().getSessionId(), this.getSb().toString(), "A");
} else {
if (content.contains("\n") || content.contains("\r")) {
content = content.replaceAll("\n", "<br>");
content = content.replaceAll("\r", "<br>");
System.out.println("智普流式输出:" +item);
if (!"[DONE]".equals(item)){
StreamCompletionResult resultObj = JSONObject.parseObject(item, StreamCompletionResult.class);
String content = resultObj.getChoices().get(0).getDelta().getContent();
try {
if (StringUtils.isNotBlank(resultObj.getChoices().get(0).getFinish_reason())) {
this.getSseEmitter().send("[END]");
String fullContent = this.getSb().toString();
List<Integer> completionToken = this.getEnc().encode(fullContent);
System.out.println("total token costs: " + (this.getPromptTokens().size() + completionToken.size()));
this.getConversationService().saveConversation(this.getSysUser().getId(), this.getRequest().getSessionId(), this.getSb().toString(), "A");
} else {
if (content.contains("\n") || content.contains("\r")) {
content = content.replaceAll("\n", "<br>");
content = content.replaceAll("\r", "<br>");
}
if (content.contains(" ")) {
content = content.replaceAll(" ", "&nbsp;");
}
this.getSb().append(content);
this.getSseEmitter().send(content);
}
if (content.contains(" ")) {
content = content.replaceAll(" ", "&nbsp;");
}
this.getSb().append(content);
this.getSseEmitter().send(content);
} catch (IOException e) {
log.error("ZhipuCompletionBizProcessor--->>bizProcess异常", e);
throw new RuntimeException(e);
}
} catch (IOException e) {
log.error("ZhipuCompletionBizProcessor--->>bizProcess异常", e);
throw new RuntimeException(e);
}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
import com.alibaba.fastjson2.JSONObject;
import com.hkh.ai.chain.llm.capabilities.generation.BaiduQianFanUtil;
import com.hkh.ai.chain.llm.capabilities.generation.ZhipuAiUtil;
import com.hkh.ai.chain.llm.capabilities.generation.ZhipuChatApis;
import jakarta.annotation.PostConstruct;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpStatusCode;
import org.springframework.http.MediaType;
import org.springframework.stereotype.Component;
import org.springframework.web.reactive.function.client.WebClient;
import org.springframework.web.reactive.function.client.WebClientResponseException;
Expand All @@ -30,26 +32,25 @@ public class ZhipuCompletionWebClient {
public void init(){
log.info("zhipu ai api web client init...");
this.webClient = WebClient.builder()
.defaultHeader(HttpHeaders.CONTENT_TYPE, "application/json")
.defaultHeader("content-type", "application/json")
.build();
}

public Flux<String> streamChatCompletion(JSONObject requestBody){
log.info("streamChatCompletion 参数:{}",requestBody);
String url = zhipuAiUtil.getCompletionModel();
String accessToken = zhipuAiUtil.getAccessToken();
return webClient.post()
.uri(url)
.header(HttpHeaders.CONTENT_TYPE,"application/json")
.uri(ZhipuChatApis.COMPLETION_TEXT)
.bodyValue(requestBody)
.header("Authorization",accessToken)
.bodyValue(requestBody.toJSONString())
.retrieve()
.bodyToFlux(String.class)
.onErrorResume(WebClientResponseException.class, ex -> {
ex.printStackTrace();
HttpStatusCode statusCode = ex.getStatusCode();
String res = ex.getResponseBodyAsString();
log.error("ZhipuAI API error: {} {}", statusCode, res);
return Mono.error(new RuntimeException(res));
return Flux.error(new RuntimeException(res));
});

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ public void streamChat(CustomChatMessage request, List<String> nearestList, List
body.put("messages",messages);
body.put("stream",true);
body.put("model",zhipuAiUtil.getCompletionModel());
body.put("request_id", UUID.fastUUID());
body.put("request_id", UUID.fastUUID().toString(true));
body.put("temperature",0.95);

ZhipuCompletionBizProcessor bizProcessor = ZhipuCompletionBizProcessor.builder()
Expand Down Expand Up @@ -111,7 +111,7 @@ public String blockCompletion(String content) {
JSONObject body = new JSONObject();
body.put("messages",messages);
body.put("model",zhipuAiUtil.getCompletionModel());
body.put("request_id", UUID.fastUUID());
body.put("request_id", UUID.fastUUID().toString(true));
body.put("stream",false);
body.put("temperature",0.95);

Expand Down
Loading

0 comments on commit f62ed50

Please sign in to comment.