Skip to content

Commit 39ee6ec

Browse files
committed
feat: add custom handler support for all MCP server methods
- Add custom handler fields to MCPServer struct for all basic MCP methods - Implement custom handler logic in all handle* methods with proper error handling - Support custom handlers for: Initialize, Ping, SetLevel, ListResources, ListResourceTemplates, ReadResource, ListPrompts, GetPrompt, ListTools, CallTool, and Notification methods - Maintain backward compatibility by falling back to default behavior when custom handlers are not set - Enable more flexible server customization and middleware integration
1 parent f60537b commit 39ee6ec

File tree

1 file changed

+138
-4
lines changed

1 file changed

+138
-4
lines changed

server/server.go

Lines changed: 138 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,19 @@ type MCPServer struct {
166166
paginationLimit *int
167167
sessions sync.Map
168168
hooks *Hooks
169+
170+
// custom handlers for basic methods
171+
InitializeHandler func(ctx context.Context, request mcp.InitializeRequest) (*mcp.InitializeResult, error)
172+
PingHandler func(ctx context.Context, request mcp.PingRequest) (*mcp.EmptyResult, error)
173+
ListResourcesHandler func(ctx context.Context, request mcp.ListResourcesRequest) (*mcp.ListResourcesResult, error)
174+
ListResourceTemplatesHandler func(ctx context.Context, request mcp.ListResourceTemplatesRequest) (*mcp.ListResourceTemplatesResult, error)
175+
ReadResourceHandler func(ctx context.Context, request mcp.ReadResourceRequest) (*mcp.ReadResourceResult, error)
176+
ListPromptsHandler func(ctx context.Context, request mcp.ListPromptsRequest) (*mcp.ListPromptsResult, error)
177+
GetPromptHandler func(ctx context.Context, request mcp.GetPromptRequest) (*mcp.GetPromptResult, error)
178+
ListToolsHandler func(ctx context.Context, request mcp.ListToolsRequest) (*mcp.ListToolsResult, error)
179+
CallToolHandler func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error)
180+
SetLevelHandler func(ctx context.Context, request mcp.SetLevelRequest) (*mcp.EmptyResult, error)
181+
NotificationHandler func(ctx context.Context, notification mcp.JSONRPCNotification)
169182
}
170183

