Skip to content

Commit 8bce69f

Browse files
authoredJun 25, 2024
Merge pull request #1 from reinfer/dev/prashanth/return_logits
Return logits and add instructions to build wheel locally
2 parents 174e4ee + 0a49019 commit 8bce69f

32 files changed

+505
-107
lines changed
 

‎CHANGELOG.md

+25
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,31 @@
44

55
### Fixes and improvements
66

7+
## [v3.24.0](https://github.com/OpenNMT/CTranslate2/releases/tag/v3.23.0) (2024-01-08)
8+
9+
### New features
10+
* Support of new option offset to ignore token score of special tokens
11+
12+
## [v3.23.0](https://github.com/OpenNMT/CTranslate2/releases/tag/v3.23.0) (2023-12-05)
13+
14+
### New features
15+
* Support Phi model
16+
17+
### Fixes and improvements
18+
* Fix the conversion for whisper without the "alignment_heads" in the "generation_config.json"
19+
* Fix forward batch
20+
21+
## [v3.22.0](https://github.com/OpenNMT/CTranslate2/releases/tag/v3.22.0) (2023-11-22)
22+
23+
### New features
24+
* Support "sliding window" and "chunking input" for Mistral
25+
26+
### Fixes and improvements
27+
* Take into account the "generation_config.json" and fix "lang_ids" getter for Whisper converter
28+
* Accept callback even on "generate_tokens" method
29+
* Fix iomp5 linking with latest Intel OpenAPI on Ubuntu
30+
* Fixed "decoder_start_token_id" for T5
31+
732
## [v3.21.0](https://github.com/OpenNMT/CTranslate2/releases/tag/v3.21.0) (2023-11-09)
833

934
### New features

‎CMakeLists.txt

