diff --git a/codegen/smithy-aws-go-codegen/src/main/java/software/amazon/smithy/aws/go/codegen/customization/RemoveOperations.java b/codegen/smithy-aws-go-codegen/src/main/java/software/amazon/smithy/aws/go/codegen/customization/RemoveOperations.java index 61afe4c76bb..f5cea887f98 100644 --- a/codegen/smithy-aws-go-codegen/src/main/java/software/amazon/smithy/aws/go/codegen/customization/RemoveOperations.java +++ b/codegen/smithy-aws-go-codegen/src/main/java/software/amazon/smithy/aws/go/codegen/customization/RemoveOperations.java @@ -13,11 +13,6 @@ public class RemoveOperations implements GoIntegration { private Map> SHAPES_TO_REMOVE = Map.of( - "bedrock runtime", List.of( - ShapeId.from("com.amazonaws.bedrockruntime#InvokeModelWithBidirectionalStream"), - ShapeId.from("com.amazonaws.bedrockruntime#InvokeModelWithBidirectionalStreamRequest"), - ShapeId.from("com.amazonaws.bedrockruntime#InvokeModelWithBidirectionalStreamResponse") - ), "sagemaker runtime http2", List.of( ShapeId.from("com.amazonaws.sagemakerruntimehttp2#InvokeEndpointWithBidirectionalStream"), ShapeId.from("com.amazonaws.sagemakerruntimehttp2#InvokeEndpointWithBidirectionalStreamInput"), diff --git a/service/bedrockruntime/api_client.go b/service/bedrockruntime/api_client.go index f612a068dea..619f9cd5e74 100644 --- a/service/bedrockruntime/api_client.go +++ b/service/bedrockruntime/api_client.go @@ -310,33 +310,41 @@ func (c *Client) invokeOperation( o.Meter = options.MeterProvider.Meter("github.com/aws/aws-sdk-go-v2/service/bedrockruntime") }) decorated := middleware.DecorateHandler(handler, stack) - result, metadata, err = decorated.Handle(ctx, params) - if err != nil { - span.SetProperty("exception.type", fmt.Sprintf("%T", err)) - span.SetProperty("exception.message", err.Error()) - - var aerr smithy.APIError - if errors.As(err, &aerr) { - span.SetProperty("api.error_code", aerr.ErrorCode()) - span.SetProperty("api.error_message", aerr.ErrorMessage()) - span.SetProperty("api.error_fault", aerr.ErrorFault().String()) - } + // set this on an if + // set a channel for early return + results := make(chan PartialResult[*InvokeModelWithBidirectionalStreamOutput], 1) + ctx = context.WithValue(ctx, "asyncChan", results) + + // do + go func() { + result, metadata, err = decorated.Handle(ctx, params) + if err != nil { + span.SetProperty("exception.type", fmt.Sprintf("%T", err)) + span.SetProperty("exception.message", err.Error()) + + var aerr smithy.APIError + if errors.As(err, &aerr) { + span.SetProperty("api.error_code", aerr.ErrorCode()) + span.SetProperty("api.error_message", aerr.ErrorMessage()) + span.SetProperty("api.error_fault", aerr.ErrorFault().String()) + } - err = &smithy.OperationError{ - ServiceID: ServiceID, - OperationName: opID, - Err: err, + err = &smithy.OperationError{ + ServiceID: ServiceID, + OperationName: opID, + Err: err, + } } - } - span.SetProperty("error", err != nil) - if err == nil { - span.SetStatus(tracing.SpanStatusOK) - } else { - span.SetStatus(tracing.SpanStatusError) - } - - return result, metadata, err + span.SetProperty("error", err != nil) + if err == nil { + span.SetStatus(tracing.SpanStatusOK) + } else { + span.SetStatus(tracing.SpanStatusError) + } + }() + res := <-results + return res.Output, res.Metadata, res.Error } type operationInputKey struct{} diff --git a/service/bedrockruntime/api_op_InvokeModelWithBidirectionalStream.go b/service/bedrockruntime/api_op_InvokeModelWithBidirectionalStream.go new file mode 100644 index 00000000000..98111002412 --- /dev/null +++ b/service/bedrockruntime/api_op_InvokeModelWithBidirectionalStream.go @@ -0,0 +1,319 @@ +// Code generated by smithy-go-codegen DO NOT EDIT. + +package bedrockruntime + +import ( + "context" + "fmt" + awsmiddleware "github.com/aws/aws-sdk-go-v2/aws/middleware" + "github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream/eventstreamapi" + "github.com/aws/aws-sdk-go-v2/service/bedrockruntime/types" + "github.com/aws/smithy-go/middleware" + smithysync "github.com/aws/smithy-go/sync" + smithyhttp "github.com/aws/smithy-go/transport/http" + "sync" + "time" +) + +// Invoke the specified Amazon Bedrock model to run inference using the +// bidirectional stream. The response is returned in a stream that remains open for +// 8 minutes. A single session can contain multiple prompts and responses from the +// model. The prompts to the model are provided as audio files and the model's +// responses are spoken back to the user and transcribed. +// +// It is possible for users to interrupt the model's response with a new prompt, +// which will halt the response speech. The model will retain contextual awareness +// of the conversation while pivoting to respond to the new prompt. +func (c *Client) InvokeModelWithBidirectionalStream(ctx context.Context, params *InvokeModelWithBidirectionalStreamInput, optFns ...func(*Options)) (*InvokeModelWithBidirectionalStreamOutput, error) { + if params == nil { + params = &InvokeModelWithBidirectionalStreamInput{} + } + + result, metadata, err := c.invokeOperation(ctx, "InvokeModelWithBidirectionalStream", params, optFns, c.addOperationInvokeModelWithBidirectionalStreamMiddlewares) + if err != nil { + return nil, err + } + + out := result.(*InvokeModelWithBidirectionalStreamOutput) + out.ResultMetadata = metadata + return out, nil +} + +type InvokeModelWithBidirectionalStreamInput struct { + + // The model ID or ARN of the model ID to use. Currently, only + // amazon.nova-sonic-v1:0 is supported. + // + // This member is required. + ModelId *string + + noSmithyDocumentSerde +} + +type InvokeModelWithBidirectionalStreamOutput struct { + eventStream *InvokeModelWithBidirectionalStreamEventStream + + // Metadata pertaining to the operation's result. + ResultMetadata middleware.Metadata + + noSmithyDocumentSerde +} + +// GetStream returns the type to interact with the event stream. +func (o *InvokeModelWithBidirectionalStreamOutput) GetStream() *InvokeModelWithBidirectionalStreamEventStream { + return o.eventStream +} + +func (c *Client) addOperationInvokeModelWithBidirectionalStreamMiddlewares(stack *middleware.Stack, options Options) (err error) { + if err := stack.Serialize.Add(&setOperationInputMiddleware{}, middleware.After); err != nil { + return err + } + err = stack.Serialize.Add(&awsRestjson1_serializeOpInvokeModelWithBidirectionalStream{}, middleware.After) + if err != nil { + return err + } + err = stack.Deserialize.Add(&awsRestjson1_deserializeOpInvokeModelWithBidirectionalStream{}, middleware.After) + if err != nil { + return err + } + if err := addProtocolFinalizerMiddlewares(stack, options, "InvokeModelWithBidirectionalStream"); err != nil { + return fmt.Errorf("add protocol finalizers: %v", err) + } + + if err = addlegacyEndpointContextSetter(stack, options); err != nil { + return err + } + if err = addEventStreamInvokeModelWithBidirectionalStreamMiddleware(stack, options); err != nil { + return err + } + if err = smithyhttp.AddRequireMinimumProtocol(stack, 2, 0); err != nil { + return err + } + if err = addSetLoggerMiddleware(stack, options); err != nil { + return err + } + if err = addClientRequestID(stack); err != nil { + return err + } + if err = addResolveEndpointMiddleware(stack, options); err != nil { + return err + } + if err = addStreamingEventsPayload(stack); err != nil { + return err + } + if err = addContentSHA256Header(stack); err != nil { + return err + } + if err = addRetry(stack, options); err != nil { + return err + } + if err = addRawResponseToMetadata(stack); err != nil { + return err + } + if err = addRecordResponseTiming(stack); err != nil { + return err + } + if err = addSpanRetryLoop(stack, options); err != nil { + return err + } + if err = addClientUserAgent(stack, options); err != nil { + return err + } + if err = eventstreamapi.AddInitializeStreamWriter(stack); err != nil { + return err + } + if err = addSetLegacyContextSigningOptionsMiddleware(stack); err != nil { + return err + } + if err = addTimeOffsetBuild(stack, c); err != nil { + return err + } + if err = addUserAgentRetryMode(stack, options); err != nil { + return err + } + if err = addCredentialSource(stack, options); err != nil { + return err + } + if err = addOpInvokeModelWithBidirectionalStreamValidationMiddleware(stack); err != nil { + return err + } + if err = stack.Initialize.Add(newServiceMetadataMiddleware_opInvokeModelWithBidirectionalStream(options.Region), middleware.Before); err != nil { + return err + } + if err = addRecursionDetection(stack); err != nil { + return err + } + if err = addRequestIDRetrieverMiddleware(stack); err != nil { + return err + } + if err = addResponseErrorMiddleware(stack); err != nil { + return err + } + if err = addRequestResponseLogging(stack, options); err != nil { + return err + } + if err = addDisableHTTPSMiddleware(stack, options); err != nil { + return err + } + if err = addInterceptBeforeRetryLoop(stack, options); err != nil { + return err + } + if err = addInterceptAttempt(stack, options); err != nil { + return err + } + if err = addInterceptors(stack, options); err != nil { + return err + } + return nil +} + +func newServiceMetadataMiddleware_opInvokeModelWithBidirectionalStream(region string) *awsmiddleware.RegisterServiceMetadata { + return &awsmiddleware.RegisterServiceMetadata{ + Region: region, + ServiceID: ServiceID, + OperationName: "InvokeModelWithBidirectionalStream", + } +} + +// InvokeModelWithBidirectionalStreamEventStream provides the event stream handling for the InvokeModelWithBidirectionalStream operation. +// +// For testing and mocking the event stream this type should be initialized via +// the NewInvokeModelWithBidirectionalStreamEventStream constructor function. Using the functional options +// to pass in nested mock behavior. +type InvokeModelWithBidirectionalStreamEventStream struct { + // InvokeModelWithBidirectionalStreamInputWriter is the EventStream writer for the + // InvokeModelWithBidirectionalStreamInput events. This value is automatically set + // by the SDK when the API call is made Use this member when unit testing your code + // with the SDK to mock out the EventStream Writer. + // + // Must not be nil. + Writer InvokeModelWithBidirectionalStreamInputWriter + + // InvokeModelWithBidirectionalStreamOutputReader is the EventStream reader for + // the InvokeModelWithBidirectionalStreamOutput events. This value is automatically + // set by the SDK when the API call is made Use this member when unit testing your + // code with the SDK to mock out the EventStream Reader. + // + // Must not be nil. + Reader InvokeModelWithBidirectionalStreamOutputReader + + done chan struct{} + closeOnce sync.Once + err *smithysync.OnceErr +} + +// NewInvokeModelWithBidirectionalStreamEventStream initializes an InvokeModelWithBidirectionalStreamEventStream. +// This function should only be used for testing and mocking the InvokeModelWithBidirectionalStreamEventStream +// stream within your application. +// +// The Writer member must be set before writing events to the stream. +// +// The Reader member must be set before reading events from the stream. +func NewInvokeModelWithBidirectionalStreamEventStream(optFns ...func(*InvokeModelWithBidirectionalStreamEventStream)) *InvokeModelWithBidirectionalStreamEventStream { + es := &InvokeModelWithBidirectionalStreamEventStream{ + done: make(chan struct{}), + err: smithysync.NewOnceErr(), + } + for _, fn := range optFns { + fn(es) + } + return es +} + +// Send writes the event to the stream blocking until the event is written. +// Returns an error if the event was not written. +func (es *InvokeModelWithBidirectionalStreamEventStream) Send(ctx context.Context, event types.InvokeModelWithBidirectionalStreamInput) error { + return es.Writer.Send(ctx, event) +} + +// Events returns a channel to read events from. +func (es *InvokeModelWithBidirectionalStreamEventStream) Events() <-chan types.InvokeModelWithBidirectionalStreamOutput { + return es.Reader.Events() +} + +// Close closes the stream. This will also cause the stream to be closed. +// Close must be called when done using the stream API. Not calling Close +// may result in resource leaks. +// +// Will close the underlying EventStream writer and reader, and no more events can be +// sent or received. +func (es *InvokeModelWithBidirectionalStreamEventStream) Close() error { + es.closeOnce.Do(es.safeClose) + return es.Err() +} + +func (es *InvokeModelWithBidirectionalStreamEventStream) safeClose() { + close(es.done) + + t := time.NewTicker(time.Second) + defer t.Stop() + writeCloseDone := make(chan error) + go func() { + if err := es.Writer.Close(); err != nil { + es.err.SetError(err) + } + close(writeCloseDone) + }() + select { + case <-t.C: + case <-writeCloseDone: + } + + es.Reader.Close() +} + +// Err returns any error that occurred while reading or writing EventStream Events +// from the service API's response. Returns nil if there were no errors. +func (es *InvokeModelWithBidirectionalStreamEventStream) Err() error { + if err := es.err.Err(); err != nil { + return err + } + + if err := es.Writer.Err(); err != nil { + return err + } + + if err := es.Reader.Err(); err != nil { + return err + } + + return nil +} + +func (es *InvokeModelWithBidirectionalStreamEventStream) waitStreamClose() { + type errorSet interface { + ErrorSet() <-chan struct{} + } + + var inputErrCh <-chan struct{} + if v, ok := es.Writer.(errorSet); ok { + inputErrCh = v.ErrorSet() + } + + var outputErrCh <-chan struct{} + if v, ok := es.Reader.(errorSet); ok { + outputErrCh = v.ErrorSet() + } + var outputClosedCh <-chan struct{} + if v, ok := es.Reader.(interface{ Closed() <-chan struct{} }); ok { + outputClosedCh = v.Closed() + } + + select { + case <-es.done: + case <-inputErrCh: + es.err.SetError(es.Writer.Err()) + es.Close() + + case <-outputErrCh: + es.err.SetError(es.Reader.Err()) + es.Close() + + case <-outputClosedCh: + if err := es.Reader.Err(); err != nil { + es.err.SetError(es.Reader.Err()) + } + es.Close() + + } +} diff --git a/service/bedrockruntime/deserializers.go b/service/bedrockruntime/deserializers.go index 5eafbd26f5c..83d876c3552 100644 --- a/service/bedrockruntime/deserializers.go +++ b/service/bedrockruntime/deserializers.go @@ -1145,6 +1145,124 @@ func awsRestjson1_deserializeOpDocumentInvokeModelOutput(v *InvokeModelOutput, b return nil } +type awsRestjson1_deserializeOpInvokeModelWithBidirectionalStream struct { +} + +func (*awsRestjson1_deserializeOpInvokeModelWithBidirectionalStream) ID() string { + return "OperationDeserializer" +} + +func (m *awsRestjson1_deserializeOpInvokeModelWithBidirectionalStream) HandleDeserialize(ctx context.Context, in middleware.DeserializeInput, next middleware.DeserializeHandler) ( + out middleware.DeserializeOutput, metadata middleware.Metadata, err error, +) { + out, metadata, err = next.HandleDeserialize(ctx, in) + if err != nil { + return out, metadata, err + } + + _, span := tracing.StartSpan(ctx, "OperationDeserializer") + endTimer := startMetricTimer(ctx, "client.call.deserialization_duration") + defer endTimer() + defer span.End() + response, ok := out.RawResponse.(*smithyhttp.Response) + if !ok { + return out, metadata, &smithy.DeserializationError{Err: fmt.Errorf("unknown transport type %T", out.RawResponse)} + } + + if response.StatusCode < 200 || response.StatusCode >= 300 { + return out, metadata, awsRestjson1_deserializeOpErrorInvokeModelWithBidirectionalStream(response, &metadata) + } + output := &InvokeModelWithBidirectionalStreamOutput{} + out.Result = output + + span.End() + return out, metadata, err +} + +func awsRestjson1_deserializeOpErrorInvokeModelWithBidirectionalStream(response *smithyhttp.Response, metadata *middleware.Metadata) error { + var errorBuffer bytes.Buffer + if _, err := io.Copy(&errorBuffer, response.Body); err != nil { + return &smithy.DeserializationError{Err: fmt.Errorf("failed to copy error response body, %w", err)} + } + errorBody := bytes.NewReader(errorBuffer.Bytes()) + + errorCode := "UnknownError" + errorMessage := errorCode + + headerCode := response.Header.Get("X-Amzn-ErrorType") + if len(headerCode) != 0 { + errorCode = restjson.SanitizeErrorCode(headerCode) + } + + var buff [1024]byte + ringBuffer := smithyio.NewRingBuffer(buff[:]) + + body := io.TeeReader(errorBody, ringBuffer) + decoder := json.NewDecoder(body) + decoder.UseNumber() + jsonCode, message, err := restjson.GetErrorInfo(decoder) + if err != nil { + var snapshot bytes.Buffer + io.Copy(&snapshot, ringBuffer) + err = &smithy.DeserializationError{ + Err: fmt.Errorf("failed to decode response body, %w", err), + Snapshot: snapshot.Bytes(), + } + return err + } + + errorBody.Seek(0, io.SeekStart) + if len(headerCode) == 0 && len(jsonCode) != 0 { + errorCode = restjson.SanitizeErrorCode(jsonCode) + } + if len(message) != 0 { + errorMessage = message + } + + switch { + case strings.EqualFold("AccessDeniedException", errorCode): + return awsRestjson1_deserializeErrorAccessDeniedException(response, errorBody) + + case strings.EqualFold("InternalServerException", errorCode): + return awsRestjson1_deserializeErrorInternalServerException(response, errorBody) + + case strings.EqualFold("ModelErrorException", errorCode): + return awsRestjson1_deserializeErrorModelErrorException(response, errorBody) + + case strings.EqualFold("ModelNotReadyException", errorCode): + return awsRestjson1_deserializeErrorModelNotReadyException(response, errorBody) + + case strings.EqualFold("ModelStreamErrorException", errorCode): + return awsRestjson1_deserializeErrorModelStreamErrorException(response, errorBody) + + case strings.EqualFold("ModelTimeoutException", errorCode): + return awsRestjson1_deserializeErrorModelTimeoutException(response, errorBody) + + case strings.EqualFold("ResourceNotFoundException", errorCode): + return awsRestjson1_deserializeErrorResourceNotFoundException(response, errorBody) + + case strings.EqualFold("ServiceQuotaExceededException", errorCode): + return awsRestjson1_deserializeErrorServiceQuotaExceededException(response, errorBody) + + case strings.EqualFold("ServiceUnavailableException", errorCode): + return awsRestjson1_deserializeErrorServiceUnavailableException(response, errorBody) + + case strings.EqualFold("ThrottlingException", errorCode): + return awsRestjson1_deserializeErrorThrottlingException(response, errorBody) + + case strings.EqualFold("ValidationException", errorCode): + return awsRestjson1_deserializeErrorValidationException(response, errorBody) + + default: + genericError := &smithy.GenericAPIError{ + Code: errorCode, + Message: errorMessage, + } + return genericError + + } +} + type awsRestjson1_deserializeOpInvokeModelWithResponseStream struct { } @@ -1638,7 +1756,7 @@ func awsRestjson1_deserializeOpDocumentStartAsyncInvokeOutput(v **StartAsyncInvo return nil } -func awsRestjson1_deserializeEventStreamResponseStream(v *types.ResponseStream, msg *eventstream.Message) error { +func awsRestjson1_deserializeEventStreamInvokeModelWithBidirectionalStreamOutput(v *types.InvokeModelWithBidirectionalStreamOutput, msg *eventstream.Message) error { if v == nil { return fmt.Errorf("unexpected serialization of nil %T", v) } @@ -1650,8 +1768,8 @@ func awsRestjson1_deserializeEventStreamResponseStream(v *types.ResponseStream, switch { case strings.EqualFold("chunk", eventType.String()): - vv := &types.ResponseStreamMemberChunk{} - if err := awsRestjson1_deserializeEventMessagePayloadPart(&vv.Value, msg); err != nil { + vv := &types.InvokeModelWithBidirectionalStreamOutputMemberChunk{} + if err := awsRestjson1_deserializeEventMessageBidirectionalOutputPayloadPart(&vv.Value, msg); err != nil { return err } *v = vv @@ -1669,7 +1787,7 @@ func awsRestjson1_deserializeEventStreamResponseStream(v *types.ResponseStream, } } -func awsRestjson1_deserializeEventStreamExceptionResponseStream(msg *eventstream.Message) error { +func awsRestjson1_deserializeEventStreamExceptionInvokeModelWithBidirectionalStreamOutput(msg *eventstream.Message) error { exceptionType := msg.Headers.Get(eventstreamapi.ExceptionTypeHeader) if exceptionType == nil { return fmt.Errorf("%s event header not present", eventstreamapi.ExceptionTypeHeader) @@ -1724,7 +1842,7 @@ func awsRestjson1_deserializeEventStreamExceptionResponseStream(msg *eventstream } } -func awsRestjson1_deserializeEventMessagePayloadPart(v *types.PayloadPart, msg *eventstream.Message) error { +func awsRestjson1_deserializeEventMessageBidirectionalOutputPayloadPart(v *types.BidirectionalOutputPayloadPart, msg *eventstream.Message) error { if v == nil { return fmt.Errorf("unexpected serialization of nil %T", v) } @@ -1747,7 +1865,7 @@ func awsRestjson1_deserializeEventMessagePayloadPart(v *types.PayloadPart, msg * return err } - if err := awsRestjson1_deserializeDocumentPayloadPart(&v, shape); err != nil { + if err := awsRestjson1_deserializeDocumentBidirectionalOutputPayloadPart(&v, shape); err != nil { if err != nil { var snapshot bytes.Buffer io.Copy(&snapshot, ringBuffer) @@ -1972,6 +2090,50 @@ func awsRestjson1_deserializeEventMessageExceptionServiceUnavailableException(ms return v } +func awsRestjson1_deserializeDocumentBidirectionalOutputPayloadPart(v **types.BidirectionalOutputPayloadPart, value interface{}) error { + if v == nil { + return fmt.Errorf("unexpected nil of type %T", v) + } + if value == nil { + return nil + } + + shape, ok := value.(map[string]interface{}) + if !ok { + return fmt.Errorf("unexpected JSON type %v", value) + } + + var sv *types.BidirectionalOutputPayloadPart + if *v == nil { + sv = &types.BidirectionalOutputPayloadPart{} + } else { + sv = *v + } + + for key, value := range shape { + switch key { + case "bytes": + if value != nil { + jtv, ok := value.(string) + if !ok { + return fmt.Errorf("expected PartBody to be []byte, got %T instead", value) + } + dv, err := base64.StdEncoding.DecodeString(jtv) + if err != nil { + return fmt.Errorf("failed to base64 decode PartBody, %w", err) + } + sv.Bytes = dv + } + + default: + _, _ = key, value + + } + } + *v = sv + return nil +} + func awsRestjson1_deserializeDocumentInternalServerException(v **types.InternalServerException, value interface{}) error { if v == nil { return fmt.Errorf("unexpected nil of type %T", v) @@ -2114,7 +2276,7 @@ func awsRestjson1_deserializeDocumentModelTimeoutException(v **types.ModelTimeou return nil } -func awsRestjson1_deserializeDocumentPayloadPart(v **types.PayloadPart, value interface{}) error { +func awsRestjson1_deserializeDocumentServiceUnavailableException(v **types.ServiceUnavailableException, value interface{}) error { if v == nil { return fmt.Errorf("unexpected nil of type %T", v) } @@ -2127,26 +2289,22 @@ func awsRestjson1_deserializeDocumentPayloadPart(v **types.PayloadPart, value in return fmt.Errorf("unexpected JSON type %v", value) } - var sv *types.PayloadPart + var sv *types.ServiceUnavailableException if *v == nil { - sv = &types.PayloadPart{} + sv = &types.ServiceUnavailableException{} } else { sv = *v } for key, value := range shape { switch key { - case "bytes": + case "message", "Message": if value != nil { jtv, ok := value.(string) if !ok { - return fmt.Errorf("expected PartBody to be []byte, got %T instead", value) - } - dv, err := base64.StdEncoding.DecodeString(jtv) - if err != nil { - return fmt.Errorf("failed to base64 decode PartBody, %w", err) + return fmt.Errorf("expected NonBlankString to be of type string, got %T instead", value) } - sv.Bytes = dv + sv.Message = ptr.String(jtv) } default: @@ -2158,7 +2316,7 @@ func awsRestjson1_deserializeDocumentPayloadPart(v **types.PayloadPart, value in return nil } -func awsRestjson1_deserializeDocumentServiceUnavailableException(v **types.ServiceUnavailableException, value interface{}) error { +func awsRestjson1_deserializeDocumentThrottlingException(v **types.ThrottlingException, value interface{}) error { if v == nil { return fmt.Errorf("unexpected nil of type %T", v) } @@ -2171,9 +2329,9 @@ func awsRestjson1_deserializeDocumentServiceUnavailableException(v **types.Servi return fmt.Errorf("unexpected JSON type %v", value) } - var sv *types.ServiceUnavailableException + var sv *types.ThrottlingException if *v == nil { - sv = &types.ServiceUnavailableException{} + sv = &types.ThrottlingException{} } else { sv = *v } @@ -2198,7 +2356,7 @@ func awsRestjson1_deserializeDocumentServiceUnavailableException(v **types.Servi return nil } -func awsRestjson1_deserializeDocumentThrottlingException(v **types.ThrottlingException, value interface{}) error { +func awsRestjson1_deserializeDocumentValidationException(v **types.ValidationException, value interface{}) error { if v == nil { return fmt.Errorf("unexpected nil of type %T", v) } @@ -2211,9 +2369,9 @@ func awsRestjson1_deserializeDocumentThrottlingException(v **types.ThrottlingExc return fmt.Errorf("unexpected JSON type %v", value) } - var sv *types.ThrottlingException + var sv *types.ValidationException if *v == nil { - sv = &types.ThrottlingException{} + sv = &types.ValidationException{} } else { sv = *v } @@ -2238,7 +2396,131 @@ func awsRestjson1_deserializeDocumentThrottlingException(v **types.ThrottlingExc return nil } -func awsRestjson1_deserializeDocumentValidationException(v **types.ValidationException, value interface{}) error { +func awsRestjson1_deserializeEventStreamResponseStream(v *types.ResponseStream, msg *eventstream.Message) error { + if v == nil { + return fmt.Errorf("unexpected serialization of nil %T", v) + } + + eventType := msg.Headers.Get(eventstreamapi.EventTypeHeader) + if eventType == nil { + return fmt.Errorf("%s event header not present", eventstreamapi.EventTypeHeader) + } + + switch { + case strings.EqualFold("chunk", eventType.String()): + vv := &types.ResponseStreamMemberChunk{} + if err := awsRestjson1_deserializeEventMessagePayloadPart(&vv.Value, msg); err != nil { + return err + } + *v = vv + return nil + + default: + buffer := bytes.NewBuffer(nil) + eventstream.NewEncoder().Encode(buffer, *msg) + *v = &types.UnknownUnionMember{ + Tag: eventType.String(), + Value: buffer.Bytes(), + } + return nil + + } +} + +func awsRestjson1_deserializeEventStreamExceptionResponseStream(msg *eventstream.Message) error { + exceptionType := msg.Headers.Get(eventstreamapi.ExceptionTypeHeader) + if exceptionType == nil { + return fmt.Errorf("%s event header not present", eventstreamapi.ExceptionTypeHeader) + } + + switch { + case strings.EqualFold("internalServerException", exceptionType.String()): + return awsRestjson1_deserializeEventMessageExceptionInternalServerException(msg) + + case strings.EqualFold("modelStreamErrorException", exceptionType.String()): + return awsRestjson1_deserializeEventMessageExceptionModelStreamErrorException(msg) + + case strings.EqualFold("modelTimeoutException", exceptionType.String()): + return awsRestjson1_deserializeEventMessageExceptionModelTimeoutException(msg) + + case strings.EqualFold("serviceUnavailableException", exceptionType.String()): + return awsRestjson1_deserializeEventMessageExceptionServiceUnavailableException(msg) + + case strings.EqualFold("throttlingException", exceptionType.String()): + return awsRestjson1_deserializeEventMessageExceptionThrottlingException(msg) + + case strings.EqualFold("validationException", exceptionType.String()): + return awsRestjson1_deserializeEventMessageExceptionValidationException(msg) + + default: + br := bytes.NewReader(msg.Payload) + var buff [1024]byte + ringBuffer := smithyio.NewRingBuffer(buff[:]) + + body := io.TeeReader(br, ringBuffer) + decoder := json.NewDecoder(body) + decoder.UseNumber() + code, message, err := restjson.GetErrorInfo(decoder) + if err != nil { + return err + } + errorCode := "UnknownError" + errorMessage := errorCode + if ev := exceptionType.String(); len(ev) > 0 { + errorCode = ev + } else if ev := code; len(ev) > 0 { + errorCode = ev + } + if ev := message; len(ev) > 0 { + errorMessage = ev + } + return &smithy.GenericAPIError{ + Code: errorCode, + Message: errorMessage, + } + + } +} + +func awsRestjson1_deserializeEventMessagePayloadPart(v *types.PayloadPart, msg *eventstream.Message) error { + if v == nil { + return fmt.Errorf("unexpected serialization of nil %T", v) + } + + br := bytes.NewReader(msg.Payload) + var buff [1024]byte + ringBuffer := smithyio.NewRingBuffer(buff[:]) + + body := io.TeeReader(br, ringBuffer) + decoder := json.NewDecoder(body) + decoder.UseNumber() + var shape interface{} + if err := decoder.Decode(&shape); err != nil && err != io.EOF { + var snapshot bytes.Buffer + io.Copy(&snapshot, ringBuffer) + err = &smithy.DeserializationError{ + Err: fmt.Errorf("failed to decode response body, %w", err), + Snapshot: snapshot.Bytes(), + } + return err + } + + if err := awsRestjson1_deserializeDocumentPayloadPart(&v, shape); err != nil { + if err != nil { + var snapshot bytes.Buffer + io.Copy(&snapshot, ringBuffer) + err = &smithy.DeserializationError{ + Err: fmt.Errorf("failed to decode response body, %w", err), + Snapshot: snapshot.Bytes(), + } + return err + } + + } + return nil +} + +func awsRestjson1_deserializeDocumentPayloadPart(v **types.PayloadPart, value interface{}) error { if v == nil { return fmt.Errorf("unexpected nil of type %T", v) } @@ -2251,22 +2533,26 @@ func awsRestjson1_deserializeDocumentValidationException(v **types.ValidationExc return fmt.Errorf("unexpected JSON type %v", value) } - var sv *types.ValidationException + var sv *types.PayloadPart if *v == nil { - sv = &types.ValidationException{} + sv = &types.PayloadPart{} } else { sv = *v } for key, value := range shape { switch key { - case "message", "Message": + case "bytes": if value != nil { jtv, ok := value.(string) if !ok { - return fmt.Errorf("expected NonBlankString to be of type string, got %T instead", value) + return fmt.Errorf("expected PartBody to be []byte, got %T instead", value) } - sv.Message = ptr.String(jtv) + dv, err := base64.StdEncoding.DecodeString(jtv) + if err != nil { + return fmt.Errorf("failed to base64 decode PartBody, %w", err) + } + sv.Bytes = dv } default: diff --git a/service/bedrockruntime/eventstream.go b/service/bedrockruntime/eventstream.go index 56f1ffac22c..480254b5f2e 100644 --- a/service/bedrockruntime/eventstream.go +++ b/service/bedrockruntime/eventstream.go @@ -3,11 +3,14 @@ package bedrockruntime import ( + "bytes" "context" "fmt" "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream" "github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream/eventstreamapi" + "github.com/aws/aws-sdk-go-v2/aws/signer/v4" + internalauthsmithy "github.com/aws/aws-sdk-go-v2/internal/auth/smithy" "github.com/aws/aws-sdk-go-v2/service/bedrockruntime/types" smithy "github.com/aws/smithy-go" "github.com/aws/smithy-go/middleware" @@ -16,8 +19,19 @@ import ( "io" "io/ioutil" "sync" + "time" ) +// InvokeModelWithBidirectionalStreamInputWriter provides the interface for +// writing events to a stream. +// +// The writer's Close method must allow multiple concurrent calls. +type InvokeModelWithBidirectionalStreamInputWriter interface { + Send(context.Context, types.InvokeModelWithBidirectionalStreamInput) error + Close() error + Err() error +} + // ConverseStreamOutputReader provides the interface for reading events from a // stream. // @@ -28,6 +42,16 @@ type ConverseStreamOutputReader interface { Err() error } +// InvokeModelWithBidirectionalStreamOutputReader provides the interface for +// reading events from a stream. +// +// The writer's Close method must allow multiple concurrent calls. +type InvokeModelWithBidirectionalStreamOutputReader interface { + Events() <-chan types.InvokeModelWithBidirectionalStreamOutput + Close() error + Err() error +} + // ResponseStreamReader provides the interface for reading events from a stream. // // The writer's Close method must allow multiple concurrent calls. @@ -37,6 +61,345 @@ type ResponseStreamReader interface { Err() error } +type eventStreamSigner interface { + GetSignature(ctx context.Context, headers, payload []byte, signingTime time.Time, optFns ...func(*v4.StreamSignerOptions)) ([]byte, error) +} + +type asyncInvokeModelWithBidirectionalStreamInput struct { + Event types.InvokeModelWithBidirectionalStreamInput + Result chan<- error +} + +func (e asyncInvokeModelWithBidirectionalStreamInput) ReportResult(cancel <-chan struct{}, err error) bool { + select { + case e.Result <- err: + return true + + case <-cancel: + return false + + } +} + +type invokeModelWithBidirectionalStreamInputWriter struct { + encoder *eventstream.Encoder + signer eventStreamSigner + stream chan asyncInvokeModelWithBidirectionalStreamInput + serializationBuffer *bytes.Buffer + signingBuffer *bytes.Buffer + eventStream io.WriteCloser + done chan struct{} + closeOnce sync.Once + err *smithysync.OnceErr +} + +func newInvokeModelWithBidirectionalStreamInputWriter(stream io.WriteCloser, encoder *eventstream.Encoder, signer eventStreamSigner) *invokeModelWithBidirectionalStreamInputWriter { + w := &invokeModelWithBidirectionalStreamInputWriter{ + encoder: encoder, + signer: signer, + stream: make(chan asyncInvokeModelWithBidirectionalStreamInput), + eventStream: stream, + done: make(chan struct{}), + err: smithysync.NewOnceErr(), + serializationBuffer: bytes.NewBuffer(nil), + signingBuffer: bytes.NewBuffer(nil), + } + + go w.writeStream() + + return w + +} + +func (w *invokeModelWithBidirectionalStreamInputWriter) Send(ctx context.Context, event types.InvokeModelWithBidirectionalStreamInput) error { + return w.send(ctx, event) +} + +func (w *invokeModelWithBidirectionalStreamInputWriter) send(ctx context.Context, event types.InvokeModelWithBidirectionalStreamInput) error { + if err := w.err.Err(); err != nil { + return err + } + + resultCh := make(chan error) + + wrapped := asyncInvokeModelWithBidirectionalStreamInput{ + Event: event, + Result: resultCh, + } + + select { + case w.stream <- wrapped: + case <-ctx.Done(): + return ctx.Err() + case <-w.done: + return fmt.Errorf("stream closed, unable to send event") + + } + + select { + case err := <-resultCh: + return err + case <-ctx.Done(): + return ctx.Err() + case <-w.done: + return fmt.Errorf("stream closed, unable to send event") + + } + +} + +func (w *invokeModelWithBidirectionalStreamInputWriter) writeStream() { + defer w.Close() + + for { + select { + case wrapper := <-w.stream: + err := w.writeEvent(wrapper.Event) + wrapper.ReportResult(w.done, err) + if err != nil { + w.err.SetError(err) + return + } + + case <-w.done: + if err := w.closeStream(); err != nil { + w.err.SetError(err) + } + return + + } + } +} + +func (w *invokeModelWithBidirectionalStreamInputWriter) writeEvent(event types.InvokeModelWithBidirectionalStreamInput) error { + // serializedEvent returned bytes refers to an underlying byte buffer and must not + // escape this writeEvent scope without first copying. Any previous bytes stored in + // the buffer are cleared by this call. + serializedEvent, err := w.serializeEvent(event) + if err != nil { + return err + } + + // signedEvent returned bytes refers to an underlying byte buffer and must not + // escape this writeEvent scope without first copying. Any previous bytes stored in + // the buffer are cleared by this call. + signedEvent, err := w.signEvent(serializedEvent) + if err != nil { + return err + } + + // bytes are now copied to the underlying stream writer + _, err = io.Copy(w.eventStream, bytes.NewReader(signedEvent)) + return err +} + +func (w *invokeModelWithBidirectionalStreamInputWriter) serializeEvent(event types.InvokeModelWithBidirectionalStreamInput) ([]byte, error) { + w.serializationBuffer.Reset() + + eventMessage := eventstream.Message{} + + if err := awsRestjson1_serializeEventStreamInvokeModelWithBidirectionalStreamInput(event, &eventMessage); err != nil { + return nil, err + } + + if err := w.encoder.Encode(w.serializationBuffer, eventMessage); err != nil { + return nil, err + } + + return w.serializationBuffer.Bytes(), nil +} + +func (w *invokeModelWithBidirectionalStreamInputWriter) signEvent(payload []byte) ([]byte, error) { + w.signingBuffer.Reset() + + date := time.Now().UTC() + + var msg eventstream.Message + msg.Headers.Set(eventstreamapi.DateHeader, eventstream.TimestampValue(date)) + msg.Payload = payload + + var headers bytes.Buffer + if err := eventstream.EncodeHeaders(&headers, msg.Headers); err != nil { + return nil, err + } + + sig, err := w.signer.GetSignature(context.Background(), headers.Bytes(), msg.Payload, date) + if err != nil { + return nil, err + } + + msg.Headers.Set(eventstreamapi.ChunkSignatureHeader, eventstream.BytesValue(sig)) + + if err := w.encoder.Encode(w.signingBuffer, msg); err != nil { + return nil, err + } + + return w.signingBuffer.Bytes(), nil +} + +func (w *invokeModelWithBidirectionalStreamInputWriter) closeStream() (err error) { + defer func() { + if cErr := w.eventStream.Close(); cErr != nil && err == nil { + err = cErr + } + }() + + // Per the protocol, a signed empty message is used to indicate the end of the stream, + // and that no subsequent events will be sent. + signedEvent, err := w.signEvent([]byte{}) + if err != nil { + return err + } + + _, err = io.Copy(w.eventStream, bytes.NewReader(signedEvent)) + return err +} + +func (w *invokeModelWithBidirectionalStreamInputWriter) ErrorSet() <-chan struct{} { + return w.err.ErrorSet() +} + +func (w *invokeModelWithBidirectionalStreamInputWriter) Close() error { + w.closeOnce.Do(w.safeClose) + return w.Err() +} + +func (w *invokeModelWithBidirectionalStreamInputWriter) safeClose() { + close(w.done) +} + +func (w *invokeModelWithBidirectionalStreamInputWriter) Err() error { + return w.err.Err() +} + +type invokeModelWithBidirectionalStreamOutputReader struct { + stream chan types.InvokeModelWithBidirectionalStreamOutput + decoder *eventstream.Decoder + eventStream io.ReadCloser + err *smithysync.OnceErr + payloadBuf []byte + done chan struct{} + closeOnce sync.Once +} + +func newInvokeModelWithBidirectionalStreamOutputReader(readCloser io.ReadCloser, decoder *eventstream.Decoder) *invokeModelWithBidirectionalStreamOutputReader { + w := &invokeModelWithBidirectionalStreamOutputReader{ + stream: make(chan types.InvokeModelWithBidirectionalStreamOutput), + decoder: decoder, + eventStream: readCloser, + err: smithysync.NewOnceErr(), + done: make(chan struct{}), + payloadBuf: make([]byte, 10*1024), + } + + go w.readEventStream() + + return w +} + +func (r *invokeModelWithBidirectionalStreamOutputReader) Events() <-chan types.InvokeModelWithBidirectionalStreamOutput { + return r.stream +} + +func (r *invokeModelWithBidirectionalStreamOutputReader) readEventStream() { + defer r.Close() + defer close(r.stream) + + for { + r.payloadBuf = r.payloadBuf[0:0] + decodedMessage, err := r.decoder.Decode(r.eventStream, r.payloadBuf) + if err != nil { + if err == io.EOF { + return + } + select { + case <-r.done: + return + default: + r.err.SetError(err) + return + } + } + + event, err := r.deserializeEventMessage(&decodedMessage) + if err != nil { + r.err.SetError(err) + return + } + + select { + case r.stream <- event: + case <-r.done: + return + } + + } +} + +func (r *invokeModelWithBidirectionalStreamOutputReader) deserializeEventMessage(msg *eventstream.Message) (types.InvokeModelWithBidirectionalStreamOutput, error) { + messageType := msg.Headers.Get(eventstreamapi.MessageTypeHeader) + if messageType == nil { + return nil, fmt.Errorf("%s event header not present", eventstreamapi.MessageTypeHeader) + } + + switch messageType.String() { + case eventstreamapi.EventMessageType: + var v types.InvokeModelWithBidirectionalStreamOutput + if err := awsRestjson1_deserializeEventStreamInvokeModelWithBidirectionalStreamOutput(&v, msg); err != nil { + return nil, err + } + return v, nil + + case eventstreamapi.ExceptionMessageType: + return nil, awsRestjson1_deserializeEventStreamExceptionInvokeModelWithBidirectionalStreamOutput(msg) + + case eventstreamapi.ErrorMessageType: + errorCode := "UnknownError" + errorMessage := errorCode + if header := msg.Headers.Get(eventstreamapi.ErrorCodeHeader); header != nil { + errorCode = header.String() + } + if header := msg.Headers.Get(eventstreamapi.ErrorMessageHeader); header != nil { + errorMessage = header.String() + } + return nil, &smithy.GenericAPIError{ + Code: errorCode, + Message: errorMessage, + } + + default: + mc := msg.Clone() + return nil, &UnknownEventMessageError{ + Type: messageType.String(), + Message: &mc, + } + + } +} + +func (r *invokeModelWithBidirectionalStreamOutputReader) ErrorSet() <-chan struct{} { + return r.err.ErrorSet() +} + +func (r *invokeModelWithBidirectionalStreamOutputReader) Close() error { + r.closeOnce.Do(r.safeClose) + return r.Err() +} + +func (r *invokeModelWithBidirectionalStreamOutputReader) safeClose() { + close(r.done) + r.eventStream.Close() + +} + +func (r *invokeModelWithBidirectionalStreamOutputReader) Err() error { + return r.err.Err() +} + +func (r *invokeModelWithBidirectionalStreamOutputReader) Closed() <-chan struct{} { + return r.done +} + type responseStreamReader struct { stream chan types.ResponseStream decoder *eventstream.Decoder @@ -381,6 +744,196 @@ func addEventStreamConverseStreamMiddleware(stack *middleware.Stack, options Opt } +type awsRestjson1_deserializeOpEventStreamInvokeModelWithBidirectionalStream struct { + LogEventStreamWrites bool + LogEventStreamReads bool +} + +func (*awsRestjson1_deserializeOpEventStreamInvokeModelWithBidirectionalStream) ID() string { + return "OperationEventStreamDeserializer" +} + +type deserializeResult struct { + reader io.ReadCloser + err error +} + +type asyncEventStreamReader struct { + pipeReader *io.PipeReader + pipeWriter *io.PipeWriter +} + +func newAsyncEventStreamReader(resultChan <-chan deserializeResult) *asyncEventStreamReader { + pipeReader, pipeWriter := io.Pipe() + + reader := &asyncEventStreamReader{ + pipeReader: pipeReader, + pipeWriter: pipeWriter, + } + + // Start background copying + go func() { + result := <-resultChan + if result.err != nil { + pipeWriter.CloseWithError(result.err) + return + } + + // Copy from real response body to pipe + _, err := io.Copy(pipeWriter, result.reader) + pipeWriter.CloseWithError(err) + }() + + return reader +} + +type PartialResult[T any] struct { + Output T + Metadata middleware.Metadata + Error error +} + +func (m *awsRestjson1_deserializeOpEventStreamInvokeModelWithBidirectionalStream) HandleDeserialize( + ctx context.Context, in middleware.DeserializeInput, next middleware.DeserializeHandler, +) (out middleware.DeserializeOutput, metadata middleware.Metadata, err error, +) { + defer func() { + if err == nil { + return + } + m.closeResponseBody(out) + }() + + logger := middleware.GetLogger(ctx) + + request, ok := in.Request.(*smithyhttp.Request) + if !ok { + return out, metadata, fmt.Errorf("unknown transport type: %T", in.Request) + } + _ = request + + if err := eventstreamapi.ApplyHTTPTransportFixes(request); err != nil { + return out, metadata, err + } + + requestSignature, err := v4.GetSignedRequestSignature(request.Request) + if err != nil { + return out, metadata, fmt.Errorf("failed to get event stream seed signature: %v", err) + } + + identity := getIdentity(ctx) + if identity == nil { + return out, metadata, fmt.Errorf("no identity") + } + + creds, ok := identity.(*internalauthsmithy.CredentialsAdapter) + if !ok { + return out, metadata, fmt.Errorf("identity is not sigv4 credentials") + } + + rscheme := getResolvedAuthScheme(ctx) + if rscheme == nil { + return out, metadata, fmt.Errorf("no resolved auth scheme") + } + + name, ok := smithyhttp.GetSigV4SigningName(&rscheme.SignerProperties) + if !ok { + return out, metadata, fmt.Errorf("no sigv4 signing name") + } + + region, ok := smithyhttp.GetSigV4SigningRegion(&rscheme.SignerProperties) + if !ok { + return out, metadata, fmt.Errorf("no sigv4 signing region") + } + signer := v4.NewStreamSigner(creds.Credentials, name, region, requestSignature) + + eventWriter := newInvokeModelWithBidirectionalStreamInputWriter( + eventstreamapi.GetInputStreamWriter(ctx), + eventstream.NewEncoder(func(options *eventstream.EncoderOptions) { + options.Logger = logger + options.LogMessages = m.LogEventStreamWrites + + }), + signer, + ) + defer func() { + if err == nil { + return + } + _ = eventWriter.Close() + }() + // Create async result channel + asyncResult := make(chan deserializeResult, 1) + asyncReader := newAsyncEventStreamReader(asyncResult) + eventReader := newInvokeModelWithBidirectionalStreamOutputReader( + asyncReader.pipeReader, + eventstream.NewDecoder(func(options *eventstream.DecoderOptions) { + options.Logger = logger + options.LogMessages = m.LogEventStreamReads + }), + ) + defer func() { + if err == nil { + return + } + _ = eventReader.Close() + }() + + output := &InvokeModelWithBidirectionalStreamOutput{} + output.eventStream = NewInvokeModelWithBidirectionalStreamEventStream(func(stream *InvokeModelWithBidirectionalStreamEventStream) { + stream.Writer = eventWriter + stream.Reader = eventReader + }) + + go output.eventStream.waitStreamClose() + + ch := ctx.Value("asyncChan") + if ch == nil { + panic("missing asyncChan") + } + + c, ok := ch.(chan PartialResult[*InvokeModelWithBidirectionalStreamOutput]) + if !ok { + panic("asyncChan was not a partialResult") + } + partial := PartialResult[*InvokeModelWithBidirectionalStreamOutput]{ + Output: output, + Metadata: middleware.Metadata{}, + Error: nil, + } + c <- partial + + out, metadata, err = next.HandleDeserialize(ctx, in) + + if err == nil { + // Extract actual response and create real reader + resp := out.RawResponse.(*smithyhttp.Response) + // TODO lmadrig this should have more than just the body + asyncResult <- deserializeResult{reader: resp.Body, err: nil} + } else { + asyncResult <- deserializeResult{reader: nil, err: err} + } + return out, metadata, err +} + +func (*awsRestjson1_deserializeOpEventStreamInvokeModelWithBidirectionalStream) closeResponseBody(out middleware.DeserializeOutput) { + if resp, ok := out.RawResponse.(*smithyhttp.Response); ok && resp != nil && resp.Body != nil { + _, _ = io.Copy(ioutil.Discard, resp.Body) + _ = resp.Body.Close() + } +} + +func addEventStreamInvokeModelWithBidirectionalStreamMiddleware(stack *middleware.Stack, options Options) error { + if err := stack.Deserialize.Insert(&awsRestjson1_deserializeOpEventStreamInvokeModelWithBidirectionalStream{ + LogEventStreamWrites: options.ClientLogMode.IsRequestEventMessage(), + LogEventStreamReads: options.ClientLogMode.IsResponseEventMessage(), + }, "OperationDeserializer", middleware.Before); err != nil { + return err + } + return nil + +} + type awsRestjson1_deserializeOpEventStreamInvokeModelWithResponseStream struct { LogEventStreamWrites bool LogEventStreamReads bool @@ -487,6 +1040,10 @@ func setSafeEventStreamClientLogMode(o *Options, operation string) { toggleEventStreamClientLogMode(o, false, true) return + case "InvokeModelWithBidirectionalStream": + toggleEventStreamClientLogMode(o, true, true) + return + case "InvokeModelWithResponseStream": toggleEventStreamClientLogMode(o, false, true) return diff --git a/service/bedrockruntime/generated.json b/service/bedrockruntime/generated.json index 58f35cc2f4b..eb3bb5ad1d0 100644 --- a/service/bedrockruntime/generated.json +++ b/service/bedrockruntime/generated.json @@ -15,6 +15,7 @@ "api_op_CountTokens.go", "api_op_GetAsyncInvoke.go", "api_op_InvokeModel.go", + "api_op_InvokeModelWithBidirectionalStream.go", "api_op_InvokeModelWithResponseStream.go", "api_op_ListAsyncInvokes.go", "api_op_StartAsyncInvoke.go", diff --git a/service/bedrockruntime/go.mod b/service/bedrockruntime/go.mod index 920918f2082..e0e553692a4 100644 --- a/service/bedrockruntime/go.mod +++ b/service/bedrockruntime/go.mod @@ -17,3 +17,5 @@ replace github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream => ../../aws/proto replace github.com/aws/aws-sdk-go-v2/internal/configsources => ../../internal/configsources/ replace github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 => ../../internal/endpoints/v2/ + +replace github.com/aws/smithy-go => ../../../smithy-go \ No newline at end of file diff --git a/service/bedrockruntime/serializers.go b/service/bedrockruntime/serializers.go index 884a7ead009..78f1639ccb1 100644 --- a/service/bedrockruntime/serializers.go +++ b/service/bedrockruntime/serializers.go @@ -6,6 +6,8 @@ import ( "bytes" "context" "fmt" + "github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream" + "github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream/eventstreamapi" "github.com/aws/aws-sdk-go-v2/service/bedrockruntime/document" internaldocument "github.com/aws/aws-sdk-go-v2/service/bedrockruntime/internal/document" "github.com/aws/aws-sdk-go-v2/service/bedrockruntime/types" @@ -751,6 +753,79 @@ func awsRestjson1_serializeOpHttpBindingsInvokeModelInput(v *InvokeModelInput, e return nil } +type awsRestjson1_serializeOpInvokeModelWithBidirectionalStream struct { +} + +func (*awsRestjson1_serializeOpInvokeModelWithBidirectionalStream) ID() string { + return "OperationSerializer" +} + +func (m *awsRestjson1_serializeOpInvokeModelWithBidirectionalStream) HandleSerialize(ctx context.Context, in middleware.SerializeInput, next middleware.SerializeHandler) ( + out middleware.SerializeOutput, metadata middleware.Metadata, err error, +) { + _, span := tracing.StartSpan(ctx, "OperationSerializer") + endTimer := startMetricTimer(ctx, "client.call.serialization_duration") + defer endTimer() + defer span.End() + request, ok := in.Request.(*smithyhttp.Request) + if !ok { + return out, metadata, &smithy.SerializationError{Err: fmt.Errorf("unknown transport type %T", in.Request)} + } + + input, ok := in.Parameters.(*InvokeModelWithBidirectionalStreamInput) + _ = input + if !ok { + return out, metadata, &smithy.SerializationError{Err: fmt.Errorf("unknown input parameters type %T", in.Parameters)} + } + + opPath, opQuery := httpbinding.SplitURI("/model/{modelId}/invoke-with-bidirectional-stream") + request.URL.Path = smithyhttp.JoinPath(request.URL.Path, opPath) + request.URL.RawQuery = smithyhttp.JoinRawQuery(request.URL.RawQuery, opQuery) + request.Method = "POST" + var restEncoder *httpbinding.Encoder + if request.URL.RawPath == "" { + restEncoder, err = httpbinding.NewEncoder(request.URL.Path, request.URL.RawQuery, request.Header) + } else { + request.URL.RawPath = smithyhttp.JoinPath(request.URL.RawPath, opPath) + restEncoder, err = httpbinding.NewEncoderWithRawPath(request.URL.Path, request.URL.RawPath, request.URL.RawQuery, request.Header) + } + + if err != nil { + return out, metadata, &smithy.SerializationError{Err: err} + } + + if err := awsRestjson1_serializeOpHttpBindingsInvokeModelWithBidirectionalStreamInput(input, restEncoder); err != nil { + return out, metadata, &smithy.SerializationError{Err: err} + } + + restEncoder.SetHeader("Content-Type").String("application/vnd.amazon.eventstream") + + if request.Request, err = restEncoder.Encode(request.Request); err != nil { + return out, metadata, &smithy.SerializationError{Err: err} + } + in.Request = request + + endTimer() + span.End() + return next.HandleSerialize(ctx, in) +} +func awsRestjson1_serializeOpHttpBindingsInvokeModelWithBidirectionalStreamInput(v *InvokeModelWithBidirectionalStreamInput, encoder *httpbinding.Encoder) error { + if v == nil { + return fmt.Errorf("unsupported serialization of nil %T", v) + } + + if v.ModelId == nil || len(*v.ModelId) == 0 { + return &smithy.SerializationError{Err: fmt.Errorf("input member modelId must not be empty")} + } + if v.ModelId != nil { + if err := encoder.SetURI("modelId").String(*v.ModelId); err != nil { + return err + } + } + + return nil +} + type awsRestjson1_serializeOpInvokeModelWithResponseStream struct { } @@ -1066,6 +1141,48 @@ func awsRestjson1_serializeOpDocumentStartAsyncInvokeInput(v *StartAsyncInvokeIn return nil } +func awsRestjson1_serializeEventStreamInvokeModelWithBidirectionalStreamInput(v types.InvokeModelWithBidirectionalStreamInput, msg *eventstream.Message) error { + if v == nil { + return fmt.Errorf("unexpected serialization of nil %T", v) + } + + switch vv := v.(type) { + case *types.InvokeModelWithBidirectionalStreamInputMemberChunk: + msg.Headers.Set(eventstreamapi.EventTypeHeader, eventstream.StringValue("chunk")) + return awsRestjson1_serializeEventMessageBidirectionalInputPayloadPart(&vv.Value, msg) + + default: + return fmt.Errorf("unexpected event message type: %v", v) + + } +} +func awsRestjson1_serializeEventMessageBidirectionalInputPayloadPart(v *types.BidirectionalInputPayloadPart, msg *eventstream.Message) error { + if v == nil { + return fmt.Errorf("unexpected serialization of nil %T", v) + } + + msg.Headers.Set(eventstreamapi.MessageTypeHeader, eventstream.StringValue(eventstreamapi.EventMessageType)) + msg.Headers.Set(eventstreamapi.ContentTypeHeader, eventstream.StringValue("application/json")) + jsonEncoder := smithyjson.NewEncoder() + if err := awsRestjson1_serializeDocumentBidirectionalInputPayloadPart(v, jsonEncoder.Value); err != nil { + return err + } + msg.Payload = jsonEncoder.Bytes() + return nil +} + +func awsRestjson1_serializeDocumentBidirectionalInputPayloadPart(v *types.BidirectionalInputPayloadPart, value smithyjson.Value) error { + object := value.Object() + defer object.Close() + + if v.Bytes != nil { + ok := object.Key("bytes") + ok.Base64EncodeBytes(v.Bytes) + } + + return nil +} + func awsRestjson1_serializeDocumentAdditionalModelResponseFieldPaths(v []string, value smithyjson.Value) error { array := value.Array() defer array.Close() diff --git a/service/bedrockruntime/snapshot_test.go b/service/bedrockruntime/snapshot_test.go index 4af9657f829..5c0ecb68ab5 100644 --- a/service/bedrockruntime/snapshot_test.go +++ b/service/bedrockruntime/snapshot_test.go @@ -134,6 +134,18 @@ func TestCheckSnapshot_InvokeModel(t *testing.T) { } } +func TestCheckSnapshot_InvokeModelWithBidirectionalStream(t *testing.T) { + svc := New(Options{}) + _, err := svc.InvokeModelWithBidirectionalStream(context.Background(), nil, func(o *Options) { + o.APIOptions = append(o.APIOptions, func(stack *middleware.Stack) error { + return testSnapshot(stack, "InvokeModelWithBidirectionalStream") + }) + }) + if _, ok := err.(snapshotOK); !ok && err != nil { + t.Fatal(err) + } +} + func TestCheckSnapshot_InvokeModelWithResponseStream(t *testing.T) { svc := New(Options{}) _, err := svc.InvokeModelWithResponseStream(context.Background(), nil, func(o *Options) { @@ -241,6 +253,18 @@ func TestUpdateSnapshot_InvokeModel(t *testing.T) { } } +func TestUpdateSnapshot_InvokeModelWithBidirectionalStream(t *testing.T) { + svc := New(Options{}) + _, err := svc.InvokeModelWithBidirectionalStream(context.Background(), nil, func(o *Options) { + o.APIOptions = append(o.APIOptions, func(stack *middleware.Stack) error { + return updateSnapshot(stack, "InvokeModelWithBidirectionalStream") + }) + }) + if _, ok := err.(snapshotOK); !ok && err != nil { + t.Fatal(err) + } +} + func TestUpdateSnapshot_InvokeModelWithResponseStream(t *testing.T) { svc := New(Options{}) _, err := svc.InvokeModelWithResponseStream(context.Background(), nil, func(o *Options) { diff --git a/service/bedrockruntime/sra_operation_order_test.go b/service/bedrockruntime/sra_operation_order_test.go index 95fa3e1c77a..78229e007b1 100644 --- a/service/bedrockruntime/sra_operation_order_test.go +++ b/service/bedrockruntime/sra_operation_order_test.go @@ -229,6 +229,41 @@ func TestOpInvokeModelSRAOperationOrder(t *testing.T) { t.Errorf("order mismatch:\nexpect: %v\nactual: %v\nall: %v", expect, actual, all) } } +func TestOpInvokeModelWithBidirectionalStreamSRAOperationOrder(t *testing.T) { + expect := []string{ + "OperationSerializer", + "Retry", + "ResolveAuthScheme", + "GetIdentity", + "ResolveEndpointV2", + "Signing", + "OperationDeserializer", + } + + var captured middleware.Stack + svc := New(Options{ + APIOptions: []func(*middleware.Stack) error{ + captureMiddlewareStack(&captured), + }, + }) + _, err := svc.InvokeModelWithBidirectionalStream(context.Background(), nil) + if err != nil && !errors.Is(err, errTestReturnEarly) { + t.Fatalf("unexpected error: %v", err) + } + + var actual, all []string + for _, step := range strings.Split(captured.String(), "\n") { + trimmed := strings.TrimSpace(step) + all = append(all, trimmed) + if slices.Contains(expect, trimmed) { + actual = append(actual, trimmed) + } + } + + if !slices.Equal(expect, actual) { + t.Errorf("order mismatch:\nexpect: %v\nactual: %v\nall: %v", expect, actual, all) + } +} func TestOpInvokeModelWithResponseStreamSRAOperationOrder(t *testing.T) { expect := []string{ "OperationSerializer", diff --git a/service/bedrockruntime/types/types.go b/service/bedrockruntime/types/types.go index 85b18b495c9..397a2a77a6a 100644 --- a/service/bedrockruntime/types/types.go +++ b/service/bedrockruntime/types/types.go @@ -127,6 +127,25 @@ type AutoToolChoice struct { noSmithyDocumentSerde } +// Payload content for the bidirectional input. The input is an audio stream. +type BidirectionalInputPayloadPart struct { + + // The audio content for the bidirectional input. + Bytes []byte + + noSmithyDocumentSerde +} + +// Output from the bidirectional stream. The output is speech and a text +// transcription. +type BidirectionalOutputPayloadPart struct { + + // The speech output of the bidirectional stream. + Bytes []byte + + noSmithyDocumentSerde +} + // Defines a section of content to be cached for reuse in subsequent API calls. type CachePointBlock struct { @@ -2006,6 +2025,45 @@ type InvokeModelTokensRequest struct { noSmithyDocumentSerde } +// Payload content, the speech chunk, for the bidirectional input of the +// invocation step. +// +// The following types satisfy this interface: +// +// InvokeModelWithBidirectionalStreamInputMemberChunk +type InvokeModelWithBidirectionalStreamInput interface { + isInvokeModelWithBidirectionalStreamInput() +} + +// The audio chunk that is used as input for the invocation step. +type InvokeModelWithBidirectionalStreamInputMemberChunk struct { + Value BidirectionalInputPayloadPart + + noSmithyDocumentSerde +} + +func (*InvokeModelWithBidirectionalStreamInputMemberChunk) isInvokeModelWithBidirectionalStreamInput() { +} + +// Output from the bidirectional stream that was used for model invocation. +// +// The following types satisfy this interface: +// +// InvokeModelWithBidirectionalStreamOutputMemberChunk +type InvokeModelWithBidirectionalStreamOutput interface { + isInvokeModelWithBidirectionalStreamOutput() +} + +// The speech chunk that was provided as output from the invocation step. +type InvokeModelWithBidirectionalStreamOutputMemberChunk struct { + Value BidirectionalOutputPayloadPart + + noSmithyDocumentSerde +} + +func (*InvokeModelWithBidirectionalStreamOutputMemberChunk) isInvokeModelWithBidirectionalStreamOutput() { +} + // A message input, or returned from, a call to [Converse] or [ConverseStream]. // // [Converse]: https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_Converse.html @@ -2839,32 +2897,34 @@ type UnknownUnionMember struct { noSmithyDocumentSerde } -func (*UnknownUnionMember) isAsyncInvokeOutputDataConfig() {} -func (*UnknownUnionMember) isCitationGeneratedContent() {} -func (*UnknownUnionMember) isCitationLocation() {} -func (*UnknownUnionMember) isCitationSourceContent() {} -func (*UnknownUnionMember) isContentBlock() {} -func (*UnknownUnionMember) isContentBlockDelta() {} -func (*UnknownUnionMember) isContentBlockStart() {} -func (*UnknownUnionMember) isConverseOutput() {} -func (*UnknownUnionMember) isConverseStreamOutput() {} -func (*UnknownUnionMember) isCountTokensInput() {} -func (*UnknownUnionMember) isDocumentContentBlock() {} -func (*UnknownUnionMember) isDocumentSource() {} -func (*UnknownUnionMember) isGuardrailAutomatedReasoningFinding() {} -func (*UnknownUnionMember) isGuardrailContentBlock() {} -func (*UnknownUnionMember) isGuardrailConverseContentBlock() {} -func (*UnknownUnionMember) isGuardrailConverseImageSource() {} -func (*UnknownUnionMember) isGuardrailImageSource() {} -func (*UnknownUnionMember) isImageSource() {} -func (*UnknownUnionMember) isPromptVariableValues() {} -func (*UnknownUnionMember) isReasoningContentBlock() {} -func (*UnknownUnionMember) isReasoningContentBlockDelta() {} -func (*UnknownUnionMember) isResponseStream() {} -func (*UnknownUnionMember) isSystemContentBlock() {} -func (*UnknownUnionMember) isTool() {} -func (*UnknownUnionMember) isToolChoice() {} -func (*UnknownUnionMember) isToolInputSchema() {} -func (*UnknownUnionMember) isToolResultBlockDelta() {} -func (*UnknownUnionMember) isToolResultContentBlock() {} -func (*UnknownUnionMember) isVideoSource() {} +func (*UnknownUnionMember) isAsyncInvokeOutputDataConfig() {} +func (*UnknownUnionMember) isCitationGeneratedContent() {} +func (*UnknownUnionMember) isCitationLocation() {} +func (*UnknownUnionMember) isCitationSourceContent() {} +func (*UnknownUnionMember) isContentBlock() {} +func (*UnknownUnionMember) isContentBlockDelta() {} +func (*UnknownUnionMember) isContentBlockStart() {} +func (*UnknownUnionMember) isConverseOutput() {} +func (*UnknownUnionMember) isConverseStreamOutput() {} +func (*UnknownUnionMember) isCountTokensInput() {} +func (*UnknownUnionMember) isDocumentContentBlock() {} +func (*UnknownUnionMember) isDocumentSource() {} +func (*UnknownUnionMember) isGuardrailAutomatedReasoningFinding() {} +func (*UnknownUnionMember) isGuardrailContentBlock() {} +func (*UnknownUnionMember) isGuardrailConverseContentBlock() {} +func (*UnknownUnionMember) isGuardrailConverseImageSource() {} +func (*UnknownUnionMember) isGuardrailImageSource() {} +func (*UnknownUnionMember) isImageSource() {} +func (*UnknownUnionMember) isInvokeModelWithBidirectionalStreamInput() {} +func (*UnknownUnionMember) isInvokeModelWithBidirectionalStreamOutput() {} +func (*UnknownUnionMember) isPromptVariableValues() {} +func (*UnknownUnionMember) isReasoningContentBlock() {} +func (*UnknownUnionMember) isReasoningContentBlockDelta() {} +func (*UnknownUnionMember) isResponseStream() {} +func (*UnknownUnionMember) isSystemContentBlock() {} +func (*UnknownUnionMember) isTool() {} +func (*UnknownUnionMember) isToolChoice() {} +func (*UnknownUnionMember) isToolInputSchema() {} +func (*UnknownUnionMember) isToolResultBlockDelta() {} +func (*UnknownUnionMember) isToolResultContentBlock() {} +func (*UnknownUnionMember) isVideoSource() {} diff --git a/service/bedrockruntime/types/types_exported_test.go b/service/bedrockruntime/types/types_exported_test.go index 562f9aa7f8f..c283c06b0a1 100644 --- a/service/bedrockruntime/types/types_exported_test.go +++ b/service/bedrockruntime/types/types_exported_test.go @@ -480,6 +480,42 @@ func ExampleImageSource_outputUsage() { var _ *types.S3Location var _ []byte +func ExampleInvokeModelWithBidirectionalStreamInput_outputUsage() { + var union types.InvokeModelWithBidirectionalStreamInput + // type switches can be used to check the union value + switch v := union.(type) { + case *types.InvokeModelWithBidirectionalStreamInputMemberChunk: + _ = v.Value // Value is types.BidirectionalInputPayloadPart + + case *types.UnknownUnionMember: + fmt.Println("unknown tag:", v.Tag) + + default: + fmt.Println("union is nil or unknown type") + + } +} + +var _ *types.BidirectionalInputPayloadPart + +func ExampleInvokeModelWithBidirectionalStreamOutput_outputUsage() { + var union types.InvokeModelWithBidirectionalStreamOutput + // type switches can be used to check the union value + switch v := union.(type) { + case *types.InvokeModelWithBidirectionalStreamOutputMemberChunk: + _ = v.Value // Value is types.BidirectionalOutputPayloadPart + + case *types.UnknownUnionMember: + fmt.Println("unknown tag:", v.Tag) + + default: + fmt.Println("union is nil or unknown type") + + } +} + +var _ *types.BidirectionalOutputPayloadPart + func ExamplePromptVariableValues_outputUsage() { var union types.PromptVariableValues // type switches can be used to check the union value diff --git a/service/bedrockruntime/validators.go b/service/bedrockruntime/validators.go index 2b4f3cc2361..782c8dbebca 100644 --- a/service/bedrockruntime/validators.go +++ b/service/bedrockruntime/validators.go @@ -130,6 +130,26 @@ func (m *validateOpInvokeModel) HandleInitialize(ctx context.Context, in middlew return next.HandleInitialize(ctx, in) } +type validateOpInvokeModelWithBidirectionalStream struct { +} + +func (*validateOpInvokeModelWithBidirectionalStream) ID() string { + return "OperationInputValidation" +} + +func (m *validateOpInvokeModelWithBidirectionalStream) HandleInitialize(ctx context.Context, in middleware.InitializeInput, next middleware.InitializeHandler) ( + out middleware.InitializeOutput, metadata middleware.Metadata, err error, +) { + input, ok := in.Parameters.(*InvokeModelWithBidirectionalStreamInput) + if !ok { + return out, metadata, fmt.Errorf("unknown input parameters type %T", in.Parameters) + } + if err := validateOpInvokeModelWithBidirectionalStreamInput(input); err != nil { + return out, metadata, err + } + return next.HandleInitialize(ctx, in) +} + type validateOpInvokeModelWithResponseStream struct { } @@ -194,6 +214,10 @@ func addOpInvokeModelValidationMiddleware(stack *middleware.Stack) error { return stack.Initialize.Add(&validateOpInvokeModel{}, middleware.After) } +func addOpInvokeModelWithBidirectionalStreamValidationMiddleware(stack *middleware.Stack) error { + return stack.Initialize.Add(&validateOpInvokeModelWithBidirectionalStream{}, middleware.After) +} + func addOpInvokeModelWithResponseStreamValidationMiddleware(stack *middleware.Stack) error { return stack.Initialize.Add(&validateOpInvokeModelWithResponseStream{}, middleware.After) } @@ -1289,6 +1313,21 @@ func validateOpInvokeModelInput(v *InvokeModelInput) error { } } +func validateOpInvokeModelWithBidirectionalStreamInput(v *InvokeModelWithBidirectionalStreamInput) error { + if v == nil { + return nil + } + invalidParams := smithy.InvalidParamsError{Context: "InvokeModelWithBidirectionalStreamInput"} + if v.ModelId == nil { + invalidParams.Add(smithy.NewErrParamRequired("ModelId")) + } + if invalidParams.Len() > 0 { + return invalidParams + } else { + return nil + } +} + func validateOpInvokeModelWithResponseStreamInput(v *InvokeModelWithResponseStreamInput) error { if v == nil { return nil