diff --git a/a2aclient/rest.go b/a2aclient/rest.go index 2414d777..dfbed816 100644 --- a/a2aclient/rest.go +++ b/a2aclient/rest.go @@ -179,7 +179,7 @@ func (t *RESTTransport) GetTask(ctx context.Context, params ServiceParams, req * // ListTasks retrieves a list of tasks. func (t *RESTTransport) ListTasks(ctx context.Context, params ServiceParams, req *a2a.ListTasksRequest) (*a2a.ListTasksResponse, error) { - path := rest.MakeTasksListPath() + path := rest.MakeListTasksPath() query := url.Values{} if req.ContextID != "" { diff --git a/a2aext/activator_test.go b/a2aext/activator_test.go index d640af2d..7688f5f9 100644 --- a/a2aext/activator_test.go +++ b/a2aext/activator_test.go @@ -100,8 +100,8 @@ func TestActivator(t *testing.T) { t.Fatalf("a2aclient.NewFromEndpoints() error = %v", err) } - _, err = client.SendMessage(ctx, &a2a.MessageSendParams{ - Message: a2a.NewMessage(a2a.MessageRoleUser, a2a.TextPart{Text: "verify extensions"}), + _, err = client.SendMessage(ctx, &a2a.SendMessageRequest{ + Message: a2a.NewMessage(a2a.MessageRoleUser, a2a.NewTextPart("verify extensions")), }) if err != nil { t.Fatalf("client.SendMessage() error = %v", err) @@ -123,15 +123,16 @@ func startServerWithExtensions(t *testing.T, executor a2asrv.AgentExecutor, exte for _, uri := range extensionURIs { extensions = append(extensions, a2a.AgentExtension{URI: uri}) } - card := &a2a.AgentCard{ - Capabilities: a2a.AgentCapabilities{ - Extensions: extensions, - }, - } + reqHandler := a2asrv.NewHandler(executor) server := httptest.NewServer(a2asrv.NewJSONRPCHandler(reqHandler)) - card.URL = server.URL - card.PreferredTransport = a2a.TransportProtocolJSONRPC t.Cleanup(server.Close) + + card := &a2a.AgentCard{ + Capabilities: a2a.AgentCapabilities{Extensions: extensions}, + SupportedInterfaces: []a2a.AgentInterface{ + {URL: server.URL, ProtocolBinding: a2a.TransportProtocolJSONRPC}, + }, + } return card } diff --git a/a2aext/propagator_test.go b/a2aext/propagator_test.go index 0bf7a6dc..fc39a593 100644 --- a/a2aext/propagator_test.go +++ b/a2aext/propagator_test.go @@ -129,8 +129,8 @@ func TestTripleHopPropagation(t *testing.T) { t.Fatalf("a2aclient.NewFromEndpoints() error = %v", err) } - resp, err := client.SendMessage(ctx, &a2a.MessageSendParams{ - Message: a2a.NewMessage(a2a.MessageRoleUser, a2a.TextPart{Text: "Hi!"}), + resp, err := client.SendMessage(ctx, &a2a.SendMessageRequest{ + Message: a2a.NewMessage(a2a.MessageRoleUser, a2a.NewTextPart("Hi!")), Metadata: tc.clientSvcParams, }) if err != nil { @@ -233,8 +233,8 @@ func TestDefaultPropagation(t *testing.T) { t.Fatalf("a2aclient.NewFromEndpoints() error = %v", err) } - resp, err := client.SendMessage(ctx, &a2a.MessageSendParams{ - Message: a2a.NewMessage(a2a.MessageRoleUser, a2a.TextPart{Text: "Hi!"}), + resp, err := client.SendMessage(ctx, &a2a.SendMessageRequest{ + Message: a2a.NewMessage(a2a.MessageRoleUser, a2a.NewTextPart("Hi!")), Metadata: tc.clientSvcParams, }) if err != nil { @@ -260,7 +260,7 @@ func startServer(t *testing.T, interceptor a2asrv.CallInterceptor, executor a2as reqHandler := a2asrv.NewHandler(executor, a2asrv.WithCallInterceptors(interceptor)) server := httptest.NewServer(a2asrv.NewJSONRPCHandler(reqHandler)) t.Cleanup(server.Close) - return a2a.AgentInterface{URL: server.URL, Transport: a2a.TransportProtocolJSONRPC} + return a2a.AgentInterface{URL: server.URL, ProtocolBinding: a2a.TransportProtocolJSONRPC} } func newAgentCard(endpoint a2a.AgentInterface, extensionURIs []string) *a2a.AgentCard { @@ -269,10 +269,9 @@ func newAgentCard(endpoint a2a.AgentInterface, extensionURIs []string) *a2a.Agen extensions[i] = a2a.AgentExtension{URI: uri} } return &a2a.AgentCard{ - URL: endpoint.URL, - PreferredTransport: endpoint.Transport, - Capabilities: a2a.AgentCapabilities{ - Extensions: extensions, + Capabilities: a2a.AgentCapabilities{Extensions: extensions}, + SupportedInterfaces: []a2a.AgentInterface{ + {URL: endpoint.URL, ProtocolBinding: endpoint.ProtocolBinding}, }, } } @@ -300,7 +299,7 @@ func newProxyExecutor(interceptor a2aclient.CallInterceptor, target proxyTarget) yield(nil, err) return } - result, err := client.SendMessage(ctx, &a2a.MessageSendParams{ + result, err := client.SendMessage(ctx, &a2a.SendMessageRequest{ Message: a2a.NewMessage(a2a.MessageRoleUser, execCtx.Message.Parts...), }) if err != nil { diff --git a/a2asrv/agentexec.go b/a2asrv/agentexec.go index 1bbc7c9b..28a3bb36 100644 --- a/a2asrv/agentexec.go +++ b/a2asrv/agentexec.go @@ -34,8 +34,7 @@ import ( // For streaming responses [a2a.TaskArtifactUpdatEvent]-s should be used. // A2A server stops processing events after one of these events: // - An [a2a.Message] with any payload. -// - An [a2a.TaskStatusUpdateEvent] with Final field set to true. -// - An [a2a.Task] with a [a2a.TaskState] for which Terminal() method returns true. +// - An [a2a.Task] or [a2a.TaskStatusUpdateEvent] with a [a2a.TaskState] for which Terminal() method returns true or it is TaskStateInputRequired. // // The following code can be used as a streaming implementation template with generateOutputs and toParts missing: // @@ -78,7 +77,6 @@ import ( // } // // event = a2a.NewStatusUpdateEvent(execCtx, a2a.TaskStateCompleted, nil) -// event.Final = true // if err := queue.Write(ctx, event); err != nil { // return fmt.Errorf("failed to write state working: %w", err) // } @@ -113,7 +111,7 @@ type factory struct { var _ taskexec.Factory = (*factory)(nil) -func (f *factory) CreateExecutor(ctx context.Context, tid a2a.TaskID, params *a2a.MessageSendParams) (taskexec.Executor, taskexec.Processor, error) { +func (f *factory) CreateExecutor(ctx context.Context, tid a2a.TaskID, params *a2a.SendMessageRequest) (taskexec.Executor, taskexec.Processor, error) { execCtx, err := f.loadExecutionContext(ctx, tid, params) if err != nil { return nil, nil, err @@ -149,7 +147,7 @@ type executionContext struct { } // loadExecutionContext returns the information necessary for creating agent executor and agent event processor. -func (f *factory) loadExecutionContext(ctx context.Context, tid a2a.TaskID, params *a2a.MessageSendParams) (*executionContext, error) { +func (f *factory) loadExecutionContext(ctx context.Context, tid a2a.TaskID, params *a2a.SendMessageRequest) (*executionContext, error) { message := params.Message taskStoreTask, err := f.taskStore.Get(ctx, tid) @@ -207,7 +205,7 @@ func (f *factory) loadExecutionContext(ctx context.Context, tid a2a.TaskID, para }, nil } -func (f *factory) createNewExecutionContext(tid a2a.TaskID, params *a2a.MessageSendParams) (*executionContext, error) { +func (f *factory) createNewExecutionContext(tid a2a.TaskID, params *a2a.SendMessageRequest) (*executionContext, error) { msg := params.Message contextID := msg.ContextID if contextID == "" { @@ -222,7 +220,7 @@ func (f *factory) createNewExecutionContext(tid a2a.TaskID, params *a2a.MessageS return &executionContext{ctx: execCtx, task: nil}, nil } -func (f *factory) CreateCanceler(ctx context.Context, params *a2a.TaskIDParams) (taskexec.Canceler, taskexec.Processor, error) { +func (f *factory) CreateCanceler(ctx context.Context, params *a2a.CancelTaskRequest) (taskexec.Canceler, taskexec.Processor, error) { storedTask, err := f.taskStore.Get(ctx, params.ID) if err != nil { return nil, nil, fmt.Errorf("failed to load a task: %w", err) @@ -237,7 +235,7 @@ func (f *factory) CreateCanceler(ctx context.Context, params *a2a.TaskIDParams) TaskID: task.ID, StoredTask: task, ContextID: task.ContextID, - Metadata: params.Metadata, + Metadata: nil, // TODO: Fix spec https://github.com/a2aproject/A2A/pull/1485 } if callCtx, ok := CallContextFrom(ctx); ok { execCtx.User = callCtx.User diff --git a/a2asrv/handler.go b/a2asrv/handler.go index 759da26a..e7b61dd9 100644 --- a/a2asrv/handler.go +++ b/a2asrv/handler.go @@ -31,38 +31,38 @@ import ( // RequestHandler defines a transport-agnostic interface for handling incoming A2A requests. type RequestHandler interface { - // OnGetTask handles the 'tasks/get' protocol method. - OnGetTask(ctx context.Context, query *a2a.TaskQueryParams) (*a2a.Task, error) + // GetTask handles the 'tasks/get' protocol method. + GetTask(context.Context, *a2a.GetTaskRequest) (*a2a.Task, error) - // OnListTasks handles the 'tasks/list' protocol method. - OnListTasks(ctx context.Context, req *a2a.ListTasksRequest) (*a2a.ListTasksResponse, error) + // ListTasks handles the 'tasks/list' protocol method. + ListTasks(context.Context, *a2a.ListTasksRequest) (*a2a.ListTasksResponse, error) - // OnCancelTask handles the 'tasks/cancel' protocol method. - OnCancelTask(ctx context.Context, id *a2a.TaskIDParams) (*a2a.Task, error) + // CancelTask handles the 'tasks/cancel' protocol method. + CancelTask(context.Context, *a2a.CancelTaskRequest) (*a2a.Task, error) - // OnSendMessage handles the 'message/send' protocol method (non-streaming). - OnSendMessage(ctx context.Context, message *a2a.MessageSendParams) (a2a.SendMessageResult, error) + // SendMessage handles the 'message/send' protocol method (non-streaming). + SendMessage(context.Context, *a2a.SendMessageRequest) (a2a.SendMessageResult, error) - // OnResubscribeToTask handles the `tasks/resubscribe` protocol method. - OnResubscribeToTask(ctx context.Context, id *a2a.TaskIDParams) iter.Seq2[a2a.Event, error] + // SubscribeToTask handles the `tasks/resubscribe` protocol method. + SubscribeToTask(context.Context, *a2a.SubscribeToTaskRequest) iter.Seq2[a2a.Event, error] - // OnSendMessageStream handles the 'message/stream' protocol method (streaming). - OnSendMessageStream(ctx context.Context, message *a2a.MessageSendParams) iter.Seq2[a2a.Event, error] + // SendStreamingMessage handles the 'message/stream' protocol method (streaming). + SendStreamingMessage(context.Context, *a2a.SendMessageRequest) iter.Seq2[a2a.Event, error] - // OnGetTaskPushConfig handles the `tasks/pushNotificationConfig/get` protocol method. - OnGetTaskPushConfig(ctx context.Context, params *a2a.GetTaskPushConfigParams) (*a2a.TaskPushConfig, error) + // GetTaskPushConfig handles the `tasks/pushNotificationConfig/get` protocol method. + GetTaskPushConfig(context.Context, *a2a.GetTaskPushConfigRequest) (*a2a.TaskPushConfig, error) - // OnListTaskPushConfig handles the `tasks/pushNotificationConfig/list` protocol method. - OnListTaskPushConfig(ctx context.Context, params *a2a.ListTaskPushConfigParams) ([]*a2a.TaskPushConfig, error) + // ListTaskPushConfig handles the `tasks/pushNotificationConfig/list` protocol method. + ListTaskPushConfig(context.Context, *a2a.ListTaskPushConfigRequest) ([]*a2a.TaskPushConfig, error) - // OnSetTaskPushConfig handles the `tasks/pushNotificationConfig/set` protocol method. - OnSetTaskPushConfig(ctx context.Context, params *a2a.TaskPushConfig) (*a2a.TaskPushConfig, error) + // CreateTaskPushConfig handles the `tasks/pushNotificationConfig/set` protocol method. + CreateTaskPushConfig(context.Context, *a2a.CreateTaskPushConfigRequest) (*a2a.TaskPushConfig, error) - // OnDeleteTaskPushConfig handles the `tasks/pushNotificationConfig/delete` protocol method. - OnDeleteTaskPushConfig(ctx context.Context, params *a2a.DeleteTaskPushConfigParams) error + // DeleteTaskPushConfig handles the `tasks/pushNotificationConfig/delete` protocol method. + DeleteTaskPushConfig(context.Context, *a2a.DeleteTaskPushConfigRequest) error // GetAgentCard returns an extended a2a.AgentCard if configured. - OnGetExtendedAgentCard(ctx context.Context) (*a2a.AgentCard, error) + GetExtendedAgentCard(context.Context) (*a2a.AgentCard, error) } // Implements a2asrv.RequestHandler. @@ -201,8 +201,8 @@ func NewHandler(executor AgentExecutor, options ...RequestHandlerOption) Request return ih } -func (h *defaultRequestHandler) OnGetTask(ctx context.Context, query *a2a.TaskQueryParams) (*a2a.Task, error) { - taskID := query.ID +func (h *defaultRequestHandler) GetTask(ctx context.Context, req *a2a.GetTaskRequest) (*a2a.Task, error) { + taskID := req.ID if taskID == "" { return nil, fmt.Errorf("missing TaskID: %w", a2a.ErrInvalidParams) } @@ -213,8 +213,8 @@ func (h *defaultRequestHandler) OnGetTask(ctx context.Context, query *a2a.TaskQu } task := storedTask.Task - if query.HistoryLength != nil { - historyLength := *query.HistoryLength + if req.HistoryLength != nil { + historyLength := *req.HistoryLength if historyLength <= 0 { task.History = []*a2a.Message{} @@ -226,7 +226,7 @@ func (h *defaultRequestHandler) OnGetTask(ctx context.Context, query *a2a.TaskQu return task, nil } -func (h *defaultRequestHandler) OnListTasks(ctx context.Context, req *a2a.ListTasksRequest) (*a2a.ListTasksResponse, error) { +func (h *defaultRequestHandler) ListTasks(ctx context.Context, req *a2a.ListTasksRequest) (*a2a.ListTasksResponse, error) { listResponse, err := h.taskStore.List(ctx, req) if err != nil { return nil, fmt.Errorf("failed to list tasks: %w", err) @@ -234,20 +234,20 @@ func (h *defaultRequestHandler) OnListTasks(ctx context.Context, req *a2a.ListTa return listResponse, nil } -func (h *defaultRequestHandler) OnCancelTask(ctx context.Context, params *a2a.TaskIDParams) (*a2a.Task, error) { - if params == nil { +func (h *defaultRequestHandler) CancelTask(ctx context.Context, req *a2a.CancelTaskRequest) (*a2a.Task, error) { + if req == nil { return nil, a2a.ErrInvalidParams } - response, err := h.execManager.Cancel(ctx, params) + response, err := h.execManager.Cancel(ctx, req) if err != nil { return nil, fmt.Errorf("failed to cancel: %w", err) } return response, nil } -func (h *defaultRequestHandler) OnSendMessage(ctx context.Context, params *a2a.MessageSendParams) (a2a.SendMessageResult, error) { - subscription, err := h.handleSendMessage(ctx, params) +func (h *defaultRequestHandler) SendMessage(ctx context.Context, req *a2a.SendMessageRequest) (a2a.SendMessageResult, error) { + subscription, err := h.handleSendMessage(ctx, req) if err != nil { return nil, err } @@ -258,7 +258,7 @@ func (h *defaultRequestHandler) OnSendMessage(ctx context.Context, params *a2a.M return nil, err } - if taskID, interrupt := shouldInterruptNonStreaming(params, event); interrupt { + if taskID, interrupt := shouldInterruptNonStreaming(req, event); interrupt { storedTask, err := h.taskStore.Get(ctx, taskID) if err != nil { return nil, fmt.Errorf("failed to load task on event processing interrupt: %w", err) @@ -279,9 +279,9 @@ func (h *defaultRequestHandler) OnSendMessage(ctx context.Context, params *a2a.M return task.Task, nil } -func (h *defaultRequestHandler) OnSendMessageStream(ctx context.Context, params *a2a.MessageSendParams) iter.Seq2[a2a.Event, error] { +func (h *defaultRequestHandler) SendStreamingMessage(ctx context.Context, req *a2a.SendMessageRequest) iter.Seq2[a2a.Event, error] { return func(yield func(a2a.Event, error) bool) { - subscription, err := h.handleSendMessage(ctx, params) + subscription, err := h.handleSendMessage(ctx, req) if err != nil { yield(nil, err) return @@ -295,14 +295,14 @@ func (h *defaultRequestHandler) OnSendMessageStream(ctx context.Context, params } } -func (h *defaultRequestHandler) OnResubscribeToTask(ctx context.Context, params *a2a.TaskIDParams) iter.Seq2[a2a.Event, error] { +func (h *defaultRequestHandler) SubscribeToTask(ctx context.Context, req *a2a.SubscribeToTaskRequest) iter.Seq2[a2a.Event, error] { return func(yield func(a2a.Event, error) bool) { - if params == nil { + if req == nil { yield(nil, a2a.ErrInvalidParams) return } - subscription, err := h.execManager.Resubscribe(ctx, params.ID) + subscription, err := h.execManager.Resubscribe(ctx, req.ID) if err != nil { yield(nil, fmt.Errorf("%w: %w", a2a.ErrTaskNotFound, err)) return @@ -316,88 +316,91 @@ func (h *defaultRequestHandler) OnResubscribeToTask(ctx context.Context, params } } -func (h *defaultRequestHandler) handleSendMessage(ctx context.Context, params *a2a.MessageSendParams) (taskexec.Subscription, error) { +func (h *defaultRequestHandler) handleSendMessage(ctx context.Context, req *a2a.SendMessageRequest) (taskexec.Subscription, error) { switch { - case params == nil: + case req == nil: return nil, fmt.Errorf("message send params is required: %w", a2a.ErrInvalidParams) - case params.Message == nil: + case req.Message == nil: return nil, fmt.Errorf("message is required: %w", a2a.ErrInvalidParams) - case params.Message.ID == "": + case req.Message.ID == "": return nil, fmt.Errorf("message ID is required: %w", a2a.ErrInvalidParams) - case len(params.Message.Parts) == 0: + case len(req.Message.Parts) == 0: return nil, fmt.Errorf("message parts is required: %w", a2a.ErrInvalidParams) - case params.Message.Role == "": + case req.Message.Role == "": return nil, fmt.Errorf("message role is required: %w", a2a.ErrInvalidParams) } - return h.execManager.Execute(ctx, params) + return h.execManager.Execute(ctx, req) } -func (h *defaultRequestHandler) OnGetTaskPushConfig(ctx context.Context, params *a2a.GetTaskPushConfigParams) (*a2a.TaskPushConfig, error) { +func (h *defaultRequestHandler) GetTaskPushConfig(ctx context.Context, req *a2a.GetTaskPushConfigRequest) (*a2a.TaskPushConfig, error) { if h.pushConfigStore == nil || h.pushSender == nil { return nil, a2a.ErrPushNotificationNotSupported } - config, err := h.pushConfigStore.Get(ctx, params.TaskID, params.ConfigID) + config, err := h.pushConfigStore.Get(ctx, req.TaskID, req.ID) if err != nil { return nil, fmt.Errorf("failed to get push configs: %w", err) } if config != nil { return &a2a.TaskPushConfig{ - TaskID: params.TaskID, + TaskID: req.TaskID, Config: *config, }, nil } return nil, push.ErrPushConfigNotFound } -func (h *defaultRequestHandler) OnListTaskPushConfig(ctx context.Context, params *a2a.ListTaskPushConfigParams) ([]*a2a.TaskPushConfig, error) { +func (h *defaultRequestHandler) ListTaskPushConfig(ctx context.Context, req *a2a.ListTaskPushConfigRequest) ([]*a2a.TaskPushConfig, error) { if h.pushConfigStore == nil || h.pushSender == nil { return nil, a2a.ErrPushNotificationNotSupported } - configs, err := h.pushConfigStore.List(ctx, params.TaskID) + configs, err := h.pushConfigStore.List(ctx, req.TaskID) if err != nil { return nil, fmt.Errorf("failed to list push configs: %w", err) } result := make([]*a2a.TaskPushConfig, len(configs)) for i, config := range configs { result[i] = &a2a.TaskPushConfig{ - TaskID: params.TaskID, + TaskID: req.TaskID, Config: *config, } } return result, nil } -func (h *defaultRequestHandler) OnSetTaskPushConfig(ctx context.Context, params *a2a.TaskPushConfig) (*a2a.TaskPushConfig, error) { +func (h *defaultRequestHandler) CreateTaskPushConfig(ctx context.Context, req *a2a.CreateTaskPushConfigRequest) (*a2a.TaskPushConfig, error) { if h.pushConfigStore == nil || h.pushSender == nil { return nil, a2a.ErrPushNotificationNotSupported } - saved, err := h.pushConfigStore.Save(ctx, params.TaskID, ¶ms.Config) + + saved, err := h.pushConfigStore.Save(ctx, req.TaskID, &req.Config) if err != nil { return nil, fmt.Errorf("failed to save push config: %w", err) } + return &a2a.TaskPushConfig{ - TaskID: params.TaskID, + TaskID: req.TaskID, + ID: saved.ID, Config: *saved, }, nil } -func (h *defaultRequestHandler) OnDeleteTaskPushConfig(ctx context.Context, params *a2a.DeleteTaskPushConfigParams) error { +func (h *defaultRequestHandler) DeleteTaskPushConfig(ctx context.Context, req *a2a.DeleteTaskPushConfigRequest) error { if h.pushConfigStore == nil || h.pushSender == nil { return a2a.ErrPushNotificationNotSupported } - return h.pushConfigStore.Delete(ctx, params.TaskID, params.ConfigID) + return h.pushConfigStore.Delete(ctx, req.TaskID, req.ID) } -func (h *defaultRequestHandler) OnGetExtendedAgentCard(ctx context.Context) (*a2a.AgentCard, error) { +func (h *defaultRequestHandler) GetExtendedAgentCard(ctx context.Context) (*a2a.AgentCard, error) { if h.authenticatedCardProducer == nil { return nil, a2a.ErrAuthenticatedExtendedCardNotConfigured } return h.authenticatedCardProducer.Card(ctx) } -func shouldInterruptNonStreaming(params *a2a.MessageSendParams, event a2a.Event) (a2a.TaskID, bool) { +func shouldInterruptNonStreaming(req *a2a.SendMessageRequest, event a2a.Event) (a2a.TaskID, bool) { // Non-blocking clients receive a result on the first task event, default Blocking to TRUE - if params.Config != nil && params.Config.Blocking != nil && !(*params.Config.Blocking) { + if req.Config != nil && req.Config.Blocking != nil && !(*req.Config.Blocking) { if _, ok := event.(*a2a.Message); ok { return "", false } diff --git a/a2asrv/handler_test.go b/a2asrv/handler_test.go index c6f8f4e7..405ee1cc 100644 --- a/a2asrv/handler_test.go +++ b/a2asrv/handler_test.go @@ -35,7 +35,7 @@ import ( var fixedTime = time.Now() -func TestRequestHandler_OnSendMessage(t *testing.T) { +func TestRequestHandler_SendMessage(t *testing.T) { artifactID := a2a.NewArtifactID() taskSeed := &a2a.Task{ID: a2a.NewTaskID(), ContextID: a2a.NewContextID()} inputRequiredTaskSeed := &a2a.Task{ID: a2a.NewTaskID(), ContextID: a2a.NewContextID(), Status: a2a.TaskStatus{State: a2a.TaskStateInputRequired}} @@ -44,7 +44,7 @@ func TestRequestHandler_OnSendMessage(t *testing.T) { type testCase struct { name string - input *a2a.MessageSendParams + input *a2a.SendMessageRequest agentEvents []a2a.Event wantResult a2a.SendMessageResult wantErr error @@ -79,15 +79,15 @@ func TestRequestHandler_OnSendMessage(t *testing.T) { }, { name: "fails if unknown task state", - input: &a2a.MessageSendParams{ + input: &a2a.SendMessageRequest{ Message: newUserMessage(taskSeed, "Work"), }, agentEvents: []a2a.Event{newTaskWithStatus(taskSeed, a2a.TaskStateUnknown, "...")}, - wantErr: fmt.Errorf("unknown task state: unknown"), + wantErr: fmt.Errorf("unknown task state: %s", a2a.TaskStateUnknown), }, { name: "final task overwrites intermediate task events", - input: &a2a.MessageSendParams{ + input: &a2a.SendMessageRequest{ Message: newUserMessage(taskSeed, "Work"), }, agentEvents: []a2a.Event{ @@ -98,7 +98,7 @@ func TestRequestHandler_OnSendMessage(t *testing.T) { }, { name: "final task overwrites intermediate status updates", - input: &a2a.MessageSendParams{ + input: &a2a.SendMessageRequest{ Message: newUserMessage(taskSeed, "Work"), }, agentEvents: []a2a.Event{ @@ -108,27 +108,9 @@ func TestRequestHandler_OnSendMessage(t *testing.T) { }, wantResult: newTaskWithStatus(taskSeed, a2a.TaskStateCompleted, "no status change history"), }, - { - name: "event final flag takes precedence over task state", - input: &a2a.MessageSendParams{Message: newUserMessage(taskSeed, "Work")}, - agentEvents: []a2a.Event{ - newTaskStatusUpdate(taskSeed, a2a.TaskStateCompleted, "Working..."), - newFinalTaskStatusUpdate(taskSeed, a2a.TaskStateWorking, "Done!"), - }, - wantResult: &a2a.Task{ - ID: taskSeed.ID, - ContextID: taskSeed.ContextID, - Status: a2a.TaskStatus{ - State: a2a.TaskStateWorking, - Message: newAgentMessage("Done!"), - Timestamp: &fixedTime, - }, - History: []*a2a.Message{newUserMessage(taskSeed, "Work"), newAgentMessage("Working...")}, - }, - }, { name: "task status update accumulation", - input: &a2a.MessageSendParams{Message: newUserMessage(taskSeed, "Syn")}, + input: &a2a.SendMessageRequest{Message: newUserMessage(taskSeed, "Syn")}, agentEvents: []a2a.Event{ newTaskStatusUpdate(taskSeed, a2a.TaskStateSubmitted, "Ack"), newTaskStatusUpdate(taskSeed, a2a.TaskStateWorking, "Working..."), @@ -151,7 +133,7 @@ func TestRequestHandler_OnSendMessage(t *testing.T) { }, { name: "input-required task status update", - input: &a2a.MessageSendParams{Message: newUserMessage(taskSeed, "Syn")}, + input: &a2a.SendMessageRequest{Message: newUserMessage(taskSeed, "Syn")}, agentEvents: []a2a.Event{ newTaskStatusUpdate(taskSeed, a2a.TaskStateSubmitted, "Ack"), newTaskStatusUpdate(taskSeed, a2a.TaskStateWorking, "Working..."), @@ -174,11 +156,11 @@ func TestRequestHandler_OnSendMessage(t *testing.T) { }, { name: "task artifact streaming", - input: &a2a.MessageSendParams{Message: newUserMessage(taskSeed, "Syn")}, + input: &a2a.SendMessageRequest{Message: newUserMessage(taskSeed, "Syn")}, agentEvents: []a2a.Event{ newTaskStatusUpdate(taskSeed, a2a.TaskStateSubmitted, "Ack"), - newArtifactEvent(taskSeed, artifactID, a2a.TextPart{Text: "Hello"}), - a2a.NewArtifactUpdateEvent(taskSeed, artifactID, a2a.TextPart{Text: ", world!"}), + newArtifactEvent(taskSeed, artifactID, a2a.NewTextPart("Hello")), + a2a.NewArtifactUpdateEvent(taskSeed, artifactID, a2a.NewTextPart(", world!")), newFinalTaskStatusUpdate(taskSeed, a2a.TaskStateCompleted, "Done!"), }, wantResult: &a2a.Task{ @@ -187,17 +169,17 @@ func TestRequestHandler_OnSendMessage(t *testing.T) { Status: a2a.TaskStatus{State: a2a.TaskStateCompleted, Message: newAgentMessage("Done!"), Timestamp: &fixedTime}, History: []*a2a.Message{newUserMessage(taskSeed, "Syn"), newAgentMessage("Ack")}, Artifacts: []*a2a.Artifact{ - {ID: artifactID, Parts: a2a.ContentParts{a2a.TextPart{Text: "Hello"}, a2a.TextPart{Text: ", world!"}}}, + {ID: artifactID, Parts: a2a.ContentParts{a2a.NewTextPart("Hello"), a2a.NewTextPart(", world!")}}, }, }, }, { name: "task with multiple artifacts", - input: &a2a.MessageSendParams{Message: newUserMessage(taskSeed, "Syn")}, + input: &a2a.SendMessageRequest{Message: newUserMessage(taskSeed, "Syn")}, agentEvents: []a2a.Event{ newTaskStatusUpdate(taskSeed, a2a.TaskStateSubmitted, "Ack"), - newArtifactEvent(taskSeed, artifactID, a2a.TextPart{Text: "Hello"}), - newArtifactEvent(taskSeed, artifactID+"2", a2a.TextPart{Text: "World"}), + newArtifactEvent(taskSeed, artifactID, a2a.NewTextPart("Hello")), + newArtifactEvent(taskSeed, artifactID+"2", a2a.NewTextPart("World")), newFinalTaskStatusUpdate(taskSeed, a2a.TaskStateCompleted, "Done!"), }, wantResult: &a2a.Task{ @@ -206,14 +188,14 @@ func TestRequestHandler_OnSendMessage(t *testing.T) { Status: a2a.TaskStatus{State: a2a.TaskStateCompleted, Message: newAgentMessage("Done!"), Timestamp: &fixedTime}, History: []*a2a.Message{newUserMessage(taskSeed, "Syn"), newAgentMessage("Ack")}, Artifacts: []*a2a.Artifact{ - {ID: artifactID, Parts: a2a.ContentParts{a2a.TextPart{Text: "Hello"}}}, - {ID: artifactID + "2", Parts: a2a.ContentParts{a2a.TextPart{Text: "World"}}}, + {ID: artifactID, Parts: a2a.ContentParts{a2a.NewTextPart("Hello")}}, + {ID: artifactID + "2", Parts: a2a.ContentParts{a2a.NewTextPart("World")}}, }, }, }, { name: "task continuation", - input: &a2a.MessageSendParams{ + input: &a2a.SendMessageRequest{ Message: newUserMessage(inputRequiredTaskSeed, "continue"), }, agentEvents: []a2a.Event{ @@ -236,20 +218,20 @@ func TestRequestHandler_OnSendMessage(t *testing.T) { }, { name: "fails if no message", - input: &a2a.MessageSendParams{}, + input: &a2a.SendMessageRequest{}, wantErr: fmt.Errorf("message is required: %w", a2a.ErrInvalidParams), }, { name: "fails if no message ID", - input: &a2a.MessageSendParams{Message: &a2a.Message{ - Parts: a2a.ContentParts{a2a.TextPart{Text: "Test"}}, + input: &a2a.SendMessageRequest{Message: &a2a.Message{ + Parts: a2a.ContentParts{a2a.NewTextPart("Test")}, Role: a2a.MessageRoleUser, }}, wantErr: fmt.Errorf("message ID is required: %w", a2a.ErrInvalidParams), }, { name: "fails if no message parts", - input: &a2a.MessageSendParams{Message: &a2a.Message{ + input: &a2a.SendMessageRequest{Message: &a2a.Message{ ID: a2a.NewMessageID(), Role: a2a.MessageRoleUser, }}, @@ -257,19 +239,19 @@ func TestRequestHandler_OnSendMessage(t *testing.T) { }, { name: "fails if no message role", - input: &a2a.MessageSendParams{Message: &a2a.Message{ + input: &a2a.SendMessageRequest{Message: &a2a.Message{ ID: a2a.NewMessageID(), - Parts: a2a.ContentParts{a2a.TextPart{Text: "Test"}}, + Parts: a2a.ContentParts{a2a.NewTextPart("Test")}, }}, wantErr: fmt.Errorf("message role is required: %w", a2a.ErrInvalidParams), }, { name: "fails on non-existent task reference", - input: &a2a.MessageSendParams{ + input: &a2a.SendMessageRequest{ Message: &a2a.Message{ TaskID: "non-existent", ID: "test-message", - Parts: a2a.ContentParts{a2a.TextPart{Text: "Test"}}, + Parts: a2a.ContentParts{a2a.NewTextPart("Test")}, Role: a2a.MessageRoleUser, }, }, @@ -277,14 +259,14 @@ func TestRequestHandler_OnSendMessage(t *testing.T) { }, { name: "fails if contextID not equal to task contextID", - input: &a2a.MessageSendParams{ + input: &a2a.SendMessageRequest{ Message: &a2a.Message{TaskID: taskSeed.ID, ContextID: taskSeed.ContextID + "1", ID: "test-message"}, }, wantErr: a2a.ErrInvalidParams, }, { name: "fails if message references completed task", - input: &a2a.MessageSendParams{ + input: &a2a.SendMessageRequest{ Message: newUserMessage(completedTaskSeed, "Test"), }, wantErr: fmt.Errorf("setup failed: task in a terminal state %q: %w", a2a.TaskStateCompleted, a2a.ErrInvalidParams), @@ -293,7 +275,7 @@ func TestRequestHandler_OnSendMessage(t *testing.T) { } for _, tt := range createTestCases() { - input := &a2a.MessageSendParams{Message: newUserMessage(taskSeed, "Test")} + input := &a2a.SendMessageRequest{Message: newUserMessage(taskSeed, "Test")} if tt.input != nil { input = tt.input } @@ -305,29 +287,29 @@ func TestRequestHandler_OnSendMessage(t *testing.T) { executor := newEventReplayAgent(tt.agentEvents, nil) handler := NewHandler(executor, WithTaskStore(store)) - result, gotErr := handler.OnSendMessage(ctx, input) + result, gotErr := handler.SendMessage(ctx, input) if tt.wantErr == nil { if gotErr != nil { - t.Errorf("OnSendMessage() error = %v, wantErr nil", gotErr) + t.Errorf("SendMessage() error = %v, wantErr nil", gotErr) return } if diff := cmp.Diff(tt.wantResult, result); diff != "" { - t.Errorf("OnSendMessage() (+got,-want):\ngot = %v\nwant %v\ndiff = %s", result, tt.wantResult, diff) + t.Errorf("SendMessage() (+got,-want):\ngot = %v\nwant %v\ndiff = %s", result, tt.wantResult, diff) } } else { if gotErr == nil { - t.Errorf("OnSendMessage() error = nil, wantErr %q", tt.wantErr) + t.Errorf("SendMessage() error = nil, wantErr %q", tt.wantErr) return } if gotErr.Error() != tt.wantErr.Error() && !errors.Is(gotErr, tt.wantErr) { - t.Errorf("OnSendMessage() error = %v, wantErr %v", gotErr, tt.wantErr) + t.Errorf("SendMessage() error = %v, wantErr %v", gotErr, tt.wantErr) } } }) } for _, tt := range createTestCases() { - input := &a2a.MessageSendParams{Message: newUserMessage(taskSeed, "Test")} + input := &a2a.SendMessageRequest{Message: newUserMessage(taskSeed, "Test")} if tt.input != nil { input = tt.input } @@ -341,14 +323,14 @@ func TestRequestHandler_OnSendMessage(t *testing.T) { eventI := 0 var streamErr error - for got, gotErr := range handler.OnSendMessageStream(ctx, input) { + for got, gotErr := range handler.SendStreamingMessage(ctx, input) { if streamErr != nil { - t.Errorf("handler.OnSendMessageStream() got (%v, %v) after error, want stream end", got, gotErr) + t.Errorf("handler.SendStreamingMessage() got (%v, %v) after error, want stream end", got, gotErr) return } if gotErr != nil && tt.wantErr == nil { - t.Errorf("OnSendMessageStream() error = %v, wantErr nil", gotErr) + t.Errorf("SendStreamingMessage() error = %v, wantErr nil", gotErr) return } if gotErr != nil { @@ -362,26 +344,26 @@ func TestRequestHandler_OnSendMessage(t *testing.T) { eventI++ } if diff := cmp.Diff(want, got); diff != "" { - t.Errorf("OnSendMessageStream() (+got,-want):\ngot = %v\nwant %v\ndiff = %s", got, want, diff) + t.Errorf("SendStreamingMessage() (+got,-want):\ngot = %v\nwant %v\ndiff = %s", got, want, diff) return } } if tt.wantErr == nil && eventI != len(tt.agentEvents) { - t.Errorf("OnSendMessageStream() received %d events, want %d", eventI, len(tt.agentEvents)) + t.Errorf("SendStreamingMessage() received %d events, want %d", eventI, len(tt.agentEvents)) return } if tt.wantErr != nil && streamErr == nil { - t.Errorf("OnSendMessageStream() error = nil, want %v", tt.wantErr) + t.Errorf("SendStreamingMessage() error = nil, want %v", tt.wantErr) return } if tt.wantErr != nil && (streamErr.Error() != tt.wantErr.Error() && !errors.Is(streamErr, tt.wantErr)) { - t.Errorf("OnSendMessageStream() error = %v, wantErr %v", streamErr, tt.wantErr) + t.Errorf("SendStreamingMessage() error = %v, wantErr %v", streamErr, tt.wantErr) } }) } } -func TestRequestHandler_OnSendMessage_AuthRequired(t *testing.T) { +func TestRequestHandler_SendMessage_AuthRequired(t *testing.T) { ctx := t.Context() ts := testutil.NewTestTaskStore() authCredentialsChan := make(chan struct{}) @@ -395,45 +377,43 @@ func TestRequestHandler_OnSendMessage_AuthRequired(t *testing.T) { return } <-authCredentialsChan - result := a2a.NewStatusUpdateEvent(execCtx, a2a.TaskStateCompleted, nil) - result.Final = true - yield(result, nil) + yield(a2a.NewStatusUpdateEvent(execCtx, a2a.TaskStateCompleted, nil), nil) } }, } handler := NewHandler(executor, WithTaskStore(ts)) - msg := a2a.NewMessage(a2a.MessageRoleUser, a2a.TextPart{Text: "perform protected operation"}) - result, err := handler.OnSendMessage(ctx, &a2a.MessageSendParams{Message: msg}) + msg := a2a.NewMessage(a2a.MessageRoleUser, a2a.NewTextPart("perform protected operation")) + result, err := handler.SendMessage(ctx, &a2a.SendMessageRequest{Message: msg}) if err != nil { - t.Fatalf("OnSendMessage() error = %v, wantErr nil", err) + t.Fatalf("SendMessage() error = %v, wantErr nil", err) } var taskID a2a.TaskID if task, ok := result.(*a2a.Task); ok { if task.Status.State != a2a.TaskStateAuthRequired { - t.Fatalf("OnSendMessage() = %v, want a2a.Task in %q state", result, a2a.TaskStateAuthRequired) + t.Fatalf("SendMessage() = %v, want a2a.Task in %q state", result, a2a.TaskStateAuthRequired) } msg.TaskID = task.ID taskID = task.ID } else { - t.Fatalf("OnSendMessage() = %v, want a2a.Task", result) + t.Fatalf("SendMessage() = %v, want a2a.Task", result) } - _, err = handler.OnSendMessage(ctx, &a2a.MessageSendParams{Message: msg}) + _, err = handler.SendMessage(ctx, &a2a.SendMessageRequest{Message: msg}) if !strings.Contains(err.Error(), "execution is already in progress") { - t.Fatalf("OnSendMessage() error = %v, want err to contain 'execution is already in progress'", err) + t.Fatalf("SendMessage() error = %v, want err to contain 'execution is already in progress'", err) } authCredentialsChan <- struct{}{} time.Sleep(time.Millisecond * 10) - task, err := handler.OnGetTask(ctx, &a2a.TaskQueryParams{ID: taskID}) + task, err := handler.GetTask(ctx, &a2a.GetTaskRequest{ID: taskID}) if task.Status.State != a2a.TaskStateCompleted { t.Fatalf("handler.OnGetTask() = (%v, %v), want a task in state %q", task, err, a2a.TaskStateCompleted) } } -func TestRequestHandler_OnSendMessage_NonBlocking(t *testing.T) { +func TestRequestHandler_SendMessage_NonBlocking(t *testing.T) { taskSeed := &a2a.Task{ID: a2a.NewTaskID(), ContextID: a2a.NewContextID(), Status: a2a.TaskStatus{State: a2a.TaskStateInputRequired}} createExecutor := func(generateEvent func(execCtx *ExecutorContext) []a2a.Event) (*mockAgentExecutor, chan struct{}) { @@ -462,7 +442,7 @@ func TestRequestHandler_OnSendMessage_NonBlocking(t *testing.T) { type testCase struct { name string blocking bool - input *a2a.MessageSendParams + input *a2a.SendMessageRequest agentEvents func(execCtx *ExecutorContext) []a2a.Event wantState a2a.TaskState wantEvents int @@ -473,7 +453,7 @@ func TestRequestHandler_OnSendMessage_NonBlocking(t *testing.T) { { name: "defaults to blocking", blocking: true, - input: &a2a.MessageSendParams{Message: newUserMessage(taskSeed, "Work"), Config: &a2a.MessageSendConfig{}}, + input: &a2a.SendMessageRequest{Message: newUserMessage(taskSeed, "Work"), Config: &a2a.SendMessageConfig{}}, agentEvents: func(execCtx *ExecutorContext) []a2a.Event { return []a2a.Event{ newTaskWithStatus(execCtx, a2a.TaskStateWorking, "Working..."), @@ -485,7 +465,7 @@ func TestRequestHandler_OnSendMessage_NonBlocking(t *testing.T) { }, { name: "non-terminal task state", - input: &a2a.MessageSendParams{Message: newUserMessage(taskSeed, "Work"), Config: &a2a.MessageSendConfig{Blocking: utils.Ptr(false)}}, + input: &a2a.SendMessageRequest{Message: newUserMessage(taskSeed, "Work"), Config: &a2a.SendMessageConfig{Blocking: utils.Ptr(false)}}, agentEvents: func(execCtx *ExecutorContext) []a2a.Event { return []a2a.Event{ newTaskWithStatus(execCtx, a2a.TaskStateWorking, "Working..."), @@ -497,7 +477,7 @@ func TestRequestHandler_OnSendMessage_NonBlocking(t *testing.T) { }, { name: "non-final status update", - input: &a2a.MessageSendParams{Message: newUserMessage(taskSeed, "Work"), Config: &a2a.MessageSendConfig{Blocking: utils.Ptr(false)}}, + input: &a2a.SendMessageRequest{Message: newUserMessage(taskSeed, "Work"), Config: &a2a.SendMessageConfig{Blocking: utils.Ptr(false)}}, agentEvents: func(execCtx *ExecutorContext) []a2a.Event { return []a2a.Event{ newTaskStatusUpdate(execCtx, a2a.TaskStateWorking, "Working..."), @@ -509,10 +489,10 @@ func TestRequestHandler_OnSendMessage_NonBlocking(t *testing.T) { }, { name: "artifact update", - input: &a2a.MessageSendParams{Message: newUserMessage(taskSeed, "Work"), Config: &a2a.MessageSendConfig{Blocking: utils.Ptr(false)}}, + input: &a2a.SendMessageRequest{Message: newUserMessage(taskSeed, "Work"), Config: &a2a.SendMessageConfig{Blocking: utils.Ptr(false)}}, agentEvents: func(execCtx *ExecutorContext) []a2a.Event { return []a2a.Event{ - newArtifactEvent(execCtx, a2a.NewArtifactID(), a2a.TextPart{Text: "Artifact"}), + newArtifactEvent(execCtx, a2a.NewArtifactID(), a2a.NewTextPart("Artifact")), newFinalTaskStatusUpdate(execCtx, a2a.TaskStateCompleted, "Done!"), } }, @@ -521,7 +501,7 @@ func TestRequestHandler_OnSendMessage_NonBlocking(t *testing.T) { }, { name: "message for existing task", - input: &a2a.MessageSendParams{Message: newUserMessage(taskSeed, "Work"), Config: &a2a.MessageSendConfig{Blocking: utils.Ptr(false)}}, + input: &a2a.SendMessageRequest{Message: newUserMessage(taskSeed, "Work"), Config: &a2a.SendMessageConfig{Blocking: utils.Ptr(false)}}, agentEvents: func(execCtx *ExecutorContext) []a2a.Event { return []a2a.Event{ newTaskStatusUpdate(taskSeed, a2a.TaskStateWorking, "Working..."), @@ -533,11 +513,11 @@ func TestRequestHandler_OnSendMessage_NonBlocking(t *testing.T) { }, { name: "message", - input: &a2a.MessageSendParams{Message: newUserMessage(taskSeed, "Work"), Config: &a2a.MessageSendConfig{Blocking: utils.Ptr(false)}}, + input: &a2a.SendMessageRequest{Message: newUserMessage(taskSeed, "Work"), Config: &a2a.SendMessageConfig{Blocking: utils.Ptr(false)}}, agentEvents: func(execCtx *ExecutorContext) []a2a.Event { return []a2a.Event{ - a2a.NewMessageForTask(a2a.MessageRoleAgent, execCtx, a2a.TextPart{Text: "Done"}), - a2a.NewMessageForTask(a2a.MessageRoleAgent, execCtx, a2a.TextPart{Text: "Done-2"}), + a2a.NewMessageForTask(a2a.MessageRoleAgent, execCtx, a2a.NewTextPart("Done")), + a2a.NewMessageForTask(a2a.MessageRoleAgent, execCtx, a2a.NewTextPart("Done-2")), } }, wantEvents: 1, // streaming processing stops imeddiately after the first message @@ -556,26 +536,26 @@ func TestRequestHandler_OnSendMessage_NonBlocking(t *testing.T) { } handler := NewHandler(executor, WithTaskStore(store)) - result, gotErr := handler.OnSendMessage(ctx, tt.input) + result, gotErr := handler.SendMessage(ctx, tt.input) if !tt.blocking { close(waitingChan) } if gotErr != nil { - t.Errorf("OnSendMessage() error = %v, wantErr nil", gotErr) + t.Errorf("SendMessage() error = %v, wantErr nil", gotErr) return } if tt.wantState != a2a.TaskStateUnspecified { task, ok := result.(*a2a.Task) if !ok { - t.Errorf("OnSendMessage() returned %T, want a2a.Task", result) + t.Errorf("SendMessage() returned %T, want a2a.Task", result) return } if task.Status.State != tt.wantState { - t.Errorf("OnSendMessage() task.State = %v, want %v", task.Status.State, tt.wantState) + t.Errorf("SendMessage() task.State = %v, want %v", task.Status.State, tt.wantState) } } else { if _, ok := result.(*a2a.Message); !ok { - t.Errorf("OnSendMessage() returned %T, want a2a.Message", result) + t.Errorf("SendMessage() returned %T, want a2a.Message", result) } } }) @@ -591,20 +571,20 @@ func TestRequestHandler_OnSendMessage_NonBlocking(t *testing.T) { handler := NewHandler(executor, WithTaskStore(store)) gotEvents := 0 - for _, gotErr := range handler.OnSendMessageStream(ctx, tt.input) { + for _, gotErr := range handler.SendStreamingMessage(ctx, tt.input) { if gotErr != nil { - t.Errorf("OnSendMessageStream() error = %v, wantErr nil", gotErr) + t.Errorf("SendStreamingMessage() error = %v, wantErr nil", gotErr) } gotEvents++ } if gotEvents != tt.wantEvents { - t.Errorf("OnSendMessageStream() event count = %d, want %d", gotEvents, tt.wantEvents) + t.Errorf("SendStreamingMessage() event count = %d, want %d", gotEvents, tt.wantEvents) } }) } } -func TestRequestHandler_OnSendMessageStreaming_AuthRequired(t *testing.T) { +func TestRequestHandler_SendMessageStreaming_AuthRequired(t *testing.T) { ctx := t.Context() ts := testutil.NewTestTaskStore() authCredentialsChan := make(chan struct{}) @@ -618,43 +598,41 @@ func TestRequestHandler_OnSendMessageStreaming_AuthRequired(t *testing.T) { return } <-authCredentialsChan - result := a2a.NewStatusUpdateEvent(execCtx, a2a.TaskStateCompleted, nil) - result.Final = true - yield(result, nil) + yield(a2a.NewStatusUpdateEvent(execCtx, a2a.TaskStateCompleted, nil), nil) } }, } handler := NewHandler(executor, WithTaskStore(ts)) var lastEvent a2a.Event - msg := a2a.NewMessage(a2a.MessageRoleUser, a2a.TextPart{Text: "perform protected operation"}) - for event, err := range handler.OnSendMessageStream(ctx, &a2a.MessageSendParams{Message: msg}) { + msg := a2a.NewMessage(a2a.MessageRoleUser, a2a.NewTextPart("perform protected operation")) + for event, err := range handler.SendStreamingMessage(ctx, &a2a.SendMessageRequest{Message: msg}) { if upd, ok := event.(*a2a.TaskStatusUpdateEvent); ok && upd.Status.State == a2a.TaskStateAuthRequired { go func() { authCredentialsChan <- struct{}{} }() } if err != nil { - t.Fatalf("OnSendMessageStream() error = %v, wantErr nil", err) + t.Fatalf("SendStreamingMessage() error = %v, wantErr nil", err) } lastEvent = event } if task, ok := lastEvent.(*a2a.TaskStatusUpdateEvent); ok { if task.Status.State != a2a.TaskStateCompleted { - t.Fatalf("OnSendMessageStream() = %v, want status update with state %q", lastEvent, a2a.TaskStateAuthRequired) + t.Fatalf("SendStreamingMessage() = %v, want status update with state %q", lastEvent, a2a.TaskStateAuthRequired) } } else { - t.Fatalf("OnSendMessageStream() = %v, want a2a.TaskStatusUpdateEvent", lastEvent) + t.Fatalf("SendStreamingMessage() = %v, want a2a.TaskStatusUpdateEvent", lastEvent) } } -func TestRequestHandler_OnSendMessage_PushNotifications(t *testing.T) { +func TestRequestHandler_SendMessage_PushNotifications(t *testing.T) { ctx := t.Context() taskSeed := &a2a.Task{ID: a2a.NewTaskID(), ContextID: a2a.NewContextID()} pushConfig := &a2a.PushConfig{URL: "https://example.com/push"} - input := &a2a.MessageSendParams{ + input := &a2a.SendMessageRequest{ Message: newUserMessage(taskSeed, "work"), - Config: &a2a.MessageSendConfig{ + Config: &a2a.SendMessageConfig{ PushConfig: pushConfig, }, } @@ -670,12 +648,12 @@ func TestRequestHandler_OnSendMessage_PushNotifications(t *testing.T) { pn := testutil.NewTestPushSender(t).SetSendPushError(nil) handler := NewHandler(executor, WithTaskStore(store), WithPushNotifications(ps, pn)) - result, err := handler.OnSendMessage(ctx, input) + result, err := handler.SendMessage(ctx, input) if err != nil { - t.Fatalf("OnSendMessage() failed: %v", err) + t.Fatalf("SendMessage() failed: %v", err) } if diff := cmp.Diff(wantResult, result); diff != "" { - t.Fatalf("OnSendMessage() mismatch (-want +got):\n%s", diff) + t.Fatalf("SendMessage() mismatch (-want +got):\n%s", diff) } saved, err := ps.List(ctx, taskSeed.ID) if err != nil || len(saved) != 1 { @@ -694,9 +672,9 @@ func TestRequestHandler_TaskExecutionFailOnPush(t *testing.T) { sender := push.NewHTTPPushSender(&push.HTTPSenderConfig{FailOnError: true}) taskSeed := &a2a.Task{ID: a2a.NewTaskID(), ContextID: a2a.NewContextID()} - input := &a2a.MessageSendParams{ + input := &a2a.SendMessageRequest{ Message: newUserMessage(taskSeed, "work"), - Config: &a2a.MessageSendConfig{PushConfig: pushConfig}, + Config: &a2a.SendMessageConfig{PushConfig: pushConfig}, } agentEvents := []a2a.Event{ newFinalTaskStatusUpdate(taskSeed, a2a.TaskStateCompleted, "Done!"), @@ -708,16 +686,16 @@ func TestRequestHandler_TaskExecutionFailOnPush(t *testing.T) { executor := newEventReplayAgent(agentEvents, nil) handler := NewHandler(executor, WithTaskStore(store), WithPushNotifications(pushConfigStore, sender)) - result, err := handler.OnSendMessage(ctx, input) + result, err := handler.SendMessage(ctx, input) if err != nil { - t.Fatalf("OnSendMessage() error = %v", err) + t.Fatalf("SendMessage() error = %v", err) } task, ok := result.(*a2a.Task) if !ok { - t.Fatalf("OnSendMessage() result type = %T, want *a2a.Task", result) + t.Fatalf("SendMessage() result type = %T, want *a2a.Task", result) } if task.Status.State != a2a.TaskStateFailed { - t.Fatalf("OnSendMessage() result = %+v, want state %q", result, a2a.TaskStateFailed) + t.Fatalf("SendMessage() result = %+v, want state %q", result, a2a.TaskStateFailed) } } @@ -746,7 +724,7 @@ func TestRequestHandler_TaskExecutionFailOnInvalidEvent(t *testing.T) { t.Parallel() ctx := t.Context() taskSeed := &a2a.Task{ID: a2a.NewTaskID(), ContextID: a2a.NewContextID()} - input := &a2a.MessageSendParams{Message: newUserMessage(taskSeed, "work")} + input := &a2a.SendMessageRequest{Message: newUserMessage(taskSeed, "work")} store := testutil.NewTestTaskStore().WithTasks(t, taskSeed) executor := newEventReplayAgent([]a2a.Event{tc.event}, nil) @@ -754,33 +732,33 @@ func TestRequestHandler_TaskExecutionFailOnInvalidEvent(t *testing.T) { var result a2a.Event if streaming { - for event, err := range handler.OnSendMessageStream(ctx, input) { + for event, err := range handler.SendStreamingMessage(ctx, input) { if err != nil { - t.Fatalf("OnSendMessageStream() error = %v", err) + t.Fatalf("SendStreamingMessage() error = %v", err) } result = event } } else { - localResult, err := handler.OnSendMessage(ctx, input) + localResult, err := handler.SendMessage(ctx, input) if err != nil { - t.Fatalf("OnSendMessage() error = %v", err) + t.Fatalf("SendMessage() error = %v", err) } result = localResult } task, ok := result.(*a2a.Task) if !ok { - t.Fatalf("OnSendMessage() result type = %T, want *a2a.Task", result) + t.Fatalf("SendMessage() result type = %T, want *a2a.Task", result) } if task.Status.State != a2a.TaskStateFailed { - t.Fatalf("OnSendMessage() result = %+v, want state %q", result, a2a.TaskStateFailed) + t.Fatalf("SendMessage() result = %+v, want state %q", result, a2a.TaskStateFailed) } }) } } } -func TestRequestHandler_OnSendMessage_FailsToStoreFailedState(t *testing.T) { +func TestRequestHandler_SendMessage_FailsToStoreFailedState(t *testing.T) { ctx := t.Context() taskSeed := &a2a.Task{ID: a2a.NewTaskID(), ContextID: a2a.NewContextID()} @@ -791,19 +769,19 @@ func TestRequestHandler_OnSendMessage_FailsToStoreFailedState(t *testing.T) { } return store.InMemory.Update(ctx, req) } - input := &a2a.MessageSendParams{Message: newUserMessage(taskSeed, "work")} + input := &a2a.SendMessageRequest{Message: newUserMessage(taskSeed, "work")} executor := newEventReplayAgent([]a2a.Event{&a2a.Task{ID: "wrong id", ContextID: a2a.NewContextID()}}, nil) handler := NewHandler(executor, WithTaskStore(store)) wantErr := "wrong id" - _, err := handler.OnSendMessage(ctx, input) + _, err := handler.SendMessage(ctx, input) if !strings.Contains(err.Error(), wantErr) { - t.Fatalf("OnSendMessage() err = %v, want to contain %q", err, wantErr) + t.Fatalf("SendMessage() err = %v, want to contain %q", err, wantErr) } } -func TestRequestHandler_OnSendMessage_TaskVersion(t *testing.T) { +func TestRequestHandler_SendMessage_TaskVersion(t *testing.T) { ctx := t.Context() gotPrevVersions := make([]taskstore.TaskVersion, 0) @@ -826,9 +804,8 @@ func TestRequestHandler_OnSendMessage_TaskVersion(t *testing.T) { } } events := statusUpdates[0] - for i, state := range events { + for _, state := range events { event := a2a.NewStatusUpdateEvent(execCtx, state, nil) - event.Final = i == len(events)-1 if !yield(event, nil) { return } @@ -848,17 +825,17 @@ func TestRequestHandler_OnSendMessage_TaskVersion(t *testing.T) { for _, wantPrev := range wantPrevVersions { var msg *a2a.Message if existingTask == nil { - msg = a2a.NewMessage(a2a.MessageRoleUser, a2a.TextPart{Text: "Hi!"}) + msg = a2a.NewMessage(a2a.MessageRoleUser, a2a.NewTextPart("Hi!")) } else { - msg = a2a.NewMessageForTask(a2a.MessageRoleUser, existingTask, a2a.TextPart{Text: "Hi!"}) + msg = a2a.NewMessageForTask(a2a.MessageRoleUser, existingTask, a2a.NewTextPart("Hi!")) } - res, err := handler.OnSendMessage(ctx, &a2a.MessageSendParams{Message: msg}) + res, err := handler.SendMessage(ctx, &a2a.SendMessageRequest{Message: msg}) if err != nil { - t.Fatalf("OnSendMessage() error = %v", err) + t.Fatalf("SendMessage() error = %v", err) } task, ok := res.(*a2a.Task) if !ok { - t.Fatalf("OnSendMessage() returned %T, want *a2a.Task", res) + t.Fatalf("SendMessage() returned %T, want *a2a.Task", res) } existingTask = task @@ -870,12 +847,12 @@ func TestRequestHandler_OnSendMessage_TaskVersion(t *testing.T) { } -func TestRequestHandler_OnSendMessage_AgentExecutorPanicFailsTask(t *testing.T) { +func TestRequestHandler_SendMessage_AgentExecutorPanicFailsTask(t *testing.T) { ctx := t.Context() taskSeed := &a2a.Task{ID: a2a.NewTaskID(), ContextID: a2a.NewContextID()} store := testutil.NewTestTaskStore().WithTasks(t, taskSeed) - input := &a2a.MessageSendParams{Message: newUserMessage(taskSeed, "work")} + input := &a2a.SendMessageRequest{Message: newUserMessage(taskSeed, "work")} executor := &mockAgentExecutor{ ExecuteFunc: func(ctx context.Context, execCtx *ExecutorContext) iter.Seq2[a2a.Event, error] { @@ -886,20 +863,20 @@ func TestRequestHandler_OnSendMessage_AgentExecutorPanicFailsTask(t *testing.T) } handler := NewHandler(executor, WithTaskStore(store)) - result, err := handler.OnSendMessage(ctx, input) + result, err := handler.SendMessage(ctx, input) if err != nil { - t.Fatalf("OnSendMessage() error = %v", err) + t.Fatalf("SendMessage() error = %v", err) } task, ok := result.(*a2a.Task) if !ok { - t.Fatalf("OnSendMessage() result type = %T, want *a2a.Task", result) + t.Fatalf("SendMessage() result type = %T, want *a2a.Task", result) } if task.Status.State != a2a.TaskStateFailed { - t.Fatalf("OnSendMessage() result = %+v, want state %q", result, a2a.TaskStateFailed) + t.Fatalf("SendMessage() result = %+v, want state %q", result, a2a.TaskStateFailed) } } -func TestRequestHandler_OnGetAgentCard(t *testing.T) { +func TestRequestHandler_GetAgentCard(t *testing.T) { card := &a2a.AgentCard{Name: "agent"} tests := []struct { @@ -942,7 +919,7 @@ func TestRequestHandler_OnGetAgentCard(t *testing.T) { } handler := newTestHandler(options...) - result, gotErr := handler.OnGetExtendedAgentCard(ctx) + result, gotErr := handler.GetExtendedAgentCard(ctx) if tt.wantErr == nil { if gotErr != nil { @@ -965,54 +942,54 @@ func TestRequestHandler_OnGetAgentCard(t *testing.T) { } } -func TestRequestHandler_OnSendMessage_QueueCreationFails(t *testing.T) { +func TestRequestHandler_SendMessage_QueueCreationFails(t *testing.T) { ctx := t.Context() wantErr := errors.New("failed to create a queue") qm := testutil.NewTestQueueManager().SetError(wantErr) handler := newTestHandler(WithEventQueueManager(qm)) - result, err := handler.OnSendMessage(ctx, &a2a.MessageSendParams{ - Message: a2a.NewMessage(a2a.MessageRoleAgent, a2a.TextPart{Text: "Work"}), + result, err := handler.SendMessage(ctx, &a2a.SendMessageRequest{ + Message: a2a.NewMessage(a2a.MessageRoleAgent, a2a.NewTextPart("Work")), }) if result != nil || !errors.Is(err, wantErr) { - t.Fatalf("handler.OnSendMessage() = (%v, %v), want error %v", result, err, wantErr) + t.Fatalf("handler.SendMessage() = (%v, %v), want error %v", result, err, wantErr) } } -func TestRequestHandler_OnSendMessage_QueueReadFails(t *testing.T) { +func TestRequestHandler_SendMessage_QueueReadFails(t *testing.T) { ctx := t.Context() wantErr := errors.New("Read() failed") queue := testutil.NewTestEventQueue().SetReadOverride(nil, wantErr) qm := testutil.NewTestQueueManager().SetQueue(queue) handler := newTestHandler(WithEventQueueManager(qm)) - result, err := handler.OnSendMessage(ctx, &a2a.MessageSendParams{ - Message: a2a.NewMessage(a2a.MessageRoleUser, a2a.TextPart{Text: "Work"}), + result, err := handler.SendMessage(ctx, &a2a.SendMessageRequest{ + Message: a2a.NewMessage(a2a.MessageRoleUser, a2a.NewTextPart("Work")), }) if result != nil || !errors.Is(err, wantErr) { - t.Fatalf("handler.OnSendMessage() = (%v, %v), want error %v", result, err, wantErr) + t.Fatalf("handler.SendMessage() = (%v, %v), want error %v", result, err, wantErr) } } -func TestRequestHandler_OnSendMessage_RelatedTaskLoading(t *testing.T) { +func TestRequestHandler_SendMessage_RelatedTaskLoading(t *testing.T) { existingTask := &a2a.Task{ID: a2a.NewTaskID(), ContextID: a2a.NewContextID()} ctx := t.Context() ts := testutil.NewTestTaskStore().WithTasks(t, existingTask) - executor := newEventReplayAgent([]a2a.Event{a2a.NewMessage(a2a.MessageRoleAgent, a2a.TextPart{Text: "Hello!"})}, nil) + executor := newEventReplayAgent([]a2a.Event{a2a.NewMessage(a2a.MessageRoleAgent, a2a.NewTextPart("Hello!"))}, nil) handler := NewHandler(executor, WithExecutorContextInterceptor(&ReferencedTasksLoader{Store: ts})) - request := &a2a.MessageSendParams{ + request := &a2a.SendMessageRequest{ Message: &a2a.Message{ ID: a2a.NewMessageID(), - Parts: a2a.ContentParts{a2a.TextPart{Text: "Work"}}, + Parts: a2a.ContentParts{a2a.NewTextPart("Work")}, Role: a2a.MessageRoleUser, ReferenceTasks: []a2a.TaskID{a2a.NewTaskID(), existingTask.ID}, }, } - _, err := handler.OnSendMessage(ctx, request) + _, err := handler.SendMessage(ctx, request) if err != nil { - t.Fatalf("handler.OnSendMessage() failed: %v", err) + t.Fatalf("handler.SendMessage() failed: %v", err) } capturedExecContext := executor.capturedExecContext @@ -1021,21 +998,21 @@ func TestRequestHandler_OnSendMessage_RelatedTaskLoading(t *testing.T) { } } -func TestRequestHandler_OnSendMessage_AgentExecutionFails(t *testing.T) { +func TestRequestHandler_SendMessage_AgentExecutionFails(t *testing.T) { ctx := t.Context() wantErr := errors.New("failed to create a queue") executor := newEventReplayAgent([]a2a.Event{}, wantErr) handler := NewHandler(executor) - result, err := handler.OnSendMessage(ctx, &a2a.MessageSendParams{ - Message: a2a.NewMessage(a2a.MessageRoleUser, a2a.TextPart{Text: "Work"}), + result, err := handler.SendMessage(ctx, &a2a.SendMessageRequest{ + Message: a2a.NewMessage(a2a.MessageRoleUser, a2a.NewTextPart("Work")), }) if result != nil || !errors.Is(err, wantErr) { - t.Fatalf("handler.OnSendMessage() = (%v, %v), want error %v", result, err, wantErr) + t.Fatalf("handler.SendMessage() = (%v, %v), want error %v", result, err, wantErr) } } -func TestRequestHandler_OnSendMessage_NoTaskCreated(t *testing.T) { +func TestRequestHandler_SendMessage_NoTaskCreated(t *testing.T) { ctx := t.Context() getCalled := 0 savedCalled := 0 @@ -1052,25 +1029,25 @@ func TestRequestHandler_OnSendMessage_NoTaskCreated(t *testing.T) { executor := newEventReplayAgent([]a2a.Event{newAgentMessage("hello")}, nil) handler := NewHandler(executor, WithTaskStore(mockStore)) - result, gotErr := handler.OnSendMessage(ctx, &a2a.MessageSendParams{ - Message: a2a.NewMessage(a2a.MessageRoleUser, a2a.TextPart{Text: "Work"}), + result, gotErr := handler.SendMessage(ctx, &a2a.SendMessageRequest{ + Message: a2a.NewMessage(a2a.MessageRoleUser, a2a.NewTextPart("Work")), }) if gotErr != nil { - t.Fatalf("OnSendMessage() error = %v, wantErr nil", gotErr) + t.Fatalf("SendMessage() error = %v, wantErr nil", gotErr) } if _, ok := result.(*a2a.Message); !ok { - t.Fatalf("OnSendMessage() = %v, want a2a.Message", result) + t.Fatalf("SendMessage() = %v, want a2a.Message", result) } if getCalled != 1 { - t.Fatalf("OnSendMessage() TaskStore.Get called %d times, want 1", getCalled) + t.Fatalf("SendMessage() TaskStore.Get called %d times, want 1", getCalled) } if savedCalled > 0 { - t.Fatalf("OnSendMessage() TaskStore.Save called %d times, want 0", savedCalled) + t.Fatalf("SendMessage() TaskStore.Save called %d times, want 0", savedCalled) } } -func TestRequestHandler_OnSendMessage_NewTaskHistory(t *testing.T) { +func TestRequestHandler_SendMessage_NewTaskHistory(t *testing.T) { ctx := t.Context() ts := taskstore.NewInMemory(nil) executor := &mockAgentExecutor{ @@ -1084,21 +1061,21 @@ func TestRequestHandler_OnSendMessage_NewTaskHistory(t *testing.T) { } handler := NewHandler(executor, WithTaskStore(ts)) - msg := a2a.NewMessage(a2a.MessageRoleUser, a2a.TextPart{Text: "Complete the task!"}) - result, gotErr := handler.OnSendMessage(ctx, &a2a.MessageSendParams{Message: msg}) + msg := a2a.NewMessage(a2a.MessageRoleUser, a2a.NewTextPart("Complete the task!")) + result, gotErr := handler.SendMessage(ctx, &a2a.SendMessageRequest{Message: msg}) if gotErr != nil { - t.Fatalf("OnSendMessage() error = %v, wantErr nil", gotErr) + t.Fatalf("SendMessage() error = %v, wantErr nil", gotErr) } if task, ok := result.(*a2a.Task); ok { if diff := cmp.Diff([]*a2a.Message{msg}, task.History); diff != "" { - t.Fatalf("OnSendMessage() wrong result (+got,-want):\ngot = %v\nwant = %v\ndiff = %s", task.History, []*a2a.Message{msg}, diff) + t.Fatalf("SendMessage() wrong result (+got,-want):\ngot = %v\nwant = %v\ndiff = %s", task.History, []*a2a.Message{msg}, diff) } } else { - t.Fatalf("OnSendMessage() = %v, want a2a.Task", result) + t.Fatalf("SendMessage() = %v, want a2a.Task", result) } } -func TestRequestHandler_OnGetTask(t *testing.T) { +func TestRequestHandler_GetTask(t *testing.T) { ptr := func(i int) *int { return &i } @@ -1108,43 +1085,43 @@ func TestRequestHandler_OnGetTask(t *testing.T) { tests := []struct { name string - query *a2a.TaskQueryParams + query *a2a.GetTaskRequest want *a2a.Task wantErr error }{ { name: "success with TaskID and full history", - query: &a2a.TaskQueryParams{ID: existingTaskID}, + query: &a2a.GetTaskRequest{ID: existingTaskID}, want: &a2a.Task{ID: existingTaskID, History: history}, }, { name: "missing TaskID", - query: &a2a.TaskQueryParams{ID: ""}, + query: &a2a.GetTaskRequest{ID: ""}, wantErr: fmt.Errorf("missing TaskID: %w", a2a.ErrInvalidParams), }, { name: "task not found", - query: &a2a.TaskQueryParams{ID: a2a.NewTaskID()}, + query: &a2a.GetTaskRequest{ID: a2a.NewTaskID()}, wantErr: fmt.Errorf("failed to get task: %w", a2a.ErrTaskNotFound), }, { name: "get task with limited HistoryLength", - query: &a2a.TaskQueryParams{ID: existingTaskID, HistoryLength: ptr(len(history) - 1)}, + query: &a2a.GetTaskRequest{ID: existingTaskID, HistoryLength: ptr(len(history) - 1)}, want: &a2a.Task{ID: existingTaskID, History: history[1:]}, }, { name: "get task with larger than available HistoryLength", - query: &a2a.TaskQueryParams{ID: existingTaskID, HistoryLength: ptr(len(history) + 1)}, + query: &a2a.GetTaskRequest{ID: existingTaskID, HistoryLength: ptr(len(history) + 1)}, want: &a2a.Task{ID: existingTaskID, History: history}, }, { name: "get task with zero HistoryLength", - query: &a2a.TaskQueryParams{ID: existingTaskID, HistoryLength: ptr(0)}, + query: &a2a.GetTaskRequest{ID: existingTaskID, HistoryLength: ptr(0)}, want: &a2a.Task{ID: existingTaskID, History: make([]*a2a.Message, 0)}, }, { name: "get task with negative HistoryLength", - query: &a2a.TaskQueryParams{ID: existingTaskID, HistoryLength: ptr(-1)}, + query: &a2a.GetTaskRequest{ID: existingTaskID, HistoryLength: ptr(-1)}, want: &a2a.Task{ID: existingTaskID, History: make([]*a2a.Message, 0)}, }, } @@ -1154,7 +1131,7 @@ func TestRequestHandler_OnGetTask(t *testing.T) { ctx := t.Context() ts := testutil.NewTestTaskStore().WithTasks(t, &a2a.Task{ID: existingTaskID, History: history}) handler := newTestHandler(WithTaskStore(ts)) - result, err := handler.OnGetTask(ctx, tt.query) + result, err := handler.GetTask(ctx, tt.query) if tt.wantErr == nil { if err != nil { t.Errorf("OnGetTask() error = %v, wantErr nil", err) @@ -1176,19 +1153,19 @@ func TestRequestHandler_OnGetTask(t *testing.T) { } } -func TestRequestHandler_OnGetTask_StoreGetFails(t *testing.T) { +func TestRequestHandler_GetTask_StoreGetFails(t *testing.T) { ctx := t.Context() wantErr := errors.New("failed to get task: store get failed") ts := testutil.NewTestTaskStore().SetGetOverride(nil, wantErr) handler := newTestHandler(WithTaskStore(ts)) - result, err := handler.OnGetTask(ctx, &a2a.TaskQueryParams{ID: a2a.NewTaskID()}) + result, err := handler.GetTask(ctx, &a2a.GetTaskRequest{ID: a2a.NewTaskID()}) if result != nil || !errors.Is(err, wantErr) { t.Fatalf("OnGetTask() = (%v, %v), want error %v", result, err, wantErr) } } -func TestRequestHandler_OnListTasks(t *testing.T) { +func TestRequestHandler_ListTasks(t *testing.T) { id1, id2, id3 := a2a.NewTaskID(), a2a.NewTaskID(), a2a.NewTaskID() startTime := time.Date(2025, time.December, 11, 14, 54, 0, 0, time.UTC) cutOffTime := startTime.Add(2 * time.Second) @@ -1215,7 +1192,7 @@ func TestRequestHandler_OnListTasks(t *testing.T) { {ID: id2, ContextID: "context2", History: []*a2a.Message{{ID: "test-message-4"}, {ID: "test-message-5"}}, Status: a2a.TaskStatus{State: a2a.TaskStateCanceled}}, {ID: id3, Artifacts: []*a2a.Artifact{{Name: "artifact3"}}, ContextID: "context1", History: []*a2a.Message{{ID: "test-message-6"}, {ID: "test-message-7"}, {ID: "test-message-8"}, {ID: "test-message-9"}}, Status: a2a.TaskStatus{State: a2a.TaskStateCompleted}}, }, - request: &a2a.ListTasksRequest{PageSize: 2, ContextID: "context1", Status: a2a.TaskStateCompleted, HistoryLength: 2, LastUpdatedAfter: &cutOffTime, IncludeArtifacts: true}, + request: &a2a.ListTasksRequest{PageSize: 2, ContextID: "context1", Status: a2a.TaskStateCompleted, HistoryLength: 2, StatusTimestampAfter: &cutOffTime, IncludeArtifacts: true}, wantResponse: &a2a.ListTasksResponse{Tasks: []*a2a.Task{{ID: id3, Artifacts: []*a2a.Artifact{{Name: "artifact3"}}, ContextID: "context1", History: []*a2a.Message{{ID: "test-message-8"}, {ID: "test-message-9"}}, Status: a2a.TaskStatus{State: a2a.TaskStateCompleted}}}, TotalSize: 1, PageSize: 2}, storeConfig: &taskstore.InMemoryStoreConfig{Authenticator: testAuthenticator()}, }, @@ -1239,7 +1216,7 @@ func TestRequestHandler_OnListTasks(t *testing.T) { ctx := t.Context() ts := testutil.NewTestTaskStoreWithConfig(tt.storeConfig).WithTasks(t, tt.givenTasks...) handler := newTestHandler(WithTaskStore(ts)) - result, err := handler.OnListTasks(ctx, tt.request) + result, err := handler.ListTasks(ctx, tt.request) if tt.wantErr == nil { if err != nil { @@ -1268,20 +1245,20 @@ func testAuthenticator() func(context.Context) (string, error) { } } -func TestRequestHandler_OnCancelTask(t *testing.T) { +func TestRequestHandler_CancelTask(t *testing.T) { taskToCancel := &a2a.Task{ID: a2a.NewTaskID(), ContextID: a2a.NewContextID(), Status: a2a.TaskStatus{State: a2a.TaskStateWorking}} completedTask := &a2a.Task{ID: a2a.NewTaskID(), ContextID: a2a.NewContextID(), Status: a2a.TaskStatus{State: a2a.TaskStateCompleted}} canceledTask := &a2a.Task{ID: a2a.NewTaskID(), ContextID: a2a.NewContextID(), Status: a2a.TaskStatus{State: a2a.TaskStateCanceled}} tests := []struct { name string - params *a2a.TaskIDParams + params *a2a.CancelTaskRequest want *a2a.Task wantErr error }{ { name: "success", - params: &a2a.TaskIDParams{ID: taskToCancel.ID}, + params: &a2a.CancelTaskRequest{ID: taskToCancel.ID}, want: newTaskWithStatus(taskToCancel, a2a.TaskStateCanceled, "Cancelled"), }, { @@ -1291,17 +1268,17 @@ func TestRequestHandler_OnCancelTask(t *testing.T) { }, { name: "task not found", - params: &a2a.TaskIDParams{ID: a2a.NewTaskID()}, + params: &a2a.CancelTaskRequest{ID: a2a.NewTaskID()}, wantErr: fmt.Errorf("failed to cancel: cancelation failed: setup failed: failed to load a task: %w", a2a.ErrTaskNotFound), }, { name: "task already completed", - params: &a2a.TaskIDParams{ID: completedTask.ID}, + params: &a2a.CancelTaskRequest{ID: completedTask.ID}, wantErr: fmt.Errorf("failed to cancel: cancelation failed: setup failed: task in non-cancelable state %s: %w", a2a.TaskStateCompleted, a2a.ErrTaskNotCancelable), }, { name: "task already canceled", - params: &a2a.TaskIDParams{ID: canceledTask.ID}, + params: &a2a.CancelTaskRequest{ID: canceledTask.ID}, want: canceledTask, }, } @@ -1320,7 +1297,7 @@ func TestRequestHandler_OnCancelTask(t *testing.T) { } handler := NewHandler(executor, WithTaskStore(store)) - result, err := handler.OnCancelTask(ctx, tt.params) + result, err := handler.CancelTask(ctx, tt.params) if tt.wantErr == nil { if err != nil { t.Errorf("OnCancelTask() error = %v, wantErr nil", err) @@ -1342,7 +1319,7 @@ func TestRequestHandler_OnCancelTask(t *testing.T) { } } -func TestRequestHandler_OnResubscribeToTask_Success(t *testing.T) { +func TestRequestHandler_ResubscribeToTask_Success(t *testing.T) { ctx := t.Context() taskSeed := &a2a.Task{ID: a2a.NewTaskID(), ContextID: a2a.NewContextID()} wantEvents := []a2a.Event{ @@ -1369,7 +1346,7 @@ func TestRequestHandler_OnResubscribeToTask_Success(t *testing.T) { } go func() { - for range handler.OnSendMessageStream(ctx, &a2a.MessageSendParams{ + for range handler.SendStreamingMessage(ctx, &a2a.SendMessageRequest{ Message: newUserMessage(taskSeed, "Work"), }) { // Events have to be consumed to prevent a deadlock. @@ -1378,7 +1355,7 @@ func TestRequestHandler_OnResubscribeToTask_Success(t *testing.T) { <-executionStarted - seq := handler.OnResubscribeToTask(ctx, &a2a.TaskIDParams{ID: taskSeed.ID}) + seq := handler.SubscribeToTask(ctx, &a2a.SubscribeToTaskRequest{ID: taskSeed.ID}) gotEvents, err := collectEvents(seq) if err != nil { t.Fatalf("collectEvents() failed: %v", err) @@ -1389,21 +1366,21 @@ func TestRequestHandler_OnResubscribeToTask_Success(t *testing.T) { } } -func TestRequestHandler_OnResubscribeToTask_NotFound(t *testing.T) { +func TestRequestHandler_ResubscribeToTask_NotFound(t *testing.T) { ctx := t.Context() taskID := a2a.NewTaskID() wantErr := a2a.ErrTaskNotFound executor := &mockAgentExecutor{} handler := NewHandler(executor) - result, err := collectEvents(handler.OnResubscribeToTask(ctx, &a2a.TaskIDParams{ID: taskID})) + result, err := collectEvents(handler.SubscribeToTask(ctx, &a2a.SubscribeToTaskRequest{ID: taskID})) if result != nil || !errors.Is(err, wantErr) { t.Fatalf("OnResubscribeToTask() = (%v, %v), want error %v", result, err, wantErr) } } -func TestRequestHandler_OnCancelTask_AgentCancelFails(t *testing.T) { +func TestRequestHandler_CancelTask_AgentCancelFails(t *testing.T) { ctx := t.Context() taskToCancel := &a2a.Task{ID: a2a.NewTaskID(), ContextID: a2a.NewContextID(), Status: a2a.TaskStatus{State: a2a.TaskStateWorking}} wantErr := fmt.Errorf("failed to cancel: cancelation failed: agent cancel error") @@ -1415,7 +1392,7 @@ func TestRequestHandler_OnCancelTask_AgentCancelFails(t *testing.T) { } handler := NewHandler(executor, WithTaskStore(store)) - result, err := handler.OnCancelTask(ctx, &a2a.TaskIDParams{ID: taskToCancel.ID}) + result, err := handler.CancelTask(ctx, &a2a.CancelTaskRequest{ID: taskToCancel.ID}) if result != nil || err.Error() != wantErr.Error() { t.Fatalf("OnCancelTask() error = %v, wantErr %v", err, wantErr) } @@ -1423,7 +1400,7 @@ func TestRequestHandler_OnCancelTask_AgentCancelFails(t *testing.T) { func TestRequestHandler_MultipleRequestContextInterceptors(t *testing.T) { ctx := t.Context() - executor := newEventReplayAgent([]a2a.Event{a2a.NewMessage(a2a.MessageRoleAgent, a2a.TextPart{Text: "Hello!"})}, nil) + executor := newEventReplayAgent([]a2a.Event{a2a.NewMessage(a2a.MessageRoleAgent, a2a.NewTextPart("Hello!"))}, nil) type key1Type struct{} key1, val1 := key1Type{}, 2 interceptor1 := interceptExecCtxFn(func(ctx context.Context, execCtx *ExecutorContext) (context.Context, error) { @@ -1440,11 +1417,11 @@ func TestRequestHandler_MultipleRequestContextInterceptors(t *testing.T) { WithExecutorContextInterceptor(interceptor2), ) - _, err := handler.OnSendMessage(ctx, &a2a.MessageSendParams{ - Message: a2a.NewMessage(a2a.MessageRoleUser, a2a.TextPart{Text: "Work"}), + _, err := handler.SendMessage(ctx, &a2a.SendMessageRequest{ + Message: a2a.NewMessage(a2a.MessageRoleUser, a2a.NewTextPart("Work")), }) if err != nil { - t.Fatalf("handler.OnSendMessage() failed: %v", err) + t.Fatalf("handler.SendMessage() failed: %v", err) } capturedContext := executor.capturedContext @@ -1455,19 +1432,19 @@ func TestRequestHandler_MultipleRequestContextInterceptors(t *testing.T) { func TestRequestHandler_RequestContextInterceptorRejectsRequest(t *testing.T) { ctx := t.Context() - executor := newEventReplayAgent([]a2a.Event{a2a.NewMessage(a2a.MessageRoleAgent, a2a.TextPart{Text: "Hello!"})}, nil) + executor := newEventReplayAgent([]a2a.Event{a2a.NewMessage(a2a.MessageRoleAgent, a2a.NewTextPart("Hello!"))}, nil) wantErr := errors.New("rejected") interceptor := interceptExecCtxFn(func(ctx context.Context, execCtx *ExecutorContext) (context.Context, error) { return ctx, wantErr }) handler := NewHandler(executor, WithExecutorContextInterceptor(interceptor)) - _, err := handler.OnSendMessage(ctx, &a2a.MessageSendParams{ - Message: a2a.NewMessage(a2a.MessageRoleUser, a2a.TextPart{Text: "Work"}), + _, err := handler.SendMessage(ctx, &a2a.SendMessageRequest{ + Message: a2a.NewMessage(a2a.MessageRoleUser, a2a.NewTextPart("Work")), }) if !errors.Is(err, wantErr) { - t.Fatalf("handler.OnSendMessage() error = %v, want %v", err, wantErr) + t.Fatalf("handler.SendMessage() error = %v, want %v", err, wantErr) } if executor.executeCalled { t.Fatal("want agent executor to no be called") @@ -1483,17 +1460,17 @@ func TestRequestHandler_ExecuteRequestContextLoading(t *testing.T) { } testCases := []struct { name string - newRequest func() *a2a.MessageSendParams + newRequest func() *a2a.SendMessageRequest wantExecCtxMeta map[string]any wantStoredTask *a2a.Task wantContextID string }{ { name: "new task", - newRequest: func() *a2a.MessageSendParams { - msg := a2a.NewMessage(a2a.MessageRoleUser, a2a.TextPart{Text: "Hello"}) + newRequest: func() *a2a.SendMessageRequest { + msg := a2a.NewMessage(a2a.MessageRoleUser, a2a.NewTextPart("Hello")) msg.Metadata = map[string]any{"foo1": "bar1"} - return &a2a.MessageSendParams{ + return &a2a.SendMessageRequest{ Message: msg, Metadata: map[string]any{"foo2": "bar2"}, } @@ -1502,10 +1479,10 @@ func TestRequestHandler_ExecuteRequestContextLoading(t *testing.T) { }, { name: "stored tasks", - newRequest: func() *a2a.MessageSendParams { - msg := a2a.NewMessageForTask(a2a.MessageRoleUser, taskSeed, a2a.TextPart{Text: "Hello"}) + newRequest: func() *a2a.SendMessageRequest { + msg := a2a.NewMessageForTask(a2a.MessageRoleUser, taskSeed, a2a.NewTextPart("Hello")) msg.Metadata = map[string]any{"foo1": "bar1"} - return &a2a.MessageSendParams{ + return &a2a.SendMessageRequest{ Message: msg, Metadata: map[string]any{"foo2": "bar2"}, } @@ -1516,10 +1493,10 @@ func TestRequestHandler_ExecuteRequestContextLoading(t *testing.T) { }, { name: "preserve message context", - newRequest: func() *a2a.MessageSendParams { - msg := a2a.NewMessage(a2a.MessageRoleUser, a2a.TextPart{Text: "Hello"}) + newRequest: func() *a2a.SendMessageRequest { + msg := a2a.NewMessage(a2a.MessageRoleUser, a2a.NewTextPart("Hello")) msg.ContextID = ctxID - return &a2a.MessageSendParams{Message: msg} + return &a2a.SendMessageRequest{Message: msg} }, wantContextID: ctxID, }, @@ -1529,7 +1506,7 @@ func TestRequestHandler_ExecuteRequestContextLoading(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() ctx := t.Context() - executor := newEventReplayAgent([]a2a.Event{a2a.NewMessage(a2a.MessageRoleAgent, a2a.TextPart{Text: "Done!"})}, nil) + executor := newEventReplayAgent([]a2a.Event{a2a.NewMessage(a2a.MessageRoleAgent, a2a.NewTextPart("Done!"))}, nil) var gotExecCtx *ExecutorContext handler := NewHandler( executor, @@ -1540,9 +1517,9 @@ func TestRequestHandler_ExecuteRequestContextLoading(t *testing.T) { })), ) request := tc.newRequest() - _, err := handler.OnSendMessage(ctx, request) + _, err := handler.SendMessage(ctx, request) if err != nil { - t.Fatalf("handler.OnSendMessage() error = %v, want nil", err) + t.Fatalf("handler.SendMessage() error = %v, want nil", err) } opts := []cmp.Option{cmpopts.IgnoreFields(a2a.Task{}, "History")} if diff := cmp.Diff(tc.wantStoredTask, gotExecCtx.StoredTask, opts...); diff != "" { @@ -1560,34 +1537,35 @@ func TestRequestHandler_ExecuteRequestContextLoading(t *testing.T) { } } -func TestRequestHandler_OnSetTaskPushConfig(t *testing.T) { +func TestRequestHandler_SetTaskPushConfig(t *testing.T) { ctx := t.Context() taskID := a2a.TaskID("test-task") testCases := []struct { name string - params *a2a.TaskPushConfig + params *a2a.CreateTaskPushConfigRequest wantErr error }{ { name: "valid config with id", - params: &a2a.TaskPushConfig{ + params: &a2a.CreateTaskPushConfigRequest{ TaskID: taskID, Config: a2a.PushConfig{ID: "config-1", URL: "https://example.com/push"}, }, }, { name: "valid config without id", - params: &a2a.TaskPushConfig{ + params: &a2a.CreateTaskPushConfigRequest{ TaskID: taskID, Config: a2a.PushConfig{URL: "https://example.com/push-no-id"}, }, }, { name: "invalid config - empty URL", - params: &a2a.TaskPushConfig{ - TaskID: taskID, - Config: a2a.PushConfig{ID: "config-invalid"}, + params: &a2a.CreateTaskPushConfigRequest{ + TaskID: taskID, + ConfigID: "config-invalid", + Config: a2a.PushConfig{ID: "config-invalid"}, }, wantErr: fmt.Errorf("failed to save push config: %w: push config endpoint cannot be empty", a2a.ErrInvalidParams), }, @@ -1598,7 +1576,7 @@ func TestRequestHandler_OnSetTaskPushConfig(t *testing.T) { ps := testutil.NewTestPushConfigStore() pn := testutil.NewTestPushSender(t) handler := newTestHandler(WithPushNotifications(ps, pn)) - got, err := handler.OnSetTaskPushConfig(ctx, tc.params) + got, err := handler.CreateTaskPushConfig(ctx, tc.params) if tc.wantErr != nil { if err == nil || err.Error() != tc.wantErr.Error() { @@ -1621,14 +1599,25 @@ func TestRequestHandler_OnSetTaskPushConfig(t *testing.T) { got.Config.ID = "" } - if diff := cmp.Diff(tc.params, got); diff != "" { + want := &a2a.TaskPushConfig{ + ID: tc.params.Config.ID, + Config: tc.params.Config, + TaskID: tc.params.TaskID, + } + if want.ID == "" { + want.ID = got.ID + } + if want.Config.ID == "" { + want.Config.ID = got.Config.ID + } + if diff := cmp.Diff(want, got); diff != "" { t.Errorf("OnSetTaskPushConfig() mismatch (-want +got):\n%s", diff) } }) } } -func TestRequestHandler_OnGetTaskPushConfig(t *testing.T) { +func TestRequestHandler_GetTaskPushConfig(t *testing.T) { ctx := t.Context() taskID := a2a.TaskID("test-task") config1 := &a2a.PushConfig{ID: "config-1", URL: "https://example.com/push1"} @@ -1638,30 +1627,30 @@ func TestRequestHandler_OnGetTaskPushConfig(t *testing.T) { testCases := []struct { name string - params *a2a.GetTaskPushConfigParams + params *a2a.GetTaskPushConfigRequest want *a2a.TaskPushConfig wantErr error }{ { name: "success", - params: &a2a.GetTaskPushConfigParams{TaskID: taskID, ConfigID: config1.ID}, + params: &a2a.GetTaskPushConfigRequest{TaskID: taskID, ID: config1.ID}, want: &a2a.TaskPushConfig{TaskID: taskID, Config: *config1}, }, { name: "non-existent config", - params: &a2a.GetTaskPushConfigParams{TaskID: taskID, ConfigID: "non-existent"}, + params: &a2a.GetTaskPushConfigRequest{TaskID: taskID, ID: "non-existent"}, wantErr: push.ErrPushConfigNotFound, }, { name: "non-existent task", - params: &a2a.GetTaskPushConfigParams{TaskID: "non-existent-task", ConfigID: config1.ID}, + params: &a2a.GetTaskPushConfigRequest{TaskID: "non-existent-task", ID: config1.ID}, wantErr: push.ErrPushConfigNotFound, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - got, err := handler.OnGetTaskPushConfig(ctx, tc.params) + got, err := handler.GetTaskPushConfig(ctx, tc.params) if !errors.Is(err, tc.wantErr) { t.Errorf("OnGetTaskPushConfig() error = %v, want %v", err, tc.wantErr) return @@ -1675,7 +1664,7 @@ func TestRequestHandler_OnGetTaskPushConfig(t *testing.T) { } } -func TestRequestHandler_OnListTaskPushConfig(t *testing.T) { +func TestRequestHandler_ListTaskPushConfig(t *testing.T) { ctx := t.Context() taskID := a2a.TaskID("test-task") config1 := a2a.PushConfig{ID: "config-1", URL: "https://example.com/push1"} @@ -1695,12 +1684,12 @@ func TestRequestHandler_OnListTaskPushConfig(t *testing.T) { testCases := []struct { name string - params *a2a.ListTaskPushConfigParams + params *a2a.ListTaskPushConfigRequest want []*a2a.TaskPushConfig }{ { name: "list existing", - params: &a2a.ListTaskPushConfigParams{TaskID: taskID}, + params: &a2a.ListTaskPushConfigRequest{TaskID: taskID}, want: []*a2a.TaskPushConfig{ {TaskID: taskID, Config: config1}, {TaskID: taskID, Config: config2}, @@ -1708,19 +1697,19 @@ func TestRequestHandler_OnListTaskPushConfig(t *testing.T) { }, { name: "list with empty task", - params: &a2a.ListTaskPushConfigParams{TaskID: emptyTaskID}, + params: &a2a.ListTaskPushConfigRequest{TaskID: emptyTaskID}, want: []*a2a.TaskPushConfig{}, }, { name: "list non-existent task", - params: &a2a.ListTaskPushConfigParams{TaskID: "non-existent-task"}, + params: &a2a.ListTaskPushConfigRequest{TaskID: "non-existent-task"}, want: []*a2a.TaskPushConfig{}, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - got, err := handler.OnListTaskPushConfig(ctx, tc.params) + got, err := handler.ListTaskPushConfig(ctx, tc.params) if err != nil { t.Errorf("OnListTaskPushConfig() failed: %v", err) return @@ -1733,7 +1722,7 @@ func TestRequestHandler_OnListTaskPushConfig(t *testing.T) { } } -func TestRequestHandler_OnDeleteTaskPushConfig(t *testing.T) { +func TestRequestHandler_DeleteTaskPushConfig(t *testing.T) { ctx := t.Context() taskID := a2a.TaskID("test-task") config1 := a2a.PushConfig{ID: "config-1", URL: "https://example.com/push1"} @@ -1741,17 +1730,17 @@ func TestRequestHandler_OnDeleteTaskPushConfig(t *testing.T) { testCases := []struct { name string - params *a2a.DeleteTaskPushConfigParams + params *a2a.DeleteTaskPushConfigRequest wantRemain []*a2a.TaskPushConfig }{ { name: "delete existing", - params: &a2a.DeleteTaskPushConfigParams{TaskID: taskID, ConfigID: config1.ID}, + params: &a2a.DeleteTaskPushConfigRequest{TaskID: taskID, ID: config1.ID}, wantRemain: []*a2a.TaskPushConfig{{TaskID: taskID, Config: config2}}, }, { name: "delete non-existent config", - params: &a2a.DeleteTaskPushConfigParams{TaskID: taskID, ConfigID: "non-existent"}, + params: &a2a.DeleteTaskPushConfigRequest{TaskID: taskID, ID: "non-existent"}, wantRemain: []*a2a.TaskPushConfig{ {TaskID: taskID, Config: config1}, {TaskID: taskID, Config: config2}, @@ -1759,7 +1748,7 @@ func TestRequestHandler_OnDeleteTaskPushConfig(t *testing.T) { }, { name: "delete from non-existent task", - params: &a2a.DeleteTaskPushConfigParams{TaskID: "non-existent-task", ConfigID: config1.ID}, + params: &a2a.DeleteTaskPushConfigRequest{TaskID: "non-existent-task", ID: config1.ID}, wantRemain: []*a2a.TaskPushConfig{ {TaskID: taskID, Config: config1}, {TaskID: taskID, Config: config2}, @@ -1772,13 +1761,13 @@ func TestRequestHandler_OnDeleteTaskPushConfig(t *testing.T) { ps := testutil.NewTestPushConfigStore().WithConfigs(t, taskID, &config1, &config2) pn := testutil.NewTestPushSender(t) handler := newTestHandler(WithPushNotifications(ps, pn)) - err := handler.OnDeleteTaskPushConfig(ctx, tc.params) + err := handler.DeleteTaskPushConfig(ctx, tc.params) if err != nil { t.Errorf("OnDeleteTaskPushConfig() failed: %v", err) return } - got, err := handler.OnListTaskPushConfig(ctx, &a2a.ListTaskPushConfigParams{TaskID: taskID}) + got, err := handler.ListTaskPushConfig(ctx, &a2a.ListTaskPushConfigRequest{TaskID: taskID}) if err != nil { t.Errorf("OnListTaskPushConfig() for verification failed: %v", err) return @@ -1852,16 +1841,15 @@ func newTestHandler(opts ...RequestHandlerOption) RequestHandler { } func newAgentMessage(text string) *a2a.Message { - return &a2a.Message{ID: "message-id", Parts: []a2a.Part{a2a.TextPart{Text: text}}, Role: a2a.MessageRoleAgent} + m := a2a.NewMessage(a2a.MessageRoleAgent, a2a.NewTextPart(text)) + m.ID = "message-id" + return m } func newUserMessage(task *a2a.Task, text string) *a2a.Message { - return &a2a.Message{ - ID: "message-id", - Parts: []a2a.Part{a2a.TextPart{Text: text}}, - Role: a2a.MessageRoleUser, - TaskID: task.ID, - } + m := a2a.NewMessageForTask(a2a.MessageRoleUser, task, a2a.NewTextPart(text)) + m.ID = "message-id" + return m } func newTaskStatusUpdate(task a2a.TaskInfoProvider, state a2a.TaskState, msg string) *a2a.TaskStatusUpdateEvent { @@ -1871,9 +1859,7 @@ func newTaskStatusUpdate(task a2a.TaskInfoProvider, state a2a.TaskState, msg str } func newFinalTaskStatusUpdate(task a2a.TaskInfoProvider, state a2a.TaskState, msg string) *a2a.TaskStatusUpdateEvent { - res := newTaskStatusUpdate(task, state, msg) - res.Final = true - return res + return newTaskStatusUpdate(task, state, msg) } func newTaskWithStatus(task a2a.TaskInfoProvider, state a2a.TaskState, msg string) *a2a.Task { @@ -1887,7 +1873,7 @@ func newTaskWithMeta(task a2a.TaskInfoProvider, meta map[string]any) *a2a.Task { return &a2a.Task{ID: task.TaskInfo().TaskID, ContextID: task.TaskInfo().ContextID, Metadata: meta} } -func newArtifactEvent(task a2a.TaskInfoProvider, aid a2a.ArtifactID, parts ...a2a.Part) *a2a.TaskArtifactUpdateEvent { +func newArtifactEvent(task a2a.TaskInfoProvider, aid a2a.ArtifactID, parts ...*a2a.Part) *a2a.TaskArtifactUpdateEvent { ev := a2a.NewArtifactEvent(task, parts...) ev.Artifact.ID = aid return ev diff --git a/a2asrv/intercepted_handler.go b/a2asrv/intercepted_handler.go index c32a892b..c7bfc690 100644 --- a/a2asrv/intercepted_handler.go +++ b/a2asrv/intercepted_handler.go @@ -46,34 +46,34 @@ type interceptBeforeResult[Req any, Resp any] struct { var _ RequestHandler = (*InterceptedHandler)(nil) -func (h *InterceptedHandler) OnGetTask(ctx context.Context, query *a2a.TaskQueryParams) (*a2a.Task, error) { - ctx, callCtx := withMethodCallContext(ctx, "OnGetTask") - if query != nil { - ctx = h.withLoggerContext(ctx, slog.String("task_id", string(query.ID))) +func (h *InterceptedHandler) GetTask(ctx context.Context, req *a2a.GetTaskRequest) (*a2a.Task, error) { + ctx, callCtx := withMethodCallContext(ctx, "GetTask") + if req != nil { + ctx = h.withLoggerContext(ctx, slog.String("task_id", string(req.ID))) } - return doCall(ctx, callCtx, h, query, h.Handler.OnGetTask) + return doCall(ctx, callCtx, h, req, h.Handler.GetTask) } -func (h *InterceptedHandler) OnListTasks(ctx context.Context, req *a2a.ListTasksRequest) (*a2a.ListTasksResponse, error) { - ctx, callCtx := withMethodCallContext(ctx, "OnListTasks") +func (h *InterceptedHandler) ListTasks(ctx context.Context, req *a2a.ListTasksRequest) (*a2a.ListTasksResponse, error) { + ctx, callCtx := withMethodCallContext(ctx, "ListTasks") if req != nil { ctx = h.withLoggerContext(ctx) } - return doCall(ctx, callCtx, h, req, h.Handler.OnListTasks) + return doCall(ctx, callCtx, h, req, h.Handler.ListTasks) } -func (h *InterceptedHandler) OnCancelTask(ctx context.Context, params *a2a.TaskIDParams) (*a2a.Task, error) { - ctx, callCtx := withMethodCallContext(ctx, "OnCancelTask") - if params != nil { - ctx = h.withLoggerContext(ctx, slog.String("task_id", string(params.ID))) +func (h *InterceptedHandler) CancelTask(ctx context.Context, req *a2a.CancelTaskRequest) (*a2a.Task, error) { + ctx, callCtx := withMethodCallContext(ctx, "CancelTask") + if req != nil { + ctx = h.withLoggerContext(ctx, slog.String("task_id", string(req.ID))) } - return doCall(ctx, callCtx, h, params, h.Handler.OnCancelTask) + return doCall(ctx, callCtx, h, req, h.Handler.CancelTask) } -func (h *InterceptedHandler) OnSendMessage(ctx context.Context, params *a2a.MessageSendParams) (a2a.SendMessageResult, error) { - ctx, callCtx := withMethodCallContext(ctx, "OnSendMessage") - if params != nil && params.Message != nil { - msg := params.Message +func (h *InterceptedHandler) SendMessage(ctx context.Context, req *a2a.SendMessageRequest) (a2a.SendMessageResult, error) { + ctx, callCtx := withMethodCallContext(ctx, "SendMessage") + if req != nil && req.Message != nil { + msg := req.Message ctx = h.withLoggerContext( ctx, slog.String("message_id", msg.ID), @@ -83,14 +83,14 @@ func (h *InterceptedHandler) OnSendMessage(ctx context.Context, params *a2a.Mess } else { ctx = h.withLoggerContext(ctx) } - return doCall(ctx, callCtx, h, params, h.Handler.OnSendMessage) + return doCall(ctx, callCtx, h, req, h.Handler.SendMessage) } -func (h *InterceptedHandler) OnSendMessageStream(ctx context.Context, params *a2a.MessageSendParams) iter.Seq2[a2a.Event, error] { +func (h *InterceptedHandler) SendStreamingMessage(ctx context.Context, req *a2a.SendMessageRequest) iter.Seq2[a2a.Event, error] { return func(yield func(a2a.Event, error) bool) { - ctx, callCtx := withMethodCallContext(ctx, "OnSendMessageStream") - if params != nil && params.Message != nil { - msg := params.Message + ctx, callCtx := withMethodCallContext(ctx, "SendStreamingMessage") + if req != nil && req.Message != nil { + msg := req.Message ctx = h.withLoggerContext( ctx, slog.String("message_id", msg.ID), @@ -100,7 +100,7 @@ func (h *InterceptedHandler) OnSendMessageStream(ctx context.Context, params *a2 } else { ctx = h.withLoggerContext(ctx) } - ctx, res := interceptBefore[*a2a.MessageSendParams, a2a.SendMessageResult](ctx, h, callCtx, params) + ctx, res := interceptBefore[*a2a.SendMessageRequest, a2a.SendMessageResult](ctx, h, callCtx, req) if res.earlyErr != nil { yield(nil, res.earlyErr) return @@ -109,7 +109,7 @@ func (h *InterceptedHandler) OnSendMessageStream(ctx context.Context, params *a2 yield(*res.earlyResponse, nil) return } - for event, err := range h.Handler.OnSendMessageStream(ctx, res.reqOverride) { + for event, err := range h.Handler.SendStreamingMessage(ctx, res.reqOverride) { interceptedEvent, errOverride := interceptAfter(ctx, h.Interceptors, callCtx, event, err) if errOverride != nil { yield(nil, errOverride) @@ -122,13 +122,13 @@ func (h *InterceptedHandler) OnSendMessageStream(ctx context.Context, params *a2 } } -func (h *InterceptedHandler) OnResubscribeToTask(ctx context.Context, params *a2a.TaskIDParams) iter.Seq2[a2a.Event, error] { +func (h *InterceptedHandler) SubscribeToTask(ctx context.Context, req *a2a.SubscribeToTaskRequest) iter.Seq2[a2a.Event, error] { return func(yield func(a2a.Event, error) bool) { - ctx, callCtx := withMethodCallContext(ctx, "OnResubscribeToTask") - if params != nil { - ctx = h.withLoggerContext(ctx, slog.String("task_id", string(params.ID))) + ctx, callCtx := withMethodCallContext(ctx, "SubscribeToTask") + if req != nil { + ctx = h.withLoggerContext(ctx, slog.String("task_id", string(req.ID))) } - ctx, res := interceptBefore[*a2a.TaskIDParams, a2a.SendMessageResult](ctx, h, callCtx, params) + ctx, res := interceptBefore[*a2a.SubscribeToTaskRequest, a2a.SendMessageResult](ctx, h, callCtx, req) if res.earlyErr != nil { yield(nil, res.earlyErr) return @@ -137,7 +137,7 @@ func (h *InterceptedHandler) OnResubscribeToTask(ctx context.Context, params *a2 yield(*res.earlyResponse, nil) return } - for event, err := range h.Handler.OnResubscribeToTask(ctx, res.reqOverride) { + for event, err := range h.Handler.SubscribeToTask(ctx, res.reqOverride) { interceptedEvent, errOverride := interceptAfter(ctx, h.Interceptors, callCtx, event, err) if errOverride != nil { yield(nil, errOverride) @@ -150,50 +150,50 @@ func (h *InterceptedHandler) OnResubscribeToTask(ctx context.Context, params *a2 } } -func (h *InterceptedHandler) OnGetTaskPushConfig(ctx context.Context, params *a2a.GetTaskPushConfigParams) (*a2a.TaskPushConfig, error) { - ctx, callCtx := withMethodCallContext(ctx, "OnGetTaskPushConfig") - if params != nil { - ctx = h.withLoggerContext(ctx, slog.String("task_id", string(params.TaskID))) +func (h *InterceptedHandler) GetTaskPushConfig(ctx context.Context, req *a2a.GetTaskPushConfigRequest) (*a2a.TaskPushConfig, error) { + ctx, callCtx := withMethodCallContext(ctx, "GetTaskPushConfig") + if req != nil { + ctx = h.withLoggerContext(ctx, slog.String("task_id", string(req.TaskID))) } - return doCall(ctx, callCtx, h, params, h.Handler.OnGetTaskPushConfig) + return doCall(ctx, callCtx, h, req, h.Handler.GetTaskPushConfig) } -func (h *InterceptedHandler) OnListTaskPushConfig(ctx context.Context, params *a2a.ListTaskPushConfigParams) ([]*a2a.TaskPushConfig, error) { - ctx, callCtx := withMethodCallContext(ctx, "OnListTaskPushConfig") - if params != nil { - ctx = h.withLoggerContext(ctx, slog.String("task_id", string(params.TaskID))) +func (h *InterceptedHandler) ListTaskPushConfig(ctx context.Context, req *a2a.ListTaskPushConfigRequest) ([]*a2a.TaskPushConfig, error) { + ctx, callCtx := withMethodCallContext(ctx, "ListTaskPushConfig") + if req != nil { + ctx = h.withLoggerContext(ctx, slog.String("task_id", string(req.TaskID))) } - return doCall(ctx, callCtx, h, params, h.Handler.OnListTaskPushConfig) + return doCall(ctx, callCtx, h, req, h.Handler.ListTaskPushConfig) } -func (h *InterceptedHandler) OnSetTaskPushConfig(ctx context.Context, params *a2a.TaskPushConfig) (*a2a.TaskPushConfig, error) { - ctx, callCtx := withMethodCallContext(ctx, "OnSetTaskPushConfig") - if params != nil { - ctx = h.withLoggerContext(ctx, slog.String("task_id", string(params.TaskID))) +func (h *InterceptedHandler) CreateTaskPushConfig(ctx context.Context, req *a2a.CreateTaskPushConfigRequest) (*a2a.TaskPushConfig, error) { + ctx, callCtx := withMethodCallContext(ctx, "CreateTaskPushConfig") + if req != nil { + ctx = h.withLoggerContext(ctx, slog.String("task_id", string(req.TaskID))) } - return doCall(ctx, callCtx, h, params, h.Handler.OnSetTaskPushConfig) + return doCall(ctx, callCtx, h, req, h.Handler.CreateTaskPushConfig) } -func (h *InterceptedHandler) OnDeleteTaskPushConfig(ctx context.Context, params *a2a.DeleteTaskPushConfigParams) error { - ctx, callCtx := withMethodCallContext(ctx, "OnDeleteTaskPushConfig") - if params != nil { - ctx = h.withLoggerContext(ctx, slog.String("task_id", string(params.TaskID))) +func (h *InterceptedHandler) DeleteTaskPushConfig(ctx context.Context, req *a2a.DeleteTaskPushConfigRequest) error { + ctx, callCtx := withMethodCallContext(ctx, "DeleteTaskPushConfig") + if req != nil { + ctx = h.withLoggerContext(ctx, slog.String("task_id", string(req.TaskID))) } - ctx, res := interceptBefore[*a2a.DeleteTaskPushConfigParams, struct{}](ctx, h, callCtx, params) + ctx, res := interceptBefore[*a2a.DeleteTaskPushConfigRequest, struct{}](ctx, h, callCtx, req) if res.earlyErr != nil { return res.earlyErr } if res.earlyResponse != nil { return nil } - err := h.Handler.OnDeleteTaskPushConfig(ctx, res.reqOverride) + err := h.Handler.DeleteTaskPushConfig(ctx, res.reqOverride) var emptyResponse struct{} _, errOverride := interceptAfter(ctx, h.Interceptors, callCtx, emptyResponse, err) return errOverride } -func (h *InterceptedHandler) OnGetExtendedAgentCard(ctx context.Context) (*a2a.AgentCard, error) { - ctx, callCtx := withMethodCallContext(ctx, "OnGetExtendedAgentCard") +func (h *InterceptedHandler) GetExtendedAgentCard(ctx context.Context) (*a2a.AgentCard, error) { + ctx, callCtx := withMethodCallContext(ctx, "GetExtendedAgentCard") ctx = h.withLoggerContext(ctx) var req *struct{} @@ -204,7 +204,7 @@ func (h *InterceptedHandler) OnGetExtendedAgentCard(ctx context.Context) (*a2a.A if res.earlyResponse != nil { return *res.earlyResponse, nil } - response, err := h.Handler.OnGetExtendedAgentCard(ctx) + response, err := h.Handler.GetExtendedAgentCard(ctx) return interceptAfter(ctx, h.Interceptors, callCtx, response, err) } diff --git a/a2asrv/intercepted_handler_test.go b/a2asrv/intercepted_handler_test.go index b5531861..4f339748 100644 --- a/a2asrv/intercepted_handler_test.go +++ b/a2asrv/intercepted_handler_test.go @@ -27,19 +27,19 @@ import ( ) type mockHandler struct { - lastCallContext *CallContext - resultErr error - OnGetTaskFn func(ctx context.Context, query *a2a.TaskQueryParams) (*a2a.Task, error) - OnSendMessageFn func(ctx context.Context, params *a2a.MessageSendParams) (a2a.SendMessageResult, error) - OnSendMessageStreamFn func(ctx context.Context, params *a2a.MessageSendParams) iter.Seq2[a2a.Event, error] + lastCallContext *CallContext + resultErr error + GetTaskFn func(context.Context, *a2a.GetTaskRequest) (*a2a.Task, error) + SendMessageFn func(context.Context, *a2a.SendMessageRequest) (a2a.SendMessageResult, error) + SendStreamingMessageFn func(context.Context, *a2a.SendMessageRequest) iter.Seq2[a2a.Event, error] } var _ RequestHandler = (*mockHandler)(nil) -func (h *mockHandler) OnGetTask(ctx context.Context, query *a2a.TaskQueryParams) (*a2a.Task, error) { +func (h *mockHandler) GetTask(ctx context.Context, query *a2a.GetTaskRequest) (*a2a.Task, error) { h.lastCallContext, _ = CallContextFrom(ctx) - if h.OnGetTaskFn != nil { - return h.OnGetTaskFn(ctx, query) + if h.GetTaskFn != nil { + return h.GetTaskFn(ctx, query) } if h.resultErr != nil { return nil, h.resultErr @@ -47,7 +47,7 @@ func (h *mockHandler) OnGetTask(ctx context.Context, query *a2a.TaskQueryParams) return &a2a.Task{}, nil } -func (h *mockHandler) OnListTasks(ctx context.Context, req *a2a.ListTasksRequest) (*a2a.ListTasksResponse, error) { +func (h *mockHandler) ListTasks(ctx context.Context, req *a2a.ListTasksRequest) (*a2a.ListTasksResponse, error) { h.lastCallContext, _ = CallContextFrom(ctx) if h.resultErr != nil { return nil, h.resultErr @@ -55,7 +55,7 @@ func (h *mockHandler) OnListTasks(ctx context.Context, req *a2a.ListTasksRequest return &a2a.ListTasksResponse{}, nil } -func (h *mockHandler) OnCancelTask(ctx context.Context, params *a2a.TaskIDParams) (*a2a.Task, error) { +func (h *mockHandler) CancelTask(ctx context.Context, params *a2a.CancelTaskRequest) (*a2a.Task, error) { h.lastCallContext, _ = CallContextFrom(ctx) if h.resultErr != nil { return nil, h.resultErr @@ -63,10 +63,10 @@ func (h *mockHandler) OnCancelTask(ctx context.Context, params *a2a.TaskIDParams return &a2a.Task{}, nil } -func (h *mockHandler) OnSendMessage(ctx context.Context, params *a2a.MessageSendParams) (a2a.SendMessageResult, error) { +func (h *mockHandler) SendMessage(ctx context.Context, params *a2a.SendMessageRequest) (a2a.SendMessageResult, error) { h.lastCallContext, _ = CallContextFrom(ctx) - if h.OnSendMessageFn != nil { - return h.OnSendMessageFn(ctx, params) + if h.SendMessageFn != nil { + return h.SendMessageFn(ctx, params) } if h.resultErr != nil { return nil, h.resultErr @@ -74,9 +74,9 @@ func (h *mockHandler) OnSendMessage(ctx context.Context, params *a2a.MessageSend return &a2a.Task{}, nil } -func (h *mockHandler) OnSendMessageStream(ctx context.Context, params *a2a.MessageSendParams) iter.Seq2[a2a.Event, error] { - if h.OnSendMessageStreamFn != nil { - return h.OnSendMessageStreamFn(ctx, params) +func (h *mockHandler) SendStreamingMessage(ctx context.Context, params *a2a.SendMessageRequest) iter.Seq2[a2a.Event, error] { + if h.SendStreamingMessageFn != nil { + return h.SendStreamingMessageFn(ctx, params) } return func(yield func(a2a.Event, error) bool) { h.lastCallContext, _ = CallContextFrom(ctx) @@ -88,7 +88,7 @@ func (h *mockHandler) OnSendMessageStream(ctx context.Context, params *a2a.Messa } } -func (h *mockHandler) OnResubscribeToTask(ctx context.Context, params *a2a.TaskIDParams) iter.Seq2[a2a.Event, error] { +func (h *mockHandler) SubscribeToTask(ctx context.Context, params *a2a.SubscribeToTaskRequest) iter.Seq2[a2a.Event, error] { return func(yield func(a2a.Event, error) bool) { h.lastCallContext, _ = CallContextFrom(ctx) if h.resultErr != nil { @@ -99,7 +99,7 @@ func (h *mockHandler) OnResubscribeToTask(ctx context.Context, params *a2a.TaskI } } -func (h *mockHandler) OnGetTaskPushConfig(ctx context.Context, params *a2a.GetTaskPushConfigParams) (*a2a.TaskPushConfig, error) { +func (h *mockHandler) GetTaskPushConfig(ctx context.Context, params *a2a.GetTaskPushConfigRequest) (*a2a.TaskPushConfig, error) { h.lastCallContext, _ = CallContextFrom(ctx) if h.resultErr != nil { return nil, h.resultErr @@ -107,7 +107,7 @@ func (h *mockHandler) OnGetTaskPushConfig(ctx context.Context, params *a2a.GetTa return &a2a.TaskPushConfig{}, h.resultErr } -func (h *mockHandler) OnListTaskPushConfig(ctx context.Context, params *a2a.ListTaskPushConfigParams) ([]*a2a.TaskPushConfig, error) { +func (h *mockHandler) ListTaskPushConfig(ctx context.Context, params *a2a.ListTaskPushConfigRequest) ([]*a2a.TaskPushConfig, error) { h.lastCallContext, _ = CallContextFrom(ctx) if h.resultErr != nil { return nil, h.resultErr @@ -115,7 +115,7 @@ func (h *mockHandler) OnListTaskPushConfig(ctx context.Context, params *a2a.List return []*a2a.TaskPushConfig{{}}, nil } -func (h *mockHandler) OnSetTaskPushConfig(ctx context.Context, params *a2a.TaskPushConfig) (*a2a.TaskPushConfig, error) { +func (h *mockHandler) CreateTaskPushConfig(ctx context.Context, params *a2a.CreateTaskPushConfigRequest) (*a2a.TaskPushConfig, error) { h.lastCallContext, _ = CallContextFrom(ctx) if h.resultErr != nil { return nil, h.resultErr @@ -123,12 +123,12 @@ func (h *mockHandler) OnSetTaskPushConfig(ctx context.Context, params *a2a.TaskP return &a2a.TaskPushConfig{}, h.resultErr } -func (h *mockHandler) OnDeleteTaskPushConfig(ctx context.Context, params *a2a.DeleteTaskPushConfigParams) error { +func (h *mockHandler) DeleteTaskPushConfig(ctx context.Context, params *a2a.DeleteTaskPushConfigRequest) error { h.lastCallContext, _ = CallContextFrom(ctx) return h.resultErr } -func (h *mockHandler) OnGetExtendedAgentCard(ctx context.Context) (*a2a.AgentCard, error) { +func (h *mockHandler) GetExtendedAgentCard(ctx context.Context) (*a2a.AgentCard, error) { h.lastCallContext, _ = CallContextFrom(ctx) if h.resultErr != nil { return nil, h.resultErr @@ -173,69 +173,69 @@ var methodCalls = []struct { call func(ctx context.Context, h RequestHandler) (any, error) }{ { - method: "OnGetTask", + method: "GetTask", call: func(ctx context.Context, h RequestHandler) (any, error) { - return h.OnGetTask(ctx, &a2a.TaskQueryParams{}) + return h.GetTask(ctx, &a2a.GetTaskRequest{}) }, }, { - method: "OnListTasks", + method: "ListTasks", call: func(ctx context.Context, h RequestHandler) (any, error) { - return h.OnListTasks(ctx, &a2a.ListTasksRequest{}) + return h.ListTasks(ctx, &a2a.ListTasksRequest{}) }, }, { - method: "OnCancelTask", + method: "CancelTask", call: func(ctx context.Context, h RequestHandler) (any, error) { - return h.OnCancelTask(ctx, &a2a.TaskIDParams{}) + return h.CancelTask(ctx, &a2a.CancelTaskRequest{}) }, }, { - method: "OnSendMessage", + method: "SendMessage", call: func(ctx context.Context, h RequestHandler) (any, error) { - return h.OnSendMessage(ctx, &a2a.MessageSendParams{}) + return h.SendMessage(ctx, &a2a.SendMessageRequest{}) }, }, { - method: "OnSendMessageStream", + method: "SendStreamingMessage", call: func(ctx context.Context, h RequestHandler) (any, error) { - return handleSingleItemSeq(h.OnSendMessageStream(ctx, &a2a.MessageSendParams{})) + return handleSingleItemSeq(h.SendStreamingMessage(ctx, &a2a.SendMessageRequest{})) }, }, { - method: "OnResubscribeToTask", + method: "SubscribeToTask", call: func(ctx context.Context, h RequestHandler) (any, error) { - return handleSingleItemSeq(h.OnResubscribeToTask(ctx, &a2a.TaskIDParams{})) + return handleSingleItemSeq(h.SubscribeToTask(ctx, &a2a.SubscribeToTaskRequest{})) }, }, { - method: "OnListTaskPushConfig", + method: "ListTaskPushConfig", call: func(ctx context.Context, h RequestHandler) (any, error) { - return h.OnListTaskPushConfig(ctx, &a2a.ListTaskPushConfigParams{}) + return h.ListTaskPushConfig(ctx, &a2a.ListTaskPushConfigRequest{}) }, }, { - method: "OnSetTaskPushConfig", + method: "CreateTaskPushConfig", call: func(ctx context.Context, h RequestHandler) (any, error) { - return h.OnSetTaskPushConfig(ctx, &a2a.TaskPushConfig{}) + return h.CreateTaskPushConfig(ctx, &a2a.CreateTaskPushConfigRequest{}) }, }, { - method: "OnGetTaskPushConfig", + method: "GetTaskPushConfig", call: func(ctx context.Context, h RequestHandler) (any, error) { - return h.OnGetTaskPushConfig(ctx, &a2a.GetTaskPushConfigParams{}) + return h.GetTaskPushConfig(ctx, &a2a.GetTaskPushConfigRequest{}) }, }, { - method: "OnDeleteTaskPushConfig", + method: "DeleteTaskPushConfig", call: func(ctx context.Context, h RequestHandler) (any, error) { - return nil, h.OnDeleteTaskPushConfig(ctx, &a2a.DeleteTaskPushConfigParams{}) + return nil, h.DeleteTaskPushConfig(ctx, &a2a.DeleteTaskPushConfigRequest{}) }, }, { - method: "OnGetExtendedAgentCard", + method: "GetExtendedAgentCard", call: func(ctx context.Context, h RequestHandler) (any, error) { - return h.OnGetExtendedAgentCard(ctx) + return h.GetExtendedAgentCard(ctx) }, }, } @@ -246,11 +246,11 @@ func TestInterceptedHandler_Auth(t *testing.T) { handler := &InterceptedHandler{Handler: mockHandler, Interceptors: []CallInterceptor{mockInterceptor}} var capturedCallCtx *CallContext - mockHandler.OnSendMessageFn = func(ctx context.Context, params *a2a.MessageSendParams) (a2a.SendMessageResult, error) { + mockHandler.SendMessageFn = func(ctx context.Context, params *a2a.SendMessageRequest) (a2a.SendMessageResult, error) { if callCtx, ok := CallContextFrom(ctx); ok { capturedCallCtx = callCtx } - return a2a.NewMessage(a2a.MessageRoleUser, a2a.TextPart{Text: "Hi!"}), nil + return a2a.NewMessage(a2a.MessageRoleUser, a2a.NewTextPart("Hi!")), nil } mockInterceptor.beforeFn = func(ctx context.Context, callCtx *CallContext, req *Request) (context.Context, any, error) { @@ -258,7 +258,7 @@ func TestInterceptedHandler_Auth(t *testing.T) { return ctx, nil, nil } - _, _ = handler.OnSendMessage(ctx, &a2a.MessageSendParams{}) + _, _ = handler.SendMessage(ctx, &a2a.SendMessageRequest{}) if !capturedCallCtx.User.Authenticated { t.Fatal("CallContext.User.Authenticated = false, want true") @@ -273,15 +273,15 @@ func TestInterceptedHandler_RequestResponseModification(t *testing.T) { mockHandler, mockInterceptor := &mockHandler{}, &mockInterceptor{} handler := &InterceptedHandler{Handler: mockHandler, Interceptors: []CallInterceptor{mockInterceptor}} - var capturedRequest *a2a.MessageSendParams - mockHandler.OnSendMessageFn = func(ctx context.Context, params *a2a.MessageSendParams) (a2a.SendMessageResult, error) { + var capturedRequest *a2a.SendMessageRequest + mockHandler.SendMessageFn = func(ctx context.Context, params *a2a.SendMessageRequest) (a2a.SendMessageResult, error) { capturedRequest = params - return a2a.NewMessage(a2a.MessageRoleUser, a2a.TextPart{Text: "Hi!"}), nil + return a2a.NewMessage(a2a.MessageRoleUser, a2a.NewTextPart("Hi!")), nil } wantReqKey, wantReqVal := "reqKey", 42 mockInterceptor.beforeFn = func(ctx context.Context, callCtx *CallContext, req *Request) (context.Context, any, error) { - payload := req.Payload.(*a2a.MessageSendParams) + payload := req.Payload.(*a2a.SendMessageRequest) payload.Metadata = map[string]any{wantReqKey: wantReqVal} return ctx, nil, nil } @@ -293,20 +293,20 @@ func TestInterceptedHandler_RequestResponseModification(t *testing.T) { return nil } - request := &a2a.MessageSendParams{Message: a2a.NewMessage(a2a.MessageRoleUser, a2a.TextPart{Text: "Hello!"})} - response, err := handler.OnSendMessage(ctx, request) - if mockHandler.lastCallContext.method != "OnSendMessage" { - t.Fatalf("handler.OnSendMessage() CallContext = %v, want method=OnSendMessage", mockHandler.lastCallContext) + request := &a2a.SendMessageRequest{Message: a2a.NewMessage(a2a.MessageRoleUser, a2a.NewTextPart("Hello!"))} + response, err := handler.SendMessage(ctx, request) + if mockHandler.lastCallContext.method != "SendMessage" { + t.Fatalf("handler.SendMessage() CallContext = %v, want method=SendMessage", mockHandler.lastCallContext) } if err != nil { - t.Fatalf("handler.OnSendMessage() error = %v, want nil", err) + t.Fatalf("handler.SendMessage() error = %v, want nil", err) } if capturedRequest.Metadata[wantReqKey] != wantReqVal { - t.Fatalf("OnSendMessage() Request.Metadata[%q] = %v, want %d", wantReqKey, capturedRequest.Metadata[wantReqKey], wantReqVal) + t.Fatalf("SendMessage() Request.Metadata[%q] = %v, want %d", wantReqKey, capturedRequest.Metadata[wantReqKey], wantReqVal) } responsMsg := response.(*a2a.Message) if responsMsg.Metadata[wantRespKey] != wantRespVal { - t.Fatalf("OnSendMessage() Response.Metadata[%q] = %v, want %d", wantRespKey, responsMsg.Metadata[wantRespKey], wantRespVal) + t.Fatalf("SendMessage() Response.Metadata[%q] = %v, want %d", wantRespKey, responsMsg.Metadata[wantRespKey], wantRespVal) } } @@ -314,39 +314,39 @@ func TestInterceptedHandler_RequestModification(t *testing.T) { ctx := t.Context() mockHandler, mockInterceptor := &mockHandler{}, &mockInterceptor{} handler := &InterceptedHandler{Handler: mockHandler, Interceptors: []CallInterceptor{mockInterceptor}} - originalParams := &a2a.MessageSendParams{ - Message: a2a.NewMessage(a2a.MessageRoleUser, a2a.TextPart{Text: "Hello!"}), + originalParams := &a2a.SendMessageRequest{ + Message: a2a.NewMessage(a2a.MessageRoleUser, a2a.NewTextPart("Hello!")), } - var receivedParams *a2a.MessageSendParams + var receivedParams *a2a.SendMessageRequest - mockHandler.OnSendMessageFn = func(ctx context.Context, params *a2a.MessageSendParams) (a2a.SendMessageResult, error) { + mockHandler.SendMessageFn = func(ctx context.Context, params *a2a.SendMessageRequest) (a2a.SendMessageResult, error) { receivedParams = params - message := params.Message.Parts[0].(a2a.TextPart).Text - return a2a.NewMessage(a2a.MessageRoleUser, a2a.TextPart{Text: message}), nil + message := string(params.Message.Parts[0].Content.(a2a.Text)) + return a2a.NewMessage(a2a.MessageRoleUser, a2a.NewTextPart(message)), nil } mockInterceptor.beforeFn = func(ctx context.Context, callCtx *CallContext, req *Request) (context.Context, any, error) { - if _, ok := req.Payload.(*a2a.MessageSendParams); ok { - req.Payload = &a2a.MessageSendParams{ - Message: a2a.NewMessage(a2a.MessageRoleUser, a2a.TextPart{Text: "Modified!"}), + if _, ok := req.Payload.(*a2a.SendMessageRequest); ok { + req.Payload = &a2a.SendMessageRequest{ + Message: a2a.NewMessage(a2a.MessageRoleUser, a2a.NewTextPart("Modified!")), } } return ctx, nil, nil } - resp, err := handler.OnSendMessage(ctx, originalParams) + resp, err := handler.SendMessage(ctx, originalParams) if err != nil { - t.Fatalf("handler.OnSendMessage() error = %v, want nil", err) + t.Fatalf("handler.SendMessage() error = %v, want nil", err) } - if mockHandler.lastCallContext.method != "OnSendMessage" { - t.Fatalf("handler.OnSendMessage() CallContext = %v, want method=OnSendMessage", mockHandler.lastCallContext) + if mockHandler.lastCallContext.method != "SendMessage" { + t.Fatalf("handler.SendMessage() CallContext = %v, want method=SendMessage", mockHandler.lastCallContext) } if receivedParams == originalParams { - t.Fatalf("handler.OnSendMessage() receivedParams = %v, want %v", receivedParams, originalParams) + t.Fatalf("handler.SendMessage() receivedParams = %v, want %v", receivedParams, originalParams) } reqMsg := resp.(*a2a.Message) - if reqMsg.Parts[0].(a2a.TextPart).Text != "Modified!" { - t.Fatalf("handler.OnSendMessage() Request.Text = %q, want %q", reqMsg.Parts[0].(a2a.TextPart).Text, "Modified!") + if string(reqMsg.Parts[0].Content.(a2a.Text)) != "Modified!" { + t.Fatalf("handler.SendMessage() Request.Text = %q, want %q", string(reqMsg.Parts[0].Content.(a2a.Text)), "Modified!") } } @@ -364,16 +364,16 @@ func TestInterceptedHandler_ResponseAndErrorModification(t *testing.T) { }{ { name: "replace response object", - handlerResp: a2a.NewMessage(a2a.MessageRoleUser, a2a.TextPart{Text: "Original!"}), + handlerResp: a2a.NewMessage(a2a.MessageRoleUser, a2a.NewTextPart("Original!")), interceptorFn: func(ctx context.Context, callCtx *CallContext, resp *Response) error { - resp.Payload = a2a.NewMessage(a2a.MessageRoleUser, a2a.TextPart{Text: "Modified!"}) + resp.Payload = a2a.NewMessage(a2a.MessageRoleUser, a2a.NewTextPart("Modified!")) return nil }, wantRespText: "Modified!", }, { name: "injected error: handler success, interceptor error", - handlerResp: a2a.NewMessage(a2a.MessageRoleUser, a2a.TextPart{Text: "Success!"}), + handlerResp: a2a.NewMessage(a2a.MessageRoleUser, a2a.NewTextPart("Success!")), interceptorFn: func(ctx context.Context, callCtx *CallContext, resp *Response) error { resp.Err = injectedErr return nil @@ -387,7 +387,7 @@ func TestInterceptedHandler_ResponseAndErrorModification(t *testing.T) { if resp.Err != nil { resp.Err = nil - resp.Payload = a2a.NewMessage(a2a.MessageRoleUser, a2a.TextPart{Text: "Recovered from error!"}) + resp.Payload = a2a.NewMessage(a2a.MessageRoleUser, a2a.NewTextPart("Recovered from error!")) } return nil }, @@ -401,27 +401,27 @@ func TestInterceptedHandler_ResponseAndErrorModification(t *testing.T) { mockHandler, mockInterceptor := &mockHandler{}, &mockInterceptor{} handler := &InterceptedHandler{Handler: mockHandler, Interceptors: []CallInterceptor{mockInterceptor}} - mockHandler.OnSendMessageFn = func(ctx context.Context, params *a2a.MessageSendParams) (a2a.SendMessageResult, error) { + mockHandler.SendMessageFn = func(ctx context.Context, params *a2a.SendMessageRequest) (a2a.SendMessageResult, error) { return tt.handlerResp, tt.handlerErr } mockInterceptor.afterFn = tt.interceptorFn - resp, err := handler.OnSendMessage(ctx, &a2a.MessageSendParams{ - Message: a2a.NewMessage(a2a.MessageRoleUser, a2a.TextPart{Text: "Hello!"}), + resp, err := handler.SendMessage(ctx, &a2a.SendMessageRequest{ + Message: a2a.NewMessage(a2a.MessageRoleUser, a2a.NewTextPart("Hello!")), }) if !errors.Is(err, tt.wantErr) { - t.Errorf("handler.OnSendMessage() error = %v, want %v", err, tt.wantErr) + t.Errorf("handler.SendMessage() error = %v, want %v", err, tt.wantErr) } if tt.wantErr == nil { if resp == nil { - t.Errorf("handler.OnSendMessage() resp = nil, want %v", tt.wantRespText) + t.Errorf("handler.SendMessage() resp = nil, want %v", tt.wantRespText) } msg := resp.(*a2a.Message) - if msg.Parts[0].(a2a.TextPart).Text != tt.wantRespText { - t.Errorf("handler.OnSendMessage() resp.Text = %q, want %q", msg.Parts[0].(a2a.TextPart).Text, tt.wantRespText) + if string(msg.Parts[0].Content.(a2a.Text)) != tt.wantRespText { + t.Errorf("handler.SendMessage() resp.Text = %q, want %q", string(msg.Parts[0].Content.(a2a.Text)), tt.wantRespText) } } }) @@ -438,7 +438,7 @@ func TestInterceptedHandler_TypeSafety(t *testing.T) { return ctx, nil, nil } - _, err := handler.OnSendMessage(ctx, &a2a.MessageSendParams{}) + _, err := handler.SendMessage(ctx, &a2a.SendMessageRequest{}) if err == nil { t.Fatal("got nil error, want error due to payload type mismatch") @@ -471,7 +471,7 @@ func TestInterceptedHandler_InterceptorOrdering(t *testing.T) { interceptor1, interceptor2 := createInterceptor(1), createInterceptor(2) handler := &InterceptedHandler{Handler: mockHandler, Interceptors: []CallInterceptor{interceptor1, interceptor2}} - _, _ = handler.OnGetTask(ctx, &a2a.TaskQueryParams{}) + _, _ = handler.GetTask(ctx, &a2a.GetTaskRequest{}) wantBefore := []int{1, 2} if !reflect.DeepEqual(beforeCalls, wantBefore) { @@ -489,7 +489,7 @@ func TestInterceptedHandler_EveryStreamValueIntercepted(t *testing.T) { handler := &InterceptedHandler{Handler: mockHandler, Interceptors: []CallInterceptor{mockInterceptor}} totalCount := 5 - mockHandler.OnSendMessageStreamFn = func(ctx context.Context, params *a2a.MessageSendParams) iter.Seq2[a2a.Event, error] { + mockHandler.SendStreamingMessageFn = func(ctx context.Context, params *a2a.SendMessageRequest) iter.Seq2[a2a.Event, error] { return func(yield func(a2a.Event, error) bool) { for range totalCount { if !yield(&a2a.TaskStatusUpdateEvent{Metadata: map[string]any{"count": 0}}, nil) { @@ -509,10 +509,10 @@ func TestInterceptedHandler_EveryStreamValueIntercepted(t *testing.T) { } count := 0 - request := &a2a.MessageSendParams{Message: a2a.NewMessage(a2a.MessageRoleUser, a2a.TextPart{Text: "Hello!"})} - for ev, err := range handler.OnSendMessageStream(ctx, request) { + request := &a2a.SendMessageRequest{Message: a2a.NewMessage(a2a.MessageRoleUser, a2a.NewTextPart("Hello!"))} + for ev, err := range handler.SendStreamingMessage(ctx, request) { if err != nil { - t.Fatalf("handler.OnSendMessageStream() error %v, want nil", err) + t.Fatalf("handler.SendStreamingMessage() error %v, want nil", err) } if ev.Meta()[countKey] != count { t.Fatalf("event.Meta()[%q] = %v, want %v", countKey, ev.Meta()[countKey], count) @@ -521,7 +521,7 @@ func TestInterceptedHandler_EveryStreamValueIntercepted(t *testing.T) { } if count != afterCount { - t.Fatalf("handler.OnSendMessageStream() produced %d events, want %d", count, totalCount) + t.Fatalf("handler.SendStreamingMessage() produced %d events, want %d", count, totalCount) } } @@ -651,13 +651,13 @@ func TestInterceptedHandler_RejectResponse(t *testing.T) { func TestInterceptedHandler_EarlyReturn(t *testing.T) { ctx := t.Context() - originalQuery := &a2a.TaskQueryParams{ID: "original"} + originalQuery := &a2a.GetTaskRequest{ID: "original"} earlyResult := &a2a.Task{ID: "early-cached-result"} mockHandler, interceptor1, interceptor2, interceptor3 := &mockHandler{}, &mockInterceptor{}, &mockInterceptor{}, &mockInterceptor{} handlerCalled := false - mockHandler.OnGetTaskFn = func(ctx context.Context, query *a2a.TaskQueryParams) (*a2a.Task, error) { + mockHandler.GetTaskFn = func(ctx context.Context, query *a2a.GetTaskRequest) (*a2a.Task, error) { handlerCalled = true return nil, fmt.Errorf("handler should not be called") } @@ -690,7 +690,7 @@ func TestInterceptedHandler_EarlyReturn(t *testing.T) { Interceptors: []CallInterceptor{interceptor1, interceptor2, interceptor3}, } - response, err := handler.OnGetTask(ctx, originalQuery) + response, err := handler.GetTask(ctx, originalQuery) if err != nil { t.Errorf("OnGetTask() error = %v, want nil", err) } diff --git a/a2asrv/jsonrpc.go b/a2asrv/jsonrpc.go index 60e27ea9..0212ef22 100644 --- a/a2asrv/jsonrpc.go +++ b/a2asrv/jsonrpc.go @@ -154,7 +154,11 @@ func (h *jsonrpcHandler) handleRequest(ctx context.Context, rw http.ResponseWrit case jsonrpc.MethodTasksList: result, err = h.onListTasks(ctx, req.Params) case jsonrpc.MethodMessageSend: - result, err = h.onSendMessage(ctx, req.Params) + var res a2a.SendMessageResult + res, err = h.onSendMessage(ctx, req.Params) + if err == nil { + result = a2a.StreamResponse{Event: res} + } case jsonrpc.MethodTasksCancel: result, err = h.onCancelTask(ctx, req.Params) case jsonrpc.MethodPushConfigGet: @@ -282,7 +286,7 @@ func eventSeqToSSEDataStream(ctx context.Context, req *jsonrpcRequest, sseChan c return } - resp := jsonrpcResponse{JSONRPC: jsonrpc.Version, ID: req.ID, Result: event} + resp := jsonrpcResponse{JSONRPC: jsonrpc.Version, ID: req.ID, Result: a2a.StreamResponse{Event: event}} bytes, err := json.Marshal(resp) if err != nil { handleError(err) @@ -298,11 +302,11 @@ func eventSeqToSSEDataStream(ctx context.Context, req *jsonrpcRequest, sseChan c } func (h *jsonrpcHandler) onGetTask(ctx context.Context, raw json.RawMessage) (*a2a.Task, error) { - var query a2a.TaskQueryParams + var query a2a.GetTaskRequest if err := json.Unmarshal(raw, &query); err != nil { return nil, handleUnmarshalError(err) } - return h.handler.OnGetTask(ctx, &query) + return h.handler.GetTask(ctx, &query) } func (h *jsonrpcHandler) onListTasks(ctx context.Context, raw json.RawMessage) (*a2a.ListTasksResponse, error) { @@ -310,33 +314,33 @@ func (h *jsonrpcHandler) onListTasks(ctx context.Context, raw json.RawMessage) ( if err := json.Unmarshal(raw, &request); err != nil { return nil, handleUnmarshalError(err) } - return h.handler.OnListTasks(ctx, &request) + return h.handler.ListTasks(ctx, &request) } func (h *jsonrpcHandler) onCancelTask(ctx context.Context, raw json.RawMessage) (*a2a.Task, error) { - var id a2a.TaskIDParams + var id a2a.CancelTaskRequest if err := json.Unmarshal(raw, &id); err != nil { return nil, handleUnmarshalError(err) } - return h.handler.OnCancelTask(ctx, &id) + return h.handler.CancelTask(ctx, &id) } func (h *jsonrpcHandler) onSendMessage(ctx context.Context, raw json.RawMessage) (a2a.SendMessageResult, error) { - var message a2a.MessageSendParams + var message a2a.SendMessageRequest if err := json.Unmarshal(raw, &message); err != nil { return nil, handleUnmarshalError(err) } - return h.handler.OnSendMessage(ctx, &message) + return h.handler.SendMessage(ctx, &message) } func (h *jsonrpcHandler) onResubscribeToTask(ctx context.Context, raw json.RawMessage) iter.Seq2[a2a.Event, error] { return func(yield func(a2a.Event, error) bool) { - var id a2a.TaskIDParams + var id a2a.SubscribeToTaskRequest if err := json.Unmarshal(raw, &id); err != nil { yield(nil, handleUnmarshalError(err)) return } - for event, err := range h.handler.OnResubscribeToTask(ctx, &id) { + for event, err := range h.handler.SubscribeToTask(ctx, &id) { if !yield(event, err) { return } @@ -346,12 +350,12 @@ func (h *jsonrpcHandler) onResubscribeToTask(ctx context.Context, raw json.RawMe func (h *jsonrpcHandler) onSendMessageStream(ctx context.Context, raw json.RawMessage) iter.Seq2[a2a.Event, error] { return func(yield func(a2a.Event, error) bool) { - var message a2a.MessageSendParams + var message a2a.SendMessageRequest if err := json.Unmarshal(raw, &message); err != nil { yield(nil, handleUnmarshalError(err)) return } - for event, err := range h.handler.OnSendMessageStream(ctx, &message) { + for event, err := range h.handler.SendStreamingMessage(ctx, &message) { if !yield(event, err) { return } @@ -361,39 +365,39 @@ func (h *jsonrpcHandler) onSendMessageStream(ctx context.Context, raw json.RawMe } func (h *jsonrpcHandler) onGetTaskPushConfig(ctx context.Context, raw json.RawMessage) (*a2a.TaskPushConfig, error) { - var params a2a.GetTaskPushConfigParams + var params a2a.GetTaskPushConfigRequest if err := json.Unmarshal(raw, ¶ms); err != nil { return nil, handleUnmarshalError(err) } - return h.handler.OnGetTaskPushConfig(ctx, ¶ms) + return h.handler.GetTaskPushConfig(ctx, ¶ms) } func (h *jsonrpcHandler) onListTaskPushConfig(ctx context.Context, raw json.RawMessage) ([]*a2a.TaskPushConfig, error) { - var params a2a.ListTaskPushConfigParams + var params a2a.ListTaskPushConfigRequest if err := json.Unmarshal(raw, ¶ms); err != nil { return nil, handleUnmarshalError(err) } - return h.handler.OnListTaskPushConfig(ctx, ¶ms) + return h.handler.ListTaskPushConfig(ctx, ¶ms) } func (h *jsonrpcHandler) onSetTaskPushConfig(ctx context.Context, raw json.RawMessage) (*a2a.TaskPushConfig, error) { - var params a2a.TaskPushConfig + var params a2a.CreateTaskPushConfigRequest if err := json.Unmarshal(raw, ¶ms); err != nil { return nil, handleUnmarshalError(err) } - return h.handler.OnSetTaskPushConfig(ctx, ¶ms) + return h.handler.CreateTaskPushConfig(ctx, ¶ms) } func (h *jsonrpcHandler) onDeleteTaskPushConfig(ctx context.Context, raw json.RawMessage) error { - var params a2a.DeleteTaskPushConfigParams + var params a2a.DeleteTaskPushConfigRequest if err := json.Unmarshal(raw, ¶ms); err != nil { return handleUnmarshalError(err) } - return h.handler.OnDeleteTaskPushConfig(ctx, ¶ms) + return h.handler.DeleteTaskPushConfig(ctx, ¶ms) } func (h *jsonrpcHandler) onGetAgentCard(ctx context.Context) (*a2a.AgentCard, error) { - return h.handler.OnGetExtendedAgentCard(ctx) + return h.handler.GetExtendedAgentCard(ctx) } func marshalJSONRPCError(req *jsonrpcRequest, err error) ([]byte, bool) { diff --git a/a2asrv/jsonrpc_test.go b/a2asrv/jsonrpc_test.go index 33680824..e2809fd5 100644 --- a/a2asrv/jsonrpc_test.go +++ b/a2asrv/jsonrpc_test.go @@ -42,69 +42,69 @@ func TestJSONRPC_RequestRouting(t *testing.T) { call func(ctx context.Context, client *a2aclient.Client) (any, error) }{ { - method: "OnGetTask", + method: "GetTask", call: func(ctx context.Context, client *a2aclient.Client) (any, error) { - return client.GetTask(ctx, &a2a.TaskQueryParams{}) + return client.GetTask(ctx, &a2a.GetTaskRequest{}) }, }, { - method: "OnListTasks", + method: "ListTasks", call: func(ctx context.Context, client *a2aclient.Client) (any, error) { return client.ListTasks(ctx, &a2a.ListTasksRequest{}) }, }, { - method: "OnCancelTask", + method: "CancelTask", call: func(ctx context.Context, client *a2aclient.Client) (any, error) { - return client.CancelTask(ctx, &a2a.TaskIDParams{}) + return client.CancelTask(ctx, &a2a.CancelTaskRequest{}) }, }, { - method: "OnSendMessage", + method: "SendMessage", call: func(ctx context.Context, client *a2aclient.Client) (any, error) { - return client.SendMessage(ctx, &a2a.MessageSendParams{}) + return client.SendMessage(ctx, &a2a.SendMessageRequest{}) }, }, { - method: "OnSendMessageStream", + method: "SendStreamingMessage", call: func(ctx context.Context, client *a2aclient.Client) (any, error) { - return handleSingleItemSeq(client.SendStreamingMessage(ctx, &a2a.MessageSendParams{})) + return handleSingleItemSeq(client.SendStreamingMessage(ctx, &a2a.SendMessageRequest{})) }, }, { - method: "OnResubscribeToTask", + method: "SubscribeToTask", call: func(ctx context.Context, client *a2aclient.Client) (any, error) { - return handleSingleItemSeq(client.ResubscribeToTask(ctx, &a2a.TaskIDParams{})) + return handleSingleItemSeq(client.SubscribeToTask(ctx, &a2a.SubscribeToTaskRequest{})) }, }, { - method: "OnListTaskPushConfig", + method: "ListTaskPushConfig", call: func(ctx context.Context, client *a2aclient.Client) (any, error) { - return client.ListTaskPushConfig(ctx, &a2a.ListTaskPushConfigParams{}) + return client.ListTaskPushConfig(ctx, &a2a.ListTaskPushConfigRequest{}) }, }, { - method: "OnSetTaskPushConfig", + method: "CreateTaskPushConfig", call: func(ctx context.Context, client *a2aclient.Client) (any, error) { - return client.SetTaskPushConfig(ctx, &a2a.TaskPushConfig{}) + return client.CreateTaskPushConfig(ctx, &a2a.CreateTaskPushConfigRequest{}) }, }, { - method: "OnGetTaskPushConfig", + method: "GetTaskPushConfig", call: func(ctx context.Context, client *a2aclient.Client) (any, error) { - return client.GetTaskPushConfig(ctx, &a2a.GetTaskPushConfigParams{}) + return client.GetTaskPushConfig(ctx, &a2a.GetTaskPushConfigRequest{}) }, }, { - method: "OnDeleteTaskPushConfig", + method: "DeleteTaskPushConfig", call: func(ctx context.Context, client *a2aclient.Client) (any, error) { - return nil, client.DeleteTaskPushConfig(ctx, &a2a.DeleteTaskPushConfigParams{}) + return nil, client.DeleteTaskPushConfig(ctx, &a2a.DeleteTaskPushConfigRequest{}) }, }, { - method: "OnGetExtendedAgentCard", + method: "GetExtendedAgentCard", call: func(ctx context.Context, client *a2aclient.Client) (any, error) { - return client.GetAgentCard(ctx) + return client.GetExtendedAgentCard(ctx) }, }, } @@ -125,7 +125,7 @@ func TestJSONRPC_RequestRouting(t *testing.T) { server := httptest.NewServer(NewJSONRPCHandler(reqHandler)) client, err := a2aclient.NewFromEndpoints(ctx, []a2a.AgentInterface{ - {URL: server.URL, Transport: a2a.TransportProtocolJSONRPC}, + {URL: server.URL, ProtocolBinding: a2a.TransportProtocolJSONRPC}, }) if err != nil { t.Fatalf("a2aclient.NewFromEndpoints() error = %v", err) @@ -327,7 +327,7 @@ func TestJSONRPC_StreamingKeepAlive(t *testing.T) { ExecuteFunc: func(ctx context.Context, execCtx *ExecutorContext) iter.Seq2[a2a.Event, error] { return func(yield func(a2a.Event, error) bool) { time.Sleep(agentTimeout) - if !yield(a2a.NewMessage(a2a.MessageRoleAgent, a2a.TextPart{Text: "test message"}), nil) { + if !yield(a2a.NewMessage(a2a.MessageRoleAgent, a2a.NewTextPart("test message")), nil) { return } } @@ -345,8 +345,8 @@ func TestJSONRPC_StreamingKeepAlive(t *testing.T) { request := jsonrpcRequest{ JSONRPC: "2.0", Method: jsonrpc.MethodMessageStream, - Params: mustMarshal(t, &a2a.MessageSendParams{ - Message: a2a.NewMessage(a2a.MessageRoleUser, a2a.TextPart{Text: "hello"}), + Params: mustMarshal(t, &a2a.SendMessageRequest{ + Message: a2a.NewMessage(a2a.MessageRoleUser, a2a.NewTextPart("hello")), }), ID: 1, } diff --git a/a2asrv/push/sender.go b/a2asrv/push/sender.go index d8b44ec4..958d2c4e 100644 --- a/a2asrv/push/sender.go +++ b/a2asrv/push/sender.go @@ -76,14 +76,11 @@ func (s *HTTPPushSender) SendPush(ctx context.Context, config *a2a.PushConfig, t req.Header.Set(tokenHeader, config.Token) } if config.Auth != nil && config.Auth.Credentials != "" { - // Find the first supported scheme and apply it. - for _, scheme := range config.Auth.Schemes { - switch strings.ToLower(scheme) { - case "bearer": - req.Header.Set("Authorization", "Bearer "+config.Auth.Credentials) - case "basic": - req.Header.Set("Authorization", "Basic "+config.Auth.Credentials) - } + switch strings.ToLower(config.Auth.Scheme) { + case "bearer": + req.Header.Set("Authorization", "Bearer "+config.Auth.Credentials) + case "basic": + req.Header.Set("Authorization", "Basic "+config.Auth.Credentials) } } diff --git a/a2asrv/push/sender_test.go b/a2asrv/push/sender_test.go index 50f69899..76e541ec 100644 --- a/a2asrv/push/sender_test.go +++ b/a2asrv/push/sender_test.go @@ -101,7 +101,7 @@ func TestHTTPPushSender_SendPushSuccess(t *testing.T) { config := &a2a.PushConfig{ URL: server.URL, Auth: &a2a.PushAuthInfo{ - Schemes: []string{"Bearer"}, + Scheme: "Bearer", Credentials: "my-bearer-token", }, } @@ -118,7 +118,7 @@ func TestHTTPPushSender_SendPushSuccess(t *testing.T) { }) t.Run("success with basic auth", func(t *testing.T) { - config := &a2a.PushConfig{URL: server.URL, Auth: &a2a.PushAuthInfo{Schemes: []string{"Basic"}, Credentials: "dXNlcjpwYXNz"}} + config := &a2a.PushConfig{URL: server.URL, Auth: &a2a.PushAuthInfo{Scheme: "Basic", Credentials: "dXNlcjpwYXNz"}} sender := NewHTTPPushSender(nil) err := sender.SendPush(ctx, config, task) diff --git a/a2asrv/rest.go b/a2asrv/rest.go index b28d9db6..0e1ffeb2 100644 --- a/a2asrv/rest.go +++ b/a2asrv/rest.go @@ -32,16 +32,17 @@ import ( func NewRESTHandler(handler RequestHandler) http.Handler { mux := http.NewServeMux() - mux.HandleFunc("POST /v1/message:send", handleSendMessage(handler)) - mux.HandleFunc("POST /v1/message:stream", handleStreamMessage(handler)) - mux.HandleFunc("GET /v1/tasks/{id}", handleGetTask(handler)) - mux.HandleFunc("GET /v1/tasks", handleListTasks(handler)) - mux.HandleFunc("POST /v1/tasks/{idAndAction}", handlePOSTTasks(handler)) - mux.HandleFunc("POST /v1/tasks/{id}/pushNotificationConfigs", handleSetTaskPushConfig(handler)) - mux.HandleFunc("GET /v1/tasks/{id}/pushNotificationConfigs/{configId}", handleGetTaskPushConfig(handler)) - mux.HandleFunc("GET /v1/tasks/{id}/pushNotificationConfigs", handleListTaskPushConfig(handler)) - mux.HandleFunc("DELETE /v1/tasks/{id}/pushNotificationConfigs/{configId}", handleDeleteTaskPushConfig(handler)) - mux.HandleFunc("GET /v1/card", handleGetExtendedAgentCard(handler)) + // TODO: handle tenant + mux.HandleFunc("POST "+rest.MakeSendMessagePath(), handleSendMessage(handler)) + mux.HandleFunc("POST "+rest.MakeStreamMessagePath(), handleStreamMessage(handler)) + mux.HandleFunc("GET "+rest.MakeGetTaskPath("{id}"), handleGetTask(handler)) + mux.HandleFunc("GET "+rest.MakeListTasksPath(), handleListTasks(handler)) + mux.HandleFunc("POST /tasks/{idAndAction}", handlePOSTTasks(handler)) + mux.HandleFunc("POST "+rest.MakeCreatePushConfigPath("{id}"), handleCreateTaskPushConfig(handler)) + mux.HandleFunc("GET "+rest.MakeGetPushConfigPath("{id}", "{configId}"), handleGetTaskPushConfig(handler)) + mux.HandleFunc("GET "+rest.MakeListPushConfigsPath("{id}"), handleListTaskPushConfig(handler)) + mux.HandleFunc("DELETE "+rest.MakeDeletePushConfigPath("{id}", "{configId}"), handleDeleteTaskPushConfig(handler)) + mux.HandleFunc("GET "+rest.MakeGetExtendedAgentCardPath(), handleGetExtendedAgentCard(handler)) return mux } @@ -49,20 +50,20 @@ func NewRESTHandler(handler RequestHandler) http.Handler { func handleSendMessage(handler RequestHandler) http.HandlerFunc { return func(rw http.ResponseWriter, req *http.Request) { ctx := req.Context() - var message a2a.MessageSendParams + var message a2a.SendMessageRequest if err := json.NewDecoder(req.Body).Decode(&message); err != nil { writeRESTError(ctx, rw, a2a.ErrParseError, a2a.TaskID("")) return } - result, err := handler.OnSendMessage(ctx, &message) + result, err := handler.SendMessage(ctx, &message) if err != nil { writeRESTError(ctx, rw, err, a2a.TaskID("")) return } - if err := json.NewEncoder(rw).Encode(result); err != nil { + if err := json.NewEncoder(rw).Encode(a2a.StreamResponse{Event: result}); err != nil { log.Error(ctx, "failed to encode response", err) } } @@ -71,12 +72,12 @@ func handleSendMessage(handler RequestHandler) http.HandlerFunc { func handleStreamMessage(handler RequestHandler) http.HandlerFunc { return func(rw http.ResponseWriter, req *http.Request) { ctx := req.Context() - var message a2a.MessageSendParams + var message a2a.SendMessageRequest if err := json.NewDecoder(req.Body).Decode(&message); err != nil { writeRESTError(ctx, rw, a2a.ErrParseError, a2a.TaskID("")) return } - handleStreamingRequest(handler.OnSendMessageStream(ctx, &message), rw, req) + handleStreamingRequest(handler.SendStreamingMessage(ctx, &message), rw, req) } } @@ -98,12 +99,12 @@ func handleGetTask(handler RequestHandler) http.HandlerFunc { writeRESTError(ctx, rw, a2a.ErrInvalidRequest, a2a.TaskID("")) return } - params := &a2a.TaskQueryParams{ + params := &a2a.GetTaskRequest{ ID: a2a.TaskID(taskID), HistoryLength: historyLength, } - result, err := handler.OnGetTask(ctx, params) + result, err := handler.GetTask(ctx, params) if err != nil { writeRESTError(ctx, rw, err, a2a.TaskID(taskID)) return @@ -146,13 +147,13 @@ func handleListTasks(handler RequestHandler) http.HandlerFunc { parse("pageSize", &request.PageSize) parse("pageToken", &request.PageToken) parse("historyLength", &request.HistoryLength) - parse("lastUpdatedAfter", &request.LastUpdatedAfter) + parse("statusTimestampAfter", &request.StatusTimestampAfter) parse("includeArtifacts", &request.IncludeArtifacts) if err != nil { writeRESTError(ctx, rw, a2a.ErrInvalidRequest, a2a.TaskID("")) return } - result, err := handler.OnListTasks(ctx, request) + result, err := handler.ListTasks(ctx, request) if err != nil { writeRESTError(ctx, rw, err, a2a.TaskID("")) return @@ -172,12 +173,12 @@ func handlePOSTTasks(handler RequestHandler) http.HandlerFunc { return } - if strings.HasSuffix(idAndAction, ":cancel") { - taskID := strings.TrimSuffix(idAndAction, ":cancel") + if before, ok := strings.CutSuffix(idAndAction, ":cancel"); ok { + taskID := before handleCancelTask(handler, taskID, rw, req) - } else if strings.HasSuffix(idAndAction, ":subscribe") { - taskID := strings.TrimSuffix(idAndAction, ":subscribe") - handleStreamingRequest(handler.OnResubscribeToTask(ctx, &a2a.TaskIDParams{ID: a2a.TaskID(taskID)}), rw, req) + } else if before, ok := strings.CutSuffix(idAndAction, ":subscribe"); ok { + taskID := before + handleStreamingRequest(handler.SubscribeToTask(ctx, &a2a.SubscribeToTaskRequest{ID: a2a.TaskID(taskID)}), rw, req) } else { writeRESTError(ctx, rw, a2a.ErrInvalidRequest, a2a.TaskID("")) return @@ -188,11 +189,11 @@ func handlePOSTTasks(handler RequestHandler) http.HandlerFunc { func handleCancelTask(handler RequestHandler, taskID string, rw http.ResponseWriter, req *http.Request) { ctx := req.Context() - id := &a2a.TaskIDParams{ + id := &a2a.CancelTaskRequest{ ID: a2a.TaskID(taskID), } - result, err := handler.OnCancelTask(ctx, id) + result, err := handler.CancelTask(ctx, id) if err != nil { writeRESTError(ctx, rw, err, a2a.TaskID(taskID)) @@ -218,6 +219,7 @@ func handleStreamingRequest(eventSequence iter.Seq2[a2a.Event, error], rw http.R requestCtx, cancel := context.WithCancel(ctx) defer cancel() + // TODO: handle panic and sse keep-alives similar to jsonrpc go func() { defer close(sseChan) events := eventSequence @@ -228,7 +230,7 @@ func handleStreamingRequest(eventSequence iter.Seq2[a2a.Event, error], rw http.R return } - b, jbErr := json.Marshal(event) + b, jbErr := json.Marshal(a2a.StreamResponse{Event: event}) if jbErr != nil { errObj := map[string]string{"error": jbErr.Error()} if eb, err := json.Marshal(errObj); err == nil { @@ -268,7 +270,7 @@ func handleStreamingRequest(eventSequence iter.Seq2[a2a.Event, error], rw http.R } } -func handleSetTaskPushConfig(handler RequestHandler) http.HandlerFunc { +func handleCreateTaskPushConfig(handler RequestHandler) http.HandlerFunc { return func(rw http.ResponseWriter, req *http.Request) { ctx := req.Context() taskID := req.PathValue("id") @@ -283,12 +285,12 @@ func handleSetTaskPushConfig(handler RequestHandler) http.HandlerFunc { return } - params := &a2a.TaskPushConfig{ + params := &a2a.CreateTaskPushConfigRequest{ TaskID: a2a.TaskID(taskID), Config: *config, } - result, err := handler.OnSetTaskPushConfig(ctx, params) + result, err := handler.CreateTaskPushConfig(ctx, params) if err != nil { writeRESTError(ctx, rw, err, a2a.TaskID(taskID)) @@ -312,12 +314,12 @@ func handleGetTaskPushConfig(handler RequestHandler) http.HandlerFunc { return } - params := &a2a.GetTaskPushConfigParams{ - TaskID: a2a.TaskID(taskID), - ConfigID: configID, + params := &a2a.GetTaskPushConfigRequest{ + TaskID: a2a.TaskID(taskID), + ID: configID, } - result, err := handler.OnGetTaskPushConfig(ctx, params) + result, err := handler.GetTaskPushConfig(ctx, params) if err != nil { writeRESTError(ctx, rw, err, a2a.TaskID(taskID)) @@ -340,11 +342,11 @@ func handleListTaskPushConfig(handler RequestHandler) http.HandlerFunc { return } - params := &a2a.ListTaskPushConfigParams{ + params := &a2a.ListTaskPushConfigRequest{ TaskID: a2a.TaskID(taskID), } - result, err := handler.OnListTaskPushConfig(ctx, params) + result, err := handler.ListTaskPushConfig(ctx, params) if err != nil { writeRESTError(ctx, rw, err, a2a.TaskID(taskID)) @@ -367,12 +369,12 @@ func handleDeleteTaskPushConfig(handler RequestHandler) http.HandlerFunc { return } - params := &a2a.DeleteTaskPushConfigParams{ - TaskID: a2a.TaskID(taskID), - ConfigID: configID, + params := &a2a.DeleteTaskPushConfigRequest{ + TaskID: a2a.TaskID(taskID), + ID: configID, } - err := handler.OnDeleteTaskPushConfig(ctx, params) + err := handler.DeleteTaskPushConfig(ctx, params) if err != nil { writeRESTError(ctx, rw, err, a2a.TaskID(taskID)) @@ -384,7 +386,7 @@ func handleDeleteTaskPushConfig(handler RequestHandler) http.HandlerFunc { func handleGetExtendedAgentCard(handler RequestHandler) http.HandlerFunc { return func(rw http.ResponseWriter, req *http.Request) { ctx := req.Context() - result, err := handler.OnGetExtendedAgentCard(ctx) + result, err := handler.GetExtendedAgentCard(ctx) if err != nil { writeRESTError(ctx, rw, err, a2a.TaskID("")) diff --git a/a2asrv/rest_test.go b/a2asrv/rest_test.go index 132fff55..2691a00d 100644 --- a/a2asrv/rest_test.go +++ b/a2asrv/rest_test.go @@ -40,69 +40,69 @@ func TestREST_RequestRouting(t *testing.T) { call func(ctx context.Context, client *a2aclient.Client) (any, error) }{ { - method: "OnSendMessage", + method: "SendMessage", call: func(ctx context.Context, client *a2aclient.Client) (any, error) { - return client.SendMessage(ctx, &a2a.MessageSendParams{}) + return client.SendMessage(ctx, &a2a.SendMessageRequest{}) }, }, { - method: "OnSendMessageStream", + method: "SendStreamingMessage", call: func(ctx context.Context, client *a2aclient.Client) (any, error) { - return handleSingleItemSeq(client.SendStreamingMessage(ctx, &a2a.MessageSendParams{})) + return handleSingleItemSeq(client.SendStreamingMessage(ctx, &a2a.SendMessageRequest{})) }, }, { - method: "OnGetTask", + method: "GetTask", call: func(ctx context.Context, client *a2aclient.Client) (any, error) { - return client.GetTask(ctx, &a2a.TaskQueryParams{ID: "test-id"}) + return client.GetTask(ctx, &a2a.GetTaskRequest{ID: "test-id"}) }, }, { - method: "OnListTasks", + method: "ListTasks", call: func(ctx context.Context, client *a2aclient.Client) (any, error) { return client.ListTasks(ctx, &a2a.ListTasksRequest{}) }, }, { - method: "OnCancelTask", + method: "CancelTask", call: func(ctx context.Context, client *a2aclient.Client) (any, error) { - return client.CancelTask(ctx, &a2a.TaskIDParams{}) + return client.CancelTask(ctx, &a2a.CancelTaskRequest{ID: "test-id"}) }, }, { - method: "OnResubscribeToTask", + method: "SubscribeToTask", call: func(ctx context.Context, client *a2aclient.Client) (any, error) { - return handleSingleItemSeq(client.ResubscribeToTask(ctx, &a2a.TaskIDParams{})) + return handleSingleItemSeq(client.SubscribeToTask(ctx, &a2a.SubscribeToTaskRequest{ID: "test-id"})) }, }, { - method: "OnSetTaskPushConfig", + method: "CreateTaskPushConfig", call: func(ctx context.Context, client *a2aclient.Client) (any, error) { - return client.SetTaskPushConfig(ctx, &a2a.TaskPushConfig{TaskID: a2a.TaskID("test-id")}) + return client.CreateTaskPushConfig(ctx, &a2a.CreateTaskPushConfigRequest{TaskID: a2a.TaskID("test-id")}) }, }, { - method: "OnGetTaskPushConfig", + method: "GetTaskPushConfig", call: func(ctx context.Context, client *a2aclient.Client) (any, error) { - return client.GetTaskPushConfig(ctx, &a2a.GetTaskPushConfigParams{TaskID: a2a.TaskID("test-id"), ConfigID: "test-config-id"}) + return client.GetTaskPushConfig(ctx, &a2a.GetTaskPushConfigRequest{TaskID: a2a.TaskID("test-id"), ID: "test-config-id"}) }, }, { - method: "OnListTaskPushConfig", + method: "ListTaskPushConfig", call: func(ctx context.Context, client *a2aclient.Client) (any, error) { - return client.ListTaskPushConfig(ctx, &a2a.ListTaskPushConfigParams{TaskID: a2a.TaskID("test-id")}) + return client.ListTaskPushConfig(ctx, &a2a.ListTaskPushConfigRequest{TaskID: a2a.TaskID("test-id")}) }, }, { - method: "OnDeleteTaskPushConfig", + method: "DeleteTaskPushConfig", call: func(ctx context.Context, client *a2aclient.Client) (any, error) { - return nil, client.DeleteTaskPushConfig(ctx, &a2a.DeleteTaskPushConfigParams{TaskID: a2a.TaskID("test-id"), ConfigID: "test-config-id"}) + return nil, client.DeleteTaskPushConfig(ctx, &a2a.DeleteTaskPushConfigRequest{TaskID: a2a.TaskID("test-id"), ID: "test-config-id"}) }, }, { - method: "OnGetExtendedAgentCard", + method: "GetExtendedAgentCard", call: func(ctx context.Context, client *a2aclient.Client) (any, error) { - return client.GetAgentCard(ctx) + return client.GetExtendedAgentCard(ctx) }, }, } @@ -124,7 +124,7 @@ func TestREST_RequestRouting(t *testing.T) { server := httptest.NewServer(NewRESTHandler(reqHandler)) client, err := a2aclient.NewFromEndpoints(ctx, []a2a.AgentInterface{ - {URL: server.URL, Transport: a2a.TransportProtocolHTTPJSON}, + {URL: server.URL, ProtocolBinding: a2a.TransportProtocolHTTPJSON}, }) if err != nil { t.Fatalf("a2aclient.NewFromEndpoints() error = %v", err) @@ -168,50 +168,50 @@ func TestREST_Validations(t *testing.T) { { name: "SendMessage", methods: []string{http.MethodPost}, - path: "/v1/message:send", - body: a2a.MessageSendParams{Message: a2a.NewMessage(a2a.MessageRoleUser, &a2a.TextPart{Text: "test"})}, + path: "/message:send", + body: a2a.SendMessageRequest{Message: a2a.NewMessage(a2a.MessageRoleUser, a2a.NewTextPart("test"))}, }, { name: "SendMessageStream", methods: []string{http.MethodPost}, - path: "/v1/message:stream", + path: "/message:stream", }, { name: "GetTask", methods: []string{http.MethodGet}, - path: "/v1/tasks/" + string(taskID), + path: "/tasks/" + string(taskID), }, { name: "ListTasks", methods: []string{http.MethodGet}, - path: "/v1/tasks", + path: "/tasks", }, { name: "CancelTask", methods: []string{http.MethodPost}, - path: "/v1/tasks/" + string(taskID) + ":cancel", + path: "/tasks/" + string(taskID) + ":cancel", }, { name: "ResubscribeToTask", methods: []string{http.MethodPost}, - path: "/v1/tasks/" + string(taskID) + ":subscribe", + path: "/tasks/" + string(taskID) + ":subscribe", }, { name: "SetAndListTaskPushConfig", methods: []string{http.MethodGet, http.MethodPost}, - path: "/v1/tasks/" + string(taskID) + "/pushNotificationConfigs", + path: "/tasks/" + string(taskID) + "/pushNotificationConfigs", body: config, }, { name: "GetAndDeleteTaskPushConfig", methods: []string{http.MethodGet, http.MethodDelete}, - path: "/v1/tasks/" + string(taskID) + "/pushNotificationConfigs/" + string(config.ID), + path: "/tasks/" + string(taskID) + "/pushNotificationConfigs/" + string(config.ID), body: config, }, { name: "GetExtendedAgentCard", methods: []string{http.MethodGet}, - path: "/v1/card", + path: "/extendedAgentCard", }, } store := testutil.NewTestTaskStoreWithConfig(&taskstore.InMemoryStoreConfig{ @@ -222,7 +222,7 @@ func TestREST_Validations(t *testing.T) { mock := &mockAgentExecutor{ ExecuteFunc: func(ctx context.Context, execCtx *ExecutorContext) iter.Seq2[a2a.Event, error] { return func(yield func(a2a.Event, error) bool) { - yield(a2a.NewMessage(a2a.MessageRoleAgent, a2a.TextPart{Text: "test message"}), nil) + yield(a2a.NewMessage(a2a.MessageRoleAgent, a2a.NewTextPart("test message")), nil) } }, CancelFunc: func(ctx context.Context, execCtx *ExecutorContext) iter.Seq2[a2a.Event, error] { @@ -302,15 +302,15 @@ func TestREST_InvalidPayloads(t *testing.T) { }{ { name: "SendMessage with invalid payload", - path: "/v1/message:send", + path: "/message:send", }, { name: "SendMessageStream with invalid payload", - path: "/v1/message:stream", + path: "/message:stream", }, { name: "SetTaskPushConfig with invalid payload", - path: "/v1/tasks/" + string(taskID) + "/pushNotificationConfigs", + path: "/tasks/" + string(taskID) + "/pushNotificationConfigs", }, } diff --git a/a2asrv/workqueue/queue.go b/a2asrv/workqueue/queue.go index c01ef173..8523f108 100644 --- a/a2asrv/workqueue/queue.go +++ b/a2asrv/workqueue/queue.go @@ -41,10 +41,10 @@ type Payload struct { Type PayloadType // TaskID is an ID of the task to execute or cancel. TaskID a2a.TaskID - // CancelParams defines the cancelation parameters. It is only set for [PayloadTypeCancel]. - CancelParams *a2a.TaskIDParams - // ExecuteParams defines the execution parameters. It is only set for [PayloadTypeExecute]. - ExecuteParams *a2a.MessageSendParams + // CancelRequest defines the cancelation parameters. It is only set for [PayloadTypeCancel]. + CancelRequest *a2a.CancelTaskRequest + // ExecuteRequest defines the execution parameters. It is only set for [PayloadTypeExecute]. + ExecuteRequest *a2a.SendMessageRequest } // HandlerFn starts agent execution for the provided payload. diff --git a/e2e/extensions_durations_test.go b/e2e/extensions_durations_test.go index 67dce1ab..af46b147 100644 --- a/e2e/extensions_durations_test.go +++ b/e2e/extensions_durations_test.go @@ -98,8 +98,7 @@ func TestDurationsExtension(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { serverCard := &a2a.AgentCard{ - PreferredTransport: a2a.TransportProtocolJSONRPC, - Capabilities: a2a.AgentCapabilities{Extensions: tc.serverDeclares}, + Capabilities: a2a.AgentCapabilities{Extensions: tc.serverDeclares}, } agentExecutor := testexecutor.FromEventGenerator(func(execCtx *a2asrv.ExecutorContext) []a2a.Event { @@ -108,7 +107,10 @@ func TestDurationsExtension(t *testing.T) { handler := a2asrv.NewHandler(agentExecutor, a2asrv.WithCallInterceptors(&durationTracker{})) server := httptest.NewServer(a2asrv.NewJSONRPCHandler(handler)) - serverCard.URL = server.URL + serverCard.SupportedInterfaces = []a2a.AgentInterface{{ + URL: server.URL, + ProtocolBinding: a2a.TransportProtocolJSONRPC, + }} defer server.Close() client, err := a2aclient.NewFromCard(ctx, serverCard, a2aclient.WithCallInterceptors( @@ -117,8 +119,8 @@ func TestDurationsExtension(t *testing.T) { if err != nil { t.Fatalf("a2aclient.NewFromCard() error = %v", err) } - result, err := client.SendMessage(ctx, &a2a.MessageSendParams{ - Message: a2a.NewMessage(a2a.MessageRoleUser, a2a.TextPart{Text: "ping"}), + result, err := client.SendMessage(ctx, &a2a.SendMessageRequest{ + Message: a2a.NewMessage(a2a.MessageRoleUser, a2a.NewTextPart("ping")), }) if err != nil { t.Fatalf("SendMessage failed: %v", err) diff --git a/e2e/jsonrpc_test.go b/e2e/jsonrpc_test.go index e550208b..d520dd84 100644 --- a/e2e/jsonrpc_test.go +++ b/e2e/jsonrpc_test.go @@ -35,14 +35,13 @@ func TestJSONRPC_Streaming(t *testing.T) { executor := testexecutor.FromEventGenerator(func(execCtx *a2asrv.ExecutorContext) []a2a.Event { task := &a2a.Task{ID: execCtx.TaskID, ContextID: execCtx.ContextID} - artifact := a2a.NewArtifactEvent(task, a2a.TextPart{Text: "Hello"}) - finalUpdate := a2a.NewStatusUpdateEvent(task, a2a.TaskStateCompleted, a2a.NewMessage(a2a.MessageRoleAgent, a2a.TextPart{Text: "Done!"})) - finalUpdate.Final = true + artifact := a2a.NewArtifactEvent(task, a2a.NewTextPart("Hello")) + finalUpdate := a2a.NewStatusUpdateEvent(task, a2a.TaskStateCompleted, a2a.NewMessage(a2a.MessageRoleAgent, a2a.NewTextPart("Done!"))) return []a2a.Event{ a2a.NewSubmittedTask(execCtx, execCtx.Message), a2a.NewStatusUpdateEvent(task, a2a.TaskStateWorking, nil), artifact, - a2a.NewArtifactUpdateEvent(task, artifact.Artifact.ID, a2a.TextPart{Text: ", world!"}), + a2a.NewArtifactUpdateEvent(task, artifact.Artifact.ID, a2a.NewTextPart(", world!")), finalUpdate, } }) @@ -64,7 +63,7 @@ func TestJSONRPC_Streaming(t *testing.T) { client := mustCreateClient(t, card) var received []a2a.Event - msg := &a2a.MessageSendParams{Message: a2a.NewMessage(a2a.MessageRoleUser, a2a.TextPart{Text: "Work"})} + msg := &a2a.SendMessageRequest{Message: a2a.NewMessage(a2a.MessageRoleUser, a2a.NewTextPart("Work"))} for event, err := range client.SendStreamingMessage(ctx, msg) { if err != nil { t.Fatalf("client.SendStreamingMessage() error = %v", err) @@ -93,7 +92,7 @@ func TestJSONRPC_ExecutionScopeStreamingPanic(t *testing.T) { client := mustCreateClient(t, newAgentCard(server.URL)) var gotErr error - msg := &a2a.MessageSendParams{Message: a2a.NewMessage(a2a.MessageRoleUser, a2a.TextPart{Text: "Work"})} + msg := &a2a.SendMessageRequest{Message: a2a.NewMessage(a2a.MessageRoleUser, a2a.NewTextPart("Work"))} for _, err := range client.SendStreamingMessage(ctx, msg) { gotErr = err } @@ -120,7 +119,7 @@ func TestJSONRPC_RequestScopeStreamingPanic(t *testing.T) { client := mustCreateClient(t, newAgentCard(server.URL)) var gotErr error - msg := &a2a.MessageSendParams{Message: a2a.NewMessage(a2a.MessageRoleUser, a2a.TextPart{Text: "Work"})} + msg := &a2a.SendMessageRequest{Message: a2a.NewMessage(a2a.MessageRoleUser, a2a.NewTextPart("Work"))} for _, err := range client.SendStreamingMessage(ctx, msg) { gotErr = err } @@ -140,9 +139,10 @@ func mustCreateClient(t *testing.T, card *a2a.AgentCard) *a2aclient.Client { func newAgentCard(url string) *a2a.AgentCard { return &a2a.AgentCard{ - URL: url, - PreferredTransport: a2a.TransportProtocolJSONRPC, - Capabilities: a2a.AgentCapabilities{Streaming: true}, + SupportedInterfaces: []a2a.AgentInterface{ + {URL: url, ProtocolBinding: a2a.TransportProtocolJSONRPC}, + }, + Capabilities: a2a.AgentCapabilities{Streaming: true}, } } diff --git a/e2e/tck/sut.go b/e2e/tck/sut.go index 1fece961..0b08e3c5 100644 --- a/e2e/tck/sut.go +++ b/e2e/tck/sut.go @@ -25,7 +25,7 @@ import ( "time" "github.com/a2aproject/a2a-go/a2a" - "github.com/a2aproject/a2a-go/a2agrpc" + // "github.com/a2aproject/a2a-go/a2agrpc" "github.com/a2aproject/a2a-go/a2asrv" "golang.org/x/sync/errgroup" "google.golang.org/grpc" @@ -37,9 +37,9 @@ type intercepter struct { func (i *intercepter) Before(ctx context.Context, callCtx *a2asrv.CallContext, req *a2asrv.Request) (context.Context, any, error) { if callCtx.Method() == "OnSendMessage" { - sendParams := req.Payload.(*a2a.MessageSendParams) + sendParams := req.Payload.(*a2a.SendMessageRequest) if sendParams.Config == nil { - sendParams.Config = &a2a.MessageSendConfig{} + sendParams.Config = &a2a.SendMessageConfig{} } if sendParams.Config.Blocking == nil { blocking := false @@ -70,11 +70,12 @@ func main() { } agentCard := &a2a.AgentCard{ - Name: "TCK Core Agent", - Description: "A complete A2A agent implementation designed specifically for testing with the A2A Technology Compatibility Kit (TCK)", - URL: cardUrl, + Name: "TCK Core Agent", + Description: "A complete A2A agent implementation designed specifically for testing with the A2A Technology Compatibility Kit (TCK)", + SupportedInterfaces: []a2a.AgentInterface{ + {URL: cardUrl, ProtocolBinding: preferredTransport}, + }, Version: "1.0.0", - PreferredTransport: preferredTransport, DefaultInputModes: []string{"text"}, DefaultOutputModes: []string{"text"}, Capabilities: a2a.AgentCapabilities{Streaming: true}, @@ -112,9 +113,11 @@ func startGRPCServer(port int, handler a2asrv.RequestHandler) error { } log.Printf("Starting a gRPC server on 127.0.0.1:%d", port) - grpcHandler := a2agrpc.NewHandler(handler) + // TODO: uncomment and fix after pbconv is implemented + + // grpcHandler := a2agrpc.NewHandler(handler) grpcServer := grpc.NewServer() - grpcHandler.RegisterWith(grpcServer) + // grpcHandler.RegisterWith(grpcServer) return grpcServer.Serve(grpcListener) } diff --git a/e2e/tck/sut_agent_executor.go b/e2e/tck/sut_agent_executor.go index cd8d2b71..61af74b6 100644 --- a/e2e/tck/sut_agent_executor.go +++ b/e2e/tck/sut_agent_executor.go @@ -41,7 +41,6 @@ func (c *SUTAgentExecutor) Execute(ctx context.Context, execCtx *a2asrv.Executor } time.Sleep(1 * time.Second) event := a2a.NewStatusUpdateEvent(execCtx, a2a.TaskStateCompleted, nil) - event.Final = true yield(event, nil) } } @@ -55,7 +54,6 @@ func (c *SUTAgentExecutor) Cancel(ctx context.Context, execCtx *a2asrv.ExecutorC } event := a2a.NewStatusUpdateEvent(execCtx, a2a.TaskStateCanceled, nil) - event.Final = true yield(event, nil) } } diff --git a/examples/clustermode/client/main.go b/examples/clustermode/client/main.go index 57c46bf1..bacdb4f2 100644 --- a/examples/clustermode/client/main.go +++ b/examples/clustermode/client/main.go @@ -45,9 +45,10 @@ func main() { ctx := context.Background() card := &a2a.AgentCard{ - URL: fmt.Sprintf("%s/invoke", *server), - PreferredTransport: a2a.TransportProtocolJSONRPC, - Capabilities: a2a.AgentCapabilities{Streaming: true}, + SupportedInterfaces: []a2a.AgentInterface{ + {URL: fmt.Sprintf("%s/invoke", *server), ProtocolBinding: a2a.TransportProtocolJSONRPC}, + }, + Capabilities: a2a.AgentCapabilities{Streaming: true}, } httpClient := &http.Client{Timeout: 5 * time.Minute} @@ -86,8 +87,8 @@ func main() { } func send(ctx context.Context, client *a2aclient.Client, text string) error { - msg := &a2a.MessageSendParams{ - Message: a2a.NewMessage(a2a.MessageRoleUser, a2a.TextPart{Text: text}), + msg := &a2a.SendMessageRequest{ + Message: a2a.NewMessage(a2a.MessageRoleUser, a2a.NewTextPart(text)), } final := false taskID := a2a.TaskID("") @@ -100,6 +101,7 @@ func send(ctx context.Context, client *a2aclient.Client, text string) error { } if ev, ok := event.(*a2a.Task); ok { taskID = ev.ID + final = ev.Status.State.Terminal() } if ev, ok := event.(*a2a.TaskStatusUpdateEvent); ok { final = ev.Status.State.Terminal() @@ -112,7 +114,7 @@ func send(ctx context.Context, client *a2aclient.Client, text string) error { } func cancel(ctx context.Context, client *a2aclient.Client, id string) error { - task, err := client.CancelTask(ctx, &a2a.TaskIDParams{ID: a2a.TaskID(id)}) + task, err := client.CancelTask(ctx, &a2a.CancelTaskRequest{ID: a2a.TaskID(id)}) if err != nil { return fmt.Errorf("failed to cancel task: %w", err) } @@ -123,13 +125,16 @@ func cancel(ctx context.Context, client *a2aclient.Client, id string) error { func subscribe(ctx context.Context, client *a2aclient.Client, id a2a.TaskID) error { final := false for !final { - for event, err := range client.ResubscribeToTask(ctx, &a2a.TaskIDParams{ID: id}) { + for event, err := range client.SubscribeToTask(ctx, &a2a.SubscribeToTaskRequest{ID: id}) { if err != nil { return fmt.Errorf("error receiving event: %w", err) } if err := printEvent(event); err != nil { return fmt.Errorf("error printing event: %w", err) } + if ev, ok := event.(*a2a.Task); ok { + final = ev.Status.State.Terminal() + } if ev, ok := event.(*a2a.TaskStatusUpdateEvent); ok { final = ev.Status.State.Terminal() } @@ -141,10 +146,14 @@ func subscribe(ctx context.Context, client *a2aclient.Client, id a2a.TaskID) err func printEvent(event a2a.Event) error { switch v := event.(type) { case *a2a.TaskArtifactUpdateEvent: - fmt.Printf("[update]: %s\n", v.Artifact.Parts[0].(a2a.TextPart).Text) + fmt.Printf("[update]: %s\n", v.Artifact.Parts[0].Text()) case *a2a.TaskStatusUpdateEvent: - fmt.Printf("[state=%q]: %s\n", v.Status.State, v.Status.Message.Parts[0].(a2a.TextPart).Text) + var msgText string + if v.Status.Message != nil && len(v.Status.Message.Parts) > 0 { + msgText = v.Status.Message.Parts[0].Text() + } + fmt.Printf("[state=%q]: %s\n", v.Status.State, msgText) default: data, err := json.MarshalIndent(event, "", " ") diff --git a/examples/clustermode/server/agent.go b/examples/clustermode/server/agent.go index 7ccbbb14..1d997183 100644 --- a/examples/clustermode/server/agent.go +++ b/examples/clustermode/server/agent.go @@ -47,7 +47,7 @@ func (a *agentExecutor) Execute(ctx context.Context, execCtx *a2asrv.ExecutorCon return func(yield func(a2a.Event, error) bool) { log.Info(ctx, "agent received task", "task_id", execCtx.TaskID) - text := execCtx.Message.Parts[0].(a2a.TextPart).Text + text := execCtx.Message.Parts[0].Text() fs := flag.NewFlagSet("agent", flag.ContinueOnError) countTo := fs.Int("count-to", 0, "number to count to") @@ -56,7 +56,7 @@ func (a *agentExecutor) Execute(ctx context.Context, execCtx *a2asrv.ExecutorCon if err := fs.Parse(strings.Fields(text)); err != nil { log.Info(ctx, "failed to interpret task", "task_id", execCtx.TaskID) - msg := a2a.NewMessage(a2a.MessageRoleAgent, a2a.TextPart{Text: fmt.Sprintf("failed to interpret task: %v", err)}) + msg := a2a.NewMessage(a2a.MessageRoleAgent, a2a.NewTextPart(fmt.Sprintf("failed to interpret task: %v", err))) yield(msg, nil) return } @@ -65,7 +65,7 @@ func (a *agentExecutor) Execute(ctx context.Context, execCtx *a2asrv.ExecutorCon log.Info(ctx, "failed to interpret task", "task_id", execCtx.TaskID) msg := a2a.NewMessage( a2a.MessageRoleAgent, - a2a.TextPart{Text: "Hello, world! Use --count-to=N --die-every=M --interval=I to test me."}, + a2a.NewTextPart("Hello, world! Use --count-to=N --die-every=M --interval=I to test me."), ) yield(msg, nil) return @@ -80,7 +80,7 @@ func (a *agentExecutor) Execute(ctx context.Context, execCtx *a2asrv.ExecutorCon } else if len(execCtx.StoredTask.Artifacts) > 0 { lastArtifact := execCtx.StoredTask.Artifacts[len(execCtx.StoredTask.Artifacts)-1] if len(lastArtifact.Parts) > 0 { - lastCount := lastArtifact.Parts[len(lastArtifact.Parts)-1].(a2a.TextPart).Text + lastCount := lastArtifact.Parts[len(lastArtifact.Parts)-1].Text() countPart := strings.Split(lastCount, ": ")[1] if val, err := strconv.Atoi(countPart); err == nil { start = val + 1 @@ -108,9 +108,9 @@ func (a *agentExecutor) Execute(ctx context.Context, execCtx *a2asrv.ExecutorCon chunk := fmt.Sprintf("%s: %d", a.workerID, i+start) var event *a2a.TaskArtifactUpdateEvent if artifactID == "" { - event = a2a.NewArtifactEvent(execCtx, a2a.TextPart{Text: chunk}) + event = a2a.NewArtifactEvent(execCtx, a2a.NewTextPart(chunk)) } else { - event = a2a.NewArtifactUpdateEvent(execCtx, artifactID, a2a.TextPart{Text: chunk}) + event = a2a.NewArtifactUpdateEvent(execCtx, artifactID, a2a.NewTextPart(chunk)) } if !yield(event, nil) { return @@ -120,9 +120,8 @@ func (a *agentExecutor) Execute(ctx context.Context, execCtx *a2asrv.ExecutorCon taskCompleted := a2a.NewStatusUpdateEvent( execCtx, a2a.TaskStateCompleted, - a2a.NewMessage(a2a.MessageRoleAgent, a2a.TextPart{Text: "Done!"}), + a2a.NewMessage(a2a.MessageRoleAgent, a2a.NewTextPart("Done!")), ) - taskCompleted.Final = true yield(taskCompleted, nil) } } @@ -132,7 +131,7 @@ func (*agentExecutor) Cancel(ctx context.Context, execCtx *a2asrv.ExecutorContex yield(a2a.NewStatusUpdateEvent( execCtx, a2a.TaskStateCanceled, - a2a.NewMessage(a2a.MessageRoleAgent, a2a.TextPart{Text: "Task cancelled"}), + a2a.NewMessage(a2a.MessageRoleAgent, a2a.NewTextPart("Task cancelled")), ), nil) } } diff --git a/examples/clustermode/server/eventqueue.go b/examples/clustermode/server/eventqueue.go index 12147e1c..ff50ca05 100644 --- a/examples/clustermode/server/eventqueue.go +++ b/examples/clustermode/server/eventqueue.go @@ -17,6 +17,7 @@ package main import ( "context" "database/sql" + "encoding/json" "fmt" "sync" "time" @@ -130,14 +131,14 @@ func newDBEventQueue(db *sql.DB, taskID a2a.TaskID, pollFromID string) *dbEventQ closeSQLRows(ctx, rows) continue } - event, err := a2a.UnmarshalEventJSON([]byte(eventJSON)) - if err != nil { + var sr a2a.StreamResponse + if err := json.Unmarshal([]byte(eventJSON), &sr); err != nil { log.Error(ctx, "failed to unmarshal event", err) continue } select { case queue.eventsCh <- &versionedEvent{ - event: event, + event: sr.Event, version: taskstore.TaskVersion(version), }: case <-queue.closeSignal: diff --git a/examples/clustermode/server/main.go b/examples/clustermode/server/main.go index 01017a59..cb8761b9 100644 --- a/examples/clustermode/server/main.go +++ b/examples/clustermode/server/main.go @@ -30,18 +30,20 @@ import ( ) var ( - port = flag.Int("port", 9001, "Port for a gGRPC A2A server to listen on.") + port = flag.Int("port", 9001, "Port for a JSONRPC A2A server to listen on.") dbName = flag.String("db", "", "Database connection string (DSN).") ) func main() { flag.Parse() + addr := fmt.Sprintf("http://127.0.0.1:%d/invoke", *port) agentCard := &a2a.AgentCard{ - Name: "A2A Cluster", - URL: fmt.Sprintf("http://127.0.0.1:%d/invoke", *port), - PreferredTransport: a2a.TransportProtocolJSONRPC, - Capabilities: a2a.AgentCapabilities{Streaming: true}, + Name: "A2A Cluster", + SupportedInterfaces: []a2a.AgentInterface{ + {URL: addr, ProtocolBinding: a2a.TransportProtocolJSONRPC}, + }, + Capabilities: a2a.AgentCapabilities{Streaming: true}, } listener, err := net.Listen("tcp", fmt.Sprintf(":%d", *port)) diff --git a/examples/helloworld/client/main.go b/examples/helloworld/client/main.go index 4d429ff6..9291f065 100644 --- a/examples/helloworld/client/main.go +++ b/examples/helloworld/client/main.go @@ -22,9 +22,8 @@ import ( "github.com/a2aproject/a2a-go/a2a" "github.com/a2aproject/a2a-go/a2aclient" "github.com/a2aproject/a2a-go/a2aclient/agentcard" - - "google.golang.org/grpc" - "google.golang.org/grpc/credentials/insecure" + // "google.golang.org/grpc" + // "google.golang.org/grpc/credentials/insecure" ) var cardURL = flag.String("card-url", "http://127.0.0.1:9001", "Base URL of AgentCard server.") @@ -40,17 +39,17 @@ func main() { } // Insecure connection is used for example purposes - withInsecureGRPC := a2aclient.WithGRPCTransport(grpc.WithTransportCredentials(insecure.NewCredentials())) + // withInsecureGRPC := a2aclient.WithGRPCTransport(grpc.WithTransportCredentials(insecure.NewCredentials())) // Create a client connected to one of the interfaces specified in the AgentCard. - client, err := a2aclient.NewFromCard(ctx, card, withInsecureGRPC) + client, err := a2aclient.NewFromCard(ctx, card) if err != nil { log.Fatalf("Failed to create a client: %v", err) } // Send a message and log the response. - msg := a2a.NewMessage(a2a.MessageRoleUser, a2a.TextPart{Text: "Hello, world"}) - resp, err := client.SendMessage(ctx, &a2a.MessageSendParams{Message: msg}) + msg := a2a.NewMessage(a2a.MessageRoleUser, a2a.NewTextPart("Hello, world")) + resp, err := client.SendMessage(ctx, &a2a.SendMessageRequest{Message: msg}) if err != nil { log.Fatalf("Failed to send a message: %v", err) } diff --git a/examples/helloworld/server/grpc/main.go b/examples/helloworld/server/grpc/main.go index 2999247f..5b00a890 100644 --- a/examples/helloworld/server/grpc/main.go +++ b/examples/helloworld/server/grpc/main.go @@ -24,7 +24,7 @@ import ( "net/http" "github.com/a2aproject/a2a-go/a2a" - "github.com/a2aproject/a2a-go/a2agrpc" + // "github.com/a2aproject/a2a-go/a2agrpc" "github.com/a2aproject/a2a-go/a2asrv" "golang.org/x/sync/errgroup" "google.golang.org/grpc" @@ -38,7 +38,7 @@ var _ a2asrv.AgentExecutor = (*agentExecutor)(nil) func (*agentExecutor) Execute(ctx context.Context, execCtx *a2asrv.ExecutorContext) iter.Seq2[a2a.Event, error] { return func(yield func(a2a.Event, error) bool) { - response := a2a.NewMessage(a2a.MessageRoleAgent, a2a.TextPart{Text: "Hello, world!"}) + response := a2a.NewMessage(a2a.MessageRoleAgent, a2a.NewTextPart("Hello, world!")) yield(response, nil) } } @@ -54,16 +54,18 @@ func startGRPCServer(port int, card *a2a.AgentCard) error { } log.Printf("Starting a gRPC server on 127.0.0.1:%d", port) + // TODO: uncomment and fix after pbconv is implemented + // A transport-agnostic implementation of A2A protocol methods. // The behavior is configurable using option-arguments of form a2asrv.With*(), for example: // a2asrv.NewHandler(executor, a2asrv.WithTaskStore(customStore)) - requestHandler := a2asrv.NewHandler(&agentExecutor{}, a2asrv.WithExtendedAgentCard(card)) + // requestHandler := a2asrv.NewHandler(&agentExecutor{}, a2asrv.WithExtendedAgentCard(card)) // A gRPC-transport implementation for A2A. - grpcHandler := a2agrpc.NewHandler(requestHandler) + // grpcHandler := a2agrpc.NewHandler(requestHandler) s := grpc.NewServer() - grpcHandler.RegisterWith(s) + // grpcHandler.RegisterWith(s) return s.Serve(listener) } @@ -88,11 +90,13 @@ var ( func main() { flag.Parse() + addr := fmt.Sprintf("http://127.0.0.1:%d", *grpcPort) agentCard := &a2a.AgentCard{ - Name: "Hello World Agent", - Description: "Just a hello world agent", - URL: fmt.Sprintf("127.0.0.1:%d", *grpcPort), - PreferredTransport: a2a.TransportProtocolGRPC, + Name: "Hello World Agent", + Description: "Just a hello world agent", + SupportedInterfaces: []a2a.AgentInterface{ + {URL: addr, ProtocolBinding: a2a.TransportProtocolGRPC}, + }, DefaultInputModes: []string{"text"}, DefaultOutputModes: []string{"text"}, Capabilities: a2a.AgentCapabilities{Streaming: true}, diff --git a/examples/helloworld/server/jsonrpc/main.go b/examples/helloworld/server/jsonrpc/main.go index 41e21762..e4cd4213 100644 --- a/examples/helloworld/server/jsonrpc/main.go +++ b/examples/helloworld/server/jsonrpc/main.go @@ -35,7 +35,7 @@ var _ a2asrv.AgentExecutor = (*agentExecutor)(nil) func (*agentExecutor) Execute(ctx context.Context, execCtx *a2asrv.ExecutorContext) iter.Seq2[a2a.Event, error] { return func(yield func(a2a.Event, error) bool) { - response := a2a.NewMessage(a2a.MessageRoleAgent, a2a.TextPart{Text: "Hello, world!"}) + response := a2a.NewMessage(a2a.MessageRoleAgent, a2a.NewTextPart("Hello, world!")) yield(response, nil) } } @@ -44,16 +44,18 @@ func (*agentExecutor) Cancel(ctx context.Context, execCtx *a2asrv.ExecutorContex return func(yield func(a2a.Event, error) bool) {} } -var port = flag.Int("port", 9001, "Port for a gGRPC A2A server to listen on.") +var port = flag.Int("port", 9001, "Port for a JSONRPC A2A server to listen on.") func main() { flag.Parse() + addr := fmt.Sprintf("http://127.0.0.1:%d/invoke", *port) agentCard := &a2a.AgentCard{ - Name: "Hello World Agent", - Description: "Just a hello world agent", - URL: fmt.Sprintf("http://127.0.0.1:%d/invoke", *port), - PreferredTransport: a2a.TransportProtocolJSONRPC, + Name: "Hello World Agent", + Description: "Just a hello world agent", + SupportedInterfaces: []a2a.AgentInterface{ + {URL: addr, ProtocolBinding: a2a.TransportProtocolJSONRPC}, + }, DefaultInputModes: []string{"text"}, DefaultOutputModes: []string{"text"}, Capabilities: a2a.AgentCapabilities{Streaming: true}, diff --git a/examples/helloworld/server/rest/main.go b/examples/helloworld/server/rest/main.go index 741b1585..83b75645 100644 --- a/examples/helloworld/server/rest/main.go +++ b/examples/helloworld/server/rest/main.go @@ -34,7 +34,7 @@ type agentExecutor struct{} func (*agentExecutor) Execute(ctx context.Context, execCtx *a2asrv.ExecutorContext) iter.Seq2[a2a.Event, error] { return func(yield func(a2a.Event, error) bool) { - response := a2a.NewMessage(a2a.MessageRoleAgent, a2a.TextPart{Text: "Hello from REST server!"}) + response := a2a.NewMessage(a2a.MessageRoleAgent, a2a.NewTextPart("Hello from REST server!")) yield(response, nil) } } @@ -67,17 +67,19 @@ var ( func main() { flag.Parse() + addr := fmt.Sprintf("http://127.0.0.1:%d", *port) agentCard := &a2a.AgentCard{ - Name: "REST Hello World Agent", - Description: "Just a rest hello world agent", - URL: fmt.Sprintf("http://127.0.0.1:%d", *port), - PreferredTransport: a2a.TransportProtocolHTTPJSON, + Name: "REST Hello World Agent", + Description: "Just a rest hello world agent", + SupportedInterfaces: []a2a.AgentInterface{ + {URL: addr, ProtocolBinding: a2a.TransportProtocolHTTPJSON}, + }, DefaultInputModes: []string{"text"}, DefaultOutputModes: []string{"text"}, Capabilities: a2a.AgentCapabilities{Streaming: true}, Skills: []a2a.AgentSkill{ { - ID: "", + ID: "hello_world", Name: "REST Hello world!", Description: "Returns a 'Hello from REST server!'", Tags: []string{"hello world"}, diff --git a/internal/rest/rest.go b/internal/rest/rest.go index cd106a0e..83a49566 100644 --- a/internal/rest/rest.go +++ b/internal/rest/rest.go @@ -23,7 +23,7 @@ import ( "github.com/a2aproject/a2a-go/a2a" ) -func MakeTasksListPath() string { +func MakeListTasksPath() string { return "/tasks" }