@@ -121,7 +121,8 @@ namespace ctranslate2 {
121
121
const Padder* input_padder,
122
122
const Padder* memory_padder,
123
123
bool return_normalized_attention,
124
- StorageView* position_bias) const {
124
+ StorageView* position_bias,
125
+ dim_t offset) const {
125
126
PROFILE (" TransformerDecoderLayer" );
126
127
127
128
const DataType dtype = input.dtype ();
@@ -149,7 +150,8 @@ namespace ctranslate2 {
149
150
input_padder,
150
151
input_padder,
151
152
true ,
152
- position_bias);
153
+ position_bias,
154
+ offset);
153
155
154
156
if (_post_attention_layer_norm)
155
157
(*_post_attention_layer_norm)(input, hidden);
@@ -172,7 +174,8 @@ namespace ctranslate2 {
172
174
input_padder,
173
175
input_padder,
174
176
true ,
175
- position_bias);
177
+ position_bias,
178
+ offset);
176
179
177
180
StorageView context (dtype, device);
178
181
if (_encoder_attention) {
@@ -330,7 +333,8 @@ namespace ctranslate2 {
330
333
? nullptr
331
334
: build_position_encoder(model, scope + " /position_encodings" , _embeddings))
332
335
, _with_encoder_attention(_layers.front()->has_cross_attention())
333
- , _proj(model, scope + " /projection" ) {
336
+ , _proj(model, scope + " /projection" )
337
+ , _sliding_window(model.get_attribute_with_default<int32_t >(scope + " /sliding_window" , 0 )) {
334
338
335
339
dim_t alignment_layer = (
336
340
model.get_attribute_with_default <int32_t >(scope + " /alignment_layer" , -1 ));
@@ -467,7 +471,13 @@ namespace ctranslate2 {
467
471
(*_layernorm_embedding)(layer_in, layer_in);
468
472
469
473
const dim_t batch_size = layer_in.dim (0 );
470
- const dim_t max_time = layer_in.dim (1 );
474
+ dim_t max_time;
475
+
476
+ if (_sliding_window > 0 && layer_in.dim (1 ) > _sliding_window) {
477
+ max_time = _sliding_window;
478
+ } else
479
+ max_time = layer_in.dim (1 );
480
+
471
481
const bool allow_padding_removal = Padder::allow_padding_removal (_device, _compute_type);
472
482
473
483
std::unique_ptr<const Padder> input_padder;
@@ -479,14 +489,14 @@ namespace ctranslate2 {
479
489
lengths = input_lengths.get ();
480
490
}
481
491
492
+ bool multi_query = _layers.front ()->get_self_attention ().multi_query ();
493
+
482
494
if (lengths) {
483
495
if (allow_padding_removal) {
484
496
input_padder = std::make_unique<Padder>(*lengths, max_time);
485
497
input_padder->remove_padding (layer_in);
486
498
}
487
499
488
- const bool multi_query = _layers.front ()->get_self_attention ().multi_query ();
489
-
490
500
StorageView lengths_mask = layers::MultiHeadAttention::prepare_length_mask (
491
501
*lengths,
492
502
_num_heads,
@@ -531,47 +541,86 @@ namespace ctranslate2 {
531
541
532
542
StorageView position_bias (dtype, device);
533
543
534
- for (size_t l = 0 ; l < _layers.size (); ++l) {
535
- StorageView* cached_self_attn_keys = nullptr ;
536
- StorageView* cached_self_attn_values = nullptr ;
537
- StorageView* cached_attn_keys = nullptr ;
538
- StorageView* cached_attn_values = nullptr ;
539
-
540
- if (step >= 0 ) {
541
- const std::string l_str = std::to_string (l);
542
- cached_self_attn_keys = &state.at (" self_keys_" + l_str);
543
- cached_self_attn_values = &state.at (" self_values_" + l_str);
544
- if (_with_encoder_attention) {
545
- cached_attn_keys = &state.at (" memory_keys_" + l_str);
546
- cached_attn_values = &state.at (" memory_values_" + l_str);
547
- }
544
+ std::vector<StorageView> layer_ins;
545
+
546
+ while (true ) {
547
+ dim_t prompt_size = layer_in.dim (1 );
548
+ if (_sliding_window == 0 || prompt_size <= _sliding_window) {
549
+ layer_ins.push_back (std::move (layer_in));
550
+ break ;
548
551
}
552
+ if (layer_in.dim (1 ) > _sliding_window) {
553
+ StorageView tmp (dtype, device);
554
+ const ops::Split split_op (1 , {_sliding_window, prompt_size - _sliding_window});
555
+ split_op (layer_in, tmp, layer_in);
556
+ layer_ins.push_back (std::move (tmp));
557
+ }
558
+ }
549
559
550
- std::unique_ptr<StorageView> heads_to_select = get_layer_alignment_heads (l, batch_size);
551
- std::unique_ptr<StorageView> layer_attention;
552
- if (attention && heads_to_select)
553
- layer_attention = std::make_unique<StorageView>(dtype, device);
560
+ for (size_t i = 0 ; i < layer_ins.size (); ++i) {
561
+ auto layer_in_chunk = layer_ins[i];
562
+ for (size_t l = 0 ; l < _layers.size (); ++l) {
563
+ StorageView* cached_self_attn_keys = nullptr ;
564
+ StorageView* cached_self_attn_values = nullptr ;
565
+ StorageView* cached_attn_keys = nullptr ;
566
+ StorageView* cached_attn_values = nullptr ;
567
+
568
+ if (step >= 0 ) {
569
+ const std::string l_str = std::to_string (l);
570
+ cached_self_attn_keys = &state.at (" self_keys_" + l_str);
571
+ cached_self_attn_values = &state.at (" self_values_" + l_str);
572
+ if (_with_encoder_attention) {
573
+ cached_attn_keys = &state.at (" memory_keys_" + l_str);
574
+ cached_attn_values = &state.at (" memory_values_" + l_str);
575
+ }
576
+ }
554
577
555
- (*_layers[l])(layer_in,
556
- input_lengths_mask.get (),
557
- memory,
558
- memory_lengths_mask.get (),
559
- cached_self_attn_keys,
560
- cached_self_attn_values,
561
- cached_attn_keys,
562
- cached_attn_values,
563
- layer_out,
564
- layer_attention.get (),
565
- input_padder.get (),
566
- memory_padder.get (),
567
- return_normalized_attention (),
568
- &position_bias);
569
- layer_in = std::move (layer_out);
578
+ std::unique_ptr<StorageView> heads_to_select = get_layer_alignment_heads (l, batch_size);
579
+ std::unique_ptr<StorageView> layer_attention;
580
+ if (attention && heads_to_select)
581
+ layer_attention = std::make_unique<StorageView>(dtype, device);
582
+
583
+ dim_t offset = _sliding_window * i + step;
584
+ offset = offset < 0 ? 0 : offset;
585
+ if (i > 0 ) {
586
+ auto max_tokens = _sliding_window + layer_in_chunk.dim (1 );
587
+ StorageView tmp_lengths = StorageView (Shape{layer_in_chunk.dim (0 )}, int32_t (max_tokens), device);
588
+ StorageView lengths_mask = layers::MultiHeadAttention::prepare_length_mask (
589
+ tmp_lengths,
590
+ _num_heads,
591
+ max_tokens,
592
+ /* mask_future=*/ true ,
593
+ multi_query);
594
+
595
+ const ops::Slide slide_lengths_op (2 , _sliding_window, layer_in_chunk.dim (1 ));
596
+ // reuse tmp_lengths
597
+ slide_lengths_op (lengths_mask, tmp_lengths);
598
+ input_lengths_mask = std::make_unique<StorageView>(std::move (tmp_lengths));
599
+ }
570
600
571
- if (layer_attention) {
572
- alignment_heads.emplace_back (dtype, device);
573
- ops::Gather (1 , 1 )(*layer_attention, *heads_to_select, alignment_heads.back ());
601
+ (*_layers[l])(layer_in_chunk,
602
+ input_lengths_mask.get (),
603
+ memory,
604
+ memory_lengths_mask.get (),
605
+ cached_self_attn_keys,
606
+ cached_self_attn_values,
607
+ cached_attn_keys,
608
+ cached_attn_values,
609
+ layer_out,
610
+ layer_attention.get (),
611
+ input_padder.get (),
612
+ memory_padder.get (),
613
+ return_normalized_attention (),
614
+ &position_bias,
615
+ offset);
616
+ layer_in_chunk = std::move (layer_out);
617
+
618
+ if (layer_attention) {
619
+ alignment_heads.emplace_back (dtype, device);
620
+ ops::Gather (1 , 1 )(*layer_attention, *heads_to_select, alignment_heads.back ());
621
+ }
574
622
}
623
+ layer_in = std::move (layer_in_chunk);
575
624
}
576
625
577
626
if (step == 0 ) {
0 commit comments