Skip to content

RUST-1992 Update the driver for bson cstr API changes #1412

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

99 changes: 43 additions & 56 deletions src/bson_compat.rs
Original file line number Diff line number Diff line change
@@ -1,84 +1,71 @@
use crate::bson::RawBson;
#[cfg(feature = "bson-3")]
pub(crate) type CStr = crate::bson::raw::CStr;
#[cfg(feature = "bson-3")]
pub(crate) type CString = crate::bson::raw::CString;
#[cfg(feature = "bson-3")]
pub(crate) use crate::bson::raw::cstr;

pub(crate) trait RawDocumentBufExt: Sized {
fn append_err(&mut self, key: impl AsRef<str>, value: impl Into<RawBson>) -> RawResult<()>;
#[cfg(not(feature = "bson-3"))]
pub(crate) type CStr = str;
#[cfg(not(feature = "bson-3"))]
pub(crate) type CString = String;
#[cfg(not(feature = "bson-3"))]
macro_rules! cstr {
($text:literal) => {
$text
};
}
#[cfg(not(feature = "bson-3"))]
pub(crate) use cstr;

fn append_ref_err<'a>(
pub(crate) fn cstr_to_str(cs: &CStr) -> &str {
#[cfg(feature = "bson-3")]
{
cs.as_str()
}
#[cfg(not(feature = "bson-3"))]
{
cs
}
}

pub(crate) trait RawDocumentBufExt: Sized {
fn append_ref_compat<'a>(
&mut self,
key: impl AsRef<str>,
value: impl Into<crate::bson::raw::RawBsonRef<'a>>,
) -> RawResult<()>;
key: impl AsRef<CStr>,
value: impl Into<crate::bson::raw::RawBsonRef<'a>> + 'a,
);

#[cfg(not(feature = "bson-3"))]
fn decode_from_bytes(data: Vec<u8>) -> RawResult<Self>;
}

#[cfg(feature = "bson-3")]
impl RawDocumentBufExt for crate::bson::RawDocumentBuf {
fn append_err(&mut self, key: impl AsRef<str>, value: impl Into<RawBson>) -> RawResult<()> {
self.append(key, value.into())
}

fn append_ref_err<'a>(
fn append_ref_compat<'a>(
&mut self,
key: impl AsRef<str>,
value: impl Into<crate::bson::raw::RawBsonRef<'a>>,
) -> RawResult<()> {
self.append(key, value)
key: impl AsRef<CStr>,
value: impl Into<crate::bson::raw::RawBsonRef<'a>> + 'a,
) {
self.append(key, value);
}
}

#[cfg(not(feature = "bson-3"))]
impl RawDocumentBufExt for crate::bson::RawDocumentBuf {
fn append_err(&mut self, key: impl AsRef<str>, value: impl Into<RawBson>) -> RawResult<()> {
self.append(key, value);
Ok(())
}

fn append_ref_err<'a>(
fn append_ref_compat<'a>(
&mut self,
key: impl AsRef<str>,
key: impl AsRef<CStr>,
value: impl Into<crate::bson::raw::RawBsonRef<'a>>,
) -> RawResult<()> {
self.append_ref(key, value);
Ok(())
) {
self.append_ref(key, value)
}

fn decode_from_bytes(data: Vec<u8>) -> RawResult<Self> {
Self::from_bytes(data)
}
}

pub(crate) trait RawArrayBufExt: Sized {
#[allow(dead_code)]
fn from_iter_err<V: Into<RawBson>, I: IntoIterator<Item = V>>(iter: I) -> RawResult<Self>;

fn push_err(&mut self, value: impl Into<RawBson>) -> RawResult<()>;
}

#[cfg(feature = "bson-3")]
impl RawArrayBufExt for crate::bson::RawArrayBuf {
fn from_iter_err<V: Into<RawBson>, I: IntoIterator<Item = V>>(iter: I) -> RawResult<Self> {
Self::from_iter(iter.into_iter().map(|v| v.into()))
}

fn push_err(&mut self, value: impl Into<RawBson>) -> RawResult<()> {
self.push(value.into())
}
}

#[cfg(not(feature = "bson-3"))]
impl RawArrayBufExt for crate::bson::RawArrayBuf {
fn from_iter_err<V: Into<RawBson>, I: IntoIterator<Item = V>>(iter: I) -> RawResult<Self> {
Ok(Self::from_iter(iter))
}

fn push_err(&mut self, value: impl Into<RawBson>) -> RawResult<()> {
self.push(value);
Ok(())
}
}

