diff --git a/internal/codecutil/encoding.go b/internal/codecutil/encoding.go index aafb1e997c..6a0988f288 100644 --- a/internal/codecutil/encoding.go +++ b/internal/codecutil/encoding.go @@ -28,10 +28,12 @@ type MarshalError struct { // Error implements the error interface. func (e MarshalError) Error() string { - return fmt.Sprintf("cannot transform type %s to a BSON Document: %v", + return fmt.Sprintf("cannot marshal type %q to a BSON Document: %v", reflect.TypeOf(e.Value), e.Err) } +func (e MarshalError) Unwrap() error { return e.Err } + // EncoderFn is used to functionally construct an encoder for marshaling values. type EncoderFn func(io.Writer) *bson.Encoder diff --git a/mongo/change_stream.go b/mongo/change_stream.go index 009e68e4e4..8b3be5aad9 100644 --- a/mongo/change_stream.go +++ b/mongo/change_stream.go @@ -376,7 +376,7 @@ AggregateExecuteLoop: } } if err != nil { - cs.err = replaceErrors(err) + cs.err = wrapErrors(err) return cs.err } @@ -384,7 +384,7 @@ AggregateExecuteLoop: cr.Server = server cs.cursor, cs.err = driver.NewBatchCursor(cr, cs.sess, cs.client.clock, cs.cursorOptions) - if cs.err = replaceErrors(cs.err); cs.err != nil { + if cs.err = wrapErrors(cs.err); cs.err != nil { return cs.Err() } @@ -597,13 +597,13 @@ func (cs *ChangeStream) Decode(val interface{}) error { // Err returns the last error seen by the change stream, or nil if no errors has occurred. func (cs *ChangeStream) Err() error { if cs.err != nil { - return replaceErrors(cs.err) + return wrapErrors(cs.err) } if cs.cursor == nil { return nil } - return replaceErrors(cs.cursor.Err()) + return wrapErrors(cs.cursor.Err()) } // Close closes this change stream and the underlying cursor. Next and TryNext must not be called after Close has been @@ -619,7 +619,7 @@ func (cs *ChangeStream) Close(ctx context.Context) error { return nil // cursor is already closed } - cs.err = replaceErrors(cs.cursor.Close(ctx)) + cs.err = wrapErrors(cs.cursor.Close(ctx)) cs.cursor = nil return cs.Err() } @@ -678,7 +678,7 @@ func (cs *ChangeStream) next(ctx context.Context, nonBlocking bool) bool { if len(cs.batch) == 0 { cs.loopNext(ctx, nonBlocking) if cs.err != nil { - cs.err = replaceErrors(cs.err) + cs.err = wrapErrors(cs.err) return false } if len(cs.batch) == 0 { @@ -719,7 +719,7 @@ func (cs *ChangeStream) loopNext(ctx context.Context, nonBlocking bool) { return } - cs.err = replaceErrors(cs.cursor.Err()) + cs.err = wrapErrors(cs.cursor.Err()) if cs.err == nil { // Check if cursor is alive if cs.ID() == 0 { diff --git a/mongo/client.go b/mongo/client.go index f3bf5ed5fb..885246a211 100644 --- a/mongo/client.go +++ b/mongo/client.go @@ -239,7 +239,7 @@ func newClient(opts ...*options.ClientOptions) (*Client, error) { if client.deployment == nil { client.deployment, err = topology.New(cfg) if err != nil { - return nil, replaceErrors(err) + return nil, wrapErrors(err) } } @@ -261,7 +261,7 @@ func (c *Client) connect() error { if connector, ok := c.deployment.(driver.Connector); ok { err := connector.Connect() if err != nil { - return replaceErrors(err) + return wrapErrors(err) } } @@ -293,7 +293,7 @@ func (c *Client) connect() error { if subscriber, ok := c.deployment.(driver.Subscriber); ok { sub, err := subscriber.Subscribe() if err != nil { - return replaceErrors(err) + return wrapErrors(err) } updateChan = sub.Updates } @@ -350,7 +350,7 @@ func (c *Client) Disconnect(ctx context.Context) error { } if disconnector, ok := c.deployment.(driver.Disconnector); ok { - return replaceErrors(disconnector.Disconnect(ctx)) + return wrapErrors(disconnector.Disconnect(ctx)) } return nil @@ -381,7 +381,7 @@ func (c *Client) Ping(ctx context.Context, rp *readpref.ReadPref) error { {"ping", 1}, }, options.RunCmd().SetReadPreference(rp)) - return replaceErrors(res.Err()) + return wrapErrors(res.Err()) } // StartSession starts a new session configured with the given options. @@ -434,7 +434,7 @@ func (c *Client) StartSession(opts ...options.Lister[options.SessionOptions]) (* sess, err := session.NewClientSession(c.sessionPool, c.id, coreOpts) if err != nil { - return nil, replaceErrors(err) + return nil, wrapErrors(err) } return &Session{ @@ -741,7 +741,7 @@ func (c *Client) ListDatabases(ctx context.Context, filter interface{}, opts ... err = op.Execute(ctx) if err != nil { - return ListDatabasesResult{}, replaceErrors(err) + return ListDatabasesResult{}, wrapErrors(err) } return newListDatabasesResultFromOperation(op.Result()), nil @@ -965,7 +965,7 @@ func (c *Client) BulkWrite(ctx context.Context, writes []ClientBulkWrite, op.result.Acknowledged = acknowledged op.result.HasVerboseResults = !op.errorsOnly err = op.execute(ctx) - return &op.result, replaceErrors(err) + return &op.result, wrapErrors(err) } // newLogger will use the LoggerOptions to create an internal logger and publish diff --git a/mongo/collection.go b/mongo/collection.go index d7693c4245..6acadc285e 100644 --- a/mongo/collection.go +++ b/mongo/collection.go @@ -248,7 +248,7 @@ func (coll *Collection) BulkWrite(ctx context.Context, models []WriteModel, err = op.execute(ctx) - return &op.result, replaceErrors(err) + return &op.result, wrapErrors(err) } func (coll *Collection) insert( @@ -1049,15 +1049,15 @@ func aggregate(a aggregateParams, opts ...options.Lister[options.AggregateOption if errors.As(err, &wce) && wce.WriteConcernError != nil { return nil, *convertDriverWriteConcernError(wce.WriteConcernError) } - return nil, replaceErrors(err) + return nil, wrapErrors(err) } bc, err := op.Result(cursorOpts) if err != nil { - return nil, replaceErrors(err) + return nil, wrapErrors(err) } cursor, err := newCursorWithSession(bc, a.client.bsonOpts, a.registry, sess) - return cursor, replaceErrors(err) + return cursor, wrapErrors(err) } // CountDocuments returns the number of documents in the collection. For a fast count of the documents in the @@ -1132,7 +1132,7 @@ func (coll *Collection) CountDocuments(ctx context.Context, filter interface{}, err = op.Execute(ctx) if err != nil { - return 0, replaceErrors(err) + return 0, wrapErrors(err) } batch := op.ResultCursorResponse().FirstBatch @@ -1213,7 +1213,7 @@ func (coll *Collection) EstimatedDocumentCount( op.Retry(retry) err = op.Execute(ctx) - return op.Result().N, replaceErrors(err) + return op.Result().N, wrapErrors(err) } // Distinct executes a distinct command to find the unique values for a specified field in the collection. @@ -1302,7 +1302,7 @@ func (coll *Collection) Distinct( err = op.Execute(ctx) if err != nil { - return &DistinctResult{err: replaceErrors(err)} + return &DistinctResult{err: wrapErrors(err)} } arr, ok := op.Result().Values.ArrayOK() @@ -1504,12 +1504,12 @@ func (coll *Collection) find( op = op.Retry(retry) if err = op.Execute(ctx); err != nil { - return nil, replaceErrors(err) + return nil, wrapErrors(err) } bc, err := op.Result(cursorOpts) if err != nil { - return nil, replaceErrors(err) + return nil, wrapErrors(err) } return newCursorWithSession(bc, coll.bsonOpts, coll.registry, sess) } @@ -1560,7 +1560,7 @@ func (coll *Collection) FindOne(ctx context.Context, filter interface{}, cur: cursor, bsonOpts: coll.bsonOpts, reg: coll.registry, - err: replaceErrors(err), + err: wrapErrors(err), } } @@ -2044,7 +2044,7 @@ func (coll *Collection) drop(ctx context.Context) error { // ignore namespace not found errors var driverErr driver.Error if !errors.As(err, &driverErr) || !driverErr.NamespaceNotFound() { - return replaceErrors(err) + return wrapErrors(err) } return nil } diff --git a/mongo/cursor.go b/mongo/cursor.go index ee0e848c64..e8ab9caa11 100644 --- a/mongo/cursor.go +++ b/mongo/cursor.go @@ -194,7 +194,7 @@ func (c *Cursor) next(ctx context.Context, nonBlocking bool) bool { // If we don't have a next batch if !c.bc.Next(ctx) { // Do we have an error? If so we return false. - c.err = replaceErrors(c.bc.Err()) + c.err = wrapErrors(c.bc.Err()) if c.err != nil { return false } @@ -289,7 +289,7 @@ func (c *Cursor) Err() error { return c.err } // the first call, any subsequent calls will not change the state. func (c *Cursor) Close(ctx context.Context) error { defer c.closeImplicitSession() - return replaceErrors(c.bc.Close(ctx)) + return wrapErrors(c.bc.Close(ctx)) } // All iterates the cursor and decodes each document into results. The results parameter must be a pointer to a slice. @@ -336,7 +336,7 @@ func (c *Cursor) All(ctx context.Context, results interface{}) error { batch = c.bc.Batch() } - if err = replaceErrors(c.bc.Err()); err != nil { + if err = wrapErrors(c.bc.Err()); err != nil { return err } diff --git a/mongo/database.go b/mongo/database.go index 6aa1627187..a3269ddf16 100644 --- a/mongo/database.go +++ b/mongo/database.go @@ -288,7 +288,7 @@ func (db *Database) RunCommandCursor( op, sess, err := db.processRunCommand(ctx, runCommand, true, opts...) if err != nil { closeImplicitSession(sess) - return nil, replaceErrors(err) + return nil, wrapErrors(err) } if err = op.Execute(ctx); err != nil { @@ -297,16 +297,16 @@ func (db *Database) RunCommandCursor( return nil, errors.New( "database response does not contain a cursor; try using RunCommand instead") } - return nil, replaceErrors(err) + return nil, wrapErrors(err) } bc, err := op.ResultCursor() if err != nil { closeImplicitSession(sess) - return nil, replaceErrors(err) + return nil, wrapErrors(err) } cursor, err := newCursorWithSession(bc, db.bsonOpts, db.registry, sess) - return cursor, replaceErrors(err) + return cursor, wrapErrors(err) } // Drop drops the database on the server. This method ignores "namespace not found" errors so it is safe to drop @@ -347,7 +347,7 @@ func (db *Database) Drop(ctx context.Context) error { var driverErr driver.Error if err != nil && (!errors.As(err, &driverErr) || !driverErr.NamespaceNotFound()) { - return replaceErrors(err) + return wrapErrors(err) } return nil } @@ -497,16 +497,16 @@ func (db *Database) ListCollections( err = op.Execute(ctx) if err != nil { closeImplicitSession(sess) - return nil, replaceErrors(err) + return nil, wrapErrors(err) } bc, err := op.Result(cursorOpts) if err != nil { closeImplicitSession(sess) - return nil, replaceErrors(err) + return nil, wrapErrors(err) } cursor, err := newCursorWithSession(bc, db.bsonOpts, db.registry, sess) - return cursor, replaceErrors(err) + return cursor, wrapErrors(err) } // ListCollectionNames executes a listCollections command and returns a slice containing the names of the collections @@ -944,7 +944,7 @@ func (db *Database) executeCreateOperation(ctx context.Context, op *operation.Cr Deployment(db.client.deployment). Crypt(db.client.cryptFLE) - return replaceErrors(op.Execute(ctx)) + return wrapErrors(op.Execute(ctx)) } // GridFSBucket is used to construct a GridFS bucket which can be used as a diff --git a/mongo/errors.go b/mongo/errors.go index 3d74c6495c..a353dbe707 100644 --- a/mongo/errors.go +++ b/mongo/errors.go @@ -12,6 +12,7 @@ import ( "errors" "fmt" "net" + "reflect" "strings" "go.mongodb.org/mongo-driver/v2/bson" @@ -64,7 +65,11 @@ func (e ErrMapForOrderedArgument) Error() string { return fmt.Sprintf("multi-key map passed in for ordered parameter %v", e.ParamName) } -func replaceErrors(err error) error { +// wrapErrors wraps error types and values that are defined in "internal" and +// "x" packages with error types and values that are defined in this package. +// That allows users to inspect the errors using errors.Is/errors.As without +// relying on "internal" or "x" packages. +func wrapErrors(err error) error { // Return nil when err is nil to avoid costly reflection logic below. if err == nil { return nil @@ -78,26 +83,42 @@ func replaceErrors(err error) error { if errors.Is(err, driver.ErrUnacknowledgedWrite) { return nil } - if errors.Is(err, topology.ErrTopologyClosed) { return ErrClientDisconnected } - if de, ok := err.(driver.Error); ok { + + var de driver.Error + if errors.As(err, &de) { return CommandError{ Code: de.Code, Message: de.Message, Labels: de.Labels, Name: de.Name, - Wrapped: de.Wrapped, + Wrapped: err, Raw: bson.Raw(de.Raw), + + // Set wrappedMsgOnly=true here so that the Code and Message are not + // repeated multiple times in the error string. We expect that the + // wrapped driver.Error already contains that info in the error + // string. + wrappedMsgOnly: true, } } - if qe, ok := err.(driver.QueryFailureError); ok { + + var qe driver.QueryFailureError + if errors.As(err, &qe) { // qe.Message is "command failure" ce := CommandError{ Name: qe.Message, - Wrapped: qe.Wrapped, + Wrapped: err, Raw: bson.Raw(qe.Response), + + // Don't set wrappedMsgOnly=true here because the code below adds + // additional error context that is not provided by the + // driver.QueryFailureError. Additionally, driver.QueryFailureError + // is only returned when parsing OP_QUERY replies (OP_REPLY), so + // it's unlikely this block will ever be run now that MongoDB 3.6 is + // no longer supported. } dollarErr, err := qe.Response.LookupErr("$err") @@ -111,18 +132,37 @@ func replaceErrors(err error) error { return ce } - if me, ok := err.(mongocrypt.Error); ok { - return MongocryptError{Code: me.Code, Message: me.Message} + + var me mongocrypt.Error + if errors.As(err, &me) { + return MongocryptError{ + Code: me.Code, + Message: me.Message, + wrapped: err, + + // Set wrappedMsgOnly=true here so that the Code and Message are not + // repeated multiple times in the error string. We expect that the + // wrapped mongocrypt.Error already contains that info in the error + // string. + wrappedMsgOnly: true, + } } if errors.Is(err, codecutil.ErrNilValue) { return ErrNilValue } - if marshalErr, ok := err.(codecutil.MarshalError); ok { + var marshalErr codecutil.MarshalError + if errors.As(err, &marshalErr) { return MarshalError{ Value: marshalErr.Value, - Err: marshalErr.Err, + Err: err, + + // Set wrappedMsgOnly=true here so that the Value is not repeated + // multiple times in the error string. We expect that the wrapped + // codecutil.MarshalError already contains that info in the error + // string. + wrappedMsgOnly: true, } } @@ -195,17 +235,67 @@ func IsNetworkError(err error) bool { return errorHasLabel(err, "NetworkError") } +// MarshalError is returned when attempting to marshal a value into a document +// results in an error. +type MarshalError struct { + Value interface{} + Err error + + // If wrappedMsgOnly is true, Error() only returns the error message from + // the "Err" error. + // + // This is typically only set by the wrapErrors function, which uses + // MarshalError to wrap codecutil.MarshalError, allowing users to access the + // "Value" from the underlying error but preventing duplication in the error + // string. + wrappedMsgOnly bool +} + +// Error implements the error interface. +func (me MarshalError) Error() string { + // If the MarshalError was created with wrappedMsgOnly=true, only return the + // error from the wrapped error. See the MarshalError.wrappedMsgOnly docs + // for more info. + if me.wrappedMsgOnly { + return me.Err.Error() + } + + return fmt.Sprintf("cannot marshal type %s to a BSON Document: %v", reflect.TypeOf(me.Value), me.Err) +} + +func (me MarshalError) Unwrap() error { return me.Err } + // MongocryptError represents an libmongocrypt error during in-use encryption. type MongocryptError struct { Code int32 Message string + wrapped error + + // If wrappedMsgOnly is true, Error() only returns the error message from + // the "wrapped" error. + // + // This is typically only set by the wrapErrors function, which uses + // MarshalError to wrap mongocrypt.Error, allowing users to access the + // "Code" and "Message" from the underlying error but preventing duplication + // in the error string. + wrappedMsgOnly bool } // Error implements the error interface. func (m MongocryptError) Error() string { + // If the MongocryptError was created with wrappedMsgOnly=true, only return + // the error from the wrapped error. See the MongocryptError.wrappedMsgOnly + // docs for more info. + if m.wrappedMsgOnly { + return m.wrapped.Error() + } + return fmt.Sprintf("mongocrypt error %d: %v", m.Code, m.Message) } +// Unwrap returns the underlying error. +func (m MongocryptError) Unwrap() error { return m.wrapped } + // EncryptionKeyVaultError represents an error while communicating with the key vault collection during in-use // encryption. type EncryptionKeyVaultError struct { @@ -289,14 +379,38 @@ type CommandError struct { Name string // A human-readable name corresponding to the error code Wrapped error // The underlying error, if one exists. Raw bson.Raw // The original server response containing the error. + + // If wrappedMsgOnly is true, Error() only returns the error message from + // the "Wrapped" error. + // + // This is typically only set by the wrapErrors function, which uses + // CommandError to wrap driver.Error, allowing users to access the "Code", + // "Message", "Labels", "Name", and "Raw" from the underlying error but + // preventing duplication in the error string. + wrappedMsgOnly bool } // Error implements the error interface. func (e CommandError) Error() string { + // If the CommandError was created with wrappedMsgOnly=true, only return the + // error from the wrapped error. See the CommandError.wrappedMsgOnly docs + // for more info. + if e.wrappedMsgOnly { + return e.Wrapped.Error() + } + + var msg string if e.Name != "" { - return fmt.Sprintf("(%v) %v", e.Name, e.Message) + msg += fmt.Sprintf("(%v)", e.Name) + } + if e.Message != "" { + msg += " " + e.Message } - return e.Message + if e.Wrapped != nil { + msg += ": " + e.Wrapped.Error() + } + + return msg } // Unwrap returns the underlying error. @@ -745,7 +859,7 @@ func processWriteError(err error) (returnResult, error) { var wce driver.WriteCommandError if !errors.As(err, &wce) { - return rrNone, replaceErrors(err) + return rrNone, wrapErrors(err) } return rrMany, WriteException{ diff --git a/mongo/index_view.go b/mongo/index_view.go index c92bb651be..a0734ebf16 100644 --- a/mongo/index_view.go +++ b/mongo/index_view.go @@ -117,16 +117,16 @@ func (iv IndexView) List(ctx context.Context, opts ...options.Lister[options.Lis return newEmptyCursor(), nil } - return nil, replaceErrors(err) + return nil, wrapErrors(err) } bc, err := op.Result(cursorOpts) if err != nil { closeImplicitSession(sess) - return nil, replaceErrors(err) + return nil, wrapErrors(err) } cursor, err := newCursorWithSession(bc, iv.coll.bsonOpts, iv.coll.registry, sess) - return cursor, replaceErrors(err) + return cursor, wrapErrors(err) } // ListSpecifications executes a List command and returns a slice of returned IndexSpecifications @@ -410,7 +410,7 @@ func (iv IndexView) drop(ctx context.Context, index any, _ ...options.Lister[opt err = op.Execute(ctx) if err != nil { - return replaceErrors(err) + return wrapErrors(err) } return nil diff --git a/mongo/mongo.go b/mongo/mongo.go index b40ce15c07..c12e8ba4e4 100644 --- a/mongo/mongo.go +++ b/mongo/mongo.go @@ -31,18 +31,6 @@ type Dialer interface { DialContext(ctx context.Context, network, address string) (net.Conn, error) } -// MarshalError is returned when attempting to marshal a value into a document -// results in an error. -type MarshalError struct { - Value interface{} - Err error -} - -// Error implements the error interface. -func (me MarshalError) Error() string { - return fmt.Sprintf("cannot marshal type %s to a BSON Document: %v", reflect.TypeOf(me.Value), me.Err) -} - // Pipeline is a type that makes creating aggregation pipelines easier. It is a // helper and is intended for serializing to BSON. // diff --git a/mongo/session.go b/mongo/session.go index 418b06d3d0..8ff0de56d1 100644 --- a/mongo/session.go +++ b/mongo/session.go @@ -292,7 +292,7 @@ func (s *Session) CommitTransaction(ctx context.Context) error { // Return error without updating transaction state if it is a timeout, as the transaction has not // actually been committed. if IsTimeout(err) { - return replaceErrors(err) + return wrapErrors(err) } s.clientSession.Committing = false commitErr := s.clientSession.CommitTransaction() @@ -301,7 +301,7 @@ func (s *Session) CommitTransaction(ctx context.Context) error { s.clientSession.UpdateCommitTransactionWriteConcern() if err != nil { - return replaceErrors(err) + return wrapErrors(err) } return commitErr }