diff --git a/src/enc/test.rs b/src/enc/test.rs index aacc8dce..ce983d9b 100755 --- a/src/enc/test.rs +++ b/src/enc/test.rs @@ -553,7 +553,7 @@ fn test_roundtrip_empty() { #[cfg(feature="std")] #[test] fn test_compress_into_short_buffer() { - use std::io::{Cursor, Write}; + use std::io::{Cursor, Write, ErrorKind}; // this plaintext should compress to 11 bytes let plaintext = [0u8; 2048]; @@ -564,7 +564,8 @@ fn test_compress_into_short_buffer() { let mut w = crate::CompressorWriter::new(&mut output_cursor, 4096, 4, 22); - w.write(&plaintext).unwrap_err(); + assert_eq!(w.write(&plaintext).unwrap(), 2048); + assert_eq!(w.flush().unwrap_err().kind(), ErrorKind::WriteZero); w.into_inner(); println!("{output_buffer:?}"); diff --git a/src/enc/writer.rs b/src/enc/writer.rs index a511ea70..f8cbb077 100644 --- a/src/enc/writer.rs +++ b/src/enc/writer.rs @@ -41,6 +41,7 @@ impl, Alloc: BrotliAlloc> buffer, alloc, Error::new(ErrorKind::InvalidData, "Invalid Data"), + Error::new(ErrorKind::WriteZero, "No room in output."), q, lgwin, )) @@ -127,14 +128,24 @@ pub struct CompressorWriterCustomIo< output: Option, error_if_invalid_data: Option, state: BrotliEncoderStateStruct, + error_if_zero_bytes_written: Option, } -pub fn write_all>( +pub fn write_all, ErrMaker: FnMut() -> Option>( writer: &mut W, mut buf: &[u8], + mut error_to_return_if_zero_bytes_written: ErrMaker, ) -> Result<(), ErrType> { while !buf.is_empty() { match writer.write(buf) { - Ok(bytes_written) => buf = &buf[bytes_written..], + Ok(bytes_written) => if bytes_written != 0 { + buf = &buf[bytes_written..] + } else { + if let Some(err) = error_to_return_if_zero_bytes_written() { + return Err(err); + } else { + return Ok(()); + } + }, Err(e) => return Err(e), } } @@ -148,6 +159,7 @@ impl, BufferType: SliceWrapperMut, Alloc: B buffer: BufferType, alloc: Alloc, invalid_data_error_type: ErrType, + error_if_zero_bytes_written: ErrType, q: u32, lgwin: u32, ) -> Self { @@ -157,6 +169,7 @@ impl, BufferType: SliceWrapperMut, Alloc: B output: Some(w), state: BrotliEncoderStateStruct::new(alloc), error_if_invalid_data: Some(invalid_data_error_type), + error_if_zero_bytes_written: Some(error_if_zero_bytes_written), }; ret.state .set_parameter(BrotliEncoderParameter::BROTLI_PARAM_QUALITY, q); @@ -189,9 +202,17 @@ impl, BufferType: SliceWrapperMut, Alloc: B &mut nop_callback, ); if output_offset > 0 { + let zero_err = &mut self.error_if_zero_bytes_written; + let fallback = &mut self.error_if_invalid_data; match write_all( self.output.as_mut().unwrap(), &self.output_buffer.slice_mut()[..output_offset], + || { + if let Some(err) = zero_err.take() { + return Some(err); + } + fallback.take() + }, ) { Ok(_) => {} Err(e) => return Err(e), @@ -266,12 +287,23 @@ impl, BufferType: SliceWrapperMut, Alloc: B &mut nop_callback, ); if output_offset > 0 { + let zero_err = &mut self.error_if_zero_bytes_written; + let fallback = &mut self.error_if_invalid_data; match write_all( self.output.as_mut().unwrap(), &self.output_buffer.slice_mut()[..output_offset], + || { + if let Some(err) = zero_err.take() { + return Some(err); + } + fallback.take() + }, + ) { Ok(_) => {} - Err(e) => return Err(e), + Err(e) => { + return Err(e) + }, } } if !ret {