diff --git a/src/client/auth.rs b/src/client/auth.rs index 0eee9174d..222b6483c 100644 --- a/src/client/auth.rs +++ b/src/client/auth.rs @@ -19,7 +19,7 @@ use crate::{bson::RawDocumentBuf, bson_compat::cstr, options::ClientOptions}; use derive_where::derive_where; use hmac::{digest::KeyInit, Mac}; use rand::Rng; -use serde::Deserialize; +use serde::{Deserialize, Serialize}; use typed_builder::TypedBuilder; use self::scram::ScramVersion; @@ -44,7 +44,7 @@ const MONGODB_OIDC_STR: &str = "MONGODB-OIDC"; /// The authentication mechanisms supported by MongoDB. /// /// Note: not all of these mechanisms are currently supported by the driver. -#[derive(Clone, Deserialize, PartialEq, Debug)] +#[derive(Clone, Deserialize, Serialize, PartialEq, Debug)] #[non_exhaustive] pub enum AuthMechanism { /// MongoDB Challenge Response nonce and MD5 based authentication system. It is currently @@ -558,8 +558,7 @@ impl Credential { mechanism.authenticate_stream(conn, self, opts).await } - #[cfg(test)] - pub(crate) fn serialize_for_client_options( + pub(crate) fn serialize( credential: &Option, serializer: S, ) -> std::result::Result diff --git a/src/client/options.rs b/src/client/options.rs index aa3058e52..53f61b3f5 100644 --- a/src/client/options.rs +++ b/src/client/options.rs @@ -7,7 +7,7 @@ mod resolver_config; use std::{ cmp::Ordering, - collections::HashSet, + collections::{HashMap, HashSet}, convert::TryFrom, fmt::{self, Display, Formatter, Write}, hash::{Hash, Hasher}, @@ -20,7 +20,7 @@ use std::{ use crate::bson::UuidRepresentation; use derive_where::derive_where; use macro_magic::export_tokens; -use serde::{de::Unexpected, Deserialize, Deserializer, Serialize}; +use serde::{de::Unexpected, Deserialize, Deserializer, Serialize, Serializer}; use serde_with::skip_serializing_none; use std::sync::LazyLock; use strsim::jaro_winkler; @@ -688,7 +688,7 @@ impl Serialize for ClientOptions { #[serde(serialize_with = "serde_util::serialize_duration_option_as_int_millis")] connecttimeoutms: &'a Option, - #[serde(flatten, serialize_with = "Credential::serialize_for_client_options")] + #[serde(flatten, serialize_with = "Credential::serialize")] credential: &'a Option, directconnection: &'a Option, @@ -708,7 +708,7 @@ impl Serialize for ClientOptions { maxconnecting: &'a Option, - #[serde(flatten, serialize_with = "ReadConcern::serialize_for_client_options")] + #[serde(flatten, serialize_with = "ReadConcern::serialize")] readconcern: &'a Option, replicaset: &'a Option, @@ -731,10 +731,10 @@ impl Serialize for ClientOptions { #[serde(serialize_with = "serde_util::serialize_duration_option_as_int_millis")] sockettimeoutms: &'a Option, - #[serde(flatten, serialize_with = "Tls::serialize_for_client_options")] + #[serde(flatten, serialize_with = "Tls::serialize")] tls: &'a Option, - #[serde(flatten, serialize_with = "WriteConcern::serialize_for_client_options")] + #[serde(flatten, serialize_with = "WriteConcern::serialize")] writeconcern: &'a Option, zlibcompressionlevel: &'a Option, @@ -784,10 +784,35 @@ impl Serialize for ClientOptions { } } +// For ConnectionString serialization +fn serialize_uuid_rep_option( + value: &Option, + serializer: S, +) -> std::result::Result +where + S: Serializer, +{ + #[non_exhaustive] + #[derive(Serialize)] + #[serde(remote = "UuidRepresentation")] + enum UuidRepresentationForSerialize { + Standard, + CSharpLegacy, + JavaLegacy, + PythonLegacy, + } + match value { + Some(rep) => UuidRepresentationForSerialize::serialize(rep, serializer), + None => serializer.serialize_none(), + } +} + /// Contains the options that can be set via a MongoDB connection string. /// /// The format of a MongoDB connection string is described [here](https://www.mongodb.com/docs/manual/reference/connection-string/#connection-string-formats). -#[derive(Debug, Default, PartialEq)] +#[skip_serializing_none] +#[derive(Debug, Default, PartialEq, Serialize)] +#[serde(rename_all = "camelCase")] #[non_exhaustive] pub struct ConnectionString { /// The initial list of seeds that the Client should connect to, or a DNS name used for SRV @@ -806,11 +831,13 @@ pub struct ConnectionString { /// The TLS configuration for the Client to use in its connections with the server. /// /// By default, TLS is disabled. + #[serde(serialize_with = "Tls::serialize")] pub tls: Option, /// The amount of time each monitoring thread should wait between performing server checks. /// /// The default value is 10 seconds. + #[serde(serialize_with = "serde_util::serialize_duration_option_as_int_millis")] pub heartbeat_frequency: Option, /// When running a read operation with a ReadPreference that allows selecting secondaries, @@ -824,10 +851,12 @@ pub struct ConnectionString { /// lowest average round trip time is eligible. /// /// The default value is 15 ms. + #[serde(serialize_with = "serde_util::serialize_duration_option_as_int_millis")] pub local_threshold: Option, /// Specifies the default read concern for operations performed on the Client. See the /// ReadConcern type documentation for more details. + #[serde(serialize_with = "ReadConcern::serialize")] pub read_concern: Option, /// The name of the replica set that the Client should connect to. @@ -835,12 +864,14 @@ pub struct ConnectionString { /// Specifies the default write concern for operations performed on the Client. See the /// WriteConcern type documentation for more details. + #[serde(serialize_with = "WriteConcern::serialize")] pub write_concern: Option, /// The amount of time the Client should attempt to select a server for an operation before /// timing outs /// /// The default value is 30 seconds. + #[serde(serialize_with = "serde_util::serialize_duration_option_as_int_millis")] pub server_selection_timeout: Option, /// The maximum amount of connections that the Client should allow to be created in a @@ -867,6 +898,7 @@ pub struct ConnectionString { /// closed. A value of zero indicates that connections should not be closed due to being idle. /// /// By default, connections will not be closed due to being idle. + #[serde(serialize_with = "serde_util::serialize_duration_option_as_int_millis")] pub max_idle_time: Option, #[cfg(any( @@ -884,6 +916,7 @@ pub struct ConnectionString { /// server. /// /// The default value is 10 seconds. + #[serde(serialize_with = "serde_util::serialize_duration_option_as_int_millis")] pub connect_timeout: Option, /// Whether or not the client should retry a read operation if the operation fails. @@ -908,6 +941,7 @@ pub struct ConnectionString { pub direct_connection: Option, /// The credential to use for authenticating connections made by this client. + #[serde(serialize_with = "Credential::serialize")] pub credential: Option, /// Default database for this client. @@ -920,6 +954,7 @@ pub struct ConnectionString { /// Amount of time spent attempting to send or receive on a socket before timing out; note that /// this only applies to application operations, not server discovery and monitoring. + #[serde(serialize_with = "serde_util::serialize_duration_option_as_int_millis")] pub socket_timeout: Option, /// Default read preference for the client. @@ -929,6 +964,7 @@ pub struct ConnectionString { /// the [`UuidOld`](crate::bson::spec::BinarySubtype::UuidOld) subtype. This is not used by /// the driver; client code can use this when deserializing relevant values with /// [`Binary::to_uuid_with_representation`](crate::bson::binary::Binary::to_uuid_with_representation). + #[serde(serialize_with = "serialize_uuid_rep_option")] pub uuid_representation: Option, /// Limit on the number of mongos connections that may be created for sharded topologies. @@ -937,10 +973,12 @@ pub struct ConnectionString { /// Overrides the default "mongodb" service name for SRV lookup in both discovery and polling pub srv_service_name: Option, + #[serde(serialize_with = "serde_util::serialize_duration_option_as_int_millis")] wait_queue_timeout: Option, tls_insecure: Option, #[cfg(test)] + #[serde(skip_serializing)] original_uri: String, } @@ -956,7 +994,7 @@ struct ConnectionStringParts { } /// Specification for mongodb server connections. -#[derive(Debug, PartialEq, Clone)] +#[derive(Debug, PartialEq, Clone, Serialize)] #[non_exhaustive] pub enum HostInfo { /// A set of addresses. @@ -1021,8 +1059,7 @@ impl From for Option { } impl Tls { - #[cfg(test)] - pub(crate) fn serialize_for_client_options( + pub(crate) fn serialize( tls: &Option, serializer: S, ) -> std::result::Result @@ -1030,9 +1067,7 @@ impl Tls { S: serde::Serializer, { match tls { - Some(Tls::Enabled(tls_options)) => { - TlsOptions::serialize_for_client_options(tls_options, serializer) - } + Some(Tls::Enabled(tls_options)) => TlsOptions::serialize(tls_options, serializer), _ => serializer.serialize_none(), } } @@ -1074,8 +1109,7 @@ pub struct TlsOptions { } impl TlsOptions { - #[cfg(test)] - pub(crate) fn serialize_for_client_options( + pub(crate) fn serialize( tls_options: &TlsOptions, serializer: S, ) -> std::result::Result @@ -1400,6 +1434,10 @@ fn percent_decode(s: &str, err_message: &str) -> Result { } } +fn percent_encode(s: &str) -> String { + percent_encoding::utf8_percent_encode(s, percent_encoding::NON_ALPHANUMERIC).to_string() +} + fn validate_and_parse_userinfo(s: &str, userinfo_type: &str) -> Result { if s.chars().any(|c| USERINFO_RESERVED_CHARACTERS.contains(&c)) { return Err(Error::invalid_argument(format!( @@ -1673,6 +1711,352 @@ impl ConnectionString { Ok(conn_str) } + /// Un-parses a [`ConnectionString`] struct back into a MongoDB connection string. + fn to_uri_str(&self) -> String { + let ConnectionString { + host_info, + app_name, + tls, + heartbeat_frequency, + local_threshold, + read_concern, + replica_set, + write_concern, + server_selection_timeout, + max_pool_size, + min_pool_size, + max_connecting, + max_idle_time, + #[cfg(any( + feature = "zstd-compression", + feature = "zlib-compression", + feature = "snappy-compression" + ))] + compressors, + connect_timeout, + retry_reads, + retry_writes, + server_monitoring_mode: _, + direct_connection, + credential, + default_database, + load_balanced, + socket_timeout, + read_preference, + uuid_representation, + srv_max_hosts, + srv_service_name: _, + wait_queue_timeout, + tls_insecure, + #[cfg(test)] + original_uri: _, + } = self; + + let mut res: String = String::new(); + let mut opts = String::new(); + + if self.is_srv() { + res.push_str("mongodb+srv://"); + } else { + res.push_str("mongodb://"); + } + + if let Some(credential) = credential { + if let Some(username) = &credential.username { + res.push_str(&percent_encode(username)); + if let Some(password) = &credential.password { + res.push_str(&format!(":{}", &percent_encode(password))) + } + res.push('@'); + } + } + + if self.is_srv() { + if let HostInfo::DnsRecord(dns) = host_info { + res.push_str(dns); + } + } else if let HostInfo::HostIdentifiers(hosts) = host_info { + res.push_str( + &hosts + .iter() + .map(|h| h.to_string()) + .collect::>() + .join(","), + ); + } + + res.push('/'); + + if let Some(authdb) = default_database { + res.push_str(authdb); + } + + if let Some(replica_set) = replica_set { + opts.push_str(&format!("&replicaSet={replica_set}")); + } + + if let Some(direct_connection) = direct_connection { + opts.push_str(&format!("&directConnection={direct_connection}")); + } + + if let Some(tls) = tls { + match tls { + Tls::Enabled(options) => { + opts.push_str("&tls=true"); + + if let Some(cert_key_file_path) = &options.cert_key_file_path { + opts.push_str(&format!( + "&tlsCertificateKeyFile={}", + cert_key_file_path.to_str().unwrap() + )); + } + + #[cfg(feature = "cert-key-password")] + if let Some(tls_certificate_key_file_password) = + &options.tls_certificate_key_file_password + { + opts.push_str(&format!( + "&tlsCertificateKeyFilePassword={}", + std::str::from_utf8(tls_certificate_key_file_password).unwrap() + )); + } + + if let Some(ca_file_path) = &options.ca_file_path { + opts.push_str(&format!("&tlsCAFile={}", ca_file_path.to_str().unwrap())); + } + + if let Some(allow_invalid_certificates) = options.allow_invalid_certificates { + opts.push_str(&format!( + "&tlsAllowInvalidCertificates={allow_invalid_certificates}" + )); + } + + #[cfg(feature = "openssl-tls")] + if let Some(allow_invalid_hostnames) = options.allow_invalid_hostnames { + opts.push_str(&format!( + "&tlsAllowInvalidHostnames={allow_invalid_hostnames}" + )); + } + + if let Some(tls_insecure) = tls_insecure { + opts.push_str(&format!("&tlsInsecure={tls_insecure}")); + } + } + Tls::Disabled => { + opts.push_str("&tls=false"); + } + } + } + + if let Some(connect_timeout) = connect_timeout { + opts.push_str(&format!( + "&connectTimeoutMS={}", + connect_timeout.as_millis() + )); + } + + if let Some(socket_timeout) = socket_timeout { + opts.push_str(&format!("&socketTimeoutMS={}", socket_timeout.as_millis())); + } + + #[cfg(any( + feature = "zstd-compression", + feature = "zlib-compression", + feature = "snappy-compression" + ))] + if let Some(compressors) = compressors { + opts.push_str(&format!( + "&compressors={}", + compressors + .iter() + .map(|c| c.name()) + .collect::>() + .join(",") + )); + } + + #[cfg(feature = "zlib-compression")] + if let Some(compressors) = compressors { + for compressor in compressors { + if let Compressor::Zlib { level: Some(level) } = compressor { + opts.push_str(&format!("&zlibCompressionLevel={level}")); + } + } + } + + if let Some(max_pool_size) = max_pool_size { + opts.push_str(&format!("&maxPoolSize={max_pool_size}")); + } + + if let Some(min_pool_size) = min_pool_size { + opts.push_str(&format!("&minPoolSize={min_pool_size}")); + } + + if let Some(max_connecting) = max_connecting { + opts.push_str(&format!("&maxConnecting={max_connecting}")); + } + + if let Some(max_idle_time) = max_idle_time { + opts.push_str(&format!("&maxIdleTimeMS={}", max_idle_time.as_millis())); + } + + if let Some(wait_queue_timeout) = wait_queue_timeout { + opts.push_str(&format!( + "&waitQueueTimeoutMS={}", + wait_queue_timeout.as_millis() + )); + } + + if let Some(write_concern) = write_concern { + if let Some(w) = &write_concern.w { + match w { + Acknowledgment::Nodes(i) => { + opts.push_str(&format!("&w={i}")); + } + Acknowledgment::Majority => { + opts.push_str("&w=majority"); + } + Acknowledgment::Custom(tag) => { + opts.push_str(&format!("&w={tag}")); + } + } + } + + if let Some(w_timeout) = write_concern.w_timeout { + opts.push_str(&format!("&wtimeoutMS={}", w_timeout.as_millis())); + } + + if let Some(journal) = write_concern.journal { + opts.push_str(&format!("&journal={journal}")); + } + } + + if let Some(read_concern) = read_concern { + opts.push_str(&format!( + "&readConcernLevel={}", + read_concern.level.as_str() + )); + } + + if let Some(read_preference) = read_preference { + opts.push_str(&format!("&readPreference={}", read_preference.mode())); + + if let Some(max_staleness) = read_preference.max_staleness() { + opts.push_str(&format!("&maxStalenessSeconds={}", max_staleness.as_secs())); + } + + if let Some(tag_sets) = read_preference.tag_sets() { + let ser_tag_set = |tag_set: &HashMap| -> String { + let tags = tag_set + .iter() + .map(|(k, v)| format!("{k}:{v}")) + .collect::>() + .join(","); + format!("&readPreferenceTags={tags}") + }; + opts.push_str( + &tag_sets + .iter() + .map(ser_tag_set) + .collect::>() + .join(""), + ) + } + } + + if let Some(auth_source) = credential + .as_ref() + .and_then(|c: &Credential| c.source.as_ref()) + { + opts.push_str(&format!("&authSource={auth_source}")); + } + + if let Some(auth_mechanism) = credential + .as_ref() + .and_then(|c: &Credential| c.mechanism.as_ref()) + { + opts.push_str(&format!( + "&authMechanism={}", + auth_mechanism.as_str().to_uppercase() + )); + } + + if let Some(auth_mechanism_properties) = credential + .as_ref() + .and_then(|c: &Credential| c.mechanism_properties.as_ref()) + { + if !auth_mechanism_properties.is_empty() { + opts.push_str(&format!( + "&authMechanismProperties={}", + &auth_mechanism_properties + .iter() + .map(|(k, v)| format!("{k}:{v}")) + .collect::>() + .join(",") + )) + } + } + + if let Some(local_threshold) = local_threshold { + opts.push_str(&format!( + "&localThresholdMS={}", + local_threshold.as_millis() + )); + } + + if let Some(server_selection_timeout) = server_selection_timeout { + opts.push_str(&format!( + "&serverSelectionTimeoutMS={}", + server_selection_timeout.as_millis() + )); + } + + if let Some(heartbeat_frequency) = heartbeat_frequency { + opts.push_str(&format!( + "&heartbeatFrequencyMS={}", + heartbeat_frequency.as_millis() + )); + } + + if let Some(app_name) = app_name { + opts.push_str(&format!("&appName={app_name}")); + } + + if let Some(retry_reads) = retry_reads { + opts.push_str(&format!("&retryReads={retry_reads}")); + } + + if let Some(retry_writes) = retry_writes { + opts.push_str(&format!("&retryWrites={retry_writes}")); + } + + if let Some(uuid_rep) = uuid_representation { + let s = match uuid_rep { + UuidRepresentation::Standard => "standard", + UuidRepresentation::CSharpLegacy => "csharpLegacy", + UuidRepresentation::JavaLegacy => "javaLegacy", + UuidRepresentation::PythonLegacy => "pythonLegacy", + _ => "", + }; + opts.push_str(&format!("&uuidRepresentation={s}")); + } + + if let Some(load_balanced) = load_balanced { + opts.push_str(&format!("&loadBalanced={load_balanced}")); + } + + if let Some(srv_max_hosts) = srv_max_hosts { + opts.push_str(&format!("&srvMaxHosts={srv_max_hosts}")); + } + + if !opts.is_empty() { + opts.replace_range(0..1, "?"); // mark start of options + res.push_str(&opts); + } + + res + } + /// Amount of time spent attempting to check out a connection from a server's connection pool /// before timing out. Not supported by the Rust driver. pub fn wait_queue_timeout(&self) -> Option { @@ -2280,6 +2664,12 @@ impl<'de> Deserialize<'de> for ConnectionString { } } +impl Display for ConnectionString { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.to_uri_str()) + } +} + struct ConnectionStringVisitor; impl serde::de::Visitor<'_> for ConnectionStringVisitor { diff --git a/src/compression/compressors.rs b/src/compression/compressors.rs index 842df3664..166e13cc6 100644 --- a/src/compression/compressors.rs +++ b/src/compression/compressors.rs @@ -1,10 +1,11 @@ +use serde::Serialize; use std::str::FromStr; use crate::error::{Error, ErrorKind, Result}; /// The compressors that may be used to compress messages sent to and decompress messages returned /// from the server. Note that each variant requires enabling a corresponding feature flag. -#[derive(Clone, Debug, PartialEq)] +#[derive(Clone, Debug, PartialEq, Serialize)] #[non_exhaustive] pub enum Compressor { /// `zstd` compression. See [the `zstd` manual](http://facebook.github.io/zstd/zstd_manual.html) diff --git a/src/concern.rs b/src/concern.rs index 40b5d436c..4a3940968 100644 --- a/src/concern.rs +++ b/src/concern.rs @@ -85,8 +85,7 @@ impl ReadConcern { ReadConcernLevel::from_str(level.as_ref()).into() } - #[cfg(test)] - pub(crate) fn serialize_for_client_options( + pub(crate) fn serialize( read_concern: &Option, serializer: S, ) -> std::result::Result @@ -342,8 +341,7 @@ impl WriteConcern { Ok(()) } - #[cfg(test)] - pub(crate) fn serialize_for_client_options( + pub(crate) fn serialize( write_concern: &Option, serializer: S, ) -> std::result::Result