+4-2
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ set(SOURCES
134134
src/ops/bias_add.cc
135135
src/ops/bias_add_cpu.cc
136136
src/ops/concat.cc
137-
src/ops/concat_split_cpu.cc
137+
src/ops/concat_split_slide_cpu.cc
138138
src/ops/conv1d.cc
139139
src/ops/conv1d_cpu.cc
140140
src/ops/cos.cc
@@ -168,6 +168,7 @@ set(SOURCES
168168
src/ops/softmax.cc
169169
src/ops/softmax_cpu.cc
170170
src/ops/split.cc
171+
src/ops/slide.cc
171172
src/ops/sub.cc
172173
src/ops/swish.cc
173174
src/ops/tanh.cc
@@ -263,6 +264,7 @@ if(NOT OPENMP_RUNTIME STREQUAL "NONE")
263264
${INTEL_ROOT}/oneAPI/compiler/latest/windows/compiler/lib/intel64_win
264265
${INTEL_ROOT}/oneapi/compiler/latest/linux/compiler/lib/intel64_lin
265266
${INTEL_ROOT}/oneapi/compiler/latest/mac/compiler/lib
267+
${INTEL_ROOT}/oneapi/compiler/latest/lib
266268
)
267269
if(IOMP5_LIBRARY)
268270
list(APPEND LIBRARIES ${IOMP5_LIBRARY})
@@ -505,7 +507,7 @@ if (WITH_CUDA)
505507
src/cuda/utils.cc
506508
src/ops/alibi_add_gpu.cu
507509
src/ops/bias_add_gpu.cu
508-
src/ops/concat_split_gpu.cu
510+
src/ops/concat_split_slide_gpu.cu
509511
src/ops/conv1d_gpu.cu
510512
src/ops/dequantize_gpu.cu
511513
src/ops/gather_gpu.cu

‎README.md

+17
Original file line numberDiff line numberDiff line change
@@ -123,3 +123,20 @@ Executed with CUDA 11 on a [*g5.xlarge*](https://aws.amazon.com/ec2/instance-typ
123123
* [Documentation](https://opennmt.net/CTranslate2)
124124
* [Forum](https://forum.opennmt.net)
125125
* [Gitter](https://gitter.im/OpenNMT/CTranslate2)
126+
127+
128+
# To build locally
129+
130+
cd CTranslate2
131+
mkdir build
132+
sudo cmake -DCMAKE_INSTALL_PREFIX=/usr/local -DWITH_CUDA=ON -DWITH_CUDNN=ON -DWITH_MKL=ON -DOPENMP_RUNTIME=COMP -DCMAKE_BUILD_TYPE=Release ..
133+
sudo make -j4
134+
sudo make install
135+
sudo ldconfig
136+
137+
# LD_LIBRARY_PATH should contain the ctranslate install path
138+
139+
# Build python wheel
140+
cd python
141+
python setup.py bdist_wheel --dist-dir <path_to_wheel_dir>
142+
auditwheel repair --plat manylinux_2_34_x86_64 <path_to_wheel>

‎docker/Dockerfile

+5-5
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
FROM nvidia/cuda:11.2.2-cudnn8-devel-ubuntu20.04 as builder
1+
FROM nvidia/cuda:12.2.2-cudnn8-devel-ubuntu20.04 as builder
22

33
RUN apt-get update && \
44
apt-get install -y --no-install-recommends \
@@ -66,18 +66,18 @@ RUN cd python && \
6666
python3 -m pip --no-cache-dir install -r install_requirements.txt && \
6767
python3 setup.py bdist_wheel --dist-dir $CTRANSLATE2_ROOT
6868

69-
FROM nvidia/cuda:11.2.2-base-ubuntu20.04
69+
FROM nvidia/cuda:12.2.2-base-ubuntu20.04
7070

7171
# We remove the cuda-compat package because it conflicts with the CUDA Enhanced Compatibility.
7272
# See e.g. https://github.com/NVIDIA/nvidia-docker/issues/1515
7373
RUN apt-get update && \
7474
apt-get install -y --no-install-recommends \
75-
libcublas-11-2 \
76-
libcudnn8=8.1.1.33-1+cuda11.2 \
75+
libcublas-12-2 \
76+
libcudnn8=8.9.7.29-1+cuda12.2 \
7777
libgomp1 \
7878
python3-pip \
7979
&& \
80-
apt-get purge -y cuda-compat-11-2 && \
80+
apt-get purge -y cuda-compat-12-2 && \
8181
apt-get clean && \
8282
rm -rf /var/lib/apt/lists/*
8383

‎docker/build_all.sh

+1-1
Original file line numberDiff line numberDiff line change
@@ -42,4 +42,4 @@ build()
4242
fi
4343
}
4444

45-
build Dockerfile ubuntu20.04-cuda11.2
45+
build Dockerfile ubuntu20.04-cuda12.2

‎include/ctranslate2/decoding.h

+2
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ namespace ctranslate2 {
1515
std::vector<std::vector<size_t>> hypotheses;
1616
std::vector<float> scores;
1717
std::vector<std::vector<std::vector<float>>> attention;
18+
std::vector<float> logits;
19+
// (max_decoding_steps)
1820
};
1921

2022
struct DecodingStepResult {

‎include/ctranslate2/layers/attention.h

+2-1
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,8 @@ namespace ctranslate2 {
3535
const Padder* queries_padder = nullptr,
3636
const Padder* values_padder = nullptr,
3737
bool return_normalized_attention = true,
38-
StorageView* position_bias = nullptr) const;
38+
StorageView* position_bias = nullptr,
39+
dim_t offset = 0) const;
3940

4041
bool has_positional_embeddings() const {
4142
return _relative_position_keys || _relative_attention_bias || _rotary_embeddings || _alibi;

‎include/ctranslate2/layers/transformer.h

+3-1
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,8 @@ namespace ctranslate2 {
9191
const Padder* input_padder = nullptr,
9292
const Padder* memory_padder = nullptr,
9393
bool return_normalized_attention = true,
94-
StorageView* position_bias = nullptr) const;
94+
StorageView* position_bias = nullptr,
95+
dim_t offset = 0) const;
9596

9697
DataType output_type() const override {
9798
return _ff.output_type();
@@ -209,6 +210,7 @@ namespace ctranslate2 {
209210
std::vector<std::vector<dim_t>> _alignment_heads;
210211
bool _average_alignment_heads;
211212
Dense _proj;
213+
const dim_t _sliding_window;
212214
};
213215

214216
}

‎include/ctranslate2/ops/ops.h

+1
Original file line numberDiff line numberDiff line change
@@ -36,3 +36,4 @@
3636
#include "median_filter.h"
3737
#include "rotary.h"
3838
#include "alibi_add.h"
39+
#include "slide.h"

‎include/ctranslate2/ops/slide.h

+26
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
#pragma once
2+
3+
#include "op.h"
4+
5+
namespace ctranslate2 {
6+
namespace ops {
7+
8+
class Slide : public Op {
9+
public:
10+
Slide(dim_t axis, const dim_t& index, const dim_t& size, bool no_copy = false);
11+
12+
void operator()(const StorageView& input, StorageView& output) const;
13+
private:
14+
dim_t _axis;
15+
dim_t _index;
16+
dim_t _size;
17+
bool _no_copy;
18+
19+
void check_arguments() const;
20+
21+
template <Device D, typename T>
22+
void compute(const StorageView& input, StorageView& output, const dim_t& index) const;
23+
};
24+
25+
}
26+
}

‎include/ctranslate2/scoring.h

+3-1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ namespace ctranslate2 {
1212
struct ScoringOptions {
1313
// Truncate the inputs after this many tokens (set 0 to disable truncation).
1414
size_t max_input_length = 1024;
15+
dim_t offset = 0;
1516
};
1617

1718
struct ScoringResult {
@@ -38,6 +39,7 @@ namespace ctranslate2 {
3839
layers::DecoderState& state,
3940
const std::vector<std::vector<size_t>>& sequences,
4041
const Vocabulary& vocabulary,
41-
const dim_t preferred_size_multiple = 1);
42+
const dim_t preferred_size_multiple = 1,
43+
const dim_t offset=0);
4244

4345
}

‎include/ctranslate2/translation.h

+5-1
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ namespace ctranslate2 {
8787
std::vector<std::vector<std::string>> hypotheses;
8888
std::vector<float> scores;
8989
std::vector<std::vector<std::vector<float>>> attention;
90+
std::vector<float> logits;
9091

9192
TranslationResult(std::vector<std::vector<std::string>> hypotheses_)
9293
: hypotheses(std::move(hypotheses_))
@@ -95,10 +96,12 @@ namespace ctranslate2 {
9596

9697
TranslationResult(std::vector<std::vector<std::string>> hypotheses_,
9798
std::vector<float> scores_,
98-
std::vector<std::vector<std::vector<float>>> attention_)
99+
std::vector<std::vector<std::vector<float>>> attention_,
100+
std::vector<float> logits_)
99101
: hypotheses(std::move(hypotheses_))
100102
, scores(std::move(scores_))
101103
, attention(std::move(attention_))
104+
, logits(std::move(logits_))
102105
{
103106
}
104107

@@ -109,6 +112,7 @@ namespace ctranslate2 {
109112
: hypotheses(num_hypotheses)
110113
, scores(with_score ? num_hypotheses : 0, static_cast<float>(0))
111114
, attention(with_attention ? num_hypotheses : 0)
115+
, logits(with_score ? num_hypotheses : 0)
112116
{
113117
}
114118

‎python/cpp/translation_result.cc

+6-1
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,14 @@ namespace ctranslate2 {
1616
"Score of each translation hypothesis (empty if :obj:`return_scores` was disabled).")
1717
.def_readonly("attention", &TranslationResult::attention,
1818
"Attention matrix of each translation hypothesis (empty if :obj:`return_attention` was disabled).")
19+
.def_readonly("logits", &TranslationResult::logits,
20+
"Logits for each decoding step")
1921

2022
.def("__repr__", [](const TranslationResult& result) {
2123
return "TranslationResult(hypotheses=" + std::string(py::repr(py::cast(result.hypotheses)))
2224
+ ", scores=" + std::string(py::repr(py::cast(result.scores)))
2325
+ ", attention=" + std::string(py::repr(py::cast(result.attention)))
26+
+ ", logits=" + std::string(py::repr(py::cast(result.logits)))
2427
+ ")";
2528
})
2629

@@ -39,8 +42,10 @@ namespace ctranslate2 {
3942
throw py::index_error();
4043
py::dict hypothesis;
4144
hypothesis["tokens"] = result.hypotheses[i];
42-
if (result.has_scores())
45+
if (result.has_scores()){
4346
hypothesis["score"] = result.scores[i];
47+
hypothesis["logits"] = result.logits[i];
48+
};
4449
if (result.has_attention())
4550
hypothesis["attention"] = result.attention[i];
4651
return hypothesis;

‎python/cpp/translator.cc

+8-1
Original file line numberDiff line numberDiff line change
@@ -228,10 +228,12 @@ namespace ctranslate2 {
228228
size_t max_batch_size,
229229
const std::string& batch_type_str,
230230
size_t max_input_length,
231+
dim_t offset,
231232
bool asynchronous) {
232233
const auto batch_type = str_to_batch_type(batch_type_str);
233234
ScoringOptions options;
234235
options.max_input_length = max_input_length;
236+
options.offset = offset;
235237

236238
std::shared_lock lock(_mutex);
237239
assert_model_is_ready();
@@ -252,6 +254,7 @@ namespace ctranslate2 {
252254
size_t read_batch_size,
253255
const std::string& batch_type_str,
254256
size_t max_input_length,
257+
dim_t offset,
255258
bool with_tokens_score,
256259
const TokenizeFn& source_tokenize_fn,
257260
const TokenizeFn& target_tokenize_fn,
@@ -263,7 +266,7 @@ namespace ctranslate2 {
263266
const auto batch_type = str_to_batch_type(batch_type_str);
264267
ScoringOptions options;
265268
options.max_input_length = max_input_length;
266-
269+
options.offset = offset;
267270
std::shared_lock lock(_mutex);
268271
assert_model_is_ready();
269272

@@ -592,6 +595,7 @@ namespace ctranslate2 {
592595
py::arg("max_batch_size")=0,
593596
py::arg("batch_type")="examples",
594597
py::arg("max_input_length")=1024,
598+
py::arg("offset") = 0,
595599
py::arg("asynchronous")=false,
596600
py::call_guard<py::gil_scoped_release>(),
597601
R"pbdoc(
@@ -606,6 +610,7 @@ namespace ctranslate2 {
606610
minimized.
607611
batch_type: Whether :obj:`max_batch_size` is the number of "examples" or "tokens".
608612
max_input_length: Truncate inputs after this many tokens (0 to disable).
613+
offset: Ignore the first n tokens in target in score calculation.
609614
asynchronous: Run the scoring asynchronously.
610615
611616
Returns:
@@ -621,6 +626,7 @@ namespace ctranslate2 {
621626
py::arg("read_batch_size")=0,
622627
py::arg("batch_type")="examples",
623628
py::arg("max_input_length")=1024,
629+
py::arg("offset")=0,
624630
py::arg("with_tokens_score")=false,
625631
py::arg("source_tokenize_fn")=nullptr,
626632
py::arg("target_tokenize_fn")=nullptr,
@@ -649,6 +655,7 @@ namespace ctranslate2 {
649655
batch_type: Whether :obj:`max_batch_size` and :obj:`read_batch_size` are the
650656
number of "examples" or "tokens".
651657
max_input_length: Truncate inputs after this many tokens (0 to disable).
658+
offset: Ignore the first n tokens in target in score calculation.
652659
with_tokens_score: Include the token-level scores in the output file.
653660
source_tokenize_fn: Function to tokenize source lines.
654661
target_tokenize_fn: Function to tokenize target lines.

‎python/ctranslate2/converters/opennmt_py.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -104,9 +104,10 @@ def _get_model_spec_lm(opt, variables, src_vocabs, tgt_vocabs, num_source_embedd
104104
activation_fn = getattr(opt, "pos_ffn_activation_fn", "relu")
105105
num_heads = getattr(opt, "heads", 8)
106106
num_kv = getattr(opt, "num_kv", 0)
107-
if num_kv == num_heads:
107+
if num_kv == num_heads or num_kv == 0:
108108
num_kv = None
109109
rotary_dim = 0 if with_rotary else None
110+
rotary_interleave = getattr(opt, "rotary_interleave", True)
110111
ffn_glu = activation_fn == "silu"
111112
sliding_window = getattr(opt, "sliding_window", 0)
112113

@@ -119,7 +120,7 @@ def _get_model_spec_lm(opt, variables, src_vocabs, tgt_vocabs, num_source_embedd
119120
alibi=with_alibi,
120121
rms_norm=opt.layer_norm == "rms",
121122
rotary_dim=rotary_dim,
122-
rotary_interleave=True,
123+
rotary_interleave=rotary_interleave,
123124
multi_query_attention=getattr(opt, "multiquery", False),
124125
num_heads_kv=num_kv,
125126
sliding_window=sliding_window,
@@ -329,7 +330,7 @@ def set_linear(spec, variables, scope):
329330
spec.weight = _get_variable(variables, "%s.weight" % scope)
330331
bias = variables.get("%s.bias" % scope)
331332
if bias is not None:
332-
spec.bias = bias.numpy()
333+
spec.bias = bias
333334

334335

335336
def set_embeddings(spec, variables, scope):
@@ -341,7 +342,7 @@ def set_position_encodings(spec, variables, scope):
341342

342343

343344
def _get_variable(variables, name):
344-
return variables[name].numpy()
345+
return variables[name]
345346

346347

347348
def main():

‎python/ctranslate2/converters/transformers.py

+94-5
Original file line numberDiff line numberDiff line change
@@ -889,12 +889,44 @@ def get_model_spec(self, model):
889889

890890
return spec
891891

892+
def _get_lang_ids_from_tokenizer(self, tokenizer):
893+
non_lang_special_tokens = [
894+
"<|endoftext|>",
895+
"<|startoftranscript|>",
896+
"<|translate|>",
897+
"<|transcribe|>",
898+
"<|startoflm|>",
899+
"<|startofprev|>",
900+
"<|nocaptions|>",
901+
"<|notimestamps|>",
902+
]
903+
return [
904+
token_id
905+
for token_id, token in zip(
906+
tokenizer.additional_special_tokens_ids,
907+
tokenizer.additional_special_tokens,
908+
)
909+
if token not in non_lang_special_tokens
910+
]
911+
892912
def set_config(self, config, model, tokenizer):
893-
config.suppress_ids = model.config.suppress_tokens
894-
config.suppress_ids_begin = model.config.begin_suppress_tokens
895-
config.lang_ids = tokenizer.additional_special_tokens_ids[2:-6]
913+
gen_config = getattr(model, "generation_config", None)
914+
915+
if gen_config is not None:
916+
config.suppress_ids = gen_config.suppress_tokens
917+
config.suppress_ids_begin = gen_config.begin_suppress_tokens
918+
if hasattr(gen_config, "alignment_heads"):
919+
config.alignment_heads = gen_config.alignment_heads
920+
if hasattr(gen_config, "lang_to_id"):
921+
config.lang_ids = sorted(gen_config.lang_to_id.values())
922+
else:
923+
config.suppress_ids = model.config.suppress_tokens
924+
config.suppress_ids_begin = model.config.begin_suppress_tokens
925+
config.alignment_heads = _WHISPER_ALIGNMENT_HEADS.get(model.name_or_path)
926+
927+
if getattr(config, "lang_ids", None) is None:
928+
config.lang_ids = self._get_lang_ids_from_tokenizer(tokenizer)
896929

897-
config.alignment_heads = _WHISPER_ALIGNMENT_HEADS.get(model.name_or_path)
898930
if config.alignment_heads is None:
899931
# Use the last half layers for alignment by default.
900932
num_layers = model.config.decoder_layers
@@ -1024,7 +1056,12 @@ def set_config(self, config, model, tokenizer):
10241056
config.bos_token = tokenizer.pad_token
10251057
config.eos_token = tokenizer.eos_token
10261058
config.unk_token = tokenizer.unk_token
1027-
config.decoder_start_token = tokenizer.pad_token
1059+
if hasattr(model.config, "decoder_start_token_id"):
1060+
config.decoder_start_token = tokenizer.convert_ids_to_tokens(
1061+
model.config.decoder_start_token_id
1062+
)
1063+
else:
1064+
config.decoder_start_token = tokenizer.pad_token
10281065

10291066
def set_stack(self, spec, module, is_decoder=False):
10301067
self.set_layer_norm(spec.layer_norm, module.final_layer_norm)
@@ -1493,6 +1530,58 @@ def set_decoder(self, spec, module):
14931530
self.set_linear(layer_spec.ffn.linear_1, layer.mlp.fc2)
14941531

14951532

1533+
@register_loader("PhiConfig")
1534+
class PhiLoader(ModelLoader):
1535+
@property
1536+
def architecture_name(self):
1537+
return "AutoModelForCausalLM"
1538+
1539+
def get_model_spec(self, model):
1540+
spec = transformer_spec.TransformerDecoderModelSpec.from_config(
1541+
num_layers=model.config.n_layer,
1542+
num_heads=model.config.n_head,
1543+
pre_norm=True,
1544+
activation=_SUPPORTED_ACTIVATIONS[model.config.activation_function],
1545+
rotary_dim=model.config.rotary_dim,
1546+
rotary_interleave=False,
1547+
parallel_residual=True,
1548+
shared_layer_norm=True,
1549+
)
1550+
1551+
self.set_decoder(spec.decoder, model.transformer)
1552+
self.set_linear(spec.decoder.projection, model.lm_head.linear)
1553+
self.set_layer_norm(spec.decoder.layer_norm, model.lm_head.ln)
1554+
return spec
1555+
1556+
def get_vocabulary(self, model, tokenizer):
1557+
tokens = super().get_vocabulary(model, tokenizer)
1558+
1559+
extra_ids = model.config.vocab_size - len(tokens)
1560+
for i in range(extra_ids):
1561+
tokens.append("<extra_id_%d>" % i)
1562+
1563+
return tokens
1564+
1565+
def set_vocabulary(self, spec, tokens):
1566+
spec.register_vocabulary(tokens)
1567+
1568+
def set_config(self, config, model, tokenizer):
1569+
config.bos_token = tokenizer.bos_token
1570+
config.eos_token = tokenizer.eos_token
1571+
config.unk_token = tokenizer.unk_token
1572+
1573+
def set_decoder(self, spec, module):
1574+
spec.scale_embeddings = False
1575+
self.set_embeddings(spec.embeddings, module.embd.wte)
1576+
1577+
for layer_spec, layer in zip(spec.layer, module.h):
1578+
self.set_layer_norm(layer_spec.shared_layer_norm, layer.ln)
1579+
self.set_linear(layer_spec.self_attention.linear[0], layer.mixer.Wqkv)
1580+
self.set_linear(layer_spec.self_attention.linear[1], layer.mixer.out_proj)
1581+
self.set_linear(layer_spec.ffn.linear_0, layer.mlp.fc1)
1582+
self.set_linear(layer_spec.ffn.linear_1, layer.mlp.fc2)
1583+
1584+
14961585
@register_loader("RWConfig")
14971586
class RWLoader(ModelLoader):
14981587
@property

‎python/ctranslate2/specs/transformer_spec.py

+3
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,8 @@ def __init__(
171171
self.alibi = alibi
172172
self.alibi_use_positive_positions = alibi_use_positive_positions
173173
self.scale_alibi = scale_alibi
174+
if sliding_window is not None:
175+
self.sliding_window = np.dtype("int32").type(sliding_window)
174176
if (
175177
not relative_position
176178
and not relative_attention_bias
@@ -225,6 +227,7 @@ def __init__(
225227
relative_attention_bias=relative_attention_bias,
226228
rms_norm=rms_norm,
227229
num_heads_kv=num_heads_kv,
230+
sliding_window=sliding_window,
228231
)
229232
self.ffn = FeedForwardSpec(glu=ffn_glu, rms_norm=rms_norm)
230233

‎python/ctranslate2/version.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
"""Version information."""
22

3-
__version__ = "3.21.0"
3+
__version__ = "3.24.1"

‎python/setup.py

+1
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ def _maybe_add_library_root(lib_name):
112112
"numpy",
113113
"pyyaml>=5.3,<7",
114114
],
115+
include_package_data=True,
115116
entry_points={
116117
"console_scripts": [
117118
"ct2-fairseq-converter=ctranslate2.converters.fairseq:main",

‎python/tests/requirements.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
transformers==4.29.*;platform_system=='Linux'
1+
transformers==4.35.*;platform_system=='Linux'
22
fairseq==0.12.2;platform_system=='Linux' or platform_system=='Darwin'
33
OpenNMT-py==2.2.*;platform_system=='Linux' or platform_system=='Darwin'
44
OpenNMT-tf==2.30.*

‎python/tests/test_transformers.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -984,13 +984,17 @@ def test_transformers_wav2vec2(
984984
w2v2_model = transformers.Wav2Vec2ForCTC.from_pretrained(model_name)
985985
del w2v2_model.wav2vec2.encoder.layers
986986
del w2v2_model.wav2vec2.encoder.layer_norm
987-
torch.save(w2v2_model, output_dir + "/wav2vec2_partial.bin")
987+
w2v2_model.save_pretrained(output_dir + "/wav2vec2_partial.bin")
988988
w2v2_processor = transformers.Wav2Vec2Processor.from_pretrained(model_name)
989989
torch.save(w2v2_processor, output_dir + "/wav2vec2_processor.bin")
990990

991991
device = "cuda" if os.environ.get("CUDA_VISIBLE_DEVICES") else "cpu"
992992
cpu_threads = int(os.environ.get("OMP_NUM_THREADS", 0))
993-
w2v2_model = torch.load(output_dir + "/wav2vec2_partial.bin").to(device)
993+
w2v2_model = transformers.Wav2Vec2ForCTC.from_pretrained(
994+
output_dir + "/wav2vec2_partial.bin"
995+
).to(device)
996+
del w2v2_model.wav2vec2.encoder.layers
997+
del w2v2_model.wav2vec2.encoder.layer_norm
994998
w2v2_processor = torch.load(output_dir + "/wav2vec2_processor.bin")
995999
ct2_w2v2_model = ctranslate2.models.Wav2Vec2(
9961000
output_dir,

‎python/tools/prepare_build_environment_linux.sh

+7-9
Original file line numberDiff line numberDiff line change
@@ -20,17 +20,15 @@ if [ "$CIBW_ARCHS" == "aarch64" ]; then
2020

2121
else
2222

23-
# Install CUDA 11.2, see:
24-
# * https://gitlab.com/nvidia/container-images/cuda/-/blob/master/dist/11.2.2/centos7-x86_64/base/Dockerfile
25-
# * https://gitlab.com/nvidia/container-images/cuda/-/blob/master/dist/11.2.2/centos7-x86_64/devel/Dockerfile
23+
# Install CUDA 12.2:
2624
yum-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel7/x86_64/cuda-rhel7.repo
2725
yum install --setopt=obsoletes=0 -y \
28-
cuda-nvcc-11-2-11.2.152-1 \
29-
cuda-cudart-devel-11-2-11.2.152-1 \
30-
libcurand-devel-11-2-10.2.3.152-1 \
31-
libcudnn8-devel-8.1.1.33-1.cuda11.2 \
32-
libcublas-devel-11-2-11.4.1.1043-1
33-
ln -s cuda-11.2 /usr/local/cuda
26+
cuda-nvcc-12-2-12.2.140-1 \
27+
cuda-cudart-devel-12-2-12.2.140-1 \
28+
libcurand-devel-12-2-10.3.3.141-1 \
29+
libcudnn8-devel-8.9.7.29-1.cuda12.2 \
30+
libcublas-devel-12-2-12.2.5.6-1
31+
ln -s cuda-12.2 /usr/local/cuda
3432

3533
ONEAPI_VERSION=2023.2.0
3634
yum-config-manager --add-repo https://yum.repos.intel.com/oneapi

‎python/tools/prepare_build_environment_windows.sh

+9-7
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,17 @@
33
set -e
44
set -x
55

6-
CUDA_ROOT="C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v11.2"
7-
curl -L -nv -o cuda.exe https://developer.download.nvidia.com/compute/cuda/11.2.2/local_installers/cuda_11.2.2_461.33_win10.exe
8-
./cuda.exe -s nvcc_11.2 cudart_11.2 cublas_dev_11.2 curand_dev_11.2
6+
CUDA_ROOT="C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v12.2"
7+
curl -L -nv -o cuda.exe https://developer.download.nvidia.com/compute/cuda/12.2.2/local_installers/cuda_12.2.2_537.13_windows.exe
8+
./cuda.exe -s nvcc_12.2 cudart_12.2 cublas_dev_12.2 curand_dev_12.2
99
rm cuda.exe
1010

11-
curl -L -nv -o cudnn.zip https://developer.download.nvidia.com/compute/redist/cudnn/v8.1.1/cudnn-11.2-windows-x64-v8.1.1.33.zip
12-
unzip cudnn.zip && rm cudnn.zip
13-
cp -r cuda/* "$CUDA_ROOT"
14-
rm -r cuda/
11+
CUDNN_ROOT="C:/Program Files/NVIDIA/CUDNN/v8.8"
12+
curl -L -nv -o cudnn.exe https://developer.download.nvidia.com/compute/redist/cudnn/v8.8.0/local_installers/12.0/cudnn_8.8.0.121_windows.exe
13+
./cudnn.exe -s
14+
sleep 10
15+
cp -r "$CUDNN_ROOT"/* "$CUDA_ROOT"
16+
rm cudnn.exe
1517

1618
# See https://github.com/oneapi-src/oneapi-ci for installer URLs
1719
curl -L -nv -o webimage.exe https://registrationcenter-download.intel.com/akdlm/irc_nas/19078/w_BaseKit_p_2023.0.0.25940_offline.exe

‎src/decoding.cc

+18-8
Original file line numberDiff line numberDiff line change
@@ -779,24 +779,29 @@ namespace ctranslate2 {
779779
StorageView logits(dtype, device);
780780
std::vector<dim_t> batch_offset(batch_size);
781781
std::vector<DecodingResult> results(batch_size);
782+
783+
StorageView best_ids(DataType::INT32);
784+
StorageView best_probs(dtype);
785+
StorageView alive_seq(DataType::INT32);
786+
StorageView attention_step;
787+
StorageView attention_step_device(dtype, device);
788+
789+
const dim_t max_step = get_max_step(max_length, return_prefix, prefix_ids);
790+
782791
for (dim_t i = 0; i < batch_size; ++i) {
783792
batch_offset[i] = i;
784793
sample_from.at<int32_t>(i) = start_ids[i];
785794
results[i].hypotheses.resize(1);
786795
if (return_scores)
796+
{
787797
results[i].scores.resize(1, 0.f);
798+
results[i].logits.resize(max_step);
799+
};
800+
788801
if (return_attention)
789802
results[i].attention.resize(1);
790803
}
791804

792-
StorageView best_ids(DataType::INT32);
793-
StorageView best_probs(dtype);
794-
StorageView alive_seq(DataType::INT32);
795-
StorageView attention_step;
796-
StorageView attention_step_device(dtype, device);
797-
798-
const dim_t max_step = get_max_step(max_length, return_prefix, prefix_ids);
799-
800805
for (dim_t step = 0; step < max_step; ++step) {
801806
convert_to_original_word_ids(decoder, sample_from);
802807
decoder(start_step + step,
@@ -851,6 +856,8 @@ namespace ctranslate2 {
851856
const size_t batch_id = batch_offset[i];
852857
const dim_t prefix_length = prefix_ids ? prefix_ids->at(batch_id).size() : 0;
853858
const float score = best_probs.scalar_at<float>({i, 0});
859+
// convert word_id from
860+
const float log_prob = log_probs.scalar_at<float>({i, static_cast<int32_t>(word_id)});
854861

855862
if ((!is_eos(word_id, end_ids) || include_eos_in_hypotheses)
856863
&& (return_prefix || step >= prefix_length)) {
@@ -862,7 +869,10 @@ namespace ctranslate2 {
862869
}
863870

864871
if (return_scores)
872+
{
865873
results[batch_id].scores[0] += score;
874+
results[batch_id].logits[step] = log_prob;
875+
};
866876

867877
bool is_finished = ((is_eos(word_id, end_ids) && step >= prefix_length)
868878
|| (is_last_step(step, max_length, prefix_length, return_prefix)));

‎src/layers/attention.cc

+23-5
Original file line numberDiff line numberDiff line change
@@ -430,7 +430,8 @@ namespace ctranslate2 {
430430
const Padder* queries_padder,
431431
const Padder* values_padder,
432432
bool return_normalized_attention,
433-
StorageView* position_bias) const {
433+
StorageView* position_bias,
434+
dim_t offset) const {
434435
PROFILE("MultiHeadAttention");
435436
const Device device = queries.device();
436437
const DataType dtype = queries.dtype();
@@ -449,6 +450,8 @@ namespace ctranslate2 {
449450

450451
dim_t beam_size = 1;
451452

453+
bool prefilling = (_sliding_window > 0 && values_lengths);
454+
452455
if (!_self_attention) {
453456
queries_proj = std::move(fused_proj);
454457

@@ -507,10 +510,6 @@ namespace ctranslate2 {
507510
}
508511

509512
if (_rotary_embeddings) {
510-
const dim_t offset = (cached_keys && !cached_keys->empty()
511-
? cached_keys->dim(_cache_time_dim)
512-
: 0);
513-
514513
if (_merge_time_and_head_dims) {
515514
queries_proj.reshape({queries_proj.dim(0), -1, _d_model});
516515
split_heads(queries_proj, _num_heads);
@@ -536,6 +535,15 @@ namespace ctranslate2 {
536535
concat_op({&tmp, &keys_proj}, *cached_keys);
537536
tmp = std::move(*cached_values);
538537
concat_op({&tmp, &values_proj}, *cached_values);
538+
539+
if (!prefilling && _sliding_window > 0 && cached_keys->shape()[2] > _sliding_window) {
540+
// only for generation
541+
const ops::Slide slide_op(2, 1, cached_keys->shape()[2] - 1);
542+
slide_op(*cached_keys, tmp);
543+
*cached_keys = std::move(tmp);
544+
slide_op(*cached_values, tmp);
545+
*cached_values = std::move(tmp);
546+
}
539547
}
540548
}
541549
}
@@ -564,6 +572,16 @@ namespace ctranslate2 {
564572
_alibi,
565573
position_bias);
566574

575+
if (prefilling && cached_keys && cached_keys->shape()[2] > _sliding_window) {
576+
// set only last sliding_window tokens to cached_keys and cached_values after computing attention
577+
const ops::Slide slide_op(2, cached_keys->shape()[2] - _sliding_window, _sliding_window);
578+
StorageView tmp(dtype, device);
579+
slide_op(*cached_keys, tmp);
580+
*cached_keys = std::move(tmp);
581+
slide_op(*cached_values, tmp);
582+
*cached_values = std::move(tmp);
583+
}
584+
567585
if (_merge_time_and_head_dims) {
568586
context.reshape(queries.shape());
569587
if (queries_padder)

‎src/layers/transformer.cc

+92-43
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,8 @@ namespace ctranslate2 {
121121
const Padder* input_padder,
122122
const Padder* memory_padder,
123123
bool return_normalized_attention,
124-
StorageView* position_bias) const {
124+
StorageView* position_bias,
125+
dim_t offset) const {
125126
PROFILE("TransformerDecoderLayer");
126127

127128
const DataType dtype = input.dtype();
@@ -149,7 +150,8 @@ namespace ctranslate2 {
149150
input_padder,
150151
input_padder,
151152
true,
152-
position_bias);
153+
position_bias,
154+
offset);
153155

154156
if (_post_attention_layer_norm)
155157
(*_post_attention_layer_norm)(input, hidden);
@@ -172,7 +174,8 @@ namespace ctranslate2 {
172174
input_padder,
173175
input_padder,
174176
true,
175-
position_bias);
177+
position_bias,
178+
offset);
176179

177180
StorageView context(dtype, device);
178181
if (_encoder_attention) {
@@ -330,7 +333,8 @@ namespace ctranslate2 {
330333
? nullptr
331334
: build_position_encoder(model, scope + "/position_encodings", _embeddings))
332335
, _with_encoder_attention(_layers.front()->has_cross_attention())
333-
, _proj(model, scope + "/projection") {
336+
, _proj(model, scope + "/projection")
337+
, _sliding_window(model.get_attribute_with_default<int32_t>(scope + "/sliding_window", 0)) {
334338

335339
dim_t alignment_layer = (
336340
model.get_attribute_with_default<int32_t>(scope + "/alignment_layer", -1));
@@ -467,7 +471,13 @@ namespace ctranslate2 {
467471
(*_layernorm_embedding)(layer_in, layer_in);
468472

469473
const dim_t batch_size = layer_in.dim(0);
470-
const dim_t max_time = layer_in.dim(1);
474+
dim_t max_time;
475+
476+
if (_sliding_window > 0 && layer_in.dim(1) > _sliding_window) {
477+
max_time = _sliding_window;
478+
} else
479+
max_time = layer_in.dim(1);
480+
471481
const bool allow_padding_removal = Padder::allow_padding_removal(_device, _compute_type);
472482

473483
std::unique_ptr<const Padder> input_padder;
@@ -479,14 +489,14 @@ namespace ctranslate2 {
479489
lengths = input_lengths.get();
480490
}
481491

492+
bool multi_query = _layers.front()->get_self_attention().multi_query();
493+
482494
if (lengths) {
483495
if (allow_padding_removal) {
484496
input_padder = std::make_unique<Padder>(*lengths, max_time);
485497
input_padder->remove_padding(layer_in);
486498
}
487499

488-
const bool multi_query = _layers.front()->get_self_attention().multi_query();
489-
490500
StorageView lengths_mask = layers::MultiHeadAttention::prepare_length_mask(
491501
*lengths,
492502
_num_heads,
@@ -531,47 +541,86 @@ namespace ctranslate2 {
531541

532542
StorageView position_bias(dtype, device);
533543

534-
for (size_t l = 0; l < _layers.size(); ++l) {
535-
StorageView* cached_self_attn_keys = nullptr;
536-
StorageView* cached_self_attn_values = nullptr;
537-
StorageView* cached_attn_keys = nullptr;
538-
StorageView* cached_attn_values = nullptr;
539-
540-
if (step >= 0) {
541-
const std::string l_str = std::to_string(l);
542-
cached_self_attn_keys = &state.at("self_keys_" + l_str);
543-
cached_self_attn_values = &state.at("self_values_" + l_str);
544-
if (_with_encoder_attention) {
545-
cached_attn_keys = &state.at("memory_keys_" + l_str);
546-
cached_attn_values = &state.at("memory_values_" + l_str);
547-
}
544+
std::vector<StorageView> layer_ins;
545+
546+
while (true) {
547+
dim_t prompt_size = layer_in.dim(1);
548+
if (_sliding_window == 0 || prompt_size <= _sliding_window) {
549+
layer_ins.push_back(std::move(layer_in));
550+
break;
548551
}
552+
if (layer_in.dim(1) > _sliding_window) {
553+
StorageView tmp(dtype, device);
554+
const ops::Split split_op(1, {_sliding_window, prompt_size - _sliding_window});
555+
split_op(layer_in, tmp, layer_in);
556+
layer_ins.push_back(std::move(tmp));
557+
}
558+
}
549559

550-
std::unique_ptr<StorageView> heads_to_select = get_layer_alignment_heads(l, batch_size);
551-
std::unique_ptr<StorageView> layer_attention;
552-
if (attention && heads_to_select)
553-
layer_attention = std::make_unique<StorageView>(dtype, device);
560+
for (size_t i = 0; i < layer_ins.size(); ++i) {
561+
auto layer_in_chunk = layer_ins[i];
562+
for (size_t l = 0; l < _layers.size(); ++l) {
563+
StorageView* cached_self_attn_keys = nullptr;
564+
StorageView* cached_self_attn_values = nullptr;
565+
StorageView* cached_attn_keys = nullptr;
566+
StorageView* cached_attn_values = nullptr;
567+
568+
if (step >= 0) {
569+
const std::string l_str = std::to_string(l);
570+
cached_self_attn_keys = &state.at("self_keys_" + l_str);
571+
cached_self_attn_values = &state.at("self_values_" + l_str);
572+
if (_with_encoder_attention) {
573+
cached_attn_keys = &state.at("memory_keys_" + l_str);
574+
cached_attn_values = &state.at("memory_values_" + l_str);
575+
}
576+
}
554577

555-
(*_layers[l])(layer_in,
556-
input_lengths_mask.get(),
557-
memory,
558-
memory_lengths_mask.get(),
559-
cached_self_attn_keys,
560-
cached_self_attn_values,
561-
cached_attn_keys,
562-
cached_attn_values,
563-
layer_out,
564-
layer_attention.get(),
565-
input_padder.get(),
566-
memory_padder.get(),
567-
return_normalized_attention(),
568-
&position_bias);
569-
layer_in = std::move(layer_out);
578+
std::unique_ptr<StorageView> heads_to_select = get_layer_alignment_heads(l, batch_size);
579+
std::unique_ptr<StorageView> layer_attention;
580+
if (attention && heads_to_select)
581+
layer_attention = std::make_unique<StorageView>(dtype, device);
582+
583+
dim_t offset = _sliding_window * i + step;
584+
offset = offset < 0 ? 0 : offset;
585+
if (i > 0) {
586+
auto max_tokens = _sliding_window + layer_in_chunk.dim(1);
587+
StorageView tmp_lengths = StorageView(Shape{layer_in_chunk.dim(0)}, int32_t(max_tokens), device);
588+
StorageView lengths_mask = layers::MultiHeadAttention::prepare_length_mask(
589+
tmp_lengths,
590+
_num_heads,
591+
max_tokens,
592+
/*mask_future=*/true,
593+
multi_query);
594+
595+
const ops::Slide slide_lengths_op(2, _sliding_window, layer_in_chunk.dim(1));
596+
// reuse tmp_lengths
597+
slide_lengths_op(lengths_mask, tmp_lengths);
598+
input_lengths_mask = std::make_unique<StorageView>(std::move(tmp_lengths));
599+
}
570600

571-
if (layer_attention) {
572-
alignment_heads.emplace_back(dtype, device);
573-
ops::Gather(1, 1)(*layer_attention, *heads_to_select, alignment_heads.back());
601+
(*_layers[l])(layer_in_chunk,
602+
input_lengths_mask.get(),
603+
memory,
604+
memory_lengths_mask.get(),
605+
cached_self_attn_keys,
606+
cached_self_attn_values,
607+
cached_attn_keys,
608+
cached_attn_values,
609+
layer_out,
610+
layer_attention.get(),
611+
input_padder.get(),
612+
memory_padder.get(),
613+
return_normalized_attention(),
614+
&position_bias,
615+
offset);
616+
layer_in_chunk = std::move(layer_out);
617+
618+
if (layer_attention) {
619+
alignment_heads.emplace_back(dtype, device);
620+
ops::Gather(1, 1)(*layer_attention, *heads_to_select, alignment_heads.back());
621+
}
574622
}
623+
layer_in = std::move(layer_in_chunk);
575624
}
576625

577626
if (step == 0) {

‎src/models/language_model.cc

+2-1
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,8 @@ namespace ctranslate2 {
122122
state,
123123
ids,
124124
vocabulary,
125-
_model->preferred_size_multiple());
125+
_model->preferred_size_multiple(),
126+
options.offset);
126127
}
127128

128129
bool DecoderReplica::skip_scoring(const std::vector<std::string>& tokens,

‎src/models/sequence_to_sequence.cc

+6-2
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,8 @@ namespace ctranslate2 {
256256
state,
257257
target_ids,
258258
_model->get_target_vocabulary(),
259-
_model->preferred_size_multiple());
259+
_model->preferred_size_multiple(),
260+
options.offset);
260261
}
261262

