Skip to content

Commit d8ce886

Browse files
authored
Make Go JSON-RPC requests context-aware (#1643)
* Make Go JSON-RPC requests context-aware * Fix JSON-RPC pending request test race
1 parent 736bd6e commit d8ce886

7 files changed

Lines changed: 397 additions & 260 deletions

File tree

go/canvas_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,7 @@ func TestCanvasRegisterClientSessionAPIHandlers_RawJSONRoundTrip(t *testing.T) {
251251
_ = serverToClientReader.Close()
252252
})
253253

254-
raw, err := requester.Request("canvas.open", map[string]any{
254+
raw, err := requester.Request(t.Context(), "canvas.open", map[string]any{
255255
"sessionId": "s1",
256256
"extensionId": "ext",
257257
"canvasId": "echo",
@@ -284,7 +284,7 @@ func TestCanvasRegisterClientSessionAPIHandlers_RawJSONRoundTrip(t *testing.T) {
284284
t.Fatalf("expected status=ready, got %v", decoded["status"])
285285
}
286286

287-
actionRaw, err := requester.Request("canvas.action.invoke", map[string]any{
287+
actionRaw, err := requester.Request(t.Context(), "canvas.action.invoke", map[string]any{
288288
"sessionId": "s1",
289289
"extensionId": "ext",
290290
"canvasId": "echo",

go/client.go

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -830,7 +830,7 @@ func (c *Client) CreateSession(ctx context.Context, config *SessionConfig) (*Ses
830830
}
831831
}
832832

833-
result, err := c.client.RequestWithInlineResponse("session.create", req, inlineCb)
833+
result, err := c.client.RequestWithInlineResponse(ctx, "session.create", req, inlineCb)
834834
if err != nil {
835835
if registeredSessionID != "" {
836836
c.sessionsMux.Lock()
@@ -1075,7 +1075,7 @@ func (c *Client) ResumeSessionWithOptions(ctx context.Context, sessionID string,
10751075
session.clientSessionAPIs.SessionFS = newSessionFSAdapter(provider)
10761076
}
10771077

1078-
result, err := c.client.Request("session.resume", req)
1078+
result, err := c.client.Request(ctx, "session.resume", req)
10791079
if err != nil {
10801080
c.sessionsMux.Lock()
10811081
delete(c.sessions, sessionID)
@@ -1136,7 +1136,7 @@ func (c *Client) ListSessions(ctx context.Context, filter *SessionListFilter) ([
11361136
if filter != nil {
11371137
params.Filter = filter
11381138
}
1139-
result, err := c.client.Request("session.list", params)
1139+
result, err := c.client.Request(ctx, "session.list", params)
11401140
if err != nil {
11411141
return nil, err
11421142
}
@@ -1168,7 +1168,7 @@ func (c *Client) GetSessionMetadata(ctx context.Context, sessionID string) (*Ses
11681168
return nil, err
11691169
}
11701170

1171-
result, err := c.client.Request("session.getMetadata", getSessionMetadataRequest{SessionID: sessionID})
1171+
result, err := c.client.Request(ctx, "session.getMetadata", getSessionMetadataRequest{SessionID: sessionID})
11721172
if err != nil {
11731173
return nil, err
11741174
}
@@ -1199,7 +1199,7 @@ func (c *Client) DeleteSession(ctx context.Context, sessionID string) error {
11991199
return err
12001200
}
12011201

1202-
result, err := c.client.Request("session.delete", deleteSessionRequest{SessionID: sessionID})
1202+
result, err := c.client.Request(ctx, "session.delete", deleteSessionRequest{SessionID: sessionID})
12031203
if err != nil {
12041204
return err
12051205
}
@@ -1246,7 +1246,7 @@ func (c *Client) GetLastSessionID(ctx context.Context) (*string, error) {
12461246
return nil, err
12471247
}
12481248

1249-
result, err := c.client.Request("session.getLastId", getLastSessionIDRequest{})
1249+
result, err := c.client.Request(ctx, "session.getLastId", getLastSessionIDRequest{})
12501250
if err != nil {
12511251
return nil, err
12521252
}
@@ -1278,7 +1278,7 @@ func (c *Client) GetForegroundSessionID(ctx context.Context) (*string, error) {
12781278
return nil, err
12791279
}
12801280

1281-
result, err := c.client.Request("session.getForeground", getForegroundSessionRequest{})
1281+
result, err := c.client.Request(ctx, "session.getForeground", getForegroundSessionRequest{})
12821282
if err != nil {
12831283
return nil, err
12841284
}
@@ -1306,7 +1306,7 @@ func (c *Client) SetForegroundSessionID(ctx context.Context, sessionID string) e
13061306
return err
13071307
}
13081308

1309-
result, err := c.client.Request("session.setForeground", setForegroundSessionRequest{SessionID: sessionID})
1309+
result, err := c.client.Request(ctx, "session.setForeground", setForegroundSessionRequest{SessionID: sessionID})
13101310
if err != nil {
13111311
return err
13121312
}
@@ -1446,7 +1446,7 @@ func (c *Client) Ping(ctx context.Context, message string) (*PingResponse, error
14461446
return nil, fmt.Errorf("client not connected")
14471447
}
14481448

1449-
result, err := c.client.Request("ping", pingRequest{Message: message})
1449+
result, err := c.client.Request(ctx, "ping", pingRequest{Message: message})
14501450
if err != nil {
14511451
return nil, err
14521452
}
@@ -1464,7 +1464,7 @@ func (c *Client) GetStatus(ctx context.Context) (*GetStatusResponse, error) {
14641464
return nil, fmt.Errorf("client not connected")
14651465
}
14661466

1467-
result, err := c.client.Request("status.get", getStatusRequest{})
1467+
result, err := c.client.Request(ctx, "status.get", getStatusRequest{})
14681468
if err != nil {
14691469
return nil, err
14701470
}
@@ -1482,7 +1482,7 @@ func (c *Client) GetAuthStatus(ctx context.Context) (*GetAuthStatusResponse, err
14821482
return nil, fmt.Errorf("client not connected")
14831483
}
14841484

1485-
result, err := c.client.Request("auth.getStatus", getAuthStatusRequest{})
1485+
result, err := c.client.Request(ctx, "auth.getStatus", getAuthStatusRequest{})
14861486
if err != nil {
14871487
return nil, err
14881488
}
@@ -1523,7 +1523,7 @@ func (c *Client) ListModels(ctx context.Context) ([]ModelInfo, error) {
15231523
return nil, fmt.Errorf("client not connected")
15241524
}
15251525
// Cache miss - fetch from backend while holding lock
1526-
result, err := c.client.Request("models.list", listModelsRequest{})
1526+
result, err := c.client.Request(ctx, "models.list", listModelsRequest{})
15271527
if err != nil {
15281528
return nil, err
15291529
}

go/internal/jsonrpc2/jsonrpc2.go

Lines changed: 45 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package jsonrpc2
22

33
import (
4+
"context"
45
"crypto/rand"
56
"encoding/json"
67
"errors"
@@ -202,8 +203,8 @@ func (c *Client) SetRequestHandler(method string, handler RequestHandler) {
202203
}
203204

204205
// Request sends a JSON-RPC request and waits for the response
205-
func (c *Client) Request(method string, params any) (json.RawMessage, error) {
206-
return c.RequestWithInlineResponse(method, params, nil)
206+
func (c *Client) Request(ctx context.Context, method string, params any) (json.RawMessage, error) {
207+
return c.RequestWithInlineResponse(ctx, method, params, nil)
207208
}
208209

209210
// RequestWithInlineResponse sends a JSON-RPC request and waits for the response,
@@ -214,7 +215,13 @@ func (c *Client) Request(method string, params any) (json.RawMessage, error) {
214215
// server in the response) before any subsequent notification on the same
215216
// connection is dispatched. If the callback returns an error, that error is
216217
// returned to the awaiter in place of the response.
217-
func (c *Client) RequestWithInlineResponse(method string, params any, onResponseInline func(json.RawMessage) error) (json.RawMessage, error) {
218+
func (c *Client) RequestWithInlineResponse(ctx context.Context, method string, params any, onResponseInline func(json.RawMessage) error) (json.RawMessage, error) {
219+
select {
220+
case <-ctx.Done():
221+
return nil, ctx.Err()
222+
default:
223+
}
224+
218225
requestID := generateUUID()
219226

220227
// Create response channel
@@ -237,6 +244,8 @@ func (c *Client) RequestWithInlineResponse(method string, params any, onResponse
237244
// Check if process already exited before sending
238245
if c.processDone != nil {
239246
select {
247+
case <-ctx.Done():
248+
return nil, ctx.Err()
240249
case <-c.processDone:
241250
if err := c.getProcessError(); err != nil {
242251
return nil, err
@@ -266,13 +275,18 @@ func (c *Client) RequestWithInlineResponse(method string, params any, onResponse
266275
Params: paramsData,
267276
}
268277

269-
if err := c.sendMessage(request); err != nil {
278+
if err := c.sendMessage(ctx, request); err != nil {
279+
if ctxErr := ctx.Err(); ctxErr != nil {
280+
return nil, ctxErr
281+
}
270282
return nil, fmt.Errorf("failed to send request: %w", err)
271283
}
272284

273285
// Wait for response, also checking for process exit
274286
if c.processDone != nil {
275287
select {
288+
case <-ctx.Done():
289+
return nil, ctx.Err()
276290
case response := <-responseChan:
277291
if response.Error != nil {
278292
return nil, response.Error
@@ -288,6 +302,8 @@ func (c *Client) RequestWithInlineResponse(method string, params any, onResponse
288302
}
289303
}
290304
select {
305+
case <-ctx.Done():
306+
return nil, ctx.Err()
291307
case response := <-responseChan:
292308
if response.Error != nil {
293309
return nil, response.Error
@@ -301,13 +317,26 @@ func (c *Client) RequestWithInlineResponse(method string, params any, onResponse
301317
// sendMessage writes a message to the stream.
302318
// Write serialization is achieved via a 1-buffered channel that holds the
303319
// writer when not in use, avoiding the need for a mutex on the write path.
304-
func (c *Client) sendMessage(message any) error {
320+
func (c *Client) sendMessage(ctx context.Context, message any) error {
321+
select {
322+
case <-ctx.Done():
323+
return ctx.Err()
324+
default:
325+
}
326+
305327
data, err := json.Marshal(message)
306328
if err != nil {
307329
return fmt.Errorf("failed to marshal message: %w", err)
308330
}
309331

310-
w := <-c.writer
332+
var w *headerWriter
333+
select {
334+
case <-ctx.Done():
335+
return ctx.Err()
336+
case <-c.stopChan:
337+
return fmt.Errorf("client stopped")
338+
case w = <-c.writer:
339+
}
311340
defer func() { c.writer <- w }()
312341
return w.Write(data)
313342
}
@@ -402,13 +431,15 @@ func (c *Client) handleResponse(response *Response) {
402431
}
403432

404433
func (c *Client) handleRequest(request *Request) {
434+
ctx := context.Background()
435+
405436
c.mu.Lock()
406437
handler := c.requestHandlers[request.Method]
407438
c.mu.Unlock()
408439

409440
if handler == nil {
410441
if request.IsCall() {
411-
c.sendErrorResponse(request.ID, &Error{
442+
c.sendErrorResponse(ctx, request.ID, &Error{
412443
Code: ErrMethodNotFound.Code,
413444
Message: fmt.Sprintf("Method not found: %s", request.Method),
414445
})
@@ -425,7 +456,7 @@ func (c *Client) handleRequest(request *Request) {
425456
go func() {
426457
defer func() {
427458
if r := recover(); r != nil {
428-
c.sendErrorResponse(request.ID, &Error{
459+
c.sendErrorResponse(ctx, request.ID, &Error{
429460
Code: ErrInternal.Code,
430461
Message: fmt.Sprintf("request handler panic: %v", r),
431462
})
@@ -434,31 +465,31 @@ func (c *Client) handleRequest(request *Request) {
434465

435466
result, err := handler(request.Params)
436467
if err != nil {
437-
c.sendErrorResponse(request.ID, err)
468+
c.sendErrorResponse(ctx, request.ID, err)
438469
return
439470
}
440-
c.sendResponse(request.ID, result)
471+
c.sendResponse(ctx, request.ID, result)
441472
}()
442473
}
443474

444-
func (c *Client) sendResponse(id json.RawMessage, result json.RawMessage) {
475+
func (c *Client) sendResponse(ctx context.Context, id json.RawMessage, result json.RawMessage) {
445476
response := Response{
446477
JSONRPC: version,
447478
ID: id,
448479
Result: result,
449480
}
450-
if err := c.sendMessage(response); err != nil {
481+
if err := c.sendMessage(ctx, response); err != nil {
451482
fmt.Printf("Failed to send JSON-RPC response: %v\n", err)
452483
}
453484
}
454485

455-
func (c *Client) sendErrorResponse(id json.RawMessage, rpcErr *Error) {
486+
func (c *Client) sendErrorResponse(ctx context.Context, id json.RawMessage, rpcErr *Error) {
456487
response := Response{
457488
JSONRPC: version,
458489
ID: id,
459490
Error: rpcErr,
460491
}
461-
if err := c.sendMessage(response); err != nil {
492+
if err := c.sendMessage(ctx, response); err != nil {
462493
fmt.Printf("Failed to send JSON-RPC error response: %v\n", err)
463494
}
464495
}

0 commit comments

Comments
 (0)