@@ -134,6 +134,32 @@ static std::string tokens_to_output_formatted_string(const llama_context *ctx, c
134
134
return out;
135
135
}
136
136
137
+ // Adds an RPC server
138
+ // https://github.com/ggerganov/llama.cpp/compare/4dbc8b9cb71876e005724f4e8f73a3544646bcf5..3edfa7d3753c29e44b964c0ff424d2ea8d5fdee6
139
+ static void add_rpc_devices (std::string servers) {
140
+ auto rpc_servers = string_split<std::string>(servers, ' ,' );
141
+ if (rpc_servers.empty ()) {
142
+ throw std::invalid_argument (" no RPC servers specified" );
143
+ }
144
+ ggml_backend_reg_t rpc_reg = ggml_backend_reg_by_name (" RPC" );
145
+ if (!rpc_reg) {
146
+ throw std::invalid_argument (" failed to find RPC backend" );
147
+ }
148
+ typedef ggml_backend_dev_t (*ggml_backend_rpc_add_device_t )(const char * endpoint);
149
+ ggml_backend_rpc_add_device_t ggml_backend_rpc_add_device_fn = (ggml_backend_rpc_add_device_t ) ggml_backend_reg_get_proc_address (rpc_reg, " ggml_backend_rpc_add_device" );
150
+ if (!ggml_backend_rpc_add_device_fn) {
151
+ throw std::invalid_argument (" failed to find RPC device add function" );
152
+ }
153
+ for (const auto & server : rpc_servers) {
154
+ ggml_backend_dev_t dev = ggml_backend_rpc_add_device_fn (server.c_str ());
155
+ if (dev) {
156
+ ggml_backend_device_register (dev);
157
+ } else {
158
+ throw std::invalid_argument (" failed to register RPC device" );
159
+ }
160
+ }
161
+ }
162
+
137
163
// convert a vector of completion_token_output to json
138
164
static json probs_vector_to_json (const llama_context *ctx, const std::vector<completion_token_output> &probs)
139
165
{
@@ -2282,7 +2308,7 @@ static void params_parse(const backend::ModelOptions* request,
2282
2308
2283
2309
const char *llama_grpc_servers = std::getenv (" LLAMACPP_GRPC_SERVERS" );
2284
2310
if (llama_grpc_servers != NULL ) {
2285
- params. rpc_servers = std::string (llama_grpc_servers);
2311
+ add_rpc_devices ( std::string (llama_grpc_servers) );
2286
2312
}
2287
2313
2288
2314
// TODO: Add yarn
0 commit comments