11package jsonrpc2
22
33import (
4+ "context"
45 "crypto/rand"
56 "encoding/json"
67 "errors"
@@ -202,8 +203,8 @@ func (c *Client) SetRequestHandler(method string, handler RequestHandler) {
202203}
203204
204205// Request sends a JSON-RPC request and waits for the response
205- func (c * Client ) Request (method string , params any ) (json.RawMessage , error ) {
206- return c .RequestWithInlineResponse (method , params , nil )
206+ func (c * Client ) Request (ctx context. Context , method string , params any ) (json.RawMessage , error ) {
207+ return c .RequestWithInlineResponse (ctx , method , params , nil )
207208}
208209
209210// RequestWithInlineResponse sends a JSON-RPC request and waits for the response,
@@ -214,7 +215,13 @@ func (c *Client) Request(method string, params any) (json.RawMessage, error) {
214215// server in the response) before any subsequent notification on the same
215216// connection is dispatched. If the callback returns an error, that error is
216217// returned to the awaiter in place of the response.
217- func (c * Client ) RequestWithInlineResponse (method string , params any , onResponseInline func (json.RawMessage ) error ) (json.RawMessage , error ) {
218+ func (c * Client ) RequestWithInlineResponse (ctx context.Context , method string , params any , onResponseInline func (json.RawMessage ) error ) (json.RawMessage , error ) {
219+ select {
220+ case <- ctx .Done ():
221+ return nil , ctx .Err ()
222+ default :
223+ }
224+
218225 requestID := generateUUID ()
219226
220227 // Create response channel
@@ -237,6 +244,8 @@ func (c *Client) RequestWithInlineResponse(method string, params any, onResponse
237244 // Check if process already exited before sending
238245 if c .processDone != nil {
239246 select {
247+ case <- ctx .Done ():
248+ return nil , ctx .Err ()
240249 case <- c .processDone :
241250 if err := c .getProcessError (); err != nil {
242251 return nil , err
@@ -266,13 +275,18 @@ func (c *Client) RequestWithInlineResponse(method string, params any, onResponse
266275 Params : paramsData ,
267276 }
268277
269- if err := c .sendMessage (request ); err != nil {
278+ if err := c .sendMessage (ctx , request ); err != nil {
279+ if ctxErr := ctx .Err (); ctxErr != nil {
280+ return nil , ctxErr
281+ }
270282 return nil , fmt .Errorf ("failed to send request: %w" , err )
271283 }
272284
273285 // Wait for response, also checking for process exit
274286 if c .processDone != nil {
275287 select {
288+ case <- ctx .Done ():
289+ return nil , ctx .Err ()
276290 case response := <- responseChan :
277291 if response .Error != nil {
278292 return nil , response .Error
@@ -288,6 +302,8 @@ func (c *Client) RequestWithInlineResponse(method string, params any, onResponse
288302 }
289303 }
290304 select {
305+ case <- ctx .Done ():
306+ return nil , ctx .Err ()
291307 case response := <- responseChan :
292308 if response .Error != nil {
293309 return nil , response .Error
@@ -301,13 +317,26 @@ func (c *Client) RequestWithInlineResponse(method string, params any, onResponse
301317// sendMessage writes a message to the stream.
302318// Write serialization is achieved via a 1-buffered channel that holds the
303319// writer when not in use, avoiding the need for a mutex on the write path.
304- func (c * Client ) sendMessage (message any ) error {
320+ func (c * Client ) sendMessage (ctx context.Context , message any ) error {
321+ select {
322+ case <- ctx .Done ():
323+ return ctx .Err ()
324+ default :
325+ }
326+
305327 data , err := json .Marshal (message )
306328 if err != nil {
307329 return fmt .Errorf ("failed to marshal message: %w" , err )
308330 }
309331
310- w := <- c .writer
332+ var w * headerWriter
333+ select {
334+ case <- ctx .Done ():
335+ return ctx .Err ()
336+ case <- c .stopChan :
337+ return fmt .Errorf ("client stopped" )
338+ case w = <- c .writer :
339+ }
311340 defer func () { c .writer <- w }()
312341 return w .Write (data )
313342}
@@ -402,13 +431,15 @@ func (c *Client) handleResponse(response *Response) {
402431}
403432
404433func (c * Client ) handleRequest (request * Request ) {
434+ ctx := context .Background ()
435+
405436 c .mu .Lock ()
406437 handler := c .requestHandlers [request .Method ]
407438 c .mu .Unlock ()
408439
409440 if handler == nil {
410441 if request .IsCall () {
411- c .sendErrorResponse (request .ID , & Error {
442+ c .sendErrorResponse (ctx , request .ID , & Error {
412443 Code : ErrMethodNotFound .Code ,
413444 Message : fmt .Sprintf ("Method not found: %s" , request .Method ),
414445 })
@@ -425,7 +456,7 @@ func (c *Client) handleRequest(request *Request) {
425456 go func () {
426457 defer func () {
427458 if r := recover (); r != nil {
428- c .sendErrorResponse (request .ID , & Error {
459+ c .sendErrorResponse (ctx , request .ID , & Error {
429460 Code : ErrInternal .Code ,
430461 Message : fmt .Sprintf ("request handler panic: %v" , r ),
431462 })
@@ -434,31 +465,31 @@ func (c *Client) handleRequest(request *Request) {
434465
435466 result , err := handler (request .Params )
436467 if err != nil {
437- c .sendErrorResponse (request .ID , err )
468+ c .sendErrorResponse (ctx , request .ID , err )
438469 return
439470 }
440- c .sendResponse (request .ID , result )
471+ c .sendResponse (ctx , request .ID , result )
441472 }()
442473}
443474
444- func (c * Client ) sendResponse (id json.RawMessage , result json.RawMessage ) {
475+ func (c * Client ) sendResponse (ctx context. Context , id json.RawMessage , result json.RawMessage ) {
445476 response := Response {
446477 JSONRPC : version ,
447478 ID : id ,
448479 Result : result ,
449480 }
450- if err := c .sendMessage (response ); err != nil {
481+ if err := c .sendMessage (ctx , response ); err != nil {
451482 fmt .Printf ("Failed to send JSON-RPC response: %v\n " , err )
452483 }
453484}
454485
455- func (c * Client ) sendErrorResponse (id json.RawMessage , rpcErr * Error ) {
486+ func (c * Client ) sendErrorResponse (ctx context. Context , id json.RawMessage , rpcErr * Error ) {
456487 response := Response {
457488 JSONRPC : version ,
458489 ID : id ,
459490 Error : rpcErr ,
460491 }
461- if err := c .sendMessage (response ); err != nil {
492+ if err := c .sendMessage (ctx , response ); err != nil {
462493 fmt .Printf ("Failed to send JSON-RPC error response: %v\n " , err )
463494 }
464495}
0 commit comments