diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 2028187..a70ee5d 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -77,7 +77,7 @@ jobs: - name: Upload package distribution files if: matrix.task.name == 'Build' - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4 with: name: package path: dist @@ -110,7 +110,7 @@ jobs: echo "TAG=${GITHUB_REF#refs/tags/}" >> $GITHUB_ENV - name: Download package distribution files - uses: actions/download-artifact@v3 + uses: actions/download-artifact@v4 with: name: package path: dist diff --git a/CHANGELOG.md b/CHANGELOG.md index cc027dd..4bdc53f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## Unreleased +### Fixed + +- Ensure queries are always a multiple of 128. + ## [v0.6.0](https://github.com/allenai/OLMo-in-loop-evals/releases/tag/v0.6.0) - 2024-12-19 ## [v0.5.0](https://github.com/allenai/OLMo-in-loop-evals/releases/tag/v0.5.0) - 2024-12-18 diff --git a/src/olmo_eval/tasks.py b/src/olmo_eval/tasks.py index bf92248..e630711 100644 --- a/src/olmo_eval/tasks.py +++ b/src/olmo_eval/tasks.py @@ -1,5 +1,6 @@ import abc import logging +import math import re from typing import Any, Dict, List, Optional, Sequence, Type, Union, cast @@ -157,6 +158,9 @@ def max_sequence_length(self) -> int: for sample in self.samples: if len(sample["query"]) > max_seq_len: max_seq_len = len(sample["query"]) + # Pad to multiple of 128 for efficiency. + # TODO (epwalsh): make that configurable + max_seq_len = 128 * math.ceil(max_seq_len / 128) self._max_sequence_length = max_seq_len return self._max_sequence_length @@ -181,6 +185,11 @@ def collate_fn(self, data): if len(sample["dc_query"]) > max_dc_query_len: max_dc_query_len = len(sample["dc_query"]) + # Pad to multiple of 128 for efficiency. + # TODO (epwalsh): make that configurable + max_query_len = 128 * math.ceil(max_query_len / 128) + assert max_query_len <= self.max_sequence_length + doc_ids = [] cont_ids = [] ctxs = []