@@ -24,115 +24,47 @@ use serde::{Deserialize as SerdeDeserialize, Serialize as SerdeSerialize};
2424
2525use crate :: { DeserializeBytes , Error , SerializeBytes , Size } ;
2626
27- #[ cfg( not( feature = "mls" ) ) ]
28- const MAX_LEN : u64 = ( 1 << 62 ) - 1 ;
29- #[ cfg( not( feature = "mls" ) ) ]
30- const MAX_LEN_LEN_LOG : usize = 3 ;
3127#[ cfg( feature = "mls" ) ]
32- const MAX_LEN : u64 = ( 1 << 30 ) - 1 ;
33- #[ cfg( feature = "mls" ) ]
34- const MAX_LEN_LEN_LOG : usize = 2 ;
35-
36- #[ inline( always) ]
37- fn check_min_length ( length : usize , len_len : usize ) -> Result < ( ) , Error > {
38- if cfg ! ( feature = "mls" ) {
39- // ensure that len_len is minimal for the given length
40- let min_len_len = length_encoding_bytes ( length as u64 ) ?;
41- if min_len_len != len_len {
42- return Err ( Error :: InvalidVectorLength ) ;
43- }
44- } ;
45- Ok ( ( ) )
46- }
28+ const MAX_MLS_LEN : u64 = ( 1 << 30 ) - 1 ;
4729
48- #[ inline( always) ]
49- fn calculate_length ( len_len_byte : u8 ) -> Result < ( usize , usize ) , Error > {
50- let length: usize = ( len_len_byte & 0x3F ) . into ( ) ;
51- let len_len_log = ( len_len_byte >> 6 ) . into ( ) ;
52- if !cfg ! ( fuzzing) {
53- debug_assert ! ( len_len_log <= MAX_LEN_LEN_LOG ) ;
54- }
55- if len_len_log > MAX_LEN_LEN_LOG {
56- return Err ( Error :: InvalidVectorLength ) ;
57- }
58- let len_len = match len_len_log {
59- 0 => 1 ,
60- 1 => 2 ,
61- 2 => 4 ,
62- 3 => 8 ,
63- _ => unreachable ! ( ) ,
64- } ;
65- Ok ( ( length, len_len) )
66- }
67-
68- #[ inline( always) ]
69- fn read_variable_length_bytes ( bytes : & [ u8 ] ) -> Result < ( ( usize , usize ) , & [ u8 ] ) , Error > {
70- // The length is encoded in the first two bits of the first byte.
30+ /// Thin wrapper around [`TlsVarInt`] representing the length of encoded vector content in bytes.
31+ ///
32+ /// When `mls` feature is enabled, the maximum length is limited to 30-bit. Otherwise, this type is
33+ /// no-op.
34+ struct ContentLength ( super :: TlsVarInt ) ;
7135
72- let ( len_len_byte, mut remainder) = u8:: tls_deserialize_bytes ( bytes) ?;
36+ impl ContentLength {
37+ #[ cfg( not( feature = "mls" ) ) ]
38+ #[ allow( dead_code) ] // used in arbitrary
39+ const MAX : u64 = crate :: TlsVarInt :: MAX ;
7340
74- let ( mut length, len_len) = calculate_length ( len_len_byte) ?;
41+ #[ cfg( feature = "mls" ) ]
42+ const MAX : u64 = MAX_MLS_LEN ;
7543
76- for _ in 1 ..len_len {
77- let ( next, next_remainder) = u8:: tls_deserialize_bytes ( remainder) ?;
78- remainder = next_remainder;
79- length = ( length << 8 ) + usize:: from ( next) ;
44+ fn new ( value : super :: TlsVarInt ) -> Result < Self , Error > {
45+ #[ cfg( feature = "mls" ) ]
46+ if Self :: MAX < value. value ( ) {
47+ return Err ( Error :: InvalidVectorLength ) ;
48+ }
49+ Ok ( Self ( value) )
8050 }
8151
82- check_min_length ( length , len_len ) ? ;
83-
84- Ok ( ( ( length , len_len ) , remainder ) )
52+ fn from_usize ( value : usize ) -> Result < Self , Error > {
53+ Self :: new ( super :: TlsVarInt :: try_new ( value . try_into ( ) ? ) ? )
54+ }
8555}
8656
87- #[ inline( always) ]
88- fn length_encoding_bytes ( length : u64 ) -> Result < usize , Error > {
89- if !cfg ! ( fuzzing) {
90- debug_assert ! ( length <= MAX_LEN ) ;
91- }
92- if length > MAX_LEN {
93- return Err ( Error :: InvalidVectorLength ) ;
57+ impl Size for ContentLength {
58+ fn tls_serialized_len ( & self ) -> usize {
59+ self . 0 . tls_serialized_len ( )
9460 }
95-
96- Ok ( if length <= 0x3f {
97- 1
98- } else if length <= 0x3fff {
99- 2
100- } else if length <= 0x3fff_ffff {
101- 4
102- } else {
103- 8
104- } )
10561}
10662
107- #[ inline( always) ]
108- pub fn write_variable_length ( content_length : usize ) -> Result < Vec < u8 > , Error > {
109- let len_len = length_encoding_bytes ( content_length. try_into ( ) ?) ?;
110- if !cfg ! ( fuzzing) {
111- debug_assert ! ( len_len <= 8 , "Invalid vector len_len {len_len}" ) ;
112- }
113- if len_len > 8 {
114- return Err ( Error :: LibraryError ) ;
115- }
116- let mut length_bytes = vec ! [ 0u8 ; len_len] ;
117- match len_len {
118- 1 => length_bytes[ 0 ] = 0x00 ,
119- 2 => length_bytes[ 0 ] = 0x40 ,
120- 4 => length_bytes[ 0 ] = 0x80 ,
121- 8 => length_bytes[ 0 ] = 0xc0 ,
122- _ => {
123- if !cfg ! ( fuzzing) {
124- debug_assert ! ( false , "Invalid vector len_len {len_len}" ) ;
125- }
126- return Err ( Error :: InvalidVectorLength ) ;
127- }
128- }
129- let mut len = content_length;
130- for b in length_bytes. iter_mut ( ) . rev ( ) {
131- * b |= ( len & 0xFF ) as u8 ;
132- len >>= 8 ;
63+ impl DeserializeBytes for ContentLength {
64+ fn tls_deserialize_bytes ( bytes : & [ u8 ] ) -> Result < ( Self , & [ u8 ] ) , Error > {
65+ let ( value, remainder) = super :: TlsVarInt :: tls_deserialize_bytes ( bytes) ?;
66+ Ok ( ( Self ( value) , remainder) )
13367 }
134-
135- Ok ( length_bytes)
13668}
13769
13870impl < T : Size > Size for Vec < T > {
@@ -152,7 +84,9 @@ impl<T: Size> Size for &Vec<T> {
15284impl < T : DeserializeBytes > DeserializeBytes for Vec < T > {
15385 #[ inline( always) ]
15486 fn tls_deserialize_bytes ( bytes : & [ u8 ] ) -> Result < ( Self , & [ u8 ] ) , Error > {
155- let ( ( length, len_len) , mut remainder) = read_variable_length_bytes ( bytes) ?;
87+ let ( length, mut remainder) = ContentLength :: tls_deserialize_bytes ( bytes) ?;
88+ let len_len = length. 0 . bytes_len ( ) ;
89+ let length: usize = length. 0 . value ( ) . try_into ( ) ?;
15690
15791 if length == 0 {
15892 // An empty vector.
@@ -178,11 +112,12 @@ impl<T: SerializeBytes> SerializeBytes for &[T] {
178112 // This requires more computations but the other option would be to buffer
179113 // the entire content, which can end up requiring a lot of memory.
180114 let content_length = self . iter ( ) . fold ( 0 , |acc, e| acc + e. tls_serialized_len ( ) ) ;
181- let mut length = write_variable_length ( content_length) ?;
182- let len_len = length. len ( ) ;
115+ let length = ContentLength :: from_usize ( content_length) ?;
116+ let len_len = length. 0 . bytes_len ( ) ;
183117
184118 let mut out = Vec :: with_capacity ( content_length + len_len) ;
185- out. append ( & mut length) ;
119+ out. resize ( len_len, 0 ) ;
120+ length. 0 . write_bytes ( & mut out) ?;
186121
187122 // Serialize the elements
188123 for e in self . iter ( ) {
@@ -214,11 +149,13 @@ impl<T: Size> Size for &[T] {
214149 #[ inline( always) ]
215150 fn tls_serialized_len ( & self ) -> usize {
216151 let content_length = self . iter ( ) . fold ( 0 , |acc, e| acc + e. tls_serialized_len ( ) ) ;
217- let len_len = length_encoding_bytes ( content_length as u64 ) . unwrap_or ( {
218- // We can't do anything about the error unless we change the trait.
219- // Let's say there's no content for now.
220- 0
221- } ) ;
152+ let len_len = ContentLength :: from_usize ( content_length)
153+ . map ( |content_length| content_length. 0 . bytes_len ( ) )
154+ . unwrap_or ( {
155+ // We can't do anything about the error unless we change the trait.
156+ // Let's say there's no content for now.
157+ 0
158+ } ) ;
222159 content_length + len_len
223160 }
224161}
@@ -327,10 +264,12 @@ impl From<VLBytes> for Vec<u8> {
327264#[ inline( always) ]
328265fn tls_serialize_bytes_len ( bytes : & [ u8 ] ) -> usize {
329266 let content_length = bytes. len ( ) ;
330- let len_len = length_encoding_bytes ( content_length as u64 ) . unwrap_or ( {
331- // We can't do anything about the error. Let's say there's no content.
332- 0
333- } ) ;
267+ let len_len = ContentLength :: from_usize ( content_length)
268+ . map ( |content_length| content_length. 0 . bytes_len ( ) )
269+ . unwrap_or ( {
270+ // We can't do anything about the error. Let's say there's no content.
271+ 0
272+ } ) ;
334273 content_length + len_len
335274}
336275
@@ -344,22 +283,13 @@ impl Size for VLBytes {
344283impl DeserializeBytes for VLBytes {
345284 #[ inline( always) ]
346285 fn tls_deserialize_bytes ( bytes : & [ u8 ] ) -> Result < ( Self , & [ u8 ] ) , Error > {
347- let ( ( length, _) , remainder) = read_variable_length_bytes ( bytes) ?;
286+ let ( length, remainder) = ContentLength :: tls_deserialize_bytes ( bytes) ?;
287+ let length: usize = length. 0 . value ( ) . try_into ( ) ?;
288+
348289 if length == 0 {
349290 return Ok ( ( Self :: new ( vec ! [ ] ) , remainder) ) ;
350291 }
351292
352- if !cfg ! ( fuzzing) {
353- debug_assert ! (
354- length <= MAX_LEN as usize ,
355- "Trying to allocate {length} bytes. Only {MAX_LEN} allowed." ,
356- ) ;
357- }
358- if length > MAX_LEN as usize {
359- return Err ( Error :: DecodingError ( format ! (
360- "Trying to allocate {length} bytes. Only {MAX_LEN} allowed." ,
361- ) ) ) ;
362- }
363293 match remainder. get ( ..length) . ok_or ( Error :: EndOfStream ) {
364294 Ok ( vec) => Ok ( ( Self { vec : vec. to_vec ( ) } , & remainder[ length..] ) ) ,
365295 Err ( _e) => {
@@ -422,6 +352,19 @@ pub mod rw {
422352 use super :: * ;
423353 use crate :: { Deserialize , Serialize } ;
424354
355+ impl Deserialize for ContentLength {
356+ fn tls_deserialize < R : std:: io:: Read > ( bytes : & mut R ) -> Result < Self , Error > {
357+ ContentLength :: new ( crate :: TlsVarInt :: tls_deserialize ( bytes) ?)
358+ }
359+ }
360+
361+ impl Serialize for ContentLength {
362+ #[ inline( always) ]
363+ fn tls_serialize < W : std:: io:: Write > ( & self , writer : & mut W ) -> Result < usize , Error > {
364+ self . 0 . tls_serialize ( writer)
365+ }
366+ }
367+
425368 /// Read the length of a variable-length vector.
426369 ///
427370 /// This function assumes that the reader is at the start of a variable length
@@ -430,26 +373,9 @@ pub mod rw {
430373 /// The length and number of bytes read are returned.
431374 #[ inline]
432375 pub fn read_length < R : std:: io:: Read > ( bytes : & mut R ) -> Result < ( usize , usize ) , Error > {
433- // The length is encoded in the first two bits of the first byte.
434- let mut len_len_byte = [ 0u8 ; 1 ] ;
435- if bytes. read ( & mut len_len_byte) ? == 0 {
436- // There must be at least one byte for the length.
437- // If we don't even have a length byte, this is not a valid
438- // variable-length encoded vector.
439- return Err ( Error :: InvalidVectorLength ) ;
440- }
441- let len_len_byte = len_len_byte[ 0 ] ;
442-
443- let ( mut length, len_len) = calculate_length ( len_len_byte) ?;
444-
445- for _ in 1 ..len_len {
446- let mut next = [ 0u8 ; 1 ] ;
447- bytes. read_exact ( & mut next) ?;
448- length = ( length << 8 ) + usize:: from ( next[ 0 ] ) ;
449- }
450-
451- check_min_length ( length, len_len) ?;
452-
376+ let length = ContentLength :: tls_deserialize ( bytes) ?;
377+ let len_len = length. 0 . bytes_len ( ) ;
378+ let length: usize = length. 0 . value ( ) . try_into ( ) ?;
453379 Ok ( ( length, len_len) )
454380 }
455381
@@ -479,10 +405,7 @@ pub mod rw {
479405 writer : & mut W ,
480406 content_length : usize ,
481407 ) -> Result < usize , Error > {
482- let buf = super :: write_variable_length ( content_length) ?;
483- let buf_len = buf. len ( ) ;
484- writer. write_all ( & buf) ?;
485- Ok ( buf_len)
408+ ContentLength :: from_usize ( content_length) ?. tls_serialize ( writer)
486409 }
487410
488411 impl < T : Serialize + std:: fmt:: Debug > Serialize for Vec < T > {
@@ -538,19 +461,7 @@ mod rw_bytes {
538461 // large and write it out.
539462 let content_length = bytes. len ( ) ;
540463
541- if !cfg ! ( fuzzing) {
542- debug_assert ! (
543- content_length as u64 <= MAX_LEN ,
544- "Vector can't be encoded. It's too large. {content_length} >= {MAX_LEN}" ,
545- ) ;
546- }
547- if content_length as u64 > MAX_LEN {
548- return Err ( Error :: InvalidVectorLength ) ;
549- }
550-
551- let length_bytes = write_variable_length ( content_length) ?;
552- let len_len = length_bytes. len ( ) ;
553- writer. write_all ( & length_bytes) ?;
464+ let len_len = ContentLength :: from_usize ( content_length) ?. tls_serialize ( writer) ?;
554465
555466 // Now serialize the elements
556467 writer. write_all ( bytes) ?;
@@ -574,24 +485,14 @@ mod rw_bytes {
574485
575486 impl Deserialize for VLBytes {
576487 fn tls_deserialize < R : std:: io:: Read > ( bytes : & mut R ) -> Result < Self , Error > {
577- let ( length, _) = rw:: read_length ( bytes) ?;
578- if length == 0 {
488+ let length = ContentLength :: tls_deserialize ( bytes) ?;
489+
490+ if length. 0 . value ( ) == 0 {
579491 return Ok ( Self :: new ( vec ! [ ] ) ) ;
580492 }
581493
582- if !cfg ! ( fuzzing) {
583- debug_assert ! (
584- length <= MAX_LEN as usize ,
585- "Trying to allocate {length} bytes. Only {MAX_LEN} allowed." ,
586- ) ;
587- }
588- if length > MAX_LEN as usize {
589- return Err ( Error :: DecodingError ( format ! (
590- "Trying to allocate {length} bytes. Only {MAX_LEN} allowed." ,
591- ) ) ) ;
592- }
593494 let mut result = Self {
594- vec : vec ! [ 0u8 ; length] ,
495+ vec : vec ! [ 0u8 ; length. 0 . value ( ) . try_into ( ) ? ] ,
595496 } ;
596497 bytes. read_exact ( result. vec . as_mut_slice ( ) ) ?;
597498 Ok ( result)
@@ -682,7 +583,7 @@ impl<'a> Arbitrary<'a> for VLBytes {
682583 // We generate an arbitrary `Vec<u8>` ...
683584 let mut vec = Vec :: arbitrary ( u) ?;
684585 // ... and truncate it to `MAX_LEN`.
685- vec. truncate ( MAX_LEN as usize ) ;
586+ vec. truncate ( ContentLength :: MAX as usize ) ;
686587 // We probably won't exceed `MAX_LEN` in practice, e.g., during fuzzing,
687588 // but better make sure that we generate valid instances.
688589
0 commit comments