44#include " ggml-cpp.h"
55
66#include < cinttypes>
7+ #include < cstdlib>
78#include < string>
89#include < vector>
910#include < memory>
@@ -97,6 +98,7 @@ enum rpc_cmd {
9798 RPC_CMD_GET_ALLOC_SIZE,
9899 RPC_CMD_HELLO,
99100 RPC_CMD_COUNT,
101+ RPC_CMD_AUTH,
100102};
101103
102104// Try RPC_CMD_SET_TENSOR_HASH first when data size is larger than this threshold
@@ -108,6 +110,15 @@ struct rpc_msg_hello_rsp {
108110 uint8_t patch;
109111};
110112
113+ struct rpc_msg_auth_req {
114+ uint16_t length;
115+ uint8_t token[256 ];
116+ };
117+
118+ struct rpc_msg_auth_resp {
119+ bool result;
120+ };
121+
111122struct rpc_msg_get_alloc_size_req {
112123 rpc_tensor tensor;
113124};
@@ -426,8 +437,32 @@ static bool send_rpc_cmd(const std::shared_ptr<socket_t> & sock, enum rpc_cmd cm
426437// RPC client-side implementation
427438
428439static bool check_server_version (const std::shared_ptr<socket_t > & sock) {
440+ const char * auth_token_s = std::getenv (" GGML_RPC_TOKEN" );
441+
442+ if (auth_token_s == nullptr ) {
443+ fprintf (stderr, " No authentication token secret found in environment\n " );
444+ return false ;
445+ }
446+
447+ rpc_msg_auth_req auth_request;
448+ auth_request.length = strlen (auth_token_s);
449+ snprintf ((char *)auth_request.token ,
450+ sizeof (auth_request.token ),
451+ " %.*s" ,
452+ (int )sizeof (auth_request.token )-1 ,
453+ auth_token_s);
454+
455+ rpc_msg_auth_resp auth_response;
456+ bool status = send_rpc_cmd (sock, RPC_CMD_AUTH, &auth_request, sizeof (rpc_msg_auth_req), &auth_response, sizeof (rpc_msg_auth_resp));
457+ RPC_STATUS_ASSERT (status);
458+
459+ if (auth_response.result == false ) {
460+ fprintf (stderr, " Failed to authenticate to RPC server\n " );
461+ return false ;
462+ }
463+
429464 rpc_msg_hello_rsp response;
430- bool status = send_rpc_cmd (sock, RPC_CMD_HELLO, nullptr , 0 , &response, sizeof (response));
465+ status = send_rpc_cmd (sock, RPC_CMD_HELLO, nullptr , 0 , &response, sizeof (response));
431466 RPC_STATUS_ASSERT (status);
432467 if (response.major != RPC_PROTO_MAJOR_VERSION || response.minor > RPC_PROTO_MINOR_VERSION) {
433468 fprintf (stderr, " RPC server version mismatch: %d.%d.%d\n " , response.major , response.minor , response.patch );
@@ -1371,14 +1406,86 @@ rpc_server::~rpc_server() {
13711406 }
13721407}
13731408
1409+ // Implementation borrowed from https://github.com/chmike/cst_time_memcmp
1410+ static int cst_time_memcmp (const void *m1, const void *m2, size_t n) {
1411+ const unsigned char *pm1 = (const unsigned char *)m1;
1412+ const unsigned char *pm2 = (const unsigned char *)m2;
1413+ int res = 0 , diff;
1414+ if (n > 0 ) {
1415+ do {
1416+ --n;
1417+ diff = pm1[n] - pm2[n];
1418+ res = (res & -!diff) | diff;
1419+ } while (n != 0 );
1420+ }
1421+ return (res > 0 ) - (res < 0 );
1422+ }
1423+
13741424static void rpc_serve_client (ggml_backend_t backend, const char * cache_dir,
13751425 sockfd_t sockfd, size_t free_mem, size_t total_mem) {
1426+
1427+ const char * auth_token_s = std::getenv (" GGML_RPC_TOKEN" );
1428+ if (auth_token_s == nullptr ) {
1429+ fprintf (stderr, " [%s] Authentication token secret not set\n " , __func__);
1430+ return ;
1431+ }
1432+
1433+ size_t auth_token_s_len = strlen (auth_token_s);
1434+
13761435 rpc_server server (backend, cache_dir);
13771436 uint8_t cmd;
1437+
1438+ if (!recv_data (sockfd, &cmd, 1 )) {
1439+ return ;
1440+ }
1441+
1442+ // The first command sent by the client must be AUTH
1443+ if (cmd != RPC_CMD_AUTH) {
1444+ fprintf (stderr, " Expected AUTH command, update client\n " );
1445+ return ;
1446+ }
1447+
1448+ rpc_msg_auth_req request;
1449+ if (!recv_msg (sockfd, &request, sizeof (request))) {
1450+ fprintf (stderr, " Failed to process AUTH request, update client\n " );
1451+ return ;
1452+ }
1453+
1454+ rpc_msg_auth_resp auth_response;
1455+
1456+ // This is insecure for the following reasons:
1457+ // 0) It is probably susceptible to cache timing attacks
1458+ // 1) It may leak the size of the secret auth token
1459+ // 2) It can be brute forced
1460+ // 3) It compares secrets directly, not their hashes
1461+ // 4) It can be intercepted on the wire (use socat/openssl)
1462+ // 5) The token doesn't expire
1463+ if (request.length != auth_token_s_len ||
1464+ cst_time_memcmp ((void *) auth_token_s, (void *) &request.token , auth_token_s_len) != 0 ) {
1465+ struct sockaddr_in peer_addr;
1466+ socklen_t peer_len = sizeof (peer_addr);
1467+
1468+ if (getpeername (sockfd, (struct sockaddr *)&peer_addr, &peer_len) == 0 ) {
1469+ char *ip = inet_ntoa (peer_addr.sin_addr );
1470+ fprintf (stderr, " [%s] Invalid authentication token from %s\n " ,
1471+ __func__, ip);
1472+ } else {
1473+ fprintf (stderr, " [%s] Invalid authentication token from unknown (getpeername failed)\n " ,
1474+ __func__);
1475+ }
1476+ auth_response.result = false ;
1477+ send_msg (sockfd, &auth_response, sizeof (auth_response));
1478+ return ;
1479+ }
1480+
1481+ auth_response.result = true ;
1482+ send_msg (sockfd, &auth_response, sizeof (auth_response));
1483+
13781484 if (!recv_data (sockfd, &cmd, 1 )) {
13791485 return ;
13801486 }
1381- // the first command sent by the client must be HELLO
1487+
1488+ // The second command sent by the client must be HELLO
13821489 if (cmd != RPC_CMD_HELLO) {
13831490 fprintf (stderr, " Expected HELLO command, update client\n " );
13841491 return ;
0 commit comments