@@ -21,13 +21,8 @@ std::shared_ptr<inferenceState> create_inference_state(llamaCPP *instance) {
2121
2222// --------------------------------------------
2323
24- std::string create_embedding_payload (const std::vector<float > &embedding,
24+ Json::Value create_embedding_payload (const std::vector<float > &embedding,
2525 int prompt_tokens) {
26- Json::Value root;
27-
28- root[" object" ] = " list" ;
29-
30- Json::Value dataArray (Json::arrayValue);
3126 Json::Value dataItem;
3227
3328 dataItem[" object" ] = " embedding" ;
@@ -39,20 +34,7 @@ std::string create_embedding_payload(const std::vector<float> &embedding,
3934 dataItem[" embedding" ] = embeddingArray;
4035 dataItem[" index" ] = 0 ;
4136
42- dataArray.append (dataItem);
43- root[" data" ] = dataArray;
44-
45- root[" model" ] = " _" ;
46-
47- Json::Value usage;
48- usage[" prompt_tokens" ] = prompt_tokens;
49- usage[" total_tokens" ] = prompt_tokens; // Assuming total tokens equals prompt
50- // tokens in this context
51- root[" usage" ] = usage;
52-
53- Json::StreamWriterBuilder writer;
54- writer[" indentation" ] = " " ; // Compact output
55- return Json::writeString (writer, root);
37+ return dataItem;
5638}
5739
5840std::string create_full_return_json (const std::string &id,
@@ -406,31 +388,42 @@ void llamaCPP::embedding(
406388 std::function<void (const HttpResponsePtr &)> &&callback) {
407389 const auto &jsonBody = req->getJsonObject ();
408390
409- json prompt;
410- if (jsonBody->isMember (" input" ) != 0 ) {
411- if ((*jsonBody)[" input" ].isString ()) {
412- prompt = (*jsonBody)[" input" ].asString ();
413- } else if ((*jsonBody)[" input" ].isArray ()) {
414- const auto &inputArray = (*jsonBody)[" input" ];
415- std::vector<std::string> inputStrings;
416- for (const auto &input : inputArray) {
417- if (input.isString ()) {
418- inputStrings.push_back (input.asString ());
391+ Json::Value responseData (Json::arrayValue);
392+
393+ if (jsonBody->isMember (" input" )) {
394+ const Json::Value &input = (*jsonBody)[" input" ];
395+ if (input.isString ()) {
396+ // Process the single string input
397+ const int task_id = llama.request_completion (
398+ {{" prompt" , input.asString ()}, {" n_predict" , 0 }}, false , true , -1 );
399+ task_result result = llama.next_result (task_id);
400+ std::vector<float > embedding_result = result.result_json [" embedding" ];
401+ responseData.append (create_embedding_payload (embedding_result, 0 ));
402+ } else if (input.isArray ()) {
403+ // Process each element in the array input
404+ for (const auto &elem : input) {
405+ if (elem.isString ()) {
406+ const int task_id = llama.request_completion (
407+ {{" prompt" , elem.asString ()}, {" n_predict" , 0 }}, false , true , -1 );
408+ task_result result = llama.next_result (task_id);
409+ std::vector<float > embedding_result = result.result_json [" embedding" ];
410+ responseData.append (create_embedding_payload (embedding_result, 0 ));
419411 }
420412 }
421- prompt = inputStrings;
422413 }
423- } else {
424- prompt = " " ;
425414 }
426415
427- const int task_id = llama.request_completion (
428- {{" prompt" , prompt}, {" n_predict" , 0 }}, false , true , -1 );
429- task_result result = llama.next_result (task_id);
430- std::vector<float > embedding_result = result.result_json [" embedding" ];
431416 auto resp = nitro_utils::nitroHttpResponse ();
432- std::string embedding_resp = create_embedding_payload (embedding_result, 0 );
433- resp->setBody (embedding_resp);
417+ Json::Value root;
418+ root[" data" ] = responseData;
419+ root[" model" ] = " _" ;
420+ root[" object" ] = " list" ;
421+ Json::Value usage;
422+ usage[" prompt_tokens" ] = 0 ;
423+ usage[" total_tokens" ] = 0 ;
424+ root[" usage" ] = usage;
425+
426+ resp->setBody (Json::writeString (Json::StreamWriterBuilder (), root));
434427 resp->setContentTypeString (" application/json" );
435428 callback (resp);
436429 return ;
0 commit comments