262263
bool EncoderDecoderReplica::skip_scoring(const std::vector<std::string>& source,
@@ -422,7 +423,8 @@ namespace ctranslate2 {
422423

423424
final_results.emplace_back(std::move(hypotheses),
424425
std::move(result.scores),
425-
std::move(result.attention));
426+
std::move(result.attention),
427+
std::move(result.logits));
426428
}
427429

428430
return final_results;
@@ -461,6 +463,8 @@ namespace ctranslate2 {
461463
result.scores.emplace_back(0);
462464
if (options.return_attention)
463465
result.attention.emplace_back(attention);
466+
if (options.return_scores)
467+
result.logits.emplace_back(0);
464468
}
465469

466470
return true;

‎src/ops/concat_split_cpu.cc ‎src/ops/concat_split_slide_cpu.cc

+30-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include "ctranslate2/ops/concat.h"
22
#include "ctranslate2/ops/split.h"
3+
#include "ctranslate2/ops/slide.h"
34

45
#include "cpu/parallel.h"
56
#include "type_dispatch.h"
@@ -71,13 +72,41 @@ namespace ctranslate2 {
7172
}
7273
}
7374

75+
template <Device D, typename T>
76+
void Slide::compute(const StorageView& input, StorageView& output, const dim_t& index) const {
77+
const dim_t axis = _axis < 0 ? input.rank() + _axis : _axis;
78+
const dim_t stride_axis = input.stride(axis) == 0 ? 1 : input.stride(axis);
79+
const dim_t step_size = input.dim(axis) * stride_axis;
80+
const T* input_data = input.data<T>();
81+
82+
StorageView& x = output;
83+
T* x_data = x.data<T>();
84+
85+
const dim_t copy_size = compute_copy_size(x, axis);
86+
if (copy_size == 0)
87+
return;
88+
89+
const dim_t iter_size = compute_iter_size(x, axis);
90+
91+
const dim_t grain_size = cpu::get_minimum_batch_copies_per_thread<T>(copy_size);
92+
input_data += index * stride_axis; // Read next with an offset.
93+
cpu::parallel_for(0, iter_size, grain_size, [&](dim_t begin, dim_t end) {
94+
for (dim_t i = begin; i < end; ++i)
95+
primitives<D>::copy(input_data + i * step_size, x_data + i * copy_size, copy_size);
96+
});
97+
}
98+
7499
#define DECLARE_IMPL(T) \
75100
template void \
76101
Concat::compute<Device::CPU, T>(const std::vector<const StorageView*>& inputs, \
77102
StorageView& output) const; \
78103
template void \
79104
Split::compute<Device::CPU, T>(const StorageView& input, \
80-
std::vector<StorageView*>& outputs) const;
105+
std::vector<StorageView*>& outputs) const; \
106+
template void \
107+
Slide::compute<Device::CPU, T>(const StorageView& input, \
108+
StorageView& output, \
109+
const dim_t& index) const;
81110

