1515#include < vector>
1616#include < string>
1717#include < iostream>
18+ #include < memory>
1819
1920#include " omni-vlm-wrapper.h"
2021
21-
2222struct omnivlm_context {
2323 struct clip_ctx * ctx_clip = NULL ;
2424 struct llama_context * ctx_llama = NULL ;
@@ -30,6 +30,53 @@ void* internal_chars = nullptr;
3030static struct common_params params;
3131static struct llama_model * model;
3232static struct omnivlm_context * ctx_omnivlm;
33+ static std::unique_ptr<struct omni_streaming_sample > g_oss = nullptr ;
34+
35+ static bool eval_id (struct llama_context * ctx_llama, int id, int * n_past);
36+ static void omnivlm_free (struct omnivlm_context * ctx_omnivlm);
37+
38+ struct omni_streaming_sample {
39+ struct common_sampler * ctx_sampling_;
40+ std::string image_;
41+ std::string ret_str_;
42+ int32_t n_past_;
43+ int32_t dec_cnt_;
44+
45+ omni_streaming_sample () = delete ;
46+ omni_streaming_sample (const std::string& image)
47+ :image_(image) {
48+ n_past_ = 0 ;
49+ dec_cnt_ = 0 ;
50+ params.sparams .top_k = 1 ;
51+ params.sparams .top_p = 1 .0f ;
52+ ctx_sampling_ = common_sampler_init (model, params.sparams );
53+ }
54+
55+ int32_t sample () {
56+ const llama_token id = common_sampler_sample (ctx_sampling_, ctx_omnivlm->ctx_llama , -1 );
57+ common_sampler_accept (ctx_sampling_, id, true );
58+ if (llama_token_is_eog (llama_get_model (ctx_omnivlm->ctx_llama ), id)) {
59+ ret_str_ = " </s>" ;
60+ } else {
61+ ret_str_ = common_token_to_piece (ctx_omnivlm->ctx_llama , id);
62+ }
63+ eval_id (ctx_omnivlm->ctx_llama , id, &n_past_);
64+
65+ ++dec_cnt_;
66+ return id;
67+ }
68+
69+ ~omni_streaming_sample () {
70+ common_sampler_free (ctx_sampling_);
71+ if (ctx_omnivlm != nullptr ) {
72+ ctx_omnivlm->model = nullptr ;
73+ omnivlm_free (ctx_omnivlm);
74+ free (ctx_omnivlm);
75+ ctx_omnivlm = nullptr ;
76+ }
77+ }
78+ };
79+
3380
3481static struct omni_image_embed * load_image (omnivlm_context * ctx_omnivlm, common_params * params, const std::string & fname) {
3582
@@ -286,3 +333,81 @@ void omnivlm_free() {
286333 }
287334 llama_free_model (model);
288335}
336+
337+
338+ struct omni_streaming_sample * omnivlm_inference_streaming (const char *prompt, const char *imag_path) {
339+ if (g_oss) {
340+ g_oss.reset ();
341+ }
342+ g_oss = std::make_unique<omni_streaming_sample>(std::string (imag_path));
343+
344+ ctx_omnivlm = omnivlm_init_context (¶ms, model);
345+
346+ params.prompt = prompt;
347+
348+ if (params.omni_vlm_version == " vlm-81-ocr" ) {
349+ params.prompt = " <|im_start|>system\n You are Nano-Omni-VLM, created by Nexa AI. You are a helpful assistant.<|im_end|>\n <|im_start|>user\n <|ocr_start|><|vision_start|><|image_pad|><|vision_end|><|ocr_end|><|im_end|>" ;
350+ } else if (params.omni_vlm_version == " vlm-81-instruct" || params.omni_vlm_version == " nano-vlm-instruct" ) {
351+ params.prompt = " <|im_start|>system\n You are Nano-Omni-VLM, created by Nexa AI. You are a helpful assistant.<|im_end|>\n <|im_start|>user\n\n <|vision_start|><|image_pad|><|vision_end|>" + params.prompt + " <|im_end|>" ;
352+ } else {
353+ LOG_ERR (" %s : error: you set wrong vlm version info:'%s'.\n " , __func__, params.omni_vlm_version .c_str ());
354+ throw std::runtime_error (" You set wrong vlm_version info strings." );
355+ }
356+
357+ return g_oss.get ();
358+ }
359+
360+ int32_t sample (omni_streaming_sample* oss) {
361+ const int max_tgt_len = params.n_predict < 0 ? 256 : params.n_predict ;
362+ int32_t ret_id;
363+ if (oss->n_past_ == 0 ) {
364+ auto * image_embed = load_image (ctx_omnivlm, ¶ms, oss->image_ );
365+ if (!image_embed) {
366+ LOG_ERR (" %s: failed to load image %s. Terminating\n\n " , __func__, oss->image_ .c_str ());
367+ throw std::runtime_error (" failed to load image " + oss->image_ );
368+ }
369+
370+ size_t image_pos = params.prompt .find (" <|image_pad|>" );
371+ std::string system_prompt, user_prompt;
372+
373+ system_prompt = params.prompt .substr (0 , image_pos);
374+ user_prompt = params.prompt .substr (image_pos + std::string (" <|image_pad|>" ).length ());
375+ if (params.verbose_prompt ) {
376+ auto tmp = ::common_tokenize (ctx_omnivlm->ctx_llama , system_prompt, true , true );
377+ for (int i = 0 ; i < (int ) tmp.size (); i++) {
378+ LOG_ERR (" %6d -> '%s'\n " , tmp[i], common_token_to_piece (ctx_omnivlm->ctx_llama , tmp[i]).c_str ());
379+ }
380+ }
381+ if (params.verbose_prompt ) {
382+ auto tmp = ::common_tokenize (ctx_omnivlm->ctx_llama , user_prompt, true , true );
383+ for (int i = 0 ; i < (int ) tmp.size (); i++) {
384+ LOG_ERR (" %6d -> '%s'\n " , tmp[i], common_token_to_piece (ctx_omnivlm->ctx_llama , tmp[i]).c_str ());
385+ }
386+ }
387+
388+ eval_string (ctx_omnivlm->ctx_llama , system_prompt.c_str (), params.n_batch , &(oss->n_past_ ), true );
389+ omnivlm_eval_image_embed (ctx_omnivlm->ctx_llama , image_embed, params.n_batch , &(oss->n_past_ ));
390+ eval_string (ctx_omnivlm->ctx_llama , user_prompt.c_str (), params.n_batch , &(oss->n_past_ ), false );
391+
392+ omnivlm_image_embed_free (image_embed);
393+
394+ ret_id = oss->sample ();
395+ if (oss->ret_str_ == " <|im_end|>" || oss->ret_str_ == " </s>" ) {
396+ ret_id = -1 ;
397+ }
398+ } else {
399+ if (oss->dec_cnt_ == max_tgt_len) {
400+ ret_id = -2 ;
401+ } else {
402+ ret_id = oss->sample ();
403+ if (oss->ret_str_ == " <|im_end|>" || oss->ret_str_ == " </s>" ) {
404+ ret_id = -1 ;
405+ }
406+ }
407+ }
408+ return ret_id;
409+ }
410+
411+ const char * get_str (omni_streaming_sample* oss) {
412+ return oss->ret_str_ .c_str ();
413+ }
0 commit comments