diff --git a/examples/memory/kb.go b/examples/memory/kb.go new file mode 100644 index 0000000..c3819f1 --- /dev/null +++ b/examples/memory/kb.go @@ -0,0 +1,676 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package main + +import ( + "context" + "encoding/json" + "fmt" + "os" + "slices" + "strings" + + "github.com/modelcontextprotocol/go-sdk/mcp" +) + +// store is an interface for knowledge base persistence. +type store interface { + Read() ([]byte, error) + Write(data []byte) error +} + +// memoryStore implements the store interface for in-memory persistence. +type memoryStore struct { + data []byte +} + +// Read reads data from the memory. If the data is empty, it returns an empty slice. +func (ms *memoryStore) Read() ([]byte, error) { + return ms.data, nil +} + +// Write writes data to the memory. +func (ms *memoryStore) Write(data []byte) error { + ms.data = data + return nil +} + +// fileStore implements the store interface for file-based persistence. +type fileStore struct { + path string +} + +// Read reads data from the file. If the file does not exist, it returns an empty slice. +func (fs *fileStore) Read() ([]byte, error) { + data, err := os.ReadFile(fs.path) + if err != nil { + if os.IsNotExist(err) { + return []byte{}, nil + } + return nil, fmt.Errorf("failed to read file %s: %w", fs.path, err) + } + return data, nil +} + +// Write writes data to the file. +func (fs *fileStore) Write(data []byte) error { + if err := os.WriteFile(fs.path, data, 0600); err != nil { + return fmt.Errorf("failed to write file %s: %w", fs.path, err) + } + return nil +} + +type knowledgeBase struct { + s store +} + +type kbItem struct { + Type string `json:"type"` + + // For Type == "entity" + Name string `json:"name,omitempty"` + EntityType string `json:"entityType,omitempty"` + Observations []string `json:"observations,omitempty"` + + // For Type == "relation" + From string `json:"from,omitempty"` + To string `json:"to,omitempty"` + RelationType string `json:"relationType,omitempty"` +} + +func (k knowledgeBase) loadGraph() (KnowledgeGraph, error) { + data, err := k.s.Read() + if err != nil { + return KnowledgeGraph{}, fmt.Errorf("failed to read from store: %w", err) + } + + if len(data) == 0 { + return KnowledgeGraph{}, nil + } + + var items []kbItem + if err := json.Unmarshal(data, &items); err != nil { + return KnowledgeGraph{}, fmt.Errorf("failed to unmarshal from store: %w", err) + } + + graph := KnowledgeGraph{ + Entities: []Entity{}, + Relations: []Relation{}, + } + + for _, item := range items { + switch item.Type { + case "entity": + graph.Entities = append(graph.Entities, Entity{ + Name: item.Name, + EntityType: item.EntityType, + Observations: item.Observations, + }) + case "relation": + graph.Relations = append(graph.Relations, Relation{ + From: item.From, + To: item.To, + RelationType: item.RelationType, + }) + } + } + + return graph, nil +} + +func (k knowledgeBase) saveGraph(graph KnowledgeGraph) error { + items := make([]kbItem, 0, len(graph.Entities)+len(graph.Relations)) + + for _, entity := range graph.Entities { + items = append(items, kbItem{ + Type: "entity", + Name: entity.Name, + EntityType: entity.EntityType, + Observations: entity.Observations, + }) + } + + for _, relation := range graph.Relations { + items = append(items, kbItem{ + Type: "relation", + From: relation.From, + To: relation.To, + RelationType: relation.RelationType, + }) + } + + itemsJSON, err := json.Marshal(items) + if err != nil { + return fmt.Errorf("failed to marshal items: %w", err) + } + + if err := k.s.Write(itemsJSON); err != nil { + return fmt.Errorf("failed to write to store: %w", err) + } + return nil +} + +func (k knowledgeBase) createEntities(entities []Entity) ([]Entity, error) { + graph, err := k.loadGraph() + if err != nil { + return nil, err + } + + var newEntities []Entity + for _, entity := range entities { + exists := false + for _, existingEntity := range graph.Entities { + if existingEntity.Name == entity.Name { + exists = true + break + } + } + + if !exists { + newEntities = append(newEntities, entity) + graph.Entities = append(graph.Entities, entity) + } + } + + if err := k.saveGraph(graph); err != nil { + return nil, err + } + + return newEntities, nil +} + +func (k knowledgeBase) createRelations(relations []Relation) ([]Relation, error) { + graph, err := k.loadGraph() + if err != nil { + return nil, err + } + + var newRelations []Relation + for _, relation := range relations { + exists := false + for _, existingRelation := range graph.Relations { + if existingRelation.From == relation.From && + existingRelation.To == relation.To && + existingRelation.RelationType == relation.RelationType { + exists = true + break + } + } + + if !exists { + newRelations = append(newRelations, relation) + graph.Relations = append(graph.Relations, relation) + } + } + + if err := k.saveGraph(graph); err != nil { + return nil, err + } + + return newRelations, nil +} + +func (k knowledgeBase) addObservations(observations []Observation) ([]Observation, error) { + graph, err := k.loadGraph() + if err != nil { + return nil, err + } + + var results []Observation + + for _, obs := range observations { + entityIndex := -1 + for i, entity := range graph.Entities { + if entity.Name == obs.EntityName { + entityIndex = i + break + } + } + + if entityIndex == -1 { + return nil, fmt.Errorf("entity with name %s not found", obs.EntityName) + } + + var newObservations []string + for _, content := range obs.Contents { + exists := slices.Contains(graph.Entities[entityIndex].Observations, content) + + if !exists { + newObservations = append(newObservations, content) + graph.Entities[entityIndex].Observations = append(graph.Entities[entityIndex].Observations, content) + } + } + + results = append(results, Observation{ + EntityName: obs.EntityName, + Contents: newObservations, + }) + } + + if err := k.saveGraph(graph); err != nil { + return nil, err + } + + return results, nil +} + +func (k knowledgeBase) deleteEntities(entityNames []string) error { + graph, err := k.loadGraph() + if err != nil { + return err + } + + // Create map for quick lookup + entitiesToDelete := make(map[string]bool) + for _, name := range entityNames { + entitiesToDelete[name] = true + } + + // Filter entities + var filteredEntities []Entity + for _, entity := range graph.Entities { + if !entitiesToDelete[entity.Name] { + filteredEntities = append(filteredEntities, entity) + } + } + graph.Entities = filteredEntities + + // Filter relations + var filteredRelations []Relation + for _, relation := range graph.Relations { + if !entitiesToDelete[relation.From] && !entitiesToDelete[relation.To] { + filteredRelations = append(filteredRelations, relation) + } + } + graph.Relations = filteredRelations + + return k.saveGraph(graph) +} + +func (k knowledgeBase) deleteObservations(deletions []Observation) error { + graph, err := k.loadGraph() + if err != nil { + return err + } + + for _, deletion := range deletions { + for i, entity := range graph.Entities { + if entity.Name == deletion.EntityName { + // Create a map for quick lookup + observationsToDelete := make(map[string]bool) + for _, observation := range deletion.Observations { + observationsToDelete[observation] = true + } + + // Filter observations + var filteredObservations []string + for _, observation := range entity.Observations { + if !observationsToDelete[observation] { + filteredObservations = append(filteredObservations, observation) + } + } + + graph.Entities[i].Observations = filteredObservations + break + } + } + } + + return k.saveGraph(graph) +} + +func (k knowledgeBase) deleteRelations(relations []Relation) error { + graph, err := k.loadGraph() + if err != nil { + return err + } + + var filteredRelations []Relation + for _, existingRelation := range graph.Relations { + shouldKeep := true + + for _, relationToDelete := range relations { + if existingRelation.From == relationToDelete.From && + existingRelation.To == relationToDelete.To && + existingRelation.RelationType == relationToDelete.RelationType { + shouldKeep = false + break + } + } + + if shouldKeep { + filteredRelations = append(filteredRelations, existingRelation) + } + } + + graph.Relations = filteredRelations + return k.saveGraph(graph) +} + +func (k knowledgeBase) readGraph() (KnowledgeGraph, error) { + return k.loadGraph() +} + +func (k knowledgeBase) searchNodes(query string) (KnowledgeGraph, error) { + graph, err := k.loadGraph() + if err != nil { + return KnowledgeGraph{}, err + } + + queryLower := strings.ToLower(query) + var filteredEntities []Entity + + // Filter entities + for _, entity := range graph.Entities { + if strings.Contains(strings.ToLower(entity.Name), queryLower) || + strings.Contains(strings.ToLower(entity.EntityType), queryLower) { + filteredEntities = append(filteredEntities, entity) + continue + } + + // Check observations + for _, observation := range entity.Observations { + if strings.Contains(strings.ToLower(observation), queryLower) { + filteredEntities = append(filteredEntities, entity) + break + } + } + } + + // Create map for quick entity lookup + filteredEntityNames := make(map[string]bool) + for _, entity := range filteredEntities { + filteredEntityNames[entity.Name] = true + } + + // Filter relations + var filteredRelations []Relation + for _, relation := range graph.Relations { + if filteredEntityNames[relation.From] && filteredEntityNames[relation.To] { + filteredRelations = append(filteredRelations, relation) + } + } + + return KnowledgeGraph{ + Entities: filteredEntities, + Relations: filteredRelations, + }, nil +} + +func (k knowledgeBase) openNodes(names []string) (KnowledgeGraph, error) { + graph, err := k.loadGraph() + if err != nil { + return KnowledgeGraph{}, err + } + + // Create map for quick name lookup + nameSet := make(map[string]bool) + for _, name := range names { + nameSet[name] = true + } + + // Filter entities + var filteredEntities []Entity + for _, entity := range graph.Entities { + if nameSet[entity.Name] { + filteredEntities = append(filteredEntities, entity) + } + } + + // Create map for quick entity lookup + filteredEntityNames := make(map[string]bool) + for _, entity := range filteredEntities { + filteredEntityNames[entity.Name] = true + } + + // Filter relations + var filteredRelations []Relation + for _, relation := range graph.Relations { + if filteredEntityNames[relation.From] && filteredEntityNames[relation.To] { + filteredRelations = append(filteredRelations, relation) + } + } + + return KnowledgeGraph{ + Entities: filteredEntities, + Relations: filteredRelations, + }, nil +} + +func (k knowledgeBase) CreateEntities(ctx context.Context, ss *mcp.ServerSession, params *mcp.CallToolParamsFor[CreateEntitiesArgs]) (*mcp.CallToolResultFor[CreateEntitiesResult], error) { + var res mcp.CallToolResultFor[CreateEntitiesResult] + + entities, err := k.createEntities(params.Arguments.Entities) + if err != nil { + res.IsError = true + res.Content = []mcp.Content{ + &mcp.TextContent{Text: err.Error()}, + } + return &res, nil + } + + // I think marshalling the entities and pass it as a content should not be necessary, but as for now, it looks like + // the StructuredContent is not being unmarshalled in CallToolResultFor. + entitiesJSON, err := json.Marshal(entities) + if err != nil { + res.IsError = true + res.Content = []mcp.Content{ + &mcp.TextContent{Text: err.Error()}, + } + return &res, nil + } + res.Content = []mcp.Content{ + &mcp.TextContent{Text: string(entitiesJSON)}, + } + + res.StructuredContent = CreateEntitiesResult{ + Entities: entities, + } + + return &res, nil +} + +func (k knowledgeBase) CreateRelations(ctx context.Context, ss *mcp.ServerSession, params *mcp.CallToolParamsFor[CreateRelationsArgs]) (*mcp.CallToolResultFor[CreateRelationsResult], error) { + var res mcp.CallToolResultFor[CreateRelationsResult] + + relations, err := k.createRelations(params.Arguments.Relations) + if err != nil { + res.IsError = true + res.Content = []mcp.Content{ + &mcp.TextContent{Text: err.Error()}, + } + return &res, nil + } + + relationsJSON, err := json.Marshal(relations) + if err != nil { + res.IsError = true + res.Content = []mcp.Content{ + &mcp.TextContent{Text: err.Error()}, + } + return &res, nil + } + res.Content = []mcp.Content{ + &mcp.TextContent{Text: string(relationsJSON)}, + } + + res.StructuredContent = CreateRelationsResult{ + Relations: relations, + } + + return &res, nil +} + +func (k knowledgeBase) AddObservations(ctx context.Context, ss *mcp.ServerSession, params *mcp.CallToolParamsFor[AddObservationsArgs]) (*mcp.CallToolResultFor[AddObservationsResult], error) { + var res mcp.CallToolResultFor[AddObservationsResult] + + observations, err := k.addObservations(params.Arguments.Observations) + if err != nil { + res.IsError = true + res.Content = []mcp.Content{ + &mcp.TextContent{Text: err.Error()}, + } + return &res, nil + } + + observationsJSON, err := json.Marshal(observations) + if err != nil { + res.IsError = true + res.Content = []mcp.Content{ + &mcp.TextContent{Text: err.Error()}, + } + return &res, nil + } + res.Content = []mcp.Content{ + &mcp.TextContent{Text: string(observationsJSON)}, + } + + res.StructuredContent = AddObservationsResult{ + Observations: observations, + } + + return &res, nil +} + +func (k knowledgeBase) DeleteEntities(ctx context.Context, ss *mcp.ServerSession, params *mcp.CallToolParamsFor[DeleteEntitiesArgs]) (*mcp.CallToolResultFor[struct{}], error) { + var res mcp.CallToolResultFor[struct{}] + + err := k.deleteEntities(params.Arguments.EntityNames) + if err != nil { + res.IsError = true + res.Content = []mcp.Content{ + &mcp.TextContent{Text: err.Error()}, + } + return &res, nil + } + + res.Content = []mcp.Content{ + &mcp.TextContent{Text: "Entities deleted successfully"}, + } + + return &res, nil +} + +func (k knowledgeBase) DeleteObservations(ctx context.Context, ss *mcp.ServerSession, params *mcp.CallToolParamsFor[DeleteObservationsArgs]) (*mcp.CallToolResultFor[struct{}], error) { + var res mcp.CallToolResultFor[struct{}] + + err := k.deleteObservations(params.Arguments.Deletions) + if err != nil { + res.IsError = true + res.Content = []mcp.Content{ + &mcp.TextContent{Text: err.Error()}, + } + return &res, nil + } + + res.Content = []mcp.Content{ + &mcp.TextContent{Text: "Observations deleted successfully"}, + } + + return &res, nil +} + +func (k knowledgeBase) DeleteRelations(ctx context.Context, ss *mcp.ServerSession, params *mcp.CallToolParamsFor[DeleteRelationsArgs]) (*mcp.CallToolResultFor[struct{}], error) { + var res mcp.CallToolResultFor[struct{}] + + err := k.deleteRelations(params.Arguments.Relations) + if err != nil { + res.IsError = true + res.Content = []mcp.Content{ + &mcp.TextContent{Text: err.Error()}, + } + return &res, nil + } + + res.Content = []mcp.Content{ + &mcp.TextContent{Text: "Relations deleted successfully"}, + } + + return &res, nil +} + +func (k knowledgeBase) ReadGraph(ctx context.Context, ss *mcp.ServerSession, params *mcp.CallToolParamsFor[struct{}]) (*mcp.CallToolResultFor[KnowledgeGraph], error) { + var res mcp.CallToolResultFor[KnowledgeGraph] + + graph, err := k.readGraph() + if err != nil { + res.IsError = true + res.Content = []mcp.Content{ + &mcp.TextContent{Text: err.Error()}, + } + return &res, nil + } + + graphJSON, err := json.Marshal(graph) + if err != nil { + res.IsError = true + res.Content = []mcp.Content{ + &mcp.TextContent{Text: err.Error()}, + } + return &res, nil + } + res.Content = []mcp.Content{ + &mcp.TextContent{Text: string(graphJSON)}, + } + + res.StructuredContent = graph + return &res, nil +} + +func (k knowledgeBase) SearchNodes(ctx context.Context, ss *mcp.ServerSession, params *mcp.CallToolParamsFor[SearchNodesArgs]) (*mcp.CallToolResultFor[KnowledgeGraph], error) { + var res mcp.CallToolResultFor[KnowledgeGraph] + + graph, err := k.readGraph() + if err != nil { + res.IsError = true + res.Content = []mcp.Content{ + &mcp.TextContent{Text: err.Error()}, + } + return &res, nil + } + + graphJSON, err := json.Marshal(graph) + if err != nil { + res.IsError = true + res.Content = []mcp.Content{ + &mcp.TextContent{Text: err.Error()}, + } + return &res, nil + } + res.Content = []mcp.Content{ + &mcp.TextContent{Text: string(graphJSON)}, + } + + res.StructuredContent = graph + return &res, nil +} + +func (k knowledgeBase) OpenNodes(ctx context.Context, ss *mcp.ServerSession, params *mcp.CallToolParamsFor[OpenNodesArgs]) (*mcp.CallToolResultFor[KnowledgeGraph], error) { + var res mcp.CallToolResultFor[KnowledgeGraph] + + graph, err := k.openNodes(params.Arguments.Names) + if err != nil { + res.IsError = true + res.Content = []mcp.Content{ + &mcp.TextContent{Text: err.Error()}, + } + return &res, nil + } + + graphJSON, err := json.Marshal(graph) + if err != nil { + res.IsError = true + res.Content = []mcp.Content{ + &mcp.TextContent{Text: err.Error()}, + } + return &res, nil + } + res.Content = []mcp.Content{ + &mcp.TextContent{Text: string(graphJSON)}, + } + + res.StructuredContent = graph + return &res, nil +} diff --git a/examples/memory/kb_test.go b/examples/memory/kb_test.go new file mode 100644 index 0000000..025305e --- /dev/null +++ b/examples/memory/kb_test.go @@ -0,0 +1,430 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package main + +import ( + "encoding/json" + "fmt" + "os" + "path/filepath" + "reflect" + "testing" +) + +// getStoreFactories returns a map of store factory functions for testing. +// Each factory provides a fresh store instance, ensuring test isolation. +func getStoreFactories() map[string]func(t *testing.T) store { + return map[string]func(t *testing.T) store{ + "file": func(t *testing.T) store { + tempDir, err := os.MkdirTemp("", "kb-test-file-*") + if err != nil { + t.Fatalf("failed to create temp dir: %v", err) + } + t.Cleanup(func() { os.RemoveAll(tempDir) }) + return &fileStore{path: filepath.Join(tempDir, "test-memory.json")} + }, + "memory": func(t *testing.T) store { + return &memoryStore{} + }, + } +} + +func TestKnowledgeBaseOperations(t *testing.T) { + factories := getStoreFactories() + + for name, factory := range factories { + t.Run(name, func(t *testing.T) { + s := factory(t) + kb := knowledgeBase{s: s} + + // Test empty graph + graph, err := kb.loadGraph() + if err != nil { + t.Fatalf("failed to load empty graph: %v", err) + } + if len(graph.Entities) != 0 || len(graph.Relations) != 0 { + t.Errorf("expected empty graph, got %+v", graph) + } + + // Test creating entities + testEntities := []Entity{ + { + Name: "Alice", + EntityType: "Person", + Observations: []string{"Likes coffee"}, + }, + { + Name: "Bob", + EntityType: "Person", + Observations: []string{"Likes tea"}, + }, + } + + createdEntities, err := kb.createEntities(testEntities) + if err != nil { + t.Fatalf("failed to create entities: %v", err) + } + if len(createdEntities) != 2 { + t.Errorf("expected 2 created entities, got %d", len(createdEntities)) + } + + // Test reading graph + graph, err = kb.readGraph() + if err != nil { + t.Fatalf("failed to read graph: %v", err) + } + if len(graph.Entities) != 2 { + t.Errorf("expected 2 entities, got %d", len(graph.Entities)) + } + + // Test creating relations + testRelations := []Relation{ + { + From: "Alice", + To: "Bob", + RelationType: "friend", + }, + } + + createdRelations, err := kb.createRelations(testRelations) + if err != nil { + t.Fatalf("failed to create relations: %v", err) + } + if len(createdRelations) != 1 { + t.Errorf("expected 1 created relation, got %d", len(createdRelations)) + } + + // Test adding observations + testObservations := []Observation{ + { + EntityName: "Alice", + Contents: []string{"Works as developer", "Lives in New York"}, + }, + } + + addedObservations, err := kb.addObservations(testObservations) + if err != nil { + t.Fatalf("failed to add observations: %v", err) + } + if len(addedObservations) != 1 || len(addedObservations[0].Contents) != 2 { + t.Errorf("expected 1 observation with 2 contents, got %+v", addedObservations) + } + + // Test searching nodes + searchResult, err := kb.searchNodes("developer") + if err != nil { + t.Fatalf("failed to search nodes: %v", err) + } + if len(searchResult.Entities) != 1 || searchResult.Entities[0].Name != "Alice" { + t.Errorf("expected to find Alice when searching for 'developer', got %+v", searchResult) + } + + // Test opening specific nodes + openResult, err := kb.openNodes([]string{"Bob"}) + if err != nil { + t.Fatalf("failed to open nodes: %v", err) + } + if len(openResult.Entities) != 1 || openResult.Entities[0].Name != "Bob" { + t.Errorf("expected to find Bob when opening 'Bob', got %+v", openResult) + } + + // Test deleting observations + deleteObs := []Observation{ + { + EntityName: "Alice", + Observations: []string{"Works as developer"}, + }, + } + err = kb.deleteObservations(deleteObs) + if err != nil { + t.Fatalf("failed to delete observations: %v", err) + } + + // Verify observation was deleted + graph, _ = kb.readGraph() + aliceFound := false + for _, e := range graph.Entities { + if e.Name == "Alice" { + aliceFound = true + for _, obs := range e.Observations { + if obs == "Works as developer" { + t.Errorf("observation 'Works as developer' should have been deleted") + } + } + } + } + if !aliceFound { + t.Errorf("entity 'Alice' not found after deleting observation") + } + + // Test deleting relations + err = kb.deleteRelations(testRelations) + if err != nil { + t.Fatalf("failed to delete relations: %v", err) + } + + // Verify relation was deleted + graph, _ = kb.readGraph() + if len(graph.Relations) != 0 { + t.Errorf("expected 0 relations after deletion, got %d", len(graph.Relations)) + } + + // Test deleting entities + err = kb.deleteEntities([]string{"Alice"}) + if err != nil { + t.Fatalf("failed to delete entities: %v", err) + } + + // Verify entity was deleted + graph, _ = kb.readGraph() + if len(graph.Entities) != 1 || graph.Entities[0].Name != "Bob" { + t.Errorf("expected only Bob to remain after deleting Alice, got %+v", graph.Entities) + } + }) + } +} + +func TestSaveAndLoadGraph(t *testing.T) { + factories := getStoreFactories() + + for name, factory := range factories { + t.Run(name, func(t *testing.T) { + s := factory(t) + kb := knowledgeBase{s: s} + + // Create test graph + testGraph := KnowledgeGraph{ + Entities: []Entity{ + { + Name: "Charlie", + EntityType: "Person", + Observations: []string{"Likes hiking"}, + }, + }, + Relations: []Relation{ + { + From: "Charlie", + To: "Mountains", + RelationType: "enjoys", + }, + }, + } + + // Save graph + err := kb.saveGraph(testGraph) + if err != nil { + t.Fatalf("failed to save graph: %v", err) + } + + // Load graph + loadedGraph, err := kb.loadGraph() + if err != nil { + t.Fatalf("failed to load graph: %v", err) + } + + // Check if loaded graph matches test graph + if !reflect.DeepEqual(testGraph, loadedGraph) { + t.Errorf("loaded graph does not match saved graph.\nExpected: %+v\nGot: %+v", testGraph, loadedGraph) + } + + // Test invalid JSON - specific to fileStore + if fs, ok := s.(*fileStore); ok { + err := os.WriteFile(fs.path, []byte("invalid json"), 0600) + if err != nil { + t.Fatalf("failed to write invalid json: %v", err) + } + + _, err = kb.loadGraph() + if err == nil { + t.Errorf("expected error when loading invalid JSON, got nil") + } + } + }) + } +} + +func TestDuplicateEntitiesAndRelations(t *testing.T) { + factories := getStoreFactories() + + for name, factory := range factories { + t.Run(name, func(t *testing.T) { + s := factory(t) + kb := knowledgeBase{s: s} + + // Create initial entities + initialEntities := []Entity{ + { + Name: "Dave", + EntityType: "Person", + Observations: []string{"Plays guitar"}, + }, + } + + _, err := kb.createEntities(initialEntities) + if err != nil { + t.Fatalf("failed to create initial entities: %v", err) + } + + // Try to create duplicate entities + duplicateEntities := []Entity{ + { + Name: "Dave", + EntityType: "Person", + Observations: []string{"Sings well"}, + }, + { + Name: "Eve", + EntityType: "Person", + Observations: []string{"Plays piano"}, + }, + } + + newEntities, err := kb.createEntities(duplicateEntities) + if err != nil { + t.Fatalf("failed when adding duplicate entities: %v", err) + } + + // Should only create Eve, not Dave (duplicate) + if len(newEntities) != 1 || newEntities[0].Name != "Eve" { + t.Errorf("expected only 'Eve' to be created, got %+v", newEntities) + } + + // Create initial relation + initialRelation := []Relation{ + { + From: "Dave", + To: "Eve", + RelationType: "friend", + }, + } + + _, err = kb.createRelations(initialRelation) + if err != nil { + t.Fatalf("failed to create initial relation: %v", err) + } + + // Try to create duplicate relation + duplicateRelations := []Relation{ + { + From: "Dave", + To: "Eve", + RelationType: "friend", + }, + { + From: "Eve", + To: "Dave", + RelationType: "friend", + }, + } + + newRelations, err := kb.createRelations(duplicateRelations) + if err != nil { + t.Fatalf("failed when adding duplicate relations: %v", err) + } + + // Should only create the Eve->Dave relation, not Dave->Eve (duplicate) + if len(newRelations) != 1 || newRelations[0].From != "Eve" || newRelations[0].To != "Dave" { + t.Errorf("expected only 'Eve->Dave' relation to be created, got %+v", newRelations) + } + }) + } +} + +func TestErrorHandling(t *testing.T) { + t.Run("FileStoreWriteError", func(t *testing.T) { + // Test with non-existent directory, specific to fileStore + kb := knowledgeBase{ + s: &fileStore{path: filepath.Join("nonexistent", "directory", "file.json")}, + } + + testEntities := []Entity{ + {Name: "TestEntity"}, + } + + _, err := kb.createEntities(testEntities) + if err == nil { + t.Errorf("expected error when writing to non-existent directory, got nil") + } + }) + + factories := getStoreFactories() + for name, factory := range factories { + t.Run(fmt.Sprintf("AddObservationToNonExistentEntity_%s", name), func(t *testing.T) { + s := factory(t) + kb := knowledgeBase{s: s} + + // Create a test entity first + _, err := kb.createEntities([]Entity{{Name: "RealEntity"}}) + if err != nil { + t.Fatalf("failed to create test entity: %v", err) + } + + // Try to add observation to non-existent entity + nonExistentObs := []Observation{ + { + EntityName: "NonExistentEntity", + Contents: []string{"This shouldn't work"}, + }, + } + + _, err = kb.addObservations(nonExistentObs) + if err == nil { + t.Errorf("expected error when adding observations to non-existent entity, got nil") + } + }) + } +} + +func TestFileFormatting(t *testing.T) { + factories := getStoreFactories() + + for name, factory := range factories { + t.Run(name, func(t *testing.T) { + s := factory(t) + kb := knowledgeBase{s: s} + + // Create entities + testEntities := []Entity{ + { + Name: "FileTest", + EntityType: "TestEntity", + Observations: []string{"Test observation"}, + }, + } + + _, err := kb.createEntities(testEntities) + if err != nil { + t.Fatalf("failed to create test entity: %v", err) + } + + // Read data from the store interface + data, err := s.Read() + if err != nil { + t.Fatalf("failed to read from store: %v", err) + } + + // Parse JSON to verify structure + var items []kbItem + err = json.Unmarshal(data, &items) + if err != nil { + t.Fatalf("failed to parse store data JSON: %v", err) + } + + // Verify format + if len(items) != 1 { + t.Fatalf("expected 1 item in memory file, got %d", len(items)) + } + + item := items[0] + if item.Type != "entity" || + item.Name != "FileTest" || + item.EntityType != "TestEntity" || + len(item.Observations) != 1 || + item.Observations[0] != "Test observation" { + t.Errorf("store item format incorrect: %+v", item) + } + }) + } +} diff --git a/examples/memory/main.go b/examples/memory/main.go new file mode 100644 index 0000000..6de91a1 --- /dev/null +++ b/examples/memory/main.go @@ -0,0 +1,151 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package main + +import ( + "context" + "flag" + "log" + "net/http" + "os" + + "github.com/modelcontextprotocol/go-sdk/mcp" +) + +var ( + httpAddr = flag.String("http", "", "if set, use streamable HTTP at this address, instead of stdin/stdout") + memoryFilePath = flag.String("memory", "", "If set, persist the knowledge base to this file; otherwise, it will be stored in memory and lost on exit.") +) + +type HiArgs struct { + Name string `json:"name"` +} + +type Entity struct { + Name string `json:"name"` + EntityType string `json:"entityType"` + Observations []string `json:"observations"` +} + +type Relation struct { + From string `json:"from"` + To string `json:"to"` + RelationType string `json:"relationType"` +} + +type Observation struct { + EntityName string `json:"entityName"` + Contents []string `json:"contents"` + + Observations []string `json:"observations,omitempty"` // For deletions. +} + +type KnowledgeGraph struct { + Entities []Entity `json:"entities"` + Relations []Relation `json:"relations"` +} + +type CreateEntitiesArgs struct { + Entities []Entity `json:"entities"` +} + +type CreateEntitiesResult struct { + Entities []Entity `json:"entities"` +} + +type CreateRelationsArgs struct { + Relations []Relation `json:"relations"` +} + +type CreateRelationsResult struct { + Relations []Relation `json:"relations"` +} + +type AddObservationsArgs struct { + Observations []Observation `json:"observations"` +} + +type AddObservationsResult struct { + Observations []Observation `json:"observations"` +} + +type DeleteEntitiesArgs struct { + EntityNames []string `json:"entityNames"` +} + +type DeleteObservationsArgs struct { + Deletions []Observation `json:"deletions"` +} + +type DeleteRelationsArgs struct { + Relations []Relation `json:"relations"` +} + +type SearchNodesArgs struct { + Query string `json:"query"` +} + +type OpenNodesArgs struct { + Names []string `json:"names"` +} + +func SayHi(ctx context.Context, ss *mcp.ServerSession, params *mcp.CallToolParamsFor[HiArgs]) (*mcp.CallToolResultFor[struct{}], error) { + return &mcp.CallToolResultFor[struct{}]{ + Content: []mcp.Content{ + &mcp.TextContent{Text: "Hi " + params.Arguments.Name}, + }, + }, nil +} + +func main() { + flag.Parse() + + var kbStore store + kbStore = &memoryStore{} + if *memoryFilePath != "" { + kbStore = &fileStore{path: *memoryFilePath} + } + kb := knowledgeBase{s: kbStore} + + server := mcp.NewServer("memory", "v0.0.1", nil) + server.AddTools(mcp.NewServerTool("create_entities", "Create multiple new entities in the knowledge graph", kb.CreateEntities, mcp.Input( + mcp.Property("entities", mcp.Description("Entities to create")), + ))) + server.AddTools(mcp.NewServerTool("create_relations", "Create multiple new relations between entities", kb.CreateRelations, mcp.Input( + mcp.Property("relations", mcp.Description("Relations to create")), + ))) + server.AddTools(mcp.NewServerTool("add_observations", "Add new observations to existing entities", kb.AddObservations, mcp.Input( + mcp.Property("observations", mcp.Description("Observations to add")), + ))) + server.AddTools(mcp.NewServerTool("delete_entities", "Remove entities and their relations", kb.DeleteEntities, mcp.Input( + mcp.Property("entityNames", mcp.Description("Names of entities to delete")), + ))) + server.AddTools(mcp.NewServerTool("delete_observations", "Remove specific observations from entities", kb.DeleteObservations, mcp.Input( + mcp.Property("deletions", mcp.Description("Observations to delete")), + ))) + server.AddTools(mcp.NewServerTool("delete_relations", "Remove specific relations from the graph", kb.DeleteRelations, mcp.Input( + mcp.Property("relations", mcp.Description("Relations to delete")), + ))) + server.AddTools(mcp.NewServerTool("read_graph", "Read the entire knowledge graph", kb.ReadGraph)) + server.AddTools(mcp.NewServerTool("search_nodes", "Search for nodes based on query", kb.SearchNodes, mcp.Input( + mcp.Property("query", mcp.Description("Query string")), + ))) + server.AddTools(mcp.NewServerTool("open_nodes", "Retrieve specific nodes by name", kb.OpenNodes, mcp.Input( + mcp.Property("names", mcp.Description("Names of nodes to open")), + ))) + + if *httpAddr != "" { + handler := mcp.NewStreamableHTTPHandler(func(*http.Request) *mcp.Server { + return server + }, nil) + log.Printf("MCP handler listening at %s", *httpAddr) + http.ListenAndServe(*httpAddr, handler) + } else { + t := mcp.NewLoggingTransport(mcp.NewStdioTransport(), os.Stderr) + if err := server.Run(context.Background(), t); err != nil { + log.Printf("Server failed: %v", err) + } + } +}