Skip to content

Commit

Permalink
fix stream && StreamCallback
Browse files Browse the repository at this point in the history
  • Loading branch information
Weaxs committed Nov 21, 2023
1 parent e603049 commit 4765d6d
Show file tree
Hide file tree
Showing 5 changed files with 161 additions and 78 deletions.
90 changes: 77 additions & 13 deletions binding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,22 @@ void sigint_handler(int signo) {
}
#endif

// stream for callback go function, copy from chatglm::TextStreamer
class TextBindStreamer : public chatglm::BaseStreamer {
public:
TextBindStreamer(chatglm::BaseTokenizer *tokenizer, void* draft_pipe)
: draft_pipe(draft_pipe), tokenizer_(tokenizer), is_prompt_(true), print_len_(0) {}
void put(const std::vector<int> &output_ids) override;
void end() override;

private:
void* draft_pipe;
chatglm::BaseTokenizer *tokenizer_;
bool is_prompt_;
std::vector<int> token_cache_;
int print_len_;
};

std::vector<std::string> create_vector(const char** strings, int count) {
auto vec = new std::vector<std::string>;
for (int i = 0; i < count; i++) {
Expand All @@ -55,16 +71,18 @@ int chat(void* pipe_pr, const char** history, int history_count, void* params_pt
return 0;
}

void* stream_chat(void* pipe_pr, const char** history, int history_count, void* params_ptr) {
int stream_chat(void* pipe_pr, const char** history, int history_count,void* params_ptr, char* result) {
std::vector<std::string> vectors = create_vector(history, history_count);
chatglm::Pipeline* pipe_p = (chatglm::Pipeline*) pipe_pr;
chatglm::GenerationConfig* params = (chatglm::GenerationConfig*) params_ptr;

chatglm::PerfStreamer* streamer = new chatglm::PerfStreamer;
pipe_p->chat(vectors, *params, streamer);
TextBindStreamer* text_stream = new TextBindStreamer(pipe_p->tokenizer.get(), pipe_pr);

std::string res = pipe_p->chat(vectors, *params, text_stream);
strcpy(result, res.c_str());

vectors.clear();
return streamer;
return 0;
}

int generate(void* pipe_pr, const char *prompt, void* params_ptr, char* result) {
Expand All @@ -77,20 +95,15 @@ int generate(void* pipe_pr, const char *prompt, void* params_ptr, char* result)
return 0;
}

void* stream_generate(void* pipe_pr, const char *prompt, void* params_ptr) {
int stream_generate(void* pipe_pr, const char *prompt, void* params_ptr, char* result) {
chatglm::Pipeline* pipe_p = (chatglm::Pipeline*) pipe_pr;
chatglm::GenerationConfig* params = (chatglm::GenerationConfig*) params_ptr;

chatglm::PerfStreamer* streamer = new chatglm::PerfStreamer;
pipe_p->generate(std::string(prompt), *params, streamer);
return streamer;
}

int stream_to_string(void* steamer_pr, char* result) {
chatglm::PerfStreamer* streamer_p = (chatglm::PerfStreamer*) steamer_pr;
TextBindStreamer* text_stream = new TextBindStreamer(pipe_p->tokenizer.get(), pipe_pr);

std::string res = streamer_p->to_string();
std::string res = pipe_p->generate(std::string(prompt), *params, text_stream);
strcpy(result, res.c_str());

return 0;
}

Expand Down Expand Up @@ -131,5 +144,56 @@ void chatglm_free_model(void* pipe_pr) {
delete pipe_p;
}

// copy from chatglm::TextStreamer
void TextBindStreamer::put(const std::vector<int> &output_ids) {
if (is_prompt_) {
// skip prompt
is_prompt_ = false;
return;
}

static const std::vector<char> puncts{',', '!', ':', ';', '?'};

token_cache_.insert(token_cache_.end(), output_ids.begin(), output_ids.end());
std::string text = tokenizer_->decode(token_cache_);
if (text.empty()) {
return;
}

std::string printable_text;
if (text.back() == '\n') {
// flush the cache after newline
printable_text = text.substr(print_len_);

token_cache_.clear();
print_len_ = 0;
} else if (std::find(puncts.begin(), puncts.end(), text.back()) != puncts.end()) {
// last symbol is a punctuation, hold on
} else if (text.size() >= 3 && text.compare(text.size() - 3, 3, "") == 0) {
// ends with an incomplete token, hold on
} else {
printable_text = text.substr(print_len_);
print_len_ = text.size();
}

// callback go function
if (!streamCallback(draft_pipe, (char*)printable_text.c_str())) {
return;
}
}

// copy from chatglm::TextStreamer
void TextBindStreamer::end() {
std::string text = tokenizer_->decode(token_cache_);
// callback go function
if (!streamCallback(draft_pipe, (char*)text.substr(print_len_).c_str())) {
return;
}
is_prompt_ = true;
token_cache_.clear();
print_len_ = 0;
}




10 changes: 5 additions & 5 deletions binding.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,17 @@ extern "C" {

#include <stdbool.h>

extern unsigned char tokenCallback(void *, char *);
extern bool streamCallback(void *, char *);

void* load_model(const char *name);

int chat(void* pipe_pr, const char** history, int history_count, void* params_ptr, char* result);

void* stream_chat(void* pipe_pr, const char** history, int history_count, void* params_ptr);
int stream_chat(void* pipe_pr, const char** history, int history_count, void* params_ptr, char* result);

int generate(void* pipe_pr, const char *prompt, void* params_ptr, char* result);

void* stream_generate(void* pipe_pr, const char *prompt, void* params_ptr);

int stream_to_string(void* steamer, char* result);
int stream_generate(void* pipe_pr, const char *prompt, void* params_ptr, char* result);

int get_embedding(void* pipe_pr, void* params_ptr, const char *prompt, int * result);

Expand All @@ -31,4 +29,6 @@ void chatglm_free_model(void* pipe_pr);
#ifdef __cplusplus
}



#endif
90 changes: 71 additions & 19 deletions chatglm.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,17 @@ import "C"
import (
"fmt"
"strings"
"sync"
"unsafe"
)

type Chatglm struct {
pipeline unsafe.Pointer
stream unsafe.Pointer
// default stream, of course you can customize stream by StreamCallback
stream strings.Builder
}

// New create llm struct
func New(model string) (*Chatglm, error) {
modelPath := C.CString(model)
defer C.free(unsafe.Pointer(modelPath))
Expand All @@ -33,6 +36,7 @@ func New(model string) (*Chatglm, error) {
return llm, nil
}

// Chat sync chat
func (llm *Chatglm) Chat(history []string, opts ...GenerationOption) (string, error) {
opt := NewGenerationOptions(opts...)
params := allocateParams(opt)
Expand Down Expand Up @@ -62,7 +66,7 @@ func (llm *Chatglm) Chat(history []string, opts ...GenerationOption) (string, er
return res, nil
}

func (llm *Chatglm) StreamChat(history []string, opts ...GenerationOption) error {
func (llm *Chatglm) StreamChat(history []string, opts ...GenerationOption) (string, error) {
opt := NewGenerationOptions(opts...)
params := allocateParams(opt)
defer freeParams(params)
Expand All @@ -76,9 +80,24 @@ func (llm *Chatglm) StreamChat(history []string, opts ...GenerationOption) error
pass = &reversePrompt[0]
}

streamer := C.stream_chat(llm.pipeline, pass, C.int(reverseCount), params)
llm.stream = streamer
return nil
if opt.StreamCallback != nil {
addStreamCallback(llm.pipeline, opt.StreamCallback)
} else {
addStreamCallback(llm.pipeline, defaultStreamCallback(llm))
}

if opt.MaxContextLength == 0 {
opt.MaxContextLength = 99999999
}
out := make([]byte, opt.MaxContextLength)
success := C.stream_chat(llm.pipeline, pass, C.int(reverseCount), params, (*C.char)(unsafe.Pointer(&out[0])))
if success != 0 {
return "", fmt.Errorf("model chat failed")
}
res := C.GoString((*C.char)(unsafe.Pointer(&out[0])))
res = strings.TrimPrefix(res, " ")
res = strings.TrimPrefix(res, "\n")
return res, nil
}

func (llm *Chatglm) Generate(prompt string, opts ...GenerationOption) (string, error) {
Expand All @@ -99,31 +118,25 @@ func (llm *Chatglm) Generate(prompt string, opts ...GenerationOption) (string, e
return res, nil
}

func (llm *Chatglm) StreamGenerate(prompt string, opts ...GenerationOption) error {
func (llm *Chatglm) StreamGenerate(prompt string, opts ...GenerationOption) (string, error) {
opt := NewGenerationOptions(opts...)
params := allocateParams(opt)
defer freeParams(params)

streamer := C.stream_generate(llm.pipeline, C.CString(prompt), params)
llm.stream = streamer
return nil
}

func (llm *Chatglm) GetStream(opts ...GenerationOption) (string, error) {
if llm.stream == nil {
return "", fmt.Errorf("stream is nil")
if opt.StreamCallback != nil {
addStreamCallback(llm.pipeline, opt.StreamCallback)
} else {
addStreamCallback(llm.pipeline, defaultStreamCallback(llm))
}

opt := NewGenerationOptions(opts...)
params := allocateParams(opt)
defer freeParams(params)
if opt.MaxContextLength == 0 {
opt.MaxContextLength = 99999999
}
out := make([]byte, opt.MaxContextLength)
result := C.stream_to_string(llm.stream, (*C.char)(unsafe.Pointer(&out[0])))
result := C.generate(llm.pipeline, C.CString(prompt), params, (*C.char)(unsafe.Pointer(&out[0])))

if result != 0 {
return "", fmt.Errorf("get stream failed")
return "", fmt.Errorf("model generate failed")
}
res := C.GoString((*C.char)(unsafe.Pointer(&out[0])))
return res, nil
Expand Down Expand Up @@ -159,3 +172,42 @@ func allocateParams(opt *GenerationOptions) unsafe.Pointer {
func freeParams(params unsafe.Pointer) {
C.chatglm_free_params(params)
}

var (
m sync.RWMutex
callbacks = map[unsafe.Pointer]func(string) bool{}
)

//export streamCallback
func streamCallback(pipeline unsafe.Pointer, printableText *C.char) C.bool {
m.RLock()
defer m.RUnlock()

if callback, ok := callbacks[pipeline]; ok {
return C.bool(callback(C.GoString(printableText)))
}

return C.bool(true)
}

func addStreamCallback(pipeline unsafe.Pointer, callback func(string) bool) {
m.Lock()
defer m.Unlock()

if callback == nil {
delete(callbacks, pipeline)
} else {
callbacks[pipeline] = callback
}
}

// return default stream callback
func defaultStreamCallback(llm *Chatglm) func(string) bool {
return func(text string) bool {
_, err := llm.stream.WriteString(text)
if err != nil {
return false
}
return true
}
}
41 changes: 0 additions & 41 deletions chatglm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,19 +36,6 @@ func TestGenerate(t *testing.T) {
assert.Contains(t, ret, "4")
}

func TestStreamGenerate(t *testing.T) {
err := chatglm.StreamGenerate("2+2等于多少")
if err != nil {
assert.Fail(t, "stream generate failed.")
}

ret, err := chatglm.GetStream()
if err != nil {
assert.Fail(t, "get stream failed.")
}
assert.Contains(t, ret, "4")
}

func TestChat(t *testing.T) {
history := []string{"2+2等于多少"}
ret, err := chatglm.Chat(history)
Expand All @@ -69,34 +56,6 @@ func TestChat(t *testing.T) {
assert.Len(t, history, 4)
}

func TestStreamChat(t *testing.T) {
history := []string{"2+2等于多少"}
err := chatglm.StreamChat(history)
if err != nil {
assert.Fail(t, "first chat failed")
}
ret, err := chatglm.GetStream()
if err != nil {
assert.Fail(t, "first get stream failed.")
}
assert.Contains(t, ret, "4")

history = append(history, ret)
history = append(history, "再加4等于多少")
err = chatglm.StreamChat(history)
if err != nil {
assert.Fail(t, "second chat failed")
}
ret, err = chatglm.GetStream()
if err != nil {
assert.Fail(t, "first get stream failed.")
}
assert.Contains(t, ret, "8")

history = append(history, ret)
assert.Len(t, history, 4)
}

func TestEmbedding(t *testing.T) {
maxLength := 1024
embeddings, err := chatglm.Embeddings("你好", SetMaxLength(1024))
Expand Down
8 changes: 8 additions & 0 deletions options.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ type GenerationOptions struct {
Temperature float32
RepetitionPenalty float32
NumThreads int
StreamCallback func(string) bool
}

type GenerationOption func(g *GenerationOptions)
Expand All @@ -22,6 +23,7 @@ var DefaultGenerationOptions *GenerationOptions = &GenerationOptions{
Temperature: 0.95,
RepetitionPenalty: 1.0,
NumThreads: 0,
StreamCallback: nil,
}

func NewGenerationOptions(opts ...GenerationOption) *GenerationOptions {
Expand Down Expand Up @@ -79,3 +81,9 @@ func SetNumThreads(numThreads int) GenerationOption {
g.NumThreads = numThreads
}
}

func SetStreamCallback(callback func(string) bool) GenerationOption {
return func(g *GenerationOptions) {
g.StreamCallback = callback
}
}

0 comments on commit 4765d6d

Please sign in to comment.