|
1 | 1 | //! A module for wrappers that encode / decode data. |
2 | 2 |
|
3 | 3 | use std::borrow::Cow; |
| 4 | +use std::io; |
4 | 5 |
|
5 | 6 | #[cfg(feature = "encoding")] |
6 | 7 | use encoding_rs::{Encoding, UTF_16BE, UTF_16LE, UTF_8}; |
| 8 | +#[cfg(feature = "encoding")] |
| 9 | +use encoding_rs_io::{DecodeReaderBytes, DecodeReaderBytesBuilder}; |
7 | 10 |
|
8 | 11 | #[cfg(feature = "encoding")] |
9 | 12 | use crate::Error; |
10 | 13 | use crate::Result; |
11 | 14 |
|
| 15 | +/// |
| 16 | +#[derive(Debug)] |
| 17 | +pub struct ValidatingReader<R> { |
| 18 | + reader: R, |
| 19 | + leftover_bytes_buf: [u8; 7], |
| 20 | + len: u8, |
| 21 | + first: bool, |
| 22 | +} |
| 23 | + |
| 24 | +impl<R: io::Read> ValidatingReader<R> { |
| 25 | + /// |
| 26 | + pub fn new(reader: R) -> Self { |
| 27 | + Self { |
| 28 | + reader, |
| 29 | + leftover_bytes_buf: [0; 7], |
| 30 | + len: 0, |
| 31 | + first: true, |
| 32 | + } |
| 33 | + } |
| 34 | +} |
| 35 | + |
| 36 | +impl<R: io::Read> io::Read for ValidatingReader<R> { |
| 37 | + fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> { |
| 38 | + buf[..self.len.into()].copy_from_slice(&self.leftover_bytes_buf[..self.len.into()]); |
| 39 | + let (_leftovers, copy_dest) = buf.split_at_mut(self.len.into()); |
| 40 | + let amt = self.reader.read(copy_dest)?; |
| 41 | + |
| 42 | + match std::str::from_utf8(buf) { |
| 43 | + Ok(_) => Ok(amt), |
| 44 | + Err(err) => { |
| 45 | + let (valid, after_valid) = buf.split_at(err.valid_up_to()); |
| 46 | + self.leftover_bytes_buf[..after_valid.len()].copy_from_slice(after_valid); |
| 47 | + self.len = after_valid.len() as u8; |
| 48 | + Ok(valid.len()) |
| 49 | + } |
| 50 | + } |
| 51 | + } |
| 52 | +} |
| 53 | + |
| 54 | +/// A struct for transparently decoding / validating bytes to known-valid UTF-8. |
| 55 | +#[derive(Debug)] |
| 56 | +pub struct DecodingReader<R> { |
| 57 | + #[cfg(feature = "encoding")] |
| 58 | + reader: io::BufReader<DecodeReaderBytes<R, Vec<u8>>>, |
| 59 | + #[cfg(not(feature = "encoding"))] |
| 60 | + reader: io::BufReader<ValidatingReader<R>>, |
| 61 | +} |
| 62 | + |
| 63 | +impl<R: io::Read> DecodingReader<R> { |
| 64 | + /// Build a new DecodingReader which decodes a stream of bytes into valid UTF-8. |
| 65 | + #[cfg(feature = "encoding")] |
| 66 | + pub fn new(reader: R) -> Self { |
| 67 | + let decoder = DecodeReaderBytesBuilder::new() |
| 68 | + .bom_override(true) |
| 69 | + .build(reader); |
| 70 | + |
| 71 | + Self { |
| 72 | + reader: io::BufReader::new(decoder), |
| 73 | + } |
| 74 | + } |
| 75 | + |
| 76 | + /// Build a new DecodingReader which only validates UTF-8. |
| 77 | + #[cfg(not(feature = "encoding"))] |
| 78 | + pub fn new(reader: R) -> Self { |
| 79 | + Self { |
| 80 | + reader: io::BufReader::new(ValidatingReader::new(reader)), |
| 81 | + } |
| 82 | + } |
| 83 | +} |
| 84 | + |
| 85 | +impl<R: io::Read> io::Read for DecodingReader<R> { |
| 86 | + fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> { |
| 87 | + self.reader.read(buf) |
| 88 | + } |
| 89 | +} |
| 90 | + |
| 91 | +impl<R: io::Read> io::BufRead for DecodingReader<R> { |
| 92 | + fn fill_buf(&mut self) -> io::Result<&[u8]> { |
| 93 | + self.reader.fill_buf() |
| 94 | + } |
| 95 | + |
| 96 | + fn consume(&mut self, amt: usize) { |
| 97 | + self.reader.consume(amt) |
| 98 | + } |
| 99 | +} |
| 100 | + |
12 | 101 | /// Decoder of byte slices into strings. |
13 | 102 | /// |
14 | 103 | /// If feature `encoding` is enabled, this encoding taken from the `"encoding"` |
@@ -184,3 +273,24 @@ pub fn detect_encoding(bytes: &[u8]) -> Option<&'static Encoding> { |
184 | 273 | _ => None, |
185 | 274 | } |
186 | 275 | } |
| 276 | + |
| 277 | +#[cfg(test)] |
| 278 | +mod test { |
| 279 | + use std::io::Read; |
| 280 | + |
| 281 | + use super::*; |
| 282 | + |
| 283 | + #[track_caller] |
| 284 | + fn test_input(input: &[u8]) { |
| 285 | + let mut reader = ValidatingReader::new(input); |
| 286 | + let mut buf = [0; 100]; |
| 287 | + assert_eq!(reader.read(&mut buf).unwrap(), input.len()); |
| 288 | + } |
| 289 | + |
| 290 | + // #[test] |
| 291 | + // fn test() { |
| 292 | + // test_input(b"asdf"); |
| 293 | + // test_input(b"\x82\xA0\x82\xA2\x82\xA4"); |
| 294 | + // test_input(b"\xEF\xBB\xBFfoo\xFFbar"); |
| 295 | + // } |
| 296 | +} |
0 commit comments