#[cfg(not(feature = "bson-3"))]
pub(crate) trait RawDocumentExt {
fn decode_from_bytes<D: AsRef<[u8]> + ?Sized>(data: &D) -> RawResult<&Self>;
Expand Down
24 changes: 12 additions & 12 deletions src/bson_util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use crate::{
RawBsonRef,
RawDocumentBuf,
},
bson_compat::{RawArrayBufExt, RawDocumentBufExt as _},
bson_compat::RawDocumentBufExt as _,
checked::Checked,
error::{Error, ErrorKind, Result},
runtime::SyncLittleEndianRead,
Expand Down Expand Up @@ -78,14 +78,14 @@ pub(crate) fn to_bson_array(docs: &[Document]) -> Bson {
pub(crate) fn to_raw_bson_array(docs: &[Document]) -> Result<RawBson> {
let mut array = RawArrayBuf::new();
for doc in docs {
array.push_err(RawDocumentBuf::from_document(doc)?)?;
array.push(RawDocumentBuf::from_document(doc)?);
}
Ok(RawBson::Array(array))
}
pub(crate) fn to_raw_bson_array_ser<T: Serialize>(values: &[T]) -> Result<RawBson> {
let mut array = RawArrayBuf::new();
for value in values {
array.push_err(crate::bson_compat::serialize_to_raw_document_buf(value)?)?;
array.push(crate::bson_compat::serialize_to_raw_document_buf(value)?);
}
Ok(RawBson::Array(array))
}
Expand Down Expand Up @@ -127,7 +127,7 @@ pub(crate) fn replacement_document_check(replacement: &Document) -> Result<()> {

pub(crate) fn replacement_raw_document_check(replacement: &RawDocumentBuf) -> Result<()> {
if let Some((key, _)) = replacement.iter().next().transpose()? {
if key.starts_with('$') {
if crate::bson_compat::cstr_to_str(key).starts_with('$') {
return Err(ErrorKind::InvalidArgument {
message: "replacement document must not contain update modifiers".to_string(),
}
Expand All @@ -147,12 +147,12 @@ pub(crate) fn array_entry_size_bytes(index: usize, doc_len: usize) -> Result<usi
(Checked::new(1) + num_decimal_digits(index) + 1 + doc_len).get()
}

pub(crate) fn vec_to_raw_array_buf(docs: Vec<RawDocumentBuf>) -> Result<RawArrayBuf> {
pub(crate) fn vec_to_raw_array_buf(docs: Vec<RawDocumentBuf>) -> RawArrayBuf {
let mut array = RawArrayBuf::new();
for doc in docs {
array.push_err(doc)?;
array.push(doc);
}
Ok(array)
array
}

/// The number of digits in `n` in base 10.
Expand Down Expand Up @@ -188,7 +188,7 @@ pub(crate) fn extend_raw_document_buf(
this: &mut RawDocumentBuf,
other: RawDocumentBuf,
) -> Result<()> {
let mut keys: HashSet<String> = HashSet::new();
let mut keys: HashSet<crate::bson_compat::CString> = HashSet::new();
for elem in this.iter_elements() {
keys.insert(elem?.key().to_owned());
}
Expand All @@ -200,27 +200,27 @@ pub(crate) fn extend_raw_document_buf(
k
)));
}
this.append_err(k, v.to_raw_bson())?;
this.append(k, v.to_raw_bson());
}
Ok(())
}

pub(crate) fn append_ser(
this: &mut RawDocumentBuf,
key: impl AsRef<str>,
key: impl AsRef<crate::bson_compat::CStr>,
value: impl Serialize,
) -> Result<()> {
#[derive(Serialize)]
struct Helper<T> {
value: T,
}
let raw_doc = crate::bson_compat::serialize_to_raw_document_buf(&Helper { value })?;
this.append_ref_err(
this.append_ref_compat(
key,
raw_doc
.get("value")?
.ok_or_else(|| Error::internal("no value"))?,
)?;
);
Ok(())
}

Expand Down
14 changes: 5 additions & 9 deletions src/client/auth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ mod x509;

use std::{borrow::Cow, fmt::Debug, str::FromStr};

use crate::{bson::RawDocumentBuf, bson_compat::RawDocumentBufExt as _};
use crate::{bson::RawDocumentBuf, bson_compat::cstr};
use derive_where::derive_where;
use hmac::{digest::KeyInit, Mac};
use rand::Rng;
Expand Down Expand Up @@ -447,17 +447,13 @@ impl Credential {

/// If the mechanism is missing, append the appropriate mechanism negotiation key-value-pair to
/// the provided hello or legacy hello command document.
pub(crate) fn append_needed_mechanism_negotiation(
&self,
command: &mut RawDocumentBuf,
) -> Result<()> {
pub(crate) fn append_needed_mechanism_negotiation(&self, command: &mut RawDocumentBuf) {
if let (Some(username), None) = (self.username.as_ref(), self.mechanism.as_ref()) {
command.append_err(
"saslSupportedMechs",
command.append(
cstr!("saslSupportedMechs"),
format!("{}.{}", self.resolved_source(), username),
)?;
);
}
Ok(())
}

/// Attempts to authenticate a stream according to this credential, returning an error
Expand Down
6 changes: 3 additions & 3 deletions src/client/auth/oidc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use typed_builder::TypedBuilder;

use crate::{
bson::{doc, rawdoc, spec::BinarySubtype, Binary, Document},
bson_compat::RawDocumentBufExt as _,
bson_compat::cstr,
client::options::{ServerAddress, ServerApi},
cmap::{Command, Connection},
error::{Error, Result},
Expand Down Expand Up @@ -620,9 +620,9 @@ async fn send_sasl_start_command(
) -> Result<SaslResponse> {
let mut start_doc = rawdoc! {};
if let Some(access_token) = access_token {
start_doc.append_err("jwt", access_token)?;
start_doc.append(cstr!("jwt"), access_token);
} else if let Some(username) = credential.username.as_deref() {
start_doc.append_err("n", username)?;
start_doc.append(cstr!("n"), username);
}
let sasl_start = SaslStart::new(
source.to_string(),
Expand Down
4 changes: 2 additions & 2 deletions src/client/auth/sasl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use crate::bson::{rawdoc, RawBson};

use crate::{
bson::{spec::BinarySubtype, Binary, Bson, Document},
bson_compat::RawDocumentBufExt as _,
bson_compat::cstr,
bson_util,
client::{auth::AuthMechanism, options::ServerApi},
cmap::Command,
Expand Down Expand Up @@ -42,7 +42,7 @@ impl SaslStart {
if self.mechanism == AuthMechanism::ScramSha1
|| self.mechanism == AuthMechanism::ScramSha256
{
body.append_err("options", rawdoc! { "skipEmptyExchange": true })?;
body.append(cstr!("options"), rawdoc! { "skipEmptyExchange": true });
}

let mut command = Command::new("saslStart", self.source, body);
Expand Down
4 changes: 2 additions & 2 deletions src/client/auth/scram.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ use tokio::sync::RwLock;

use crate::{
bson::{Bson, Document},
bson_compat::RawDocumentBufExt as _,
bson_compat::cstr,
client::{
auth::{
self,
Expand Down Expand Up @@ -461,7 +461,7 @@ impl ClientFirst {
let mut cmd = sasl_start.into_command()?;

if self.include_db {
cmd.body.append_err("db", self.source.clone())?;
cmd.body.append(cstr!("db"), self.source.clone());
}

Ok(cmd)
Expand Down
4 changes: 2 additions & 2 deletions src/client/auth/x509.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use crate::bson::rawdoc;

use crate::{
bson::Document,
bson_compat::RawDocumentBufExt as _,
bson_compat::cstr,
client::options::ServerApi,
cmap::{Command, Connection, RawCommandResponse},
error::{Error, Result},
Expand All @@ -25,7 +25,7 @@ pub(crate) fn build_client_first(
};

if let Some(ref username) = credential.username {
auth_command_doc.append_err("username", username.as_str())?;
auth_command_doc.append(cstr!("username"), username.as_str());
}

let mut command = Command::new("authenticate", "$external", auth_command_doc);
Expand Down
18 changes: 8 additions & 10 deletions src/client/csfle/state_machine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use std::{

use crate::{
bson::{rawdoc, Document, RawDocument, RawDocumentBuf},
bson_compat::RawDocumentBufExt as _,
bson_compat::{cstr, CString},
};
use futures_util::{stream, TryStreamExt};
use mongocrypt::ctx::{Ctx, KmsCtx, KmsProviderType, State};
Expand Down Expand Up @@ -245,6 +245,7 @@ impl CryptExecutor {
continue;
}

let prov_name: CString = provider.as_string().try_into()?;
match provider.provider_type() {
KmsProviderType::Aws => {
#[cfg(feature = "aws-auth")]
Expand All @@ -264,9 +265,9 @@ impl CryptExecutor {
"secretAccessKey": aws_creds.secret_key(),
};
if let Some(token) = aws_creds.session_token() {
creds.append_err("sessionToken", token)?;
creds.append(cstr!("sessionToken"), token);
}
kms_providers.append_err(provider.as_string(), creds)?;
kms_providers.append(prov_name, creds);
}
#[cfg(not(feature = "aws-auth"))]
{
Expand All @@ -279,10 +280,7 @@ impl CryptExecutor {
KmsProviderType::Azure => {
#[cfg(feature = "azure-kms")]
{
kms_providers.append_err(
provider.as_string(),
self.azure.get_token().await?,
)?;
kms_providers.append(prov_name, self.azure.get_token().await?);
}
#[cfg(not(feature = "azure-kms"))]
{
Expand Down Expand Up @@ -330,10 +328,10 @@ impl CryptExecutor {
.send()
.await
.map_err(|e| kms_error(e.to_string()))?;
kms_providers.append_err(
"gcp",
kms_providers.append(
cstr!("gcp"),
rawdoc! { "accessToken": response.access_token },
)?;
);
}
#[cfg(not(feature = "gcp-kms"))]
{
Expand Down
2 changes: 1 addition & 1 deletion src/client/executor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,7 @@ impl Client {
let (server, effective_criteria) = match self
.select_server(
selection_criteria,
op.name(),
crate::bson_compat::cstr_to_str(op.name()),
retry.as_ref().map(|r| &r.first_server),
op.override_criteria(),
)
Expand Down
Loading