2022-10-19 21:32:34 +08:00

1680 lines
59 KiB
Go

// Copyright (C) MongoDB, Inc. 2022-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package driver
import (
"bytes"
"context"
"errors"
"fmt"
"strconv"
"strings"
"time"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/bson/bsontype"
"go.mongodb.org/mongo-driver/bson/primitive"
"go.mongodb.org/mongo-driver/event"
"go.mongodb.org/mongo-driver/internal"
"go.mongodb.org/mongo-driver/mongo/description"
"go.mongodb.org/mongo-driver/mongo/readconcern"
"go.mongodb.org/mongo-driver/mongo/readpref"
"go.mongodb.org/mongo-driver/mongo/writeconcern"
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
"go.mongodb.org/mongo-driver/x/mongo/driver/session"
"go.mongodb.org/mongo-driver/x/mongo/driver/wiremessage"
)
const defaultLocalThreshold = 15 * time.Millisecond
var dollarCmd = [...]byte{'.', '$', 'c', 'm', 'd'}
var (
// ErrNoDocCommandResponse occurs when the server indicated a response existed, but none was found.
ErrNoDocCommandResponse = errors.New("command returned no documents")
// ErrMultiDocCommandResponse occurs when the server sent multiple documents in response to a command.
ErrMultiDocCommandResponse = errors.New("command returned multiple documents")
// ErrReplyDocumentMismatch occurs when the number of documents returned in an OP_QUERY does not match the numberReturned field.
ErrReplyDocumentMismatch = errors.New("number of documents returned does not match numberReturned field")
// ErrNonPrimaryReadPref is returned when a read is attempted in a transaction with a non-primary read preference.
ErrNonPrimaryReadPref = errors.New("read preference in a transaction must be primary")
)
const (
// maximum BSON object size when client side encryption is enabled
cryptMaxBsonObjectSize uint32 = 2097152
// minimum wire version necessary to use automatic encryption
cryptMinWireVersion int32 = 8
// minimum wire version necessary to use read snapshots
readSnapshotMinWireVersion int32 = 13
)
// RetryablePoolError is a connection pool error that can be retried while executing an operation.
type RetryablePoolError interface {
Retryable() bool
}
// InvalidOperationError is returned from Validate and indicates that a required field is missing
// from an instance of Operation.
type InvalidOperationError struct{ MissingField string }
func (err InvalidOperationError) Error() string {
return "the " + err.MissingField + " field must be set on Operation"
}
// opReply stores information returned in an OP_REPLY response from the server.
// The err field stores any error that occurred when decoding or validating the OP_REPLY response.
type opReply struct {
responseFlags wiremessage.ReplyFlag
cursorID int64
startingFrom int32
numReturned int32
documents []bsoncore.Document
err error
}
// startedInformation keeps track of all of the information necessary for monitoring started events.
type startedInformation struct {
cmd bsoncore.Document
requestID int32
cmdName string
documentSequenceIncluded bool
connID string
serverConnID *int32
redacted bool
serviceID *primitive.ObjectID
}
// finishedInformation keeps track of all of the information necessary for monitoring success and failure events.
type finishedInformation struct {
cmdName string
requestID int32
response bsoncore.Document
cmdErr error
connID string
serverConnID *int32
startTime time.Time
redacted bool
serviceID *primitive.ObjectID
}
// ResponseInfo contains the context required to parse a server response.
type ResponseInfo struct {
ServerResponse bsoncore.Document
Server Server
Connection Connection
ConnectionDescription description.Server
CurrentIndex int
}
// Operation is used to execute an operation. It contains all of the common code required to
// select a server, transform an operation into a command, write the command to a connection from
// the selected server, read a response from that connection, process the response, and potentially
// retry.
//
// The required fields are Database, CommandFn, and Deployment. All other fields are optional.
//
// While an Operation can be constructed manually, drivergen should be used to generate an
// implementation of an operation instead. This will ensure that there are helpers for constructing
// the operation and that this type isn't configured incorrectly.
type Operation struct {
// CommandFn is used to create the command that will be wrapped in a wire message and sent to
// the server. This function should only add the elements of the command and not start or end
// the enclosing BSON document. Per the command API, the first element must be the name of the
// command to run. This field is required.
CommandFn func(dst []byte, desc description.SelectedServer) ([]byte, error)
// Database is the database that the command will be run against. This field is required.
Database string
// Deployment is the MongoDB Deployment to use. While most of the time this will be multiple
// servers, commands that need to run against a single, preselected server can use the
// SingleServerDeployment type. Commands that need to run on a preselected connection can use
// the SingleConnectionDeployment type.
Deployment Deployment
// ProcessResponseFn is called after a response to the command is returned. The server is
// provided for types like Cursor that are required to run subsequent commands using the same
// server.
ProcessResponseFn func(ResponseInfo) error
// Selector is the server selector that's used during both initial server selection and
// subsequent selection for retries. Depending on the Deployment implementation, the
// SelectServer method may not actually be called.
Selector description.ServerSelector
// ReadPreference is the read preference that will be attached to the command. If this field is
// not specified a default read preference of primary will be used.
ReadPreference *readpref.ReadPref
// ReadConcern is the read concern used when running read commands. This field should not be set
// for write operations. If this field is set, it will be encoded onto the commands sent to the
// server.
ReadConcern *readconcern.ReadConcern
// MinimumReadConcernWireVersion specifies the minimum wire version to add the read concern to
// the command being executed.
MinimumReadConcernWireVersion int32
// WriteConcern is the write concern used when running write commands. This field should not be
// set for read operations. If this field is set, it will be encoded onto the commands sent to
// the server.
WriteConcern *writeconcern.WriteConcern
// MinimumWriteConcernWireVersion specifies the minimum wire version to add the write concern to
// the command being executed.
MinimumWriteConcernWireVersion int32
// Client is the session used with this operation. This can be either an implicit or explicit
// session. If the server selected does not support sessions and Client is specified the
// behavior depends on the session type. If the session is implicit, the session fields will not
// be encoded onto the command. If the session is explicit, an error will be returned. The
// caller is responsible for ensuring that this field is nil if the Deployment does not support
// sessions.
Client *session.Client
// Clock is a cluster clock, different from the one contained within a session.Client. This
// allows updating cluster times for a global cluster clock while allowing individual session's
// cluster clocks to be only updated as far as the last command that's been run.
Clock *session.ClusterClock
// RetryMode specifies how to retry. There are three modes that enable retry: RetryOnce,
// RetryOncePerCommand, and RetryContext. For more information about what these modes do, please
// refer to their definitions. Both RetryMode and Type must be set for retryability to be enabled.
RetryMode *RetryMode
// Type specifies the kind of operation this is. There is only one mode that enables retry: Write.
// For more information about what this mode does, please refer to it's definition. Both Type and
// RetryMode must be set for retryability to be enabled.
Type Type
// Batches contains the documents that are split when executing a write command that potentially
// has more documents than can fit in a single command. This should only be specified for
// commands that are batch compatible. For more information, please refer to the definition of
// Batches.
Batches *Batches
// Legacy sets the legacy type for this operation. There are only 3 types that require legacy
// support: find, getMore, and killCursors. For more information about LegacyOperationKind,
// please refer to it's definition.
Legacy LegacyOperationKind
// CommandMonitor specifies the monitor to use for APM events. If this field is not set,
// no events will be reported.
CommandMonitor *event.CommandMonitor
// Crypt specifies a Crypt object to use for automatic client side encryption and decryption.
Crypt Crypt
// ServerAPI specifies options used to configure the API version sent to the server.
ServerAPI *ServerAPIOptions
// IsOutputAggregate specifies whether this operation is an aggregate with an output stage. If true,
// read preference will not be added to the command on wire versions < 13.
IsOutputAggregate bool
// Timeout is the amount of time that this operation can execute before returning an error. The default value
// nil, which means that the timeout of the operation's caller will be used.
Timeout *time.Duration
// cmdName is only set when serializing OP_MSG and is used internally in readWireMessage.
cmdName string
}
// shouldEncrypt returns true if this operation should automatically be encrypted.
func (op Operation) shouldEncrypt() bool {
return op.Crypt != nil && !op.Crypt.BypassAutoEncryption()
}
// selectServer handles performing server selection for an operation.
func (op Operation) selectServer(ctx context.Context) (Server, error) {
if err := op.Validate(); err != nil {
return nil, err
}
selector := op.Selector
if selector == nil {
rp := op.ReadPreference
if rp == nil {
rp = readpref.Primary()
}
selector = description.CompositeSelector([]description.ServerSelector{
description.ReadPrefSelector(rp),
description.LatencySelector(defaultLocalThreshold),
})
}
return op.Deployment.SelectServer(ctx, selector)
}
// getServerAndConnection should be used to retrieve a Server and Connection to execute an operation.
func (op Operation) getServerAndConnection(ctx context.Context) (Server, Connection, error) {
server, err := op.selectServer(ctx)
if err != nil {
return nil, nil, err
}
// If the provided client session has a pinned connection, it should be used for the operation because this
// indicates that we're in a transaction and the target server is behind a load balancer.
if op.Client != nil && op.Client.PinnedConnection != nil {
return server, op.Client.PinnedConnection, nil
}
// Otherwise, default to checking out a connection from the server's pool.
conn, err := server.Connection(ctx)
if err != nil {
return nil, nil, err
}
// If we're in load balanced mode and this is the first operation in a transaction, pin the session to a connection.
if conn.Description().LoadBalanced() && op.Client != nil && op.Client.TransactionStarting() {
pinnedConn, ok := conn.(PinnedConnection)
if !ok {
// Close the original connection to avoid a leak.
_ = conn.Close()
return nil, nil, fmt.Errorf("expected Connection used to start a transaction to be a PinnedConnection, but got %T", conn)
}
if err := pinnedConn.PinToTransaction(); err != nil {
// Close the original connection to avoid a leak.
_ = conn.Close()
return nil, nil, fmt.Errorf("error incrementing connection reference count when starting a transaction: %v", err)
}
op.Client.PinnedConnection = pinnedConn
}
return server, conn, nil
}
// Validate validates this operation, ensuring the fields are set properly.
func (op Operation) Validate() error {
if op.CommandFn == nil {
return InvalidOperationError{MissingField: "CommandFn"}
}
if op.Deployment == nil {
return InvalidOperationError{MissingField: "Deployment"}
}
if op.Database == "" {
return InvalidOperationError{MissingField: "Database"}
}
if op.Client != nil && !writeconcern.AckWrite(op.WriteConcern) {
return errors.New("session provided for an unacknowledged write")
}
return nil
}
// Execute runs this operation. The scratch parameter will be used and overwritten (potentially many
// times), this should mainly be used to enable pooling of byte slices.
func (op Operation) Execute(ctx context.Context, scratch []byte) error {
err := op.Validate()
if err != nil {
return err
}
// If no deadline is set on the passed-in context, op.Timeout is set, and context is not already
// a Timeout context, honor op.Timeout in new Timeout context for operation execution.
if _, deadlineSet := ctx.Deadline(); !deadlineSet && op.Timeout != nil && !internal.IsTimeoutContext(ctx) {
newCtx, cancelFunc := internal.MakeTimeoutContext(ctx, *op.Timeout)
// Redefine ctx to be the new timeout-derived context.
ctx = newCtx
// Cancel the timeout-derived context at the end of Execute to avoid a context leak.
defer cancelFunc()
}
if op.Client != nil {
if err := op.Client.StartCommand(); err != nil {
return err
}
}
var retries int
if op.RetryMode != nil {
switch op.Type {
case Write:
if op.Client == nil {
break
}
switch *op.RetryMode {
case RetryOnce, RetryOncePerCommand:
retries = 1
case RetryContext:
retries = -1
}
case Read:
switch *op.RetryMode {
case RetryOnce, RetryOncePerCommand:
retries = 1
case RetryContext:
retries = -1
}
}
}
var srvr Server
var conn Connection
var res bsoncore.Document
var operationErr WriteCommandError
var prevErr error
batching := op.Batches.Valid()
retryEnabled := op.RetryMode != nil && op.RetryMode.Enabled()
retrySupported := false
first := true
currIndex := 0
// resetForRetry records the error that caused the retry, decrements retries, and resets the
// retry loop variables to request a new server and a new connection for the next attempt.
resetForRetry := func(err error) {
retries--
prevErr = err
// If we got a connection, close it immediately to release pool resources for
// subsequent retries.
if conn != nil {
conn.Close()
}
// Set the server and connection to nil to request a new server and connection.
srvr = nil
conn = nil
}
for {
// If the server or connection are nil, try to select a new server and get a new connection.
if srvr == nil || conn == nil {
srvr, conn, err = op.getServerAndConnection(ctx)
if err != nil {
// If the returned error is retryable and there are retries remaining (negative
// retries means retry indefinitely), then retry the operation. Set the server
// and connection to nil to request a new server and connection.
if rerr, ok := err.(RetryablePoolError); ok && rerr.Retryable() && retries != 0 {
resetForRetry(err)
continue
}
// If this is a retry and there's an error from a previous attempt, return the previous
// error instead of the current connection error.
if prevErr != nil {
return prevErr
}
return err
}
defer conn.Close()
}
// Run steps that must only be run on the first attempt, but not again for retries.
if first {
// Determine if retries are supported for the current operation on the current server
// description. Per the retryable writes specification, only determine this for the
// first server selected:
//
// If the server selected for the first attempt of a retryable write operation does
// not support retryable writes, drivers MUST execute the write as if retryable writes
// were not enabled.
retrySupported = op.retryable(conn.Description())
// If retries are supported for the current operation on the current server description,
// client retries are enabled, the operation type is write, and we haven't incremented
// the txn number yet, enable retry writes on the session and increment the txn number.
// Calling IncrementTxnNumber() for server descriptions or topologies that do not
// support retries (e.g. standalone topologies) will cause server errors. Only do this
// check for the first attempt to keep retried writes in the same transaction.
if retrySupported && op.RetryMode != nil && op.Type == Write && op.Client != nil {
op.Client.RetryWrite = false
if op.RetryMode.Enabled() {
op.Client.RetryWrite = true
if !op.Client.Committing && !op.Client.Aborting {
op.Client.IncrementTxnNumber()
}
}
}
first = false
}
desc := description.SelectedServer{Server: conn.Description(), Kind: op.Deployment.Kind()}
scratch = scratch[:0]
if desc.WireVersion == nil || desc.WireVersion.Max < 4 {
switch op.Legacy {
case LegacyFind:
return op.legacyFind(ctx, scratch, srvr, conn, desc)
case LegacyGetMore:
return op.legacyGetMore(ctx, scratch, srvr, conn, desc)
case LegacyKillCursors:
return op.legacyKillCursors(ctx, scratch, srvr, conn, desc)
}
}
if desc.WireVersion == nil || desc.WireVersion.Max < 3 {
switch op.Legacy {
case LegacyListCollections:
return op.legacyListCollections(ctx, scratch, srvr, conn, desc)
case LegacyListIndexes:
return op.legacyListIndexes(ctx, scratch, srvr, conn, desc)
}
}
if batching {
targetBatchSize := desc.MaxDocumentSize
maxDocSize := desc.MaxDocumentSize
if op.shouldEncrypt() {
// For client-side encryption, we want the batch to be split at 2 MiB instead of 16MiB.
// If there's only one document in the batch, it can be up to 16MiB, so we set target batch size to
// 2MiB but max document size to 16MiB. This will allow the AdvanceBatch call to create a batch
// with a single large document.
targetBatchSize = cryptMaxBsonObjectSize
}
err = op.Batches.AdvanceBatch(int(desc.MaxBatchCount), int(targetBatchSize), int(maxDocSize))
if err != nil {
// TODO(GODRIVER-982): Should we also be returning operationErr?
return err
}
}
// Calculate value of 'maxTimeMS' field to potentially append to the wire message based on the current
// context's deadline and the 90th percentile RTT if the ctx is a Timeout Context.
var maxTimeMS uint64
if internal.IsTimeoutContext(ctx) {
if deadline, ok := ctx.Deadline(); ok {
remainingTimeout := time.Until(deadline)
maxTimeMSVal := int64(remainingTimeout/time.Millisecond) -
int64(srvr.RTT90()/time.Millisecond)
// A maxTimeMS value <= 0 indicates that we are already at or past the Context's deadline.
if maxTimeMSVal <= 0 {
return internal.WrapErrorf(ErrDeadlineWouldBeExceeded,
"Context deadline has already been surpassed by %v", remainingTimeout)
}
maxTimeMS = uint64(maxTimeMSVal)
}
}
// convert to wire message
if len(scratch) > 0 {
scratch = scratch[:0]
}
wm, startedInfo, err := op.createWireMessage(ctx, scratch, desc, maxTimeMS, conn)
if err != nil {
return err
}
// set extra data and send event if possible
startedInfo.connID = conn.ID()
startedInfo.cmdName = op.getCommandName(startedInfo.cmd)
op.cmdName = startedInfo.cmdName
startedInfo.redacted = op.redactCommand(startedInfo.cmdName, startedInfo.cmd)
startedInfo.serviceID = conn.Description().ServiceID
startedInfo.serverConnID = conn.ServerConnectionID()
op.publishStartedEvent(ctx, startedInfo)
// get the moreToCome flag information before we compress
moreToCome := wiremessage.IsMsgMoreToCome(wm)
// compress wiremessage if allowed
if compressor, ok := conn.(Compressor); ok && op.canCompress(startedInfo.cmdName) {
wm, err = compressor.CompressWireMessage(wm, nil)
if err != nil {
return err
}
}
finishedInfo := finishedInformation{
cmdName: startedInfo.cmdName,
requestID: startedInfo.requestID,
startTime: time.Now(),
connID: startedInfo.connID,
serverConnID: startedInfo.serverConnID,
redacted: startedInfo.redacted,
serviceID: startedInfo.serviceID,
}
// Check if there's enough time to perform a round trip before the Context deadline. If ctx is
// a Timeout Context, use the 90th percentile RTT as a threshold. Otherwise, use the minimum observed
// RTT.
if deadline, ok := ctx.Deadline(); ok {
if internal.IsTimeoutContext(ctx) && time.Now().Add(srvr.RTT90()).After(deadline) {
err = internal.WrapErrorf(ErrDeadlineWouldBeExceeded,
"Remaining timeout %v applied from Timeout is less than 90th percentile RTT", time.Until(deadline))
} else if time.Now().Add(srvr.MinRTT()).After(deadline) {
err = op.networkError(context.DeadlineExceeded)
}
}
if err == nil {
// roundtrip using either the full roundTripper or a special one for when the moreToCome
// flag is set
var roundTrip = op.roundTrip
if moreToCome {
roundTrip = op.moreToComeRoundTrip
}
res, err = roundTrip(ctx, conn, wm)
if ep, ok := srvr.(ErrorProcessor); ok {
_ = ep.ProcessError(err, conn)
}
}
finishedInfo.response = res
finishedInfo.cmdErr = err
op.publishFinishedEvent(ctx, finishedInfo)
var perr error
switch tt := err.(type) {
case WriteCommandError:
if e := err.(WriteCommandError); retrySupported && op.Type == Write && e.UnsupportedStorageEngine() {
return ErrUnsupportedStorageEngine
}
connDesc := conn.Description()
retryableErr := tt.Retryable(connDesc.WireVersion)
preRetryWriteLabelVersion := connDesc.WireVersion != nil && connDesc.WireVersion.Max < 9
inTransaction := op.Client != nil &&
!(op.Client.Committing || op.Client.Aborting) && op.Client.TransactionRunning()
// If retry is enabled and the operation isn't in a transaction, add a RetryableWriteError label for
// retryable errors from pre-4.4 servers
if retryableErr && preRetryWriteLabelVersion && retryEnabled && !inTransaction {
tt.Labels = append(tt.Labels, RetryableWriteError)
}
// If retries are supported for the current operation on the first server description,
// the error is considered retryable, and there are retries remaining (negative retries
// means retry indefinitely), then retry the operation.
if retrySupported && retryableErr && retries != 0 {
if op.Client != nil && op.Client.Committing {
// Apply majority write concern for retries
op.Client.UpdateCommitTransactionWriteConcern()
op.WriteConcern = op.Client.CurrentWc
}
resetForRetry(tt)
continue
}
// If the operation isn't being retried, process the response
if op.ProcessResponseFn != nil {
info := ResponseInfo{
ServerResponse: res,
Server: srvr,
Connection: conn,
ConnectionDescription: desc.Server,
CurrentIndex: currIndex,
}
_ = op.ProcessResponseFn(info)
}
if batching && len(tt.WriteErrors) > 0 && currIndex > 0 {
for i := range tt.WriteErrors {
tt.WriteErrors[i].Index += int64(currIndex)
}
}
// If batching is enabled and either ordered is the default (which is true) or
// explicitly set to true and we have write errors, return the errors.
if batching && (op.Batches.Ordered == nil || *op.Batches.Ordered) && len(tt.WriteErrors) > 0 {
return tt
}
if op.Client != nil && op.Client.Committing && tt.WriteConcernError != nil {
// When running commitTransaction we return WriteConcernErrors as an Error.
err := Error{
Name: tt.WriteConcernError.Name,
Code: int32(tt.WriteConcernError.Code),
Message: tt.WriteConcernError.Message,
Labels: tt.Labels,
Raw: tt.Raw,
}
// The UnknownTransactionCommitResult label is added to all writeConcernErrors besides unknownReplWriteConcernCode
// and unsatisfiableWriteConcernCode
if err.Code != unknownReplWriteConcernCode && err.Code != unsatisfiableWriteConcernCode {
err.Labels = append(err.Labels, UnknownTransactionCommitResult)
}
if retryableErr && retryEnabled {
err.Labels = append(err.Labels, RetryableWriteError)
}
return err
}
operationErr.WriteConcernError = tt.WriteConcernError
operationErr.WriteErrors = append(operationErr.WriteErrors, tt.WriteErrors...)
operationErr.Labels = tt.Labels
operationErr.Raw = tt.Raw
case Error:
if tt.HasErrorLabel(TransientTransactionError) || tt.HasErrorLabel(UnknownTransactionCommitResult) {
if err := op.Client.ClearPinnedResources(); err != nil {
return err
}
}
if e := err.(Error); retrySupported && op.Type == Write && e.UnsupportedStorageEngine() {
return ErrUnsupportedStorageEngine
}
connDesc := conn.Description()
var retryableErr bool
if op.Type == Write {
retryableErr = tt.RetryableWrite(connDesc.WireVersion)
preRetryWriteLabelVersion := connDesc.WireVersion != nil && connDesc.WireVersion.Max < 9
inTransaction := op.Client != nil &&
!(op.Client.Committing || op.Client.Aborting) && op.Client.TransactionRunning()
// If retryWrites is enabled and the operation isn't in a transaction, add a RetryableWriteError label
// for network errors and retryable errors from pre-4.4 servers
if retryEnabled && !inTransaction &&
(tt.HasErrorLabel(NetworkError) || (retryableErr && preRetryWriteLabelVersion)) {
tt.Labels = append(tt.Labels, RetryableWriteError)
}
} else {
retryableErr = tt.RetryableRead()
}
// If retries are supported for the current operation on the first server description,
// the error is considered retryable, and there are retries remaining (negative retries
// means retry indefinitely), then retry the operation.
if retrySupported && retryableErr && retries != 0 {
if op.Client != nil && op.Client.Committing {
// Apply majority write concern for retries
op.Client.UpdateCommitTransactionWriteConcern()
op.WriteConcern = op.Client.CurrentWc
}
resetForRetry(tt)
continue
}
// If the operation isn't being retried, process the response
if op.ProcessResponseFn != nil {
info := ResponseInfo{
ServerResponse: res,
Server: srvr,
Connection: conn,
ConnectionDescription: desc.Server,
CurrentIndex: currIndex,
}
_ = op.ProcessResponseFn(info)
}
if op.Client != nil && op.Client.Committing && (retryableErr || tt.Code == 50) {
// If we got a retryable error or MaxTimeMSExpired error, we add UnknownTransactionCommitResult.
tt.Labels = append(tt.Labels, UnknownTransactionCommitResult)
}
return tt
case nil:
if moreToCome {
return ErrUnacknowledgedWrite
}
if op.ProcessResponseFn != nil {
info := ResponseInfo{
ServerResponse: res,
Server: srvr,
Connection: conn,
ConnectionDescription: desc.Server,
CurrentIndex: currIndex,
}
perr = op.ProcessResponseFn(info)
}
if perr != nil {
return perr
}
default:
if op.ProcessResponseFn != nil {
info := ResponseInfo{
ServerResponse: res,
Server: srvr,
Connection: conn,
ConnectionDescription: desc.Server,
CurrentIndex: currIndex,
}
_ = op.ProcessResponseFn(info)
}
return err
}
// If we're batching and there are batches remaining, advance to the next batch. This isn't
// a retry, so increment the transaction number, reset the retries number, and don't set
// server or connection to nil to continue using the same connection.
if batching && len(op.Batches.Documents) > 0 {
if retrySupported && op.Client != nil && op.RetryMode != nil {
if *op.RetryMode > RetryNone {
op.Client.IncrementTxnNumber()
}
if *op.RetryMode == RetryOncePerCommand {
retries = 1
}
}
currIndex += len(op.Batches.Current)
op.Batches.ClearBatch()
continue
}
break
}
if len(operationErr.WriteErrors) > 0 || operationErr.WriteConcernError != nil {
return operationErr
}
return nil
}
// Retryable writes are supported if the server supports sessions, the operation is not
// within a transaction, and the write is acknowledged
func (op Operation) retryable(desc description.Server) bool {
switch op.Type {
case Write:
if op.Client != nil && (op.Client.Committing || op.Client.Aborting) {
return true
}
if retryWritesSupported(desc) &&
desc.WireVersion != nil && desc.WireVersion.Max >= 6 &&
op.Client != nil && !(op.Client.TransactionInProgress() || op.Client.TransactionStarting()) &&
writeconcern.AckWrite(op.WriteConcern) {
return true
}
case Read:
if op.Client != nil && (op.Client.Committing || op.Client.Aborting) {
return true
}
if desc.WireVersion != nil && desc.WireVersion.Max >= 6 &&
(op.Client == nil || !(op.Client.TransactionInProgress() || op.Client.TransactionStarting())) {
return true
}
}
return false
}
// roundTrip writes a wiremessage to the connection and then reads a wiremessage. The wm parameter
// is reused when reading the wiremessage.
func (op Operation) roundTrip(ctx context.Context, conn Connection, wm []byte) ([]byte, error) {
err := conn.WriteWireMessage(ctx, wm)
if err != nil {
return nil, op.networkError(err)
}
return op.readWireMessage(ctx, conn, wm)
}
func (op Operation) readWireMessage(ctx context.Context, conn Connection, wm []byte) ([]byte, error) {
var err error
wm, err = conn.ReadWireMessage(ctx, wm[:0])
if err != nil {
return nil, op.networkError(err)
}
// If we're using a streamable connection, we set its streaming state based on the moreToCome flag in the server
// response.
if streamer, ok := conn.(StreamerConnection); ok {
streamer.SetStreaming(wiremessage.IsMsgMoreToCome(wm))
}
// decompress wiremessage
wm, err = op.decompressWireMessage(wm)
if err != nil {
return nil, err
}
// decode
res, err := op.decodeResult(wm)
// Update cluster/operation time and recovery tokens before handling the error to ensure we're properly updating
// everything.
op.updateClusterTimes(res)
op.updateOperationTime(res)
op.Client.UpdateRecoveryToken(bson.Raw(res))
// Update snapshot time if operation was a "find", "aggregate" or "distinct".
if op.cmdName == "find" || op.cmdName == "aggregate" || op.cmdName == "distinct" {
op.Client.UpdateSnapshotTime(res)
}
if err != nil {
return res, err
}
// If there is no error, automatically attempt to decrypt all results if client side encryption is enabled.
if op.Crypt != nil {
return op.Crypt.Decrypt(ctx, res)
}
return res, nil
}
// networkError wraps the provided error in an Error with label "NetworkError" and, if a transaction
// is running or committing, the appropriate transaction state labels. The returned error indicates
// the operation should be retried for reads and writes. If err is nil, networkError returns nil.
func (op Operation) networkError(err error) error {
if err == nil {
return nil
}
labels := []string{NetworkError}
if op.Client != nil {
op.Client.MarkDirty()
}
if op.Client != nil && op.Client.TransactionRunning() && !op.Client.Committing {
labels = append(labels, TransientTransactionError)
}
if op.Client != nil && op.Client.Committing {
labels = append(labels, UnknownTransactionCommitResult)
}
return Error{Message: err.Error(), Labels: labels, Wrapped: err}
}
// moreToComeRoundTrip writes a wiremessage to the provided connection. This is used when an OP_MSG is
// being sent with the moreToCome bit set.
func (op *Operation) moreToComeRoundTrip(ctx context.Context, conn Connection, wm []byte) ([]byte, error) {
err := conn.WriteWireMessage(ctx, wm)
if err != nil {
if op.Client != nil {
op.Client.MarkDirty()
}
err = Error{Message: err.Error(), Labels: []string{TransientTransactionError, NetworkError}, Wrapped: err}
}
return bsoncore.BuildDocument(nil, bsoncore.AppendInt32Element(nil, "ok", 1)), err
}
// decompressWireMessage handles decompressing a wiremessage. If the wiremessage
// is not compressed, this method will return the wiremessage.
func (Operation) decompressWireMessage(wm []byte) ([]byte, error) {
// read the header and ensure this is a compressed wire message
length, reqid, respto, opcode, rem, ok := wiremessage.ReadHeader(wm)
if !ok || len(wm) < int(length) {
return nil, errors.New("malformed wire message: insufficient bytes")
}
if opcode != wiremessage.OpCompressed {
return wm, nil
}
// get the original opcode and uncompressed size
opcode, rem, ok = wiremessage.ReadCompressedOriginalOpCode(rem)
if !ok {
return nil, errors.New("malformed OP_COMPRESSED: missing original opcode")
}
uncompressedSize, rem, ok := wiremessage.ReadCompressedUncompressedSize(rem)
if !ok {
return nil, errors.New("malformed OP_COMPRESSED: missing uncompressed size")
}
// get the compressor ID and decompress the message
compressorID, rem, ok := wiremessage.ReadCompressedCompressorID(rem)
if !ok {
return nil, errors.New("malformed OP_COMPRESSED: missing compressor ID")
}
compressedSize := length - 25 // header (16) + original opcode (4) + uncompressed size (4) + compressor ID (1)
// return the original wiremessage
msg, rem, ok := wiremessage.ReadCompressedCompressedMessage(rem, compressedSize)
if !ok {
return nil, errors.New("malformed OP_COMPRESSED: insufficient bytes for compressed wiremessage")
}
header := make([]byte, 0, uncompressedSize+16)
header = wiremessage.AppendHeader(header, uncompressedSize+16, reqid, respto, opcode)
opts := CompressionOpts{
Compressor: compressorID,
UncompressedSize: uncompressedSize,
}
uncompressed, err := DecompressPayload(msg, opts)
if err != nil {
return nil, err
}
return append(header, uncompressed...), nil
}
func (op Operation) createWireMessage(
ctx context.Context,
dst []byte,
desc description.SelectedServer,
maxTimeMS uint64,
conn Connection) ([]byte, startedInformation, error) {
// If topology is not LoadBalanced, API version is not declared, and wire version is unknown
// or less than 6, use OP_QUERY. Otherwise, use OP_MSG.
if desc.Kind != description.LoadBalanced && op.ServerAPI == nil &&
(desc.WireVersion == nil || desc.WireVersion.Max < wiremessage.OpmsgWireVersion) {
return op.createQueryWireMessage(maxTimeMS, dst, desc)
}
return op.createMsgWireMessage(ctx, maxTimeMS, dst, desc, conn)
}
func (op Operation) addBatchArray(dst []byte) []byte {
aidx, dst := bsoncore.AppendArrayElementStart(dst, op.Batches.Identifier)
for i, doc := range op.Batches.Current {
dst = bsoncore.AppendDocumentElement(dst, strconv.Itoa(i), doc)
}
dst, _ = bsoncore.AppendArrayEnd(dst, aidx)
return dst
}
func (op Operation) createQueryWireMessage(maxTimeMS uint64, dst []byte, desc description.SelectedServer) ([]byte, startedInformation, error) {
var info startedInformation
flags := op.secondaryOK(desc)
var wmindex int32
info.requestID = wiremessage.NextRequestID()
wmindex, dst = wiremessage.AppendHeaderStart(dst, info.requestID, 0, wiremessage.OpQuery)
dst = wiremessage.AppendQueryFlags(dst, flags)
// FullCollectionName
dst = append(dst, op.Database...)
dst = append(dst, dollarCmd[:]...)
dst = append(dst, 0x00)
dst = wiremessage.AppendQueryNumberToSkip(dst, 0)
dst = wiremessage.AppendQueryNumberToReturn(dst, -1)
wrapper := int32(-1)
rp, err := op.createReadPref(desc, true)
if err != nil {
return dst, info, err
}
if len(rp) > 0 {
wrapper, dst = bsoncore.AppendDocumentStart(dst)
dst = bsoncore.AppendHeader(dst, bsontype.EmbeddedDocument, "$query")
}
idx, dst := bsoncore.AppendDocumentStart(dst)
dst, err = op.CommandFn(dst, desc)
if err != nil {
return dst, info, err
}
if op.Batches != nil && len(op.Batches.Current) > 0 {
dst = op.addBatchArray(dst)
}
dst, err = op.addReadConcern(dst, desc)
if err != nil {
return dst, info, err
}
dst, err = op.addWriteConcern(dst, desc)
if err != nil {
return dst, info, err
}
dst, err = op.addSession(dst, desc)
if err != nil {
return dst, info, err
}
dst = op.addClusterTime(dst, desc)
dst = op.addServerAPI(dst)
// If maxTimeMS is greater than 0 append it to wire message. A maxTimeMS value of 0 only explicitly
// specifies the default behavior of no timeout server-side.
if maxTimeMS > 0 {
dst = bsoncore.AppendInt64Element(dst, "maxTimeMS", int64(maxTimeMS))
}
dst, _ = bsoncore.AppendDocumentEnd(dst, idx)
// Command monitoring only reports the document inside $query
info.cmd = dst[idx:]
if len(rp) > 0 {
var err error
dst = bsoncore.AppendDocumentElement(dst, "$readPreference", rp)
dst, err = bsoncore.AppendDocumentEnd(dst, wrapper)
if err != nil {
return dst, info, err
}
}
return bsoncore.UpdateLength(dst, wmindex, int32(len(dst[wmindex:]))), info, nil
}
func (op Operation) createMsgWireMessage(ctx context.Context, maxTimeMS uint64, dst []byte, desc description.SelectedServer,
conn Connection) ([]byte, startedInformation, error) {
var info startedInformation
var flags wiremessage.MsgFlag
var wmindex int32
// We set the MoreToCome bit if we have a write concern, it's unacknowledged, and we either
// aren't batching or we are encoding the last batch.
if op.WriteConcern != nil && !writeconcern.AckWrite(op.WriteConcern) && (op.Batches == nil || len(op.Batches.Documents) == 0) {
flags = wiremessage.MoreToCome
}
// Set the ExhaustAllowed flag if the connection supports streaming. This will tell the server that it can
// respond with the MoreToCome flag and then stream responses over this connection.
if streamer, ok := conn.(StreamerConnection); ok && streamer.SupportsStreaming() {
flags |= wiremessage.ExhaustAllowed
}
info.requestID = wiremessage.NextRequestID()
wmindex, dst = wiremessage.AppendHeaderStart(dst, info.requestID, 0, wiremessage.OpMsg)
dst = wiremessage.AppendMsgFlags(dst, flags)
// Body
dst = wiremessage.AppendMsgSectionType(dst, wiremessage.SingleDocument)
idx, dst := bsoncore.AppendDocumentStart(dst)
dst, err := op.addCommandFields(ctx, dst, desc)
if err != nil {
return dst, info, err
}
dst, err = op.addReadConcern(dst, desc)
if err != nil {
return dst, info, err
}
dst, err = op.addWriteConcern(dst, desc)
if err != nil {
return dst, info, err
}
dst, err = op.addSession(dst, desc)
if err != nil {
return dst, info, err
}
dst = op.addClusterTime(dst, desc)
dst = op.addServerAPI(dst)
// If maxTimeMS is greater than 0 append it to wire message. A maxTimeMS value of 0 only explicitly
// specifies the default behavior of no timeout server-side.
if maxTimeMS > 0 {
dst = bsoncore.AppendInt64Element(dst, "maxTimeMS", int64(maxTimeMS))
}
dst = bsoncore.AppendStringElement(dst, "$db", op.Database)
rp, err := op.createReadPref(desc, false)
if err != nil {
return dst, info, err
}
if len(rp) > 0 {
dst = bsoncore.AppendDocumentElement(dst, "$readPreference", rp)
}
dst, _ = bsoncore.AppendDocumentEnd(dst, idx)
// The command document for monitoring shouldn't include the type 1 payload as a document sequence
info.cmd = dst[idx:]
// add batch as a document sequence if auto encryption is not enabled
// if auto encryption is enabled, the batch will already be an array in the command document
if !op.shouldEncrypt() && op.Batches != nil && len(op.Batches.Current) > 0 {
info.documentSequenceIncluded = true
dst = wiremessage.AppendMsgSectionType(dst, wiremessage.DocumentSequence)
idx, dst = bsoncore.ReserveLength(dst)
dst = append(dst, op.Batches.Identifier...)
dst = append(dst, 0x00)
for _, doc := range op.Batches.Current {
dst = append(dst, doc...)
}
dst = bsoncore.UpdateLength(dst, idx, int32(len(dst[idx:])))
}
return bsoncore.UpdateLength(dst, wmindex, int32(len(dst[wmindex:]))), info, nil
}
// addCommandFields adds the fields for a command to the wire message in dst. This assumes that the start of the document
// has already been added and does not add the final 0 byte.
func (op Operation) addCommandFields(ctx context.Context, dst []byte, desc description.SelectedServer) ([]byte, error) {
if !op.shouldEncrypt() {
return op.CommandFn(dst, desc)
}
if desc.WireVersion.Max < cryptMinWireVersion {
return dst, errors.New("auto-encryption requires a MongoDB version of 4.2")
}
// create temporary command document
cidx, cmdDst := bsoncore.AppendDocumentStart(nil)
var err error
cmdDst, err = op.CommandFn(cmdDst, desc)
if err != nil {
return dst, err
}
// use a BSON array instead of a type 1 payload because mongocryptd will convert to arrays regardless
if op.Batches != nil && len(op.Batches.Current) > 0 {
cmdDst = op.addBatchArray(cmdDst)
}
cmdDst, _ = bsoncore.AppendDocumentEnd(cmdDst, cidx)
// encrypt the command
encrypted, err := op.Crypt.Encrypt(ctx, op.Database, cmdDst)
if err != nil {
return dst, err
}
// append encrypted command to original destination, removing the first 4 bytes (length) and final byte (terminator)
dst = append(dst, encrypted[4:len(encrypted)-1]...)
return dst, nil
}
// addServerAPI adds the relevant fields for server API specification to the wire message in dst.
func (op Operation) addServerAPI(dst []byte) []byte {
sa := op.ServerAPI
if sa == nil {
return dst
}
dst = bsoncore.AppendStringElement(dst, "apiVersion", sa.ServerAPIVersion)
if sa.Strict != nil {
dst = bsoncore.AppendBooleanElement(dst, "apiStrict", *sa.Strict)
}
if sa.DeprecationErrors != nil {
dst = bsoncore.AppendBooleanElement(dst, "apiDeprecationErrors", *sa.DeprecationErrors)
}
return dst
}
func (op Operation) addReadConcern(dst []byte, desc description.SelectedServer) ([]byte, error) {
if op.MinimumReadConcernWireVersion > 0 && (desc.WireVersion == nil || !desc.WireVersion.Includes(op.MinimumReadConcernWireVersion)) {
return dst, nil
}
rc := op.ReadConcern
client := op.Client
// Starting transaction's read concern overrides all others
if client != nil && client.TransactionStarting() && client.CurrentRc != nil {
rc = client.CurrentRc
}
// start transaction must append afterclustertime IF causally consistent and operation time exists
if rc == nil && client != nil && client.TransactionStarting() && client.Consistent && client.OperationTime != nil {
rc = readconcern.New()
}
if client != nil && client.Snapshot {
if desc.WireVersion.Max < readSnapshotMinWireVersion {
return dst, errors.New("snapshot reads require MongoDB 5.0 or later")
}
rc = readconcern.Snapshot()
}
if rc == nil {
return dst, nil
}
_, data, err := rc.MarshalBSONValue() // always returns a document
if err != nil {
return dst, err
}
if sessionsSupported(desc.WireVersion) && client != nil {
if client.Consistent && client.OperationTime != nil {
data = data[:len(data)-1] // remove the null byte
data = bsoncore.AppendTimestampElement(data, "afterClusterTime", client.OperationTime.T, client.OperationTime.I)
data, _ = bsoncore.AppendDocumentEnd(data, 0)
}
if client.Snapshot && client.SnapshotTime != nil {
data = data[:len(data)-1] // remove the null byte
data = bsoncore.AppendTimestampElement(data, "atClusterTime", client.SnapshotTime.T, client.SnapshotTime.I)
data, _ = bsoncore.AppendDocumentEnd(data, 0)
}
}
if len(data) == bsoncore.EmptyDocumentLength {
return dst, nil
}
return bsoncore.AppendDocumentElement(dst, "readConcern", data), nil
}
func (op Operation) addWriteConcern(dst []byte, desc description.SelectedServer) ([]byte, error) {
if op.MinimumWriteConcernWireVersion > 0 && (desc.WireVersion == nil || !desc.WireVersion.Includes(op.MinimumWriteConcernWireVersion)) {
return dst, nil
}
wc := op.WriteConcern
if wc == nil {
return dst, nil
}
t, data, err := wc.MarshalBSONValue()
if err == writeconcern.ErrEmptyWriteConcern {
return dst, nil
}
if err != nil {
return dst, err
}
return append(bsoncore.AppendHeader(dst, t, "writeConcern"), data...), nil
}
func (op Operation) addSession(dst []byte, desc description.SelectedServer) ([]byte, error) {
client := op.Client
if client == nil || !sessionsSupported(desc.WireVersion) || desc.SessionTimeoutMinutes == 0 {
return dst, nil
}
if err := client.UpdateUseTime(); err != nil {
return dst, err
}
dst = bsoncore.AppendDocumentElement(dst, "lsid", client.SessionID)
var addedTxnNumber bool
if op.Type == Write && client.RetryWrite {
addedTxnNumber = true
dst = bsoncore.AppendInt64Element(dst, "txnNumber", op.Client.TxnNumber)
}
if client.TransactionRunning() || client.RetryingCommit {
if !addedTxnNumber {
dst = bsoncore.AppendInt64Element(dst, "txnNumber", op.Client.TxnNumber)
}
if client.TransactionStarting() {
dst = bsoncore.AppendBooleanElement(dst, "startTransaction", true)
}
dst = bsoncore.AppendBooleanElement(dst, "autocommit", false)
}
return dst, client.ApplyCommand(desc.Server)
}
func (op Operation) addClusterTime(dst []byte, desc description.SelectedServer) []byte {
client, clock := op.Client, op.Clock
if (clock == nil && client == nil) || !sessionsSupported(desc.WireVersion) {
return dst
}
clusterTime := clock.GetClusterTime()
if client != nil {
clusterTime = session.MaxClusterTime(clusterTime, client.ClusterTime)
}
if clusterTime == nil {
return dst
}
val, err := clusterTime.LookupErr("$clusterTime")
if err != nil {
return dst
}
return append(bsoncore.AppendHeader(dst, val.Type, "$clusterTime"), val.Value...)
// return bsoncore.AppendDocumentElement(dst, "$clusterTime", clusterTime)
}
// updateClusterTimes updates the cluster times for the session and cluster clock attached to this
// operation. While the session's AdvanceClusterTime may return an error, this method does not
// because an error being returned from this method will not be returned further up.
func (op Operation) updateClusterTimes(response bsoncore.Document) {
// Extract cluster time.
value, err := response.LookupErr("$clusterTime")
if err != nil {
// $clusterTime not included by the server
return
}
clusterTime := bsoncore.BuildDocumentFromElements(nil, bsoncore.AppendValueElement(nil, "$clusterTime", value))
sess, clock := op.Client, op.Clock
if sess != nil {
_ = sess.AdvanceClusterTime(bson.Raw(clusterTime))
}
if clock != nil {
clock.AdvanceClusterTime(bson.Raw(clusterTime))
}
}
// updateOperationTime updates the operation time on the session attached to this operation. While
// the session's AdvanceOperationTime method may return an error, this method does not because an
// error being returned from this method will not be returned further up.
func (op Operation) updateOperationTime(response bsoncore.Document) {
sess := op.Client
if sess == nil {
return
}
opTimeElem, err := response.LookupErr("operationTime")
if err != nil {
// operationTime not included by the server
return
}
t, i := opTimeElem.Timestamp()
_ = sess.AdvanceOperationTime(&primitive.Timestamp{
T: t,
I: i,
})
}
func (op Operation) getReadPrefBasedOnTransaction() (*readpref.ReadPref, error) {
if op.Client != nil && op.Client.TransactionRunning() {
// Transaction's read preference always takes priority
rp := op.Client.CurrentRp
// Reads in a transaction must have read preference primary
// This must not be checked in startTransaction
if rp != nil && !op.Client.TransactionStarting() && rp.Mode() != readpref.PrimaryMode {
return nil, ErrNonPrimaryReadPref
}
return rp, nil
}
return op.ReadPreference, nil
}
func (op Operation) createReadPref(desc description.SelectedServer, isOpQuery bool) (bsoncore.Document, error) {
// TODO(GODRIVER-2231): Instead of checking if isOutputAggregate and desc.Server.WireVersion.Max < 13, somehow check
// TODO if supplied readPreference was "overwritten" with primary in description.selectForReplicaSet.
if desc.Server.Kind == description.Standalone || (isOpQuery && desc.Server.Kind != description.Mongos) ||
op.Type == Write || (op.IsOutputAggregate && desc.Server.WireVersion.Max < 13) {
// Don't send read preference for:
// 1. all standalones
// 2. non-mongos when using OP_QUERY
// 3. all writes
// 4. when operation is an aggregate with an output stage, and selected server's wire
// version is < 13
return nil, nil
}
idx, doc := bsoncore.AppendDocumentStart(nil)
rp, err := op.getReadPrefBasedOnTransaction()
if err != nil {
return nil, err
}
if rp == nil {
if desc.Kind == description.Single && desc.Server.Kind != description.Mongos {
doc = bsoncore.AppendStringElement(doc, "mode", "primaryPreferred")
doc, _ = bsoncore.AppendDocumentEnd(doc, idx)
return doc, nil
}
return nil, nil
}
switch rp.Mode() {
case readpref.PrimaryMode:
if desc.Server.Kind == description.Mongos {
return nil, nil
}
if desc.Kind == description.Single {
doc = bsoncore.AppendStringElement(doc, "mode", "primaryPreferred")
doc, _ = bsoncore.AppendDocumentEnd(doc, idx)
return doc, nil
}
doc = bsoncore.AppendStringElement(doc, "mode", "primary")
case readpref.PrimaryPreferredMode:
doc = bsoncore.AppendStringElement(doc, "mode", "primaryPreferred")
case readpref.SecondaryPreferredMode:
_, ok := rp.MaxStaleness()
if desc.Server.Kind == description.Mongos && isOpQuery && !ok && len(rp.TagSets()) == 0 && rp.HedgeEnabled() == nil {
return nil, nil
}
doc = bsoncore.AppendStringElement(doc, "mode", "secondaryPreferred")
case readpref.SecondaryMode:
doc = bsoncore.AppendStringElement(doc, "mode", "secondary")
case readpref.NearestMode:
doc = bsoncore.AppendStringElement(doc, "mode", "nearest")
}
sets := make([]bsoncore.Document, 0, len(rp.TagSets()))
for _, ts := range rp.TagSets() {
i, set := bsoncore.AppendDocumentStart(nil)
for _, t := range ts {
set = bsoncore.AppendStringElement(set, t.Name, t.Value)
}
set, _ = bsoncore.AppendDocumentEnd(set, i)
sets = append(sets, set)
}
if len(sets) > 0 {
var aidx int32
aidx, doc = bsoncore.AppendArrayElementStart(doc, "tags")
for i, set := range sets {
doc = bsoncore.AppendDocumentElement(doc, strconv.Itoa(i), set)
}
doc, _ = bsoncore.AppendArrayEnd(doc, aidx)
}
if d, ok := rp.MaxStaleness(); ok {
doc = bsoncore.AppendInt32Element(doc, "maxStalenessSeconds", int32(d.Seconds()))
}
if hedgeEnabled := rp.HedgeEnabled(); hedgeEnabled != nil {
var hedgeIdx int32
hedgeIdx, doc = bsoncore.AppendDocumentElementStart(doc, "hedge")
doc = bsoncore.AppendBooleanElement(doc, "enabled", *hedgeEnabled)
doc, err = bsoncore.AppendDocumentEnd(doc, hedgeIdx)
if err != nil {
return nil, fmt.Errorf("error creating hedge document: %v", err)
}
}
doc, _ = bsoncore.AppendDocumentEnd(doc, idx)
return doc, nil
}
func (op Operation) secondaryOK(desc description.SelectedServer) wiremessage.QueryFlag {
if desc.Kind == description.Single && desc.Server.Kind != description.Mongos {
return wiremessage.SecondaryOK
}
if rp := op.ReadPreference; rp != nil && rp.Mode() != readpref.PrimaryMode {
return wiremessage.SecondaryOK
}
return 0
}
func (Operation) canCompress(cmd string) bool {
if cmd == internal.LegacyHello || cmd == "hello" || cmd == "saslStart" || cmd == "saslContinue" || cmd == "getnonce" || cmd == "authenticate" ||
cmd == "createUser" || cmd == "updateUser" || cmd == "copydbSaslStart" || cmd == "copydbgetnonce" || cmd == "copydb" {
return false
}
return true
}
// decodeOpReply extracts the necessary information from an OP_REPLY wire message.
// includesHeader: specifies whether or not wm includes the message header
// Returns the decoded OP_REPLY. If the err field of the returned opReply is non-nil, an error occurred while decoding
// or validating the response and the other fields are undefined.
func (Operation) decodeOpReply(wm []byte, includesHeader bool) opReply {
var reply opReply
var ok bool
if includesHeader {
wmLength := len(wm)
var length int32
var opcode wiremessage.OpCode
length, _, _, opcode, wm, ok = wiremessage.ReadHeader(wm)
if !ok || int(length) > wmLength {
reply.err = errors.New("malformed wire message: insufficient bytes")
return reply
}
if opcode != wiremessage.OpReply {
reply.err = errors.New("malformed wire message: incorrect opcode")
return reply
}
}
reply.responseFlags, wm, ok = wiremessage.ReadReplyFlags(wm)
if !ok {
reply.err = errors.New("malformed OP_REPLY: missing flags")
return reply
}
reply.cursorID, wm, ok = wiremessage.ReadReplyCursorID(wm)
if !ok {
reply.err = errors.New("malformed OP_REPLY: missing cursorID")
return reply
}
reply.startingFrom, wm, ok = wiremessage.ReadReplyStartingFrom(wm)
if !ok {
reply.err = errors.New("malformed OP_REPLY: missing startingFrom")
return reply
}
reply.numReturned, wm, ok = wiremessage.ReadReplyNumberReturned(wm)
if !ok {
reply.err = errors.New("malformed OP_REPLY: missing numberReturned")
return reply
}
reply.documents, wm, ok = wiremessage.ReadReplyDocuments(wm)
if !ok {
reply.err = errors.New("malformed OP_REPLY: could not read documents from reply")
}
if reply.responseFlags&wiremessage.QueryFailure == wiremessage.QueryFailure {
reply.err = QueryFailureError{
Message: "command failure",
Response: reply.documents[0],
}
return reply
}
if reply.responseFlags&wiremessage.CursorNotFound == wiremessage.CursorNotFound {
reply.err = ErrCursorNotFound
return reply
}
if reply.numReturned != int32(len(reply.documents)) {
reply.err = ErrReplyDocumentMismatch
return reply
}
return reply
}
func (op Operation) decodeResult(wm []byte) (bsoncore.Document, error) {
wmLength := len(wm)
length, _, _, opcode, wm, ok := wiremessage.ReadHeader(wm)
if !ok || int(length) > wmLength {
return nil, errors.New("malformed wire message: insufficient bytes")
}
wm = wm[:wmLength-16] // constrain to just this wiremessage, incase there are multiple in the slice
switch opcode {
case wiremessage.OpReply:
reply := op.decodeOpReply(wm, false)
if reply.err != nil {
return nil, reply.err
}
if reply.numReturned == 0 {
return nil, ErrNoDocCommandResponse
}
if reply.numReturned > 1 {
return nil, ErrMultiDocCommandResponse
}
rdr := reply.documents[0]
if err := rdr.Validate(); err != nil {
return nil, NewCommandResponseError("malformed OP_REPLY: invalid document", err)
}
return rdr, ExtractErrorFromServerResponse(rdr)
case wiremessage.OpMsg:
_, wm, ok = wiremessage.ReadMsgFlags(wm)
if !ok {
return nil, errors.New("malformed wire message: missing OP_MSG flags")
}
var res bsoncore.Document
for len(wm) > 0 {
var stype wiremessage.SectionType
stype, wm, ok = wiremessage.ReadMsgSectionType(wm)
if !ok {
return nil, errors.New("malformed wire message: insuffienct bytes to read section type")
}
switch stype {
case wiremessage.SingleDocument:
res, wm, ok = wiremessage.ReadMsgSectionSingleDocument(wm)
if !ok {
return nil, errors.New("malformed wire message: insufficient bytes to read single document")
}
case wiremessage.DocumentSequence:
// TODO(GODRIVER-617): Implement document sequence returns.
_, _, wm, ok = wiremessage.ReadMsgSectionDocumentSequence(wm)
if !ok {
return nil, errors.New("malformed wire message: insufficient bytes to read document sequence")
}
default:
return nil, fmt.Errorf("malformed wire message: uknown section type %v", stype)
}
}
err := res.Validate()
if err != nil {
return nil, NewCommandResponseError("malformed OP_MSG: invalid document", err)
}
return res, ExtractErrorFromServerResponse(res)
default:
return nil, fmt.Errorf("cannot decode result from %s", opcode)
}
}
// getCommandName returns the name of the command from the given BSON document.
func (op Operation) getCommandName(doc []byte) string {
// skip 4 bytes for document length and 1 byte for element type
idx := bytes.IndexByte(doc[5:], 0x00) // look for the 0 byte after the command name
return string(doc[5 : idx+5])
}
func (op *Operation) redactCommand(cmd string, doc bsoncore.Document) bool {
if cmd == "authenticate" || cmd == "saslStart" || cmd == "saslContinue" || cmd == "getnonce" || cmd == "createUser" ||
cmd == "updateUser" || cmd == "copydbgetnonce" || cmd == "copydbsaslstart" || cmd == "copydb" {
return true
}
if strings.ToLower(cmd) != internal.LegacyHelloLowercase && cmd != "hello" {
return false
}
// A hello without speculative authentication can be monitored.
_, err := doc.LookupErr("speculativeAuthenticate")
return err == nil
}
// publishStartedEvent publishes a CommandStartedEvent to the operation's command monitor if possible. If the command is
// an unacknowledged write, a CommandSucceededEvent will be published as well. If started events are not being monitored,
// no events are published.
func (op Operation) publishStartedEvent(ctx context.Context, info startedInformation) {
if op.CommandMonitor == nil || op.CommandMonitor.Started == nil {
return
}
// Make a copy of the command. Redact if the command is security sensitive and cannot be monitored.
// If there was a type 1 payload for the current batch, convert it to a BSON array.
cmdCopy := bson.Raw{}
if !info.redacted {
cmdCopy = make([]byte, len(info.cmd))
copy(cmdCopy, info.cmd)
if info.documentSequenceIncluded {
cmdCopy = cmdCopy[:len(info.cmd)-1] // remove 0 byte at end
cmdCopy = op.addBatchArray(cmdCopy)
cmdCopy, _ = bsoncore.AppendDocumentEnd(cmdCopy, 0) // add back 0 byte and update length
}
}
started := &event.CommandStartedEvent{
Command: cmdCopy,
DatabaseName: op.Database,
CommandName: info.cmdName,
RequestID: int64(info.requestID),
ConnectionID: info.connID,
ServerConnectionID: info.serverConnID,
ServiceID: info.serviceID,
}
op.CommandMonitor.Started(ctx, started)
}
// publishFinishedEvent publishes either a CommandSucceededEvent or a CommandFailedEvent to the operation's command
// monitor if possible. If success/failure events aren't being monitored, no events are published.
func (op Operation) publishFinishedEvent(ctx context.Context, info finishedInformation) {
success := info.cmdErr == nil
if _, ok := info.cmdErr.(WriteCommandError); ok {
success = true
}
if op.CommandMonitor == nil || (success && op.CommandMonitor.Succeeded == nil) || (!success && op.CommandMonitor.Failed == nil) {
return
}
var durationNanos int64
var emptyTime time.Time
if info.startTime != emptyTime {
durationNanos = time.Since(info.startTime).Nanoseconds()
}
finished := event.CommandFinishedEvent{
CommandName: info.cmdName,
RequestID: int64(info.requestID),
ConnectionID: info.connID,
DurationNanos: durationNanos,
ServerConnectionID: info.serverConnID,
ServiceID: info.serviceID,
}
if success {
res := bson.Raw{}
// Only copy the reply for commands that are not security sensitive
if !info.redacted {
res = make([]byte, len(info.response))
copy(res, info.response)
}
successEvent := &event.CommandSucceededEvent{
Reply: res,
CommandFinishedEvent: finished,
}
op.CommandMonitor.Succeeded(ctx, successEvent)
return
}
failedEvent := &event.CommandFailedEvent{
Failure: info.cmdErr.Error(),
CommandFinishedEvent: finished,
}
op.CommandMonitor.Failed(ctx, failedEvent)
}
// sessionsSupported returns true of the given server version indicates that it supports sessions.
func sessionsSupported(wireVersion *description.VersionRange) bool {
return wireVersion != nil && wireVersion.Max >= 6
}
// retryWritesSupported returns true if this description represents a server that supports retryable writes.
func retryWritesSupported(s description.Server) bool {
return s.SessionTimeoutMinutes != 0 && s.Kind != description.Standalone
}