Skip to content

Commit

Permalink
support chatglm.cpp v0.3.0
Browse files Browse the repository at this point in the history
  • Loading branch information
Weaxs committed Nov 28, 2023
1 parent ec1e093 commit fe2d49a
Show file tree
Hide file tree
Showing 10 changed files with 394 additions and 67 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ prepare:

# build chatglm.cpp
build/chatglm.cpp: prepare
cd build && CC="$(CC)" CXX="$(CXX)" cmake ../chatglm.cpp $(CMAKE_ARGS) && VERBOSE=1 cmake --build . -j --config Release
cd build && CC="$(CC)" CXX="$(CXX)" cmake $(CMAKE_ARGS) ../chatglm.cpp && VERBOSE=1 cmake --build . -j --config Release

# chatglm.dir
chatglm.dir: build/chatglm.cpp
Expand Down
114 changes: 95 additions & 19 deletions binding.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
#include "chatglm.h"

#include "binding.h"
#include <string>
#include <vector>
Expand Down Expand Up @@ -47,40 +46,88 @@ class TextBindStreamer : public chatglm::BaseStreamer {
int print_len_;
};

std::vector<std::string> create_vector(const char** strings, int count) {
auto vec = new std::vector<std::string>;
std::vector<chatglm::ChatMessage> create_chat_message_vector(void** history, int count) {
std::vector<chatglm::ChatMessage>* vec = new std::vector<chatglm::ChatMessage>;
for (int i = 0; i < count; i++) {
vec->push_back(std::string(strings[i]));
chatglm::ChatMessage* msg = (chatglm::ChatMessage*) history[i];
vec->push_back(*msg);
}

return *vec;
}

std::vector<chatglm::ToolCallMessage> create_tool_call_vector(void** tool_calls, int count) {
std::vector<chatglm::ToolCallMessage>* vec = new std::vector<chatglm::ToolCallMessage>;
for (int i = 0; i < count; i++) {
chatglm::ToolCallMessage* msg = (chatglm::ToolCallMessage*) tool_calls[i];
vec->push_back(*msg);
}

return *vec;
}

std::string decode_with_special_tokens(chatglm::ChatGLM3Tokenizer* tokenizer, const std::vector<int> &ids) {
std::vector<std::string> pieces;
for (int id : ids) {
auto pos = tokenizer->index_special_tokens.find(id);
if (pos != tokenizer->index_special_tokens.end()) {
// special tokens
pieces.emplace_back(pos->second);
} else {
// normal tokens
pieces.emplace_back(tokenizer->sp.IdToPiece(id));
}
}

std::string text = tokenizer->sp.DecodePieces(pieces);
return text;
}

void* load_model(const char *name) {
return new chatglm::Pipeline(name);
}

int chat(void* pipe_pr, const char** history, int history_count, void* params_ptr, char* result) {
std::vector<std::string> vectors = create_vector(history, history_count);
int chat(void* pipe_pr, void** history, int history_count, void* params_ptr, char* result) {
std::vector<chatglm::ChatMessage> vectors = create_chat_message_vector(history, history_count);
chatglm::Pipeline* pipe_p = (chatglm::Pipeline*) pipe_pr;
chatglm::GenerationConfig* params = (chatglm::GenerationConfig*) params_ptr;

std::string res = pipe_p->chat(vectors, *params);
strcpy(result, res.c_str());
chatglm::ChatMessage res = pipe_p->chat(vectors, *params);

vectors.clear();
std::string out = res.content;
// ChatGLM3Tokenizer::decode_message change origin output, convert it to ChatMessage
// So we need to convert it back
if (pipe_p->model->config.model_type == chatglm::ModelType::CHATGLM3) {
std::vector<chatglm::ChatMessage>* resultVec = new std::vector<chatglm::ChatMessage>{res};
chatglm::ChatGLM3Tokenizer* tokenizer = dynamic_cast<chatglm::ChatGLM3Tokenizer*>(pipe_p->tokenizer.get());
std::vector<int> input_ids = tokenizer->encode_messages(*resultVec, params->max_context_length);
out = decode_with_special_tokens(tokenizer, input_ids);
}
strcpy(result, out.c_str());

vectors.clear();
return 0;
}

int stream_chat(void* pipe_pr, const char** history, int history_count,void* params_ptr, char* result) {
std::vector<std::string> vectors = create_vector(history, history_count);
int stream_chat(void* pipe_pr, void** history, int history_count,void* params_ptr, char* result) {
std::vector<chatglm::ChatMessage> vectors = create_chat_message_vector(history, history_count);
chatglm::Pipeline* pipe_p = (chatglm::Pipeline*) pipe_pr;
chatglm::GenerationConfig* params = (chatglm::GenerationConfig*) params_ptr;

TextBindStreamer* text_stream = new TextBindStreamer(pipe_p->tokenizer.get(), pipe_pr);

std::string res = pipe_p->chat(vectors, *params, text_stream);
strcpy(result, res.c_str());
chatglm::ChatMessage res = pipe_p->chat(vectors, *params, text_stream);

std::string out = res.content;
// ChatGLM3Tokenizer::decode_message change origin output, convert it to ChatMessage
// So we need to convert it back
if (pipe_p->model->config.model_type == chatglm::ModelType::CHATGLM3) {
std::vector<chatglm::ChatMessage>* resultVec = new std::vector<chatglm::ChatMessage>{res};
chatglm::ChatGLM3Tokenizer* tokenizer = dynamic_cast<chatglm::ChatGLM3Tokenizer*>(pipe_p->tokenizer.get());
std::vector<int> input_ids = tokenizer->encode_messages(*resultVec, params->max_context_length);
out = decode_with_special_tokens(tokenizer, input_ids);
}
strcpy(result, out.c_str());

vectors.clear();
return 0;
Expand Down Expand Up @@ -108,11 +155,10 @@ int stream_generate(void* pipe_pr, const char *prompt, void* params_ptr, char* r
return 0;
}

int get_embedding(void* pipe_pr, void* params_ptr, const char *prompt, int * result) {
int get_embedding(void* pipe_pr, const char *prompt, int max_length, int * result) {
chatglm::Pipeline* pipe_p = (chatglm::Pipeline*) pipe_pr;
chatglm::GenerationConfig* params = (chatglm::GenerationConfig*) params_ptr;

std::vector<int> embeddings = pipe_p->tokenizer->encode(prompt, params->max_length);
std::vector<int> embeddings = pipe_p->tokenizer->encode(prompt, max_length);

for (size_t i = 0; i < embeddings.size(); i++) {
result[i]=embeddings[i];
Expand All @@ -122,7 +168,7 @@ int get_embedding(void* pipe_pr, void* params_ptr, const char *prompt, int * res
}

void* allocate_params(int max_length, int max_context_length, bool do_sample, int top_k,
float top_p, float temperature, float repetition_penalty, int num_threads) {
float top_p, float temperature, float repetition_penalty, int num_threads) {
chatglm::GenerationConfig* gen_config = new chatglm::GenerationConfig;
gen_config->max_length = max_length;
gen_config->max_context_length = max_context_length;
Expand All @@ -145,6 +191,36 @@ void free_model(void* pipe_pr) {
delete pipe_p;
}

void* create_chat_message(const char* role, const char *content, void** tool_calls, int tool_calls_count) {
std::vector<chatglm::ToolCallMessage> vector = create_tool_call_vector(tool_calls, tool_calls_count);
return new chatglm::ChatMessage(role, content, vector);
}

void* create_tool_call(const char* type, void* codeOrFunc) {
if (type == chatglm::ToolCallMessage::TYPE_FUNCTION) {
chatglm::FunctionMessage* function_p = (chatglm::FunctionMessage*) codeOrFunc;
return new chatglm::ToolCallMessage(*function_p);
} else if (type == chatglm::ToolCallMessage::TYPE_CODE) {
chatglm::CodeMessage* code_p = (chatglm::CodeMessage*) codeOrFunc;
return new chatglm::ToolCallMessage(*code_p);
}
return nullptr;
}

void* create_function(const char* name, const char *arguments) {
return new chatglm::FunctionMessage(name, arguments);
}


void* create_code(const char* input) {
return new chatglm::CodeMessage(input);;
}

char* get_model_type(void* pipe_pr) {
chatglm::Pipeline* pipe_p = (chatglm::Pipeline*) pipe_pr;
return strdup(to_string(pipe_p->model->config.model_type).data());
}

// copy from chatglm::TextStreamer
void TextBindStreamer::put(const std::vector<int> &output_ids) {
if (is_prompt_) {
Expand Down Expand Up @@ -178,7 +254,7 @@ void TextBindStreamer::put(const std::vector<int> &output_ids) {
}

// callback go function
if (!streamCallback(draft_pipe, (char*)printable_text.c_str())) {
if (!streamCallback(draft_pipe, printable_text.data())) {
return;
}
}
Expand All @@ -187,7 +263,7 @@ void TextBindStreamer::put(const std::vector<int> &output_ids) {
void TextBindStreamer::end() {
std::string text = tokenizer_->decode(token_cache_);
// callback go function
if (!streamCallback(draft_pipe, (char*)text.substr(print_len_).c_str())) {
if (!streamCallback(draft_pipe, text.substr(print_len_).data())) {
return;
}
is_prompt_ = true;
Expand Down
16 changes: 13 additions & 3 deletions binding.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,15 @@ extern bool streamCallback(void *, char *);

void* load_model(const char *name);

int chat(void* pipe_pr, const char** history, int history_count, void* params_ptr, char* result);
int chat(void* pipe_pr, void** history, int history_count, void* params_ptr, char* result);

int stream_chat(void* pipe_pr, const char** history, int history_count, void* params_ptr, char* result);
int stream_chat(void* pipe_pr, void** history, int history_count, void* params_ptr, char* result);

int generate(void* pipe_pr, const char *prompt, void* params_ptr, char* result);

int stream_generate(void* pipe_pr, const char *prompt, void* params_ptr, char* result);

int get_embedding(void* pipe_pr, void* params_ptr, const char *prompt, int * result);
int get_embedding(void* pipe_pr, const char *prompt, int max_length, int * result);

void* allocate_params(int max_length, int max_context_length, bool do_sample, int top_k,
float top_p, float temperature, float repetition_penalty, int num_threads);
Expand All @@ -26,6 +26,16 @@ void free_params(void* params_ptr);

void free_model(void* pipe_pr);

void* create_chat_message(const char* role, const char *content, void** tool_calls, int tool_calls_count);

void* create_tool_call(const char* type, void* codeOrFunc);

void* create_function(const char* name, const char *arguments);

void* create_code(const char* code);

char* get_model_type(void* pipe_pr);

#ifdef __cplusplus
}

Expand Down
Loading

0 comments on commit fe2d49a

Please sign in to comment.