diff --git a/test/sampling_tests.cpp b/test/sampling_tests.cpp index 8adb47ab0..4d85d558f 100644 --- a/test/sampling_tests.cpp +++ b/test/sampling_tests.cpp @@ -461,7 +461,7 @@ TEST(SamplingTests, RandomizedSamplingTopKCuda) { // Run test for (int i = 0; i < num_iter; i++) { auto generator = Generators::CreateGenerator(*model, *params); - Generators::DeviceSpan logits_gpu = params->p_device->Allocate(config.model.vocab_size * batch_size); + Generators::DeviceSpan logits_gpu = params->p_device->Allocate(vocab_size * batch_size); auto cpu_span = logits_gpu.CpuSpan(); // Shuffle integers 1 to k randomly into cpu_span for (int i = 0; i < batch_size; i++) {