diff --git a/Cargo.lock b/Cargo.lock index aa076c379..6d6645ea9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1584,7 +1584,7 @@ dependencies = [ "serde_json", "serde_path_to_error", "serde_with", - "sha-1", + "sha1", "sha2", "snap", "socket2", @@ -2488,17 +2488,6 @@ dependencies = [ "syn 2.0.101", ] -[[package]] -name = "sha-1" -version = "0.10.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f5058ada175748e33390e40e872bd0fe59a19f265d0158daa551c5a88a76009c" -dependencies = [ - "cfg-if", - "cpufeatures", - "digest", -] - [[package]] name = "sha1" version = "0.10.6" diff --git a/Cargo.toml b/Cargo.toml index 8dfcb604b..0c9ba57e6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -105,7 +105,7 @@ rayon = { version = "1.5.3", optional = true } rustc_version_runtime = "0.3.0" rustls-pemfile = { version = "1.0.1", optional = true } serde_with = "3.8.1" -sha-1 = "0.10.0" +sha1 = "0.10.0" sha2 = "0.10.2" snap = { version = "1.0.5", optional = true } socket2 = "0.5.5" diff --git a/src/cmap/conn.rs b/src/cmap/conn.rs index 85b9ade9d..d7c39cca7 100644 --- a/src/cmap/conn.rs +++ b/src/cmap/conn.rs @@ -222,6 +222,7 @@ impl Connection { self.command_executing = true; + let max_message_size = self.max_message_size_bytes(); #[cfg(any( feature = "zstd-compression", feature = "zlib-compression", @@ -230,30 +231,30 @@ impl Connection { let write_result = match self.compressor { Some(ref compressor) if message.should_compress => { message - .write_op_compressed_to(&mut self.stream, compressor) + .write_op_compressed_to(&mut self.stream, compressor, max_message_size) + .await + } + _ => { + message + .write_op_msg_to(&mut self.stream, max_message_size) .await } - _ => message.write_op_msg_to(&mut self.stream).await, }; #[cfg(all( not(feature = "zstd-compression"), not(feature = "zlib-compression"), not(feature = "snappy-compression") ))] - let write_result = message.write_op_msg_to(&mut self.stream).await; + let write_result = message + .write_op_msg_to(&mut self.stream, max_message_size) + .await; if let Err(ref err) = write_result { self.error = Some(err.clone()); } write_result?; - let response_message_result = Message::read_from( - &mut self.stream, - self.stream_description - .as_ref() - .map(|d| d.max_message_size_bytes), - ) - .await; + let response_message_result = Message::read_from(&mut self.stream, max_message_size).await; self.command_executing = false; if let Err(ref err) = response_message_result { self.error = Some(err.clone()); @@ -306,6 +307,12 @@ impl Connection { pub(crate) fn is_streaming(&self) -> bool { self.more_to_come } + + fn max_message_size_bytes(&self) -> Option { + self.stream_description + .as_ref() + .map(|d| d.max_message_size_bytes) + } } /// A handle to a pinned connection - the connection itself can be retrieved or returned to the diff --git a/src/cmap/conn/wire/message.rs b/src/cmap/conn/wire/message.rs index c746c8b95..fcece7a1f 100644 --- a/src/cmap/conn/wire/message.rs +++ b/src/cmap/conn/wire/message.rs @@ -274,6 +274,7 @@ impl Message { pub(crate) async fn write_op_msg_to( &self, mut writer: T, + max_message_size_bytes: Option, ) -> Result<()> { let sections = self.get_sections_bytes()?; @@ -286,6 +287,15 @@ impl Message { .map(std::mem::size_of_val) .unwrap_or(0); + let max_len = + Checked::try_from(max_message_size_bytes.unwrap_or(DEFAULT_MAX_MESSAGE_SIZE_BYTES))?; + if total_length > max_len { + return Err(ErrorKind::InvalidArgument { + message: format!("Message length {} over maximum {}", total_length, max_len), + } + .into()); + } + let header = Header { length: total_length.try_into()?, request_id: self.request_id.unwrap_or_else(next_request_id), @@ -316,6 +326,7 @@ impl Message { &self, mut writer: T, compressor: &Compressor, + max_message_size_bytes: Option, ) -> Result<()> { let flag_bytes = &self.flags.bits().to_le_bytes(); let section_bytes = self.get_sections_bytes()?; @@ -329,6 +340,15 @@ impl Message { + std::mem::size_of::() + compressed_bytes.len(); + let max_len = + Checked::try_from(max_message_size_bytes.unwrap_or(DEFAULT_MAX_MESSAGE_SIZE_BYTES))?; + if total_length > max_len { + return Err(ErrorKind::InvalidArgument { + message: format!("Message length {} over maximum {}", total_length, max_len), + } + .into()); + } + let header = Header { length: total_length.try_into()?, request_id: self.request_id.unwrap_or_else(next_request_id),