Skip to content

Commit bbef3e4

Browse files
committed
internal/mcp: implement roots
Add support for roots. Change-Id: Ia50abc88f0047238272d698f30ce615b1a8fd486 Reviewed-on: https://go-review.googlesource.com/c/tools/+/671360 LUCI-TryBot-Result: Go LUCI <[email protected]> Reviewed-by: Sam Thanawalla <[email protected]> Reviewed-by: Robert Findley <[email protected]>
1 parent 2835a17 commit bbef3e4

File tree

6 files changed

+133
-28
lines changed

6 files changed

+133
-28
lines changed

internal/mcp/client.go

Lines changed: 49 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"context"
1010
"encoding/json"
1111
"fmt"
12+
"slices"
1213
"sync"
1314

1415
jsonrpc2 "golang.org/x/tools/internal/jsonrpc2_v2"
@@ -24,6 +25,7 @@ type Client struct {
2425
opts ClientOptions
2526
mu sync.Mutex
2627
conn *jsonrpc2.Connection
28+
roots *featureSet[protocol.Root]
2729
initializeResult *protocol.InitializeResult
2830
}
2931

@@ -37,6 +39,7 @@ func NewClient(name, version string, t Transport, opts *ClientOptions) *Client {
3739
name: name,
3840
version: version,
3941
transport: t,
42+
roots: newFeatureSet(func(r protocol.Root) string { return r.URI }),
4043
}
4144
if opts != nil {
4245
c.opts = *opts
@@ -106,13 +109,47 @@ func (c *Client) Wait() error {
106109
return c.conn.Wait()
107110
}
108111

109-
func (*Client) handle(ctx context.Context, req *jsonrpc2.Request) (any, error) {
112+
// AddRoots adds the given roots to the client,
113+
// replacing any with the same URIs,
114+
// and notifies any connected servers.
115+
// TODO: notification
116+
func (c *Client) AddRoots(roots ...protocol.Root) {
117+
c.mu.Lock()
118+
defer c.mu.Unlock()
119+
c.roots.add(roots...)
120+
}
121+
122+
// RemoveRoots removes the roots with the given URIs,
123+
// and notifies any connected servers if the list has changed.
124+
// It is not an error to remove a nonexistent root.
125+
// TODO: notification
126+
func (c *Client) RemoveRoots(uris ...string) {
127+
c.mu.Lock()
128+
defer c.mu.Unlock()
129+
c.roots.remove(uris...)
130+
}
131+
132+
func (c *Client) listRoots(_ context.Context, _ *protocol.ListRootsParams) (*protocol.ListRootsResult, error) {
133+
c.mu.Lock()
134+
defer c.mu.Unlock()
135+
return &protocol.ListRootsResult{
136+
Roots: slices.Collect(c.roots.all()),
137+
}, nil
138+
}
139+
140+
func (c *Client) handle(ctx context.Context, req *jsonrpc2.Request) (any, error) {
141+
// TODO: when we switch to ClientSessions, use a copy of the server's dispatch function, or
142+
// maybe just add another type parameter.
143+
//
110144
// No need to check that the connection is initialized, since we initialize
111145
// it in Connect.
112146
switch req.Method {
113147
case "ping":
114148
// The spec says that 'ping' expects an empty object result.
115149
return struct{}{}, nil
150+
case "roots/list":
151+
// ListRootsParams happens to be unused.
152+
return c.listRoots(ctx, nil)
116153
}
117154
return nil, jsonrpc2.ErrNotHandled
118155
}
@@ -162,10 +199,6 @@ func (c *Client) ListTools(ctx context.Context) ([]protocol.Tool, error) {
162199
}
163200

164201
// CallTool calls the tool with the given name and arguments.
165-
//
166-
// TODO(jba): make the following true:
167-
// If the provided arguments do not conform to the schema for the given tool,
168-
// the call fails.
169202
func (c *Client) CallTool(ctx context.Context, name string, args map[string]any) (_ *protocol.CallToolResult, err error) {
170203
defer func() {
171204
if err != nil {
@@ -180,14 +213,17 @@ func (c *Client) CallTool(ctx context.Context, name string, args map[string]any)
180213
}
181214
argsJSON[name] = argJSON
182215
}
183-
var (
184-
params = &protocol.CallToolParams{
185-
Name: name,
186-
Arguments: argsJSON,
187-
}
188-
result protocol.CallToolResult
189-
)
190-
if err := call(ctx, c.conn, "tools/call", params, &result); err != nil {
216+
217+
params := &protocol.CallToolParams{
218+
Name: name,
219+
Arguments: argsJSON,
220+
}
221+
return standardCall[protocol.CallToolResult](ctx, c.conn, "tools/call", params)
222+
}
223+
224+
func standardCall[TRes, TParams any](ctx context.Context, conn *jsonrpc2.Connection, method string, params TParams) (*TRes, error) {
225+
var result TRes
226+
if err := call(ctx, conn, method, params, &result); err != nil {
191227
return nil, err
192228
}
193229
return &result, nil

internal/mcp/mcp_test.go

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ func TestEndToEnd(t *testing.T) {
6363
)
6464

6565
// Connect the server.
66-
cc, err := s.Connect(ctx, st, nil)
66+
sc, err := s.Connect(ctx, st, nil)
6767
if err != nil {
6868
t.Fatal(err)
6969
}
@@ -75,13 +75,14 @@ func TestEndToEnd(t *testing.T) {
7575
var clientWG sync.WaitGroup
7676
clientWG.Add(1)
7777
go func() {
78-
if err := cc.Wait(); err != nil {
78+
if err := sc.Wait(); err != nil {
7979
t.Errorf("server failed: %v", err)
8080
}
8181
clientWG.Done()
8282
}()
8383

8484
c := NewClient("testClient", "v1.0.0", ct, nil)
85+
c.AddRoots(protocol.Root{URI: "file:///root"})
8586

8687
// Connect the client.
8788
if err := c.Start(ctx); err != nil {
@@ -182,6 +183,13 @@ func TestEndToEnd(t *testing.T) {
182183
t.Errorf("tools/call 'fail' mismatch (-want +got):\n%s", diff)
183184
}
184185

186+
rootRes, err := sc.ListRoots(ctx, &protocol.ListRootsParams{})
187+
gotRoots := rootRes.Roots
188+
wantRoots := slices.Collect(c.roots.all())
189+
if diff := cmp.Diff(wantRoots, gotRoots); diff != "" {
190+
t.Errorf("roots/list mismatch (-want +got):\n%s", diff)
191+
}
192+
185193
// Disconnect.
186194
c.Close()
187195
clientWG.Wait()

internal/mcp/protocol/generate.go

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,15 +75,22 @@ var declarations = config{
7575
Fields: config{"Params": {Name: "ListPromptsParams"}},
7676
},
7777
"ListPromptsResult": {Name: "ListPromptsResult"},
78+
"ListRootsRequest": {
79+
Fields: config{"Params": {Name: "ListRootsParams"}},
80+
},
81+
"ListRootsResult": {Name: "ListRootsResult"},
7882
"ListToolsRequest": {
7983
Fields: config{"Params": {Name: "ListToolsParams"}},
8084
},
8185
"ListToolsResult": {Name: "ListToolsResult"},
8286
"Prompt": {Name: "Prompt"},
8387
"PromptMessage": {Name: "PromptMessage"},
8488
"PromptArgument": {Name: "PromptArgument"},
89+
"ProgressToken": {Substitute: "any"}, // null|number|string
8590
"RequestId": {Substitute: "any"}, // null|number|string
8691
"Role": {Name: "Role"},
92+
"Root": {Name: "Root"},
93+
8794
"ServerCapabilities": {
8895
Name: "ServerCapabilities",
8996
Fields: config{
@@ -243,7 +250,11 @@ func writeType(w io.Writer, config *typeConfig, def *jsonschema.Schema, named ma
243250
// For types that explicitly allow additional properties, we can either
244251
// unmarshal them into a map[string]any, or delay unmarshalling with
245252
// json.RawMessage. For now, use json.RawMessage as it defers the choice.
246-
if def.Type == "object" && canHaveAdditionalProperties(def) {
253+
//
254+
// TODO(jba): further refine this classification of object schemas.
255+
// For example, the typescript "object" type, which should map to a Go "any",
256+
// is represented in schema.json by `{type: object, properties: {}, additionalProperties: true}`.
257+
if def.Type == "object" && canHaveAdditionalProperties(def) && def.Properties == nil {
247258
w.Write([]byte("map[string]"))
248259
return writeType(w, nil, def.AdditionalProperties, named)
249260
}

internal/mcp/protocol/protocol.go

Lines changed: 47 additions & 6 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

internal/mcp/root.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
// Copyright 2025 The Go Authors. All rights reserved.
2+
// Use of this source code is governed by a BSD-style
3+
// license that can be found in the LICENSE file.
4+
5+
package mcp

internal/mcp/server.go

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,8 @@ func NewServer(name, version string, opts *ServerOptions) *Server {
5757
}
5858
}
5959

60-
// AddPrompts adds the given prompts to the server.
60+
// AddPrompts adds the given prompts to the server,
61+
// replacing any with the same names.
6162
func (s *Server) AddPrompts(prompts ...*Prompt) {
6263
s.mu.Lock()
6364
defer s.mu.Unlock()
@@ -67,7 +68,7 @@ func (s *Server) AddPrompts(prompts ...*Prompt) {
6768
// TODO(rfindley): notify connected clients
6869
}
6970

70-
// RemovePrompts removes if the prompts with the given names.
71+
// RemovePrompts removes the prompts with the given names.
7172
// It is not an error to remove a nonexistent prompt.
7273
func (s *Server) RemovePrompts(names ...string) {
7374
s.mu.Lock()
@@ -77,9 +78,8 @@ func (s *Server) RemovePrompts(names ...string) {
7778
}
7879
}
7980

80-
// AddTools adds the given tools to the server.
81-
//
82-
// TODO(rfindley): notify connected clients of any changes.
81+
// AddTools adds the given tools to the server,
82+
// replacing any with the same names.
8383
func (s *Server) AddTools(tools ...*Tool) {
8484
s.mu.Lock()
8585
defer s.mu.Unlock()
@@ -89,7 +89,7 @@ func (s *Server) AddTools(tools ...*Tool) {
8989
// TODO(rfindley): notify connected clients
9090
}
9191

92-
// RemoveTools removes if the tools with the given names.
92+
// RemoveTools removes the tools with the given names.
9393
// It is not an error to remove a nonexistent tool.
9494
func (s *Server) RemoveTools(names ...string) {
9595
s.mu.Lock()
@@ -210,6 +210,10 @@ func (cc *ServerConnection) Ping(ctx context.Context) error {
210210
return call(ctx, cc.conn, "ping", nil, nil)
211211
}
212212

213+
func (cc *ServerConnection) ListRoots(ctx context.Context, params *protocol.ListRootsParams) (*protocol.ListRootsResult, error) {
214+
return standardCall[protocol.ListRootsResult](ctx, cc.conn, "roots/list", params)
215+
}
216+
213217
func (cc *ServerConnection) handle(ctx context.Context, req *jsonrpc2.Request) (any, error) {
214218
cc.mu.Lock()
215219
initialized := cc.initialized

0 commit comments

Comments
 (0)