82111
DECLARE_ALL_TYPES(DECLARE_IMPL)
83112

‎src/ops/concat_split_gpu.cu ‎src/ops/concat_split_slide_gpu.cu

+49-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include "ctranslate2/ops/concat.h"
22
#include "ctranslate2/ops/split.h"
3+
#include "ctranslate2/ops/slide.h"
34

45
#include <thrust/gather.h>
56
#include <thrust/iterator/counting_iterator.h>
@@ -163,14 +164,60 @@ namespace ctranslate2 {
163164
}
164165
}
165166

167+
template <Device D, typename T>
168+
void Slide::compute(const StorageView& input, StorageView& output, const dim_t& index) const {
169+
const dim_t axis = _axis < 0 ? input.rank() + _axis : _axis;
170+
const dim_t input_dim = input.dim(axis);
171+
const dim_t inner_size = input.stride(axis) == 0 ? 1 : input.stride(axis);
172+
const dim_t inner_bytes = inner_size * sizeof (T);
173+
const T* input_data = input.data<T>();
174+
175+
T* output_data = output.data<T>();
176+
const dim_t output_size = output.size();
177+
const dim_t output_bytes = output_size * sizeof (T);
178+
if (axis == 0) {
179+
dim_t offset = index * output.stride(axis);
180+
primitives<D>::copy(input_data + offset, output_data, output_size);
181+
}
182+
else {
183+
const dim_t output_dim = output.dim(axis);
184+
185+
if (inner_size == 1) {
186+
auto map_ids = thrust::make_transform_iterator(
187+
thrust::counting_iterator<cuda::index_t>(0),
188+
depth_offset_map<cuda::index_t>(index, output_dim, input_dim));
189+
THRUST_CALL(thrust::gather, map_ids, map_ids + output_size, input_data, output_data);
190+
} else if (inner_bytes % sizeof(uint4) == 0 && output_bytes % sizeof(uint4) == 0) {
191+
auto map_ids = thrust::make_transform_iterator(
192+
thrust::counting_iterator<cuda::index_t>(0),
193+
inner_dim_offset_map<cuda::index_t>(index,
194+
output_dim,
195+
input_dim,
196+
inner_bytes / sizeof(uint4)));
197+
THRUST_CALL(thrust::gather,
198+
map_ids,
199+
map_ids + output_bytes / sizeof(uint4),
200+
reinterpret_cast<const uint4 *>(input_data),
201+
reinterpret_cast<uint4 *>(output_data));
202+
} else {
203+
auto map_ids = thrust::make_transform_iterator(
204+
thrust::counting_iterator<cuda::index_t>(0),
205+
inner_dim_offset_map<cuda::index_t>(index, output_dim, input_dim, inner_size));
206+
THRUST_CALL(thrust::gather, map_ids, map_ids + output_size, input_data, output_data);
207+
}
208+
}
209+
}
210+
166211
#define DECLARE_IMPL(T) \
167212
template void \
168213
Concat::compute<Device::CUDA, T>(const std::vector<const StorageView*>& inputs, \
169214
StorageView& output) const; \
170215
template void \
171216
Split::compute<Device::CUDA, T>(const StorageView& input, \
172-
std::vector<StorageView*>& outputs) const;
173-
217+
std::vector<StorageView*>& outputs) const; \
218+
template void \
219+
Slide::compute<Device::CUDA, T>(const StorageView& input, \
220+
StorageView& output, const dim_t& index) const;
174221
DECLARE_ALL_TYPES(DECLARE_IMPL)
175222

