Skip to content

Commit f47bc43

Browse files
committed
llama : expose API to retrieve devices used by model/session.
It's useful from the library to be able to do things like list the features being used by the session.
1 parent 06c2b15 commit f47bc43

File tree

3 files changed

+20
-0
lines changed

3 files changed

+20
-0
lines changed

include/llama.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -470,6 +470,8 @@ extern "C" {
470470

471471
LLAMA_API const struct llama_model * llama_get_model (const struct llama_context * ctx);
472472
LLAMA_API enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx);
473+
LLAMA_API size_t llama_n_backends(const struct llama_context * ctx);
474+
LLAMA_API size_t llama_get_backends(const struct llama_context * ctx, ggml_backend_t * out_buf, size_t out_len);
473475

474476
LLAMA_API const struct llama_vocab * llama_model_get_vocab(const struct llama_model * model);
475477
LLAMA_API enum llama_rope_type llama_model_rope_type(const struct llama_model * model);
@@ -479,6 +481,7 @@ extern "C" {
479481
LLAMA_API int32_t llama_model_n_layer (const struct llama_model * model);
480482
LLAMA_API int32_t llama_model_n_head (const struct llama_model * model);
481483
LLAMA_API int32_t llama_model_n_head_kv (const struct llama_model * model);
484+
LLAMA_API const ggml_backend_dev_t * llama_model_get_devices (const struct llama_model * model, size_t * out_len);
482485

483486
// Get the model's RoPE frequency scaling factor
484487
LLAMA_API float llama_model_rope_freq_scale_train(const struct llama_model * model);

src/llama-context.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -610,6 +610,18 @@ enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx) {
610610
return ctx->cparams.pooling_type;
611611
}
612612

613+
size_t llama_n_backends(const struct llama_context * ctx) {
614+
return ctx->backends.size();
615+
}
616+
617+
size_t llama_get_backends(const struct llama_context * ctx, ggml_backend_t * out, size_t out_len) {
618+
size_t return_len = std::min(ctx->backends.size(), out_len);
619+
for (size_t i = 0; i < return_len; i++) {
620+
out[i] = ctx->backends[i].get();
621+
}
622+
return return_len;
623+
}
624+
613625
void llama_attach_threadpool(
614626
struct llama_context * ctx,
615627
ggml_threadpool_t threadpool,

src/llama-model.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3845,6 +3845,11 @@ int32_t llama_model_n_head_kv(const struct llama_model * model) {
38453845
return model->hparams.n_head_kv();
38463846
}
38473847

3848+
const ggml_backend_dev_t * llama_model_get_devices(const struct llama_model * model, size_t * out_len) {
3849+
*out_len = model->devices.size();
3850+
return model->devices.data();
3851+
}
3852+
38483853
// deprecated
38493854
int32_t llama_n_ctx_train(const struct llama_model * model) {
38503855
return llama_model_n_ctx_train(model);

0 commit comments

Comments
 (0)