Skip to content

llama : expose API to retrieve devices associated with the model. #12073

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions include/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -499,6 +499,8 @@ extern "C" {
LLAMA_API llama_memory_t llama_get_memory (const struct llama_context * ctx);
LLAMA_API enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx); // TODO: rename to llama_get_pooling_type

LLAMA_API size_t llama_n_backends(const struct llama_context * ctx);
LLAMA_API size_t llama_get_backends(const struct llama_context * ctx, ggml_backend_t * out_buf, size_t out_len);
DEPRECATED(LLAMA_API struct llama_kv_cache * llama_get_kv_self(struct llama_context * ctx), "use llama_get_memory instead");

LLAMA_API const struct llama_vocab * llama_model_get_vocab(const struct llama_model * model);
Expand All @@ -510,6 +512,7 @@ extern "C" {
LLAMA_API int32_t llama_model_n_head (const struct llama_model * model);
LLAMA_API int32_t llama_model_n_head_kv (const struct llama_model * model);
LLAMA_API int32_t llama_model_n_swa (const struct llama_model * model);
LLAMA_API const ggml_backend_dev_t * llama_model_get_devices (const struct llama_model * model, size_t * out_len);

// Get the model's RoPE frequency scaling factor
LLAMA_API float llama_model_rope_freq_scale_train(const struct llama_model * model);
Expand Down
26 changes: 26 additions & 0 deletions src/llama-context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "llama-mmap.h"
#include "llama-model.h"

#include <algorithm>
#include <cinttypes>
#include <cstring>
#include <limits>
Expand Down Expand Up @@ -418,6 +419,23 @@ uint32_t llama_context::n_threads_batch() const {
return cparams.n_threads_batch;
}

size_t llama_context::n_backends() const {
return backends.size();
}

size_t llama_context::copy_backends_list(ggml_backend_t* out, size_t out_len) const {
size_t copy_len;
if (out_len > backends.size()) {
copy_len = backends.size();
} else {
copy_len = out_len;
}
std::transform(backends.begin(), backends.begin() + copy_len, out, [](const ggml_backend_ptr& ptr) {
return ptr.get();
});
return copy_len;
}

llama_memory_t llama_context::get_memory() const {
return memory.get();
}
Expand Down Expand Up @@ -2428,6 +2446,14 @@ llama_memory_t llama_get_memory(const struct llama_context * ctx) {
return ctx->get_memory();
}

size_t llama_n_backends(const struct llama_context * ctx) {
return ctx->n_backends();
}

size_t llama_get_backends(const struct llama_context * ctx, ggml_backend_t * out, size_t out_len) {
return ctx->copy_backends_list(out, out_len);
}

void llama_memory_clear(llama_memory_t mem, bool data) {
if (!mem) {
return;
Expand Down
3 changes: 3 additions & 0 deletions src/llama-context.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ struct llama_context {
uint32_t n_threads() const;
uint32_t n_threads_batch() const;

size_t n_backends() const;
size_t copy_backends_list(ggml_backend_t* out, size_t out_len) const;

llama_memory_t get_memory() const;

// return true of the KV cache was updated
Expand Down
5 changes: 5 additions & 0 deletions src/llama-model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13628,6 +13628,11 @@ const char * llama_model_cls_label(const struct llama_model * model, uint32_t i)
return nullptr;
}

const ggml_backend_dev_t * llama_model_get_devices(const struct llama_model * model, size_t * out_len) {
*out_len = model->devices.size();
return model->devices.data();
}

// deprecated
int32_t llama_n_ctx_train(const llama_model * model) {
return llama_model_n_ctx_train(model);
Expand Down
Loading