171184
// WithPaginationLimit sets the pagination limit for the server.
@@ -650,9 +663,21 @@ func (s *MCPServer) AddNotificationHandler(
650663

651664
func (s *MCPServer) handleInitialize(
652665
ctx context.Context,
653-
_ any,
666+
id any,
654667
request mcp.InitializeRequest,
655668
) (*mcp.InitializeResult, *requestError) {
669+
if s.InitializeHandler != nil {
670+
result, err := s.InitializeHandler(ctx, request)
671+
if err != nil {
672+
return nil, &requestError{
673+
id: id,
674+
code: mcp.INTERNAL_ERROR,
675+
err: err,
676+
}
677+
}
678+
return result, nil
679+
}
680+
656681
capabilities := mcp.ServerCapabilities{}
657682

658683
// Only add resource capabilities if they're configured
@@ -736,10 +761,21 @@ func (s *MCPServer) protocolVersion(clientVersion string) string {
736761
}
737762

738763
func (s *MCPServer) handlePing(
739-
_ context.Context,
740-
_ any,
741-
_ mcp.PingRequest,
764+
ctx context.Context,
765+
id any,
766+
request mcp.PingRequest,
742767
) (*mcp.EmptyResult, *requestError) {
768+
if s.PingHandler != nil {
769+
result, err := s.PingHandler(ctx, request)
770+
if err != nil {
771+
return nil, &requestError{
772+
id: id,
773+
code: mcp.INTERNAL_ERROR,
774+
err: err,
775+
}
776+
}
777+
return result, nil
778+
}
743779
return &mcp.EmptyResult{}, nil
744780
}
745781

@@ -748,6 +784,18 @@ func (s *MCPServer) handleSetLevel(
748784
id any,
749785
request mcp.SetLevelRequest,
750786
) (*mcp.EmptyResult, *requestError) {
787+
if s.SetLevelHandler != nil {
788+
result, err := s.SetLevelHandler(ctx, request)
789+
if err != nil {
790+
return nil, &requestError{
791+
id: id,
792+
code: mcp.INTERNAL_ERROR,
793+
err: err,
794+
}
795+
}
796+
return result, nil
797+
}
798+
751799
clientSession := ClientSessionFromContext(ctx)
752800
if clientSession == nil || !clientSession.Initialized() {
753801
return nil, &requestError{
@@ -827,6 +875,18 @@ func (s *MCPServer) handleListResources(
827875
id any,
828876
request mcp.ListResourcesRequest,
829877
) (*mcp.ListResourcesResult, *requestError) {
878+
if s.ListResourcesHandler != nil {
879+
result, err := s.ListResourcesHandler(ctx, request)
880+
if err != nil {
881+
return nil, &requestError{
882+
id: id,
883+
code: mcp.INTERNAL_ERROR,
884+
err: err,
885+
}
886+
}
887+
return result, nil
888+
}
889+
830890
s.resourcesMu.RLock()
831891
resourceMap := make(map[string]mcp.Resource, len(s.resources))
832892
for uri, entry := range s.resources {
@@ -880,6 +940,18 @@ func (s *MCPServer) handleListResourceTemplates(
880940
id any,
881941
request mcp.ListResourceTemplatesRequest,
882942
) (*mcp.ListResourceTemplatesResult, *requestError) {
943+
if s.ListResourceTemplatesHandler != nil {
944+
result, err := s.ListResourceTemplatesHandler(ctx, request)
945+
if err != nil {
946+
return nil, &requestError{
947+
id: id,
948+
code: mcp.INTERNAL_ERROR,
949+
err: err,
950+
}
951+
}
952+
return result, nil
953+
}
954+
883955
s.resourcesMu.RLock()
884956
templates := make([]mcp.ResourceTemplate, 0, len(s.resourceTemplates))
885957
for _, entry := range s.resourceTemplates {
@@ -916,6 +988,18 @@ func (s *MCPServer) handleReadResource(
916988
id any,
917989
request mcp.ReadResourceRequest,
918990
) (*mcp.ReadResourceResult, *requestError) {
991+
if s.ReadResourceHandler != nil {
992+
result, err := s.ReadResourceHandler(ctx, request)
993+
if err != nil {
994+
return nil, &requestError{
995+
id: id,
996+
code: mcp.INTERNAL_ERROR,
997+
err: err,
998+
}
999+
}
1000+
return result, nil
1001+
}
1002+
9191003
s.resourcesMu.RLock()
9201004

9211005
// First check session-specific resources
@@ -1030,6 +1114,18 @@ func (s *MCPServer) handleListPrompts(
10301114
id any,
10311115
request mcp.ListPromptsRequest,
10321116
) (*mcp.ListPromptsResult, *requestError) {
1117+
if s.ListPromptsHandler != nil {
1118+
result, err := s.ListPromptsHandler(ctx, request)
1119+
if err != nil {
1120+
return nil, &requestError{
1121+
id: id,
1122+
code: mcp.INTERNAL_ERROR,
1123+
err: err,
1124+
}
1125+
}
1126+
return result, nil
1127+
}
1128+
10331129
s.promptsMu.RLock()
10341130
prompts := make([]mcp.Prompt, 0, len(s.prompts))
10351131
for _, prompt := range s.prompts {
@@ -1068,6 +1164,18 @@ func (s *MCPServer) handleGetPrompt(
10681164
id any,
10691165
request mcp.GetPromptRequest,
10701166
) (*mcp.GetPromptResult, *requestError) {
1167+
if s.GetPromptHandler != nil {
1168+
result, err := s.GetPromptHandler(ctx, request)
1169+
if err != nil {
1170+
return nil, &requestError{
1171+
id: id,
1172+
code: mcp.INTERNAL_ERROR,
1173+
err: err,
1174+
}
1175+
}
1176+
return result, nil
1177+
}
1178+
10711179
s.promptsMu.RLock()
10721180
handler, ok := s.promptHandlers[request.Params.Name]
10731181
s.promptsMu.RUnlock()
@@ -1097,6 +1205,17 @@ func (s *MCPServer) handleListTools(
10971205
id any,
10981206
request mcp.ListToolsRequest,
10991207
) (*mcp.ListToolsResult, *requestError) {
1208+
if s.ListToolsHandler != nil {
1209+
result, err := s.ListToolsHandler(ctx, request)
1210+
if err != nil {
1211+
return nil, &requestError{
1212+
id: id,
1213+
code: mcp.INTERNAL_ERROR,
1214+
err: err,
1215+
}
1216+
}
1217+
return result, nil
1218+
}
11001219
// Get the base tools from the server
11011220
s.toolsMu.RLock()
11021221
tools := make([]mcp.Tool, 0, len(s.tools))
@@ -1187,6 +1306,17 @@ func (s *MCPServer) handleToolCall(
11871306
id any,
11881307
request mcp.CallToolRequest,
11891308
) (*mcp.CallToolResult, *requestError) {
1309+
if s.CallToolHandler != nil {
1310+
result, err := s.CallToolHandler(ctx, request)
1311+
if err != nil {
1312+
return nil, &requestError{
1313+
id: id,
1314+
code: mcp.INTERNAL_ERROR,
1315+
err: err,
1316+
}
1317+
}
1318+
return result, nil
1319+
}
11901320
// First check session-specific tools
11911321
var tool ServerTool
11921322
var ok bool
@@ -1246,6 +1376,10 @@ func (s *MCPServer) handleNotification(
12461376
ctx context.Context,
12471377
notification mcp.JSONRPCNotification,
12481378
) mcp.JSONRPCMessage {
1379+
if s.NotificationHandler != nil {
1380+
s.NotificationHandler(ctx, notification)
1381+
return nil
1382+
}
12491383
s.notificationHandlersMu.RLock()
12501384
handler, ok := s.notificationHandlers[notification.Method]
12511385
s.notificationHandlersMu.RUnlock()

0 commit comments

Comments
 (0)