176223
}

‎src/ops/slide.cc

+47
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
#include "ctranslate2/ops/slide.h"
2+
3+
#include <numeric>
4+
5+
#include "dispatch.h"
6+
7+
namespace ctranslate2 {
8+
namespace ops {
9+
10+
Slide::Slide(dim_t axis, const dim_t& index, const dim_t& size, bool no_copy)
11+
: _axis(axis)
12+
, _index(index)
13+
, _size(size)
14+
, _no_copy(no_copy) {
15+
check_arguments();
16+
}
17+
18+
void Slide::operator()(const StorageView& input, StorageView& output) const {
19+
PROFILE("Slide");
20+
const dim_t axis = _axis < 0 ? input.rank() + _axis : _axis;
21+
22+
if (_index < 0 || _index >= input.dim(axis))
23+
throw std::invalid_argument("Index or Size given is not valid");
24+
25+
dim_t offset = input.stride(0) * _index;
26+
auto shape = input.shape();
27+
shape[axis] = _size;
28+
if (_no_copy) {
29+
TYPE_DISPATCH(input.dtype(),
30+
output.view(const_cast<T*>(input.data<T>() + offset), std::move(shape)));
31+
}
32+
else {
33+
output.resize(std::move(shape));
34+
}
35+
36+
if (!_no_copy) {
37+
DEVICE_AND_TYPE_DISPATCH(input.device(), input.dtype(), (compute<D, T>(input, output, _index)));
38+
}
39+
}
40+
41+
void Slide::check_arguments() const {
42+
if (_no_copy && _axis != 0)
43+
throw std::invalid_argument("no_copy is only defined when splitting across the first dimension");
44+
}
45+
46+
}
47+
}

‎src/scoring.cc

+3-2
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@ namespace ctranslate2 {
77
layers::DecoderState& state,
88
const std::vector<std::vector<size_t>>& sequences,
99
const Vocabulary& vocabulary,
10-
const dim_t preferred_size_multiple) {
10+
const dim_t preferred_size_multiple,
11+
const dim_t offset) {
1112
const dim_t batch_size = sequences.size();
1213
const Device device = decoder.device();
1314

@@ -57,7 +58,7 @@ namespace ctranslate2 {
5758
auto& result = results[b];
5859
result.tokens.reserve(output_length);
5960
result.tokens_score.reserve(output_length);
60-
for (dim_t t = 0; t < output_length; ++t) {
61+
for (dim_t t = offset; t < output_length; ++t) {
6162
result.tokens.emplace_back(vocabulary.to_token(output_sequences[b][t]));
6263
result.tokens_score.emplace_back(scores.at<float>({b, t}));
6364
}

0 commit comments

Comments
 (0)
Please sign in to comment.