5353
5454#include  "datatypes.h" 
5555
56+ #if  defined(__SSE2__ )
57+ #include  <tmmintrin.h> 
58+ #endif 
59+ 
60+ #if  defined(_MSC_VER )
61+ #define  ALIGNMENT (N ) __declspec(align(N))
62+ #else 
63+ #define  ALIGNMENT (N ) __attribute__((aligned(N)))
64+ #endif 
65+ 
5666/* 
5767 * bit_string is a buffer that is used to hold output strings, e.g. 
5868 * for printing. 
@@ -123,6 +133,9 @@ char *v128_bit_string(v128_t *x)
123133
124134void  v128_copy_octet_string (v128_t  * x , const  uint8_t  s [16 ])
125135{
136+ #if  defined(__SSE2__ )
137+     _mm_storeu_si128 ((__m128i  * )(x ), _mm_loadu_si128 ((const  __m128i  * )(s )));
138+ #else 
126139#ifdef  ALIGNMENT_32BIT_REQUIRED 
127140    if  ((((uint32_t )& s [0 ]) &  0x3 ) !=  0 )
128141#endif 
@@ -151,8 +164,67 @@ void v128_copy_octet_string(v128_t *x, const uint8_t s[16])
151164        v128_copy (x , v );
152165    }
153166#endif 
167+ #endif  /* defined(__SSE2__) */ 
168+ }
169+ 
170+ #if  defined(__SSSE3__ )
171+ 
172+ /* clang-format off */ 
173+ 
174+ ALIGNMENT (16 )
175+ static  const  uint8_t  right_shift_masks [5 ][16 ] =  {
176+     { 0u , 1u , 2u , 3u ,  4u , 5u , 6u , 7u ,
177+       8u , 9u , 10u , 11u ,  12u , 13u , 14u , 15u  },
178+     { 0x80 , 0x80 , 0x80 , 0x80 ,  0u , 1u , 2u , 3u ,
179+       4u , 5u , 6u , 7u ,  8u , 9u , 10u , 11u  },
180+     { 0x80 , 0x80 , 0x80 , 0x80 ,  0x80 , 0x80 , 0x80 , 0x80 ,
181+       0u , 1u , 2u , 3u ,  4u , 5u , 6u , 7u  },
182+     { 0x80 , 0x80 , 0x80 , 0x80 ,  0x80 , 0x80 , 0x80 , 0x80 ,
183+       0x80 , 0x80 , 0x80 , 0x80 ,  0u , 1u , 2u , 3u  },
184+     /* needed for bitvector_left_shift */ 
185+     { 0x80 , 0x80 , 0x80 , 0x80 ,  0x80 , 0x80 , 0x80 , 0x80 ,
186+       0x80 , 0x80 , 0x80 , 0x80 ,  0x80 , 0x80 , 0x80 , 0x80  }
187+ };
188+ 
189+ ALIGNMENT (16 )
190+ static  const  uint8_t  left_shift_masks [4 ][16 ] =  {
191+     { 0u , 1u , 2u , 3u ,  4u , 5u , 6u , 7u ,
192+       8u , 9u , 10u , 11u ,  12u , 13u , 14u , 15u  },
193+     { 4u , 5u , 6u , 7u ,  8u , 9u , 10u , 11u ,
194+       12u , 13u , 14u , 15u ,  0x80 , 0x80 , 0x80 , 0x80  },
195+     { 8u , 9u , 10u , 11u ,  12u , 13u , 14u , 15u ,
196+       0x80 , 0x80 , 0x80 , 0x80 ,  0x80 , 0x80 , 0x80 , 0x80  },
197+     { 12u , 13u , 14u , 15u ,  0x80 , 0x80 , 0x80 , 0x80 ,
198+       0x80 , 0x80 , 0x80 , 0x80 ,  0x80 , 0x80 , 0x80 , 0x80  }
199+ };
200+ 
201+ /* clang-format on */ 
202+ 
203+ void  v128_left_shift (v128_t  * x , int  shift )
204+ {
205+     if  (shift  >  127 ) {
206+         v128_set_to_zero (x );
207+         return ;
208+     }
209+ 
210+     const  int  base_index  =  shift  >> 5 ;
211+     const  int  bit_index  =  shift  &  31 ;
212+ 
213+     __m128i  mm  =  _mm_loadu_si128 ((const  __m128i  * )x );
214+     __m128i  mm_shift_right  =  _mm_cvtsi32_si128 (bit_index );
215+     __m128i  mm_shift_left  =  _mm_cvtsi32_si128 (32  -  bit_index );
216+     mm  =  _mm_shuffle_epi8 (mm , ((const  __m128i  * )left_shift_masks )[base_index ]);
217+ 
218+     __m128i  mm1  =  _mm_srl_epi32 (mm , mm_shift_right );
219+     __m128i  mm2  =  _mm_sll_epi32 (mm , mm_shift_left );
220+     mm2  =  _mm_srli_si128 (mm2 , 4 );
221+     mm1  =  _mm_or_si128 (mm1 , mm2 );
222+ 
223+     _mm_storeu_si128 ((__m128i  * )x , mm1 );
154224}
155225
226+ #else  /* defined(__SSSE3__) */ 
227+ 
156228void  v128_left_shift (v128_t  * x , int  shift )
157229{
158230    int  i ;
@@ -179,6 +251,8 @@ void v128_left_shift(v128_t *x, int shift)
179251        x -> v32 [i ] =  0 ;
180252}
181253
254+ #endif  /* defined(__SSSE3__) */ 
255+ 
182256/* functions manipulating bitvector_t */ 
183257
184258int  bitvector_alloc (bitvector_t  * v , unsigned long  length )
@@ -190,6 +264,7 @@ int bitvector_alloc(bitvector_t *v, unsigned long length)
190264        (length  +  bits_per_word  -  1 ) &  ~(unsigned long )((bits_per_word  -  1 ));
191265
192266    l  =  length  / bits_per_word  *  bytes_per_word ;
267+     l  =  (l  +  15ul ) &  ~15ul ;
193268
194269    /* allocate memory, then set parameters */ 
195270    if  (l  ==  0 ) {
@@ -225,6 +300,73 @@ void bitvector_set_to_zero(bitvector_t *x)
225300    memset (x -> word , 0 , x -> length  >> 3 );
226301}
227302
303+ #if  defined(__SSSE3__ )
304+ 
305+ void  bitvector_left_shift (bitvector_t  * x , int  shift )
306+ {
307+     if  ((uint32_t )shift  >= x -> length ) {
308+         bitvector_set_to_zero (x );
309+         return ;
310+     }
311+ 
312+     const  int  base_index  =  shift  >> 5 ;
313+     const  int  bit_index  =  shift  &  31 ;
314+     const  int  vec_length  =  (x -> length  +  127u ) >> 7 ;
315+     const  __m128i  * from  =  ((const  __m128i  * )x -> word ) +  (base_index  >> 2 );
316+     __m128i  * to  =  (__m128i  * )x -> word ;
317+     __m128i  * const  end  =  to  +  vec_length ;
318+ 
319+     __m128i  mm_right_shift_mask  = 
320+         ((const  __m128i  * )right_shift_masks )[4u  -  (base_index  &  3u )];
321+     __m128i  mm_left_shift_mask  = 
322+         ((const  __m128i  * )left_shift_masks )[base_index  &  3u ];
323+     __m128i  mm_shift_right  =  _mm_cvtsi32_si128 (bit_index );
324+     __m128i  mm_shift_left  =  _mm_cvtsi32_si128 (32  -  bit_index );
325+ 
326+     __m128i  mm_current  =  _mm_loadu_si128 (from );
327+     __m128i  mm_current_r  =  _mm_srl_epi32 (mm_current , mm_shift_right );
328+     __m128i  mm_current_l  =  _mm_sll_epi32 (mm_current , mm_shift_left );
329+ 
330+     while  ((end  -  from ) >= 2 ) {
331+         ++ from ;
332+         __m128i  mm_next  =  _mm_loadu_si128 (from );
333+ 
334+         __m128i  mm_next_r  =  _mm_srl_epi32 (mm_next , mm_shift_right );
335+         __m128i  mm_next_l  =  _mm_sll_epi32 (mm_next , mm_shift_left );
336+         mm_current_l  =  _mm_alignr_epi8 (mm_next_l , mm_current_l , 4 );
337+         mm_current  =  _mm_or_si128 (mm_current_r , mm_current_l );
338+ 
339+         mm_current  =  _mm_shuffle_epi8 (mm_current , mm_left_shift_mask );
340+ 
341+         __m128i  mm_temp_next  =  _mm_srli_si128 (mm_next_l , 4 );
342+         mm_temp_next  =  _mm_or_si128 (mm_next_r , mm_temp_next );
343+ 
344+         mm_temp_next  =  _mm_shuffle_epi8 (mm_temp_next , mm_right_shift_mask );
345+         mm_current  =  _mm_or_si128 (mm_temp_next , mm_current );
346+ 
347+         _mm_storeu_si128 (to , mm_current );
348+         ++ to ;
349+ 
350+         mm_current_r  =  mm_next_r ;
351+         mm_current_l  =  mm_next_l ;
352+     }
353+ 
354+     mm_current_l  =  _mm_srli_si128 (mm_current_l , 4 );
355+     mm_current  =  _mm_or_si128 (mm_current_r , mm_current_l );
356+ 
357+     mm_current  =  _mm_shuffle_epi8 (mm_current , mm_left_shift_mask );
358+ 
359+     _mm_storeu_si128 (to , mm_current );
360+     ++ to ;
361+ 
362+     while  (to  <  end ) {
363+         _mm_storeu_si128 (to , _mm_setzero_si128 ());
364+         ++ to ;
365+     }
366+ }
367+ 
368+ #else  /* defined(__SSSE3__) */ 
369+ 
228370void  bitvector_left_shift (bitvector_t  * x , int  shift )
229371{
230372    int  i ;
@@ -253,16 +395,73 @@ void bitvector_left_shift(bitvector_t *x, int shift)
253395        x -> word [i ] =  0 ;
254396}
255397
398+ #endif  /* defined(__SSSE3__) */ 
399+ 
256400int  srtp_octet_string_is_eq (uint8_t  * a , uint8_t  * b , int  len )
257401{
258-     uint8_t  * end  =  b  +  len ;
259-     uint8_t  accumulator  =  0 ;
260- 
261402    /* 
262403     * We use this somewhat obscure implementation to try to ensure the running 
263404     * time only depends on len, even accounting for compiler optimizations. 
264405     * The accumulator ends up zero iff the strings are equal. 
265406     */ 
407+     uint8_t  * end  =  b  +  len ;
408+     uint32_t  accumulator  =  0 ;
409+ 
410+ #if  defined(__SSE2__ )
411+     __m128i  mm_accumulator1  =  _mm_setzero_si128 ();
412+     __m128i  mm_accumulator2  =  _mm_setzero_si128 ();
413+     for  (int  i  =  0 , n  =  len  >> 5 ; i  <  n ; ++ i , a  +=  32 , b  +=  32 ) {
414+         __m128i  mm_a1  =  _mm_loadu_si128 ((const  __m128i  * )a );
415+         __m128i  mm_b1  =  _mm_loadu_si128 ((const  __m128i  * )b );
416+         __m128i  mm_a2  =  _mm_loadu_si128 ((const  __m128i  * )(a  +  16 ));
417+         __m128i  mm_b2  =  _mm_loadu_si128 ((const  __m128i  * )(b  +  16 ));
418+         mm_a1  =  _mm_xor_si128 (mm_a1 , mm_b1 );
419+         mm_a2  =  _mm_xor_si128 (mm_a2 , mm_b2 );
420+         mm_accumulator1  =  _mm_or_si128 (mm_accumulator1 , mm_a1 );
421+         mm_accumulator2  =  _mm_or_si128 (mm_accumulator2 , mm_a2 );
422+     }
423+ 
424+     mm_accumulator1  =  _mm_or_si128 (mm_accumulator1 , mm_accumulator2 );
425+ 
426+     if  ((end  -  b ) >= 16 ) {
427+         __m128i  mm_a1  =  _mm_loadu_si128 ((const  __m128i  * )a );
428+         __m128i  mm_b1  =  _mm_loadu_si128 ((const  __m128i  * )b );
429+         mm_a1  =  _mm_xor_si128 (mm_a1 , mm_b1 );
430+         mm_accumulator1  =  _mm_or_si128 (mm_accumulator1 , mm_a1 );
431+         a  +=  16 ;
432+         b  +=  16 ;
433+     }
434+ 
435+     mm_accumulator1  =  _mm_or_si128 (
436+         mm_accumulator1 , _mm_unpackhi_epi64 (mm_accumulator1 , mm_accumulator1 ));
437+     mm_accumulator1  = 
438+         _mm_or_si128 (mm_accumulator1 , _mm_srli_si128 (mm_accumulator1 , 4 ));
439+     accumulator  =  _mm_cvtsi128_si32 (mm_accumulator1 );
440+ #else 
441+     uint32_t  accumulator2  =  0 ;
442+     for  (int  i  =  0 , n  =  len  >> 3 ; i  <  n ; ++ i , a  +=  8 , b  +=  8 ) {
443+         uint32_t  a_val1 , b_val1 ;
444+         uint32_t  a_val2 , b_val2 ;
445+         memcpy (& a_val1 , a , sizeof (a_val1 ));
446+         memcpy (& b_val1 , b , sizeof (b_val1 ));
447+         memcpy (& a_val2 , a  +  4 , sizeof (a_val2 ));
448+         memcpy (& b_val2 , b  +  4 , sizeof (b_val2 ));
449+         accumulator  |= a_val1  ^ b_val1 ;
450+         accumulator2  |= a_val2  ^ b_val2 ;
451+     }
452+ 
453+     accumulator  |= accumulator2 ;
454+ 
455+     if  ((end  -  b ) >= 4 ) {
456+         uint32_t  a_val , b_val ;
457+         memcpy (& a_val , a , sizeof (a_val ));
458+         memcpy (& b_val , b , sizeof (b_val ));
459+         accumulator  |= a_val  ^ b_val ;
460+         a  +=  4 ;
461+         b  +=  4 ;
462+     }
463+ #endif 
464+ 
266465    while  (b  <  end )
267466        accumulator  |= (* a ++  ^ * b ++ );
268467
@@ -272,9 +471,14 @@ int srtp_octet_string_is_eq(uint8_t *a, uint8_t *b, int len)
272471
273472void  srtp_cleanse (void  * s , size_t  len )
274473{
474+ #if  defined(__GNUC__ )
475+     memset (s , 0 , len );
476+     __asm__ __volatile__(""  : : "r" (s ) : "memory" );
477+ #else 
275478    volatile  unsigned char   * p  =  (volatile  unsigned char   * )s ;
276479    while  (len -- )
277480        * p ++  =  0 ;
481+ #endif 
278482}
279483
280484void  octet_string_set_to_zero (void  * s , size_t  len )
0 commit comments