Skip to content

Commit 1487d32

Browse files
authored
Add streaming for omnivlm (#39)
* omni vlm add streaming * omni vlm add streaming
1 parent 5962b50 commit 1487d32

File tree

2 files changed

+136
-3
lines changed

2 files changed

+136
-3
lines changed

examples/omni-vlm/omni-vlm-wrapper.cpp

Lines changed: 126 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,10 @@
1515
#include <vector>
1616
#include <string>
1717
#include <iostream>
18+
#include <memory>
1819

1920
#include "omni-vlm-wrapper.h"
2021

21-
2222
struct omnivlm_context {
2323
struct clip_ctx * ctx_clip = NULL;
2424
struct llama_context * ctx_llama = NULL;
@@ -30,6 +30,53 @@ void* internal_chars = nullptr;
3030
static struct common_params params;
3131
static struct llama_model* model;
3232
static struct omnivlm_context* ctx_omnivlm;
33+
static std::unique_ptr<struct omni_streaming_sample> g_oss = nullptr;
34+
35+
static bool eval_id(struct llama_context * ctx_llama, int id, int * n_past);
36+
static void omnivlm_free(struct omnivlm_context * ctx_omnivlm);
37+
38+
struct omni_streaming_sample {
39+
struct common_sampler * ctx_sampling_;
40+
std::string image_;
41+
std::string ret_str_;
42+
int32_t n_past_;
43+
int32_t dec_cnt_;
44+
45+
omni_streaming_sample() = delete;
46+
omni_streaming_sample(const std::string& image)
47+
:image_(image) {
48+
n_past_ = 0;
49+
dec_cnt_ = 0;
50+
params.sparams.top_k = 1;
51+
params.sparams.top_p = 1.0f;
52+
ctx_sampling_ = common_sampler_init(model, params.sparams);
53+
}
54+
55+
int32_t sample() {
56+
const llama_token id = common_sampler_sample(ctx_sampling_, ctx_omnivlm->ctx_llama, -1);
57+
common_sampler_accept(ctx_sampling_, id, true);
58+
if (llama_token_is_eog(llama_get_model(ctx_omnivlm->ctx_llama), id)) {
59+
ret_str_ = "</s>";
60+
} else {
61+
ret_str_ = common_token_to_piece(ctx_omnivlm->ctx_llama, id);
62+
}
63+
eval_id(ctx_omnivlm->ctx_llama, id, &n_past_);
64+
65+
++dec_cnt_;
66+
return id;
67+
}
68+
69+
~omni_streaming_sample() {
70+
common_sampler_free(ctx_sampling_);
71+
if(ctx_omnivlm != nullptr) {
72+
ctx_omnivlm->model = nullptr;
73+
omnivlm_free(ctx_omnivlm);
74+
free(ctx_omnivlm);
75+
ctx_omnivlm = nullptr;
76+
}
77+
}
78+
};
79+
3380

3481
static struct omni_image_embed * load_image(omnivlm_context * ctx_omnivlm, common_params * params, const std::string & fname) {
3582

@@ -286,3 +333,81 @@ void omnivlm_free() {
286333
}
287334
llama_free_model(model);
288335
}
336+
337+
338+
struct omni_streaming_sample* omnivlm_inference_streaming(const char *prompt, const char *imag_path) {
339+
if (g_oss) {
340+
g_oss.reset();
341+
}
342+
g_oss = std::make_unique<omni_streaming_sample>(std::string(imag_path));
343+
344+
ctx_omnivlm = omnivlm_init_context(&params, model);
345+
346+
params.prompt = prompt;
347+
348+
if (params.omni_vlm_version == "vlm-81-ocr") {
349+
params.prompt = "<|im_start|>system\nYou are Nano-Omni-VLM, created by Nexa AI. You are a helpful assistant.<|im_end|>\n<|im_start|>user\n <|ocr_start|><|vision_start|><|image_pad|><|vision_end|><|ocr_end|><|im_end|>";
350+
} else if (params.omni_vlm_version == "vlm-81-instruct" || params.omni_vlm_version == "nano-vlm-instruct") {
351+
params.prompt = "<|im_start|>system\nYou are Nano-Omni-VLM, created by Nexa AI. You are a helpful assistant.<|im_end|>\n<|im_start|>user\n\n<|vision_start|><|image_pad|><|vision_end|>" + params.prompt + "<|im_end|>";
352+
} else {
353+
LOG_ERR("%s : error: you set wrong vlm version info:'%s'.\n", __func__, params.omni_vlm_version.c_str());
354+
throw std::runtime_error("You set wrong vlm_version info strings.");
355+
}
356+
357+
return g_oss.get();
358+
}
359+
360+
int32_t sample(omni_streaming_sample* oss) {
361+
const int max_tgt_len = params.n_predict < 0 ? 256 : params.n_predict;
362+
int32_t ret_id;
363+
if(oss->n_past_ == 0) {
364+
auto * image_embed = load_image(ctx_omnivlm, &params, oss->image_);
365+
if (!image_embed) {
366+
LOG_ERR("%s: failed to load image %s. Terminating\n\n", __func__, oss->image_.c_str());
367+
throw std::runtime_error("failed to load image " + oss->image_);
368+
}
369+
370+
size_t image_pos = params.prompt.find("<|image_pad|>");
371+
std::string system_prompt, user_prompt;
372+
373+
system_prompt = params.prompt.substr(0, image_pos);
374+
user_prompt = params.prompt.substr(image_pos + std::string("<|image_pad|>").length());
375+
if (params.verbose_prompt) {
376+
auto tmp = ::common_tokenize(ctx_omnivlm->ctx_llama, system_prompt, true, true);
377+
for (int i = 0; i < (int) tmp.size(); i++) {
378+
LOG_ERR("%6d -> '%s'\n", tmp[i], common_token_to_piece(ctx_omnivlm->ctx_llama, tmp[i]).c_str());
379+
}
380+
}
381+
if (params.verbose_prompt) {
382+
auto tmp = ::common_tokenize(ctx_omnivlm->ctx_llama, user_prompt, true, true);
383+
for (int i = 0; i < (int) tmp.size(); i++) {
384+
LOG_ERR("%6d -> '%s'\n", tmp[i], common_token_to_piece(ctx_omnivlm->ctx_llama, tmp[i]).c_str());
385+
}
386+
}
387+
388+
eval_string(ctx_omnivlm->ctx_llama, system_prompt.c_str(), params.n_batch, &(oss->n_past_), true);
389+
omnivlm_eval_image_embed(ctx_omnivlm->ctx_llama, image_embed, params.n_batch, &(oss->n_past_));
390+
eval_string(ctx_omnivlm->ctx_llama, user_prompt.c_str(), params.n_batch, &(oss->n_past_), false);
391+
392+
omnivlm_image_embed_free(image_embed);
393+
394+
ret_id = oss->sample();
395+
if (oss->ret_str_ == "<|im_end|>" || oss->ret_str_ == "</s>" ) {
396+
ret_id = -1;
397+
}
398+
} else {
399+
if(oss->dec_cnt_ == max_tgt_len) {
400+
ret_id = -2;
401+
} else {
402+
ret_id = oss->sample();
403+
if (oss->ret_str_ == "<|im_end|>" || oss->ret_str_ == "</s>" ) {
404+
ret_id = -1;
405+
}
406+
}
407+
}
408+
return ret_id;
409+
}
410+
411+
const char* get_str(omni_streaming_sample* oss) {
412+
return oss->ret_str_.c_str();
413+
}

examples/omni-vlm/omni-vlm-wrapper.h

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
21
#ifndef OMNIVLMWRAPPER_H
32
#define OMNIVLMWRAPPER_H
3+
#include <stdint.h>
44

55
#ifdef LLAMA_SHARED
66
# if defined(_WIN32) && !defined(__MINGW32__)
@@ -20,14 +20,22 @@
2020
extern "C" {
2121
#endif
2222

23+
struct omni_streaming_sample;
24+
2325
OMNIVLM_API void omnivlm_init(const char* llm_model_path, const char* projector_model_path, const char* omni_vlm_version);
2426

2527
OMNIVLM_API const char* omnivlm_inference(const char* prompt, const char* imag_path);
2628

29+
OMNIVLM_API struct omni_streaming_sample* omnivlm_inference_streaming(const char* prompt, const char* imag_path);
30+
31+
OMNIVLM_API int32_t sample(struct omni_streaming_sample *);
32+
33+
OMNIVLM_API const char* get_str(struct omni_streaming_sample *);
34+
2735
OMNIVLM_API void omnivlm_free();
2836

2937
#ifdef __cplusplus
3038
}
3139
#endif
3240

33-
#endif
41+
#endif

0 commit comments

Comments
 (0)