Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Safe tokenization by skipping failing docs. #245

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 25 additions & 9 deletions python/dolma/tokenizer/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,23 +367,39 @@ def tokenize_file(
tokenizer = make_tokenizer(tokenizer_name_or_path, **tokenizer_kwargs)
dtype = deepcopy(tokenizer.dtype)
decoder = msgspec.json.Decoder(InputSpec)
force_refresh = False

with smart_open.open(path, mode="rt") as input_stream:
for i, line in enumerate(input_stream, start=1):
try:
row = decoder.decode(line)
if text := row.text.strip():
# skip empty docs
tokens = tokenizer.encode(text, add_special_tokens=True)
if refresh_tokenizer_every:
# extra copy to prevent memory leaks
tokens = np.array(tokens, dtype=dtype)
yield TokenizerOutput.from_tokens(id=row.id, src=path, loc=i, tokens=tokens) # pyright: ignore
try:
row = decoder.decode(line)
if not (text := row.text.strip()):
# skip empty docs
continue

if refresh_tokenizer_every > 0 and i % refresh_tokenizer_every == 0:
# the actual tokenization happens here
tokens = tokenizer.encode(text, add_special_tokens=True)
except Exception:
# in case of failure, we log the error and continue
# We refresh the tokenizer to prevent memory leaks from affecting the rest of the processing
logger.warning("Error tokenizing %s:%d", path, i)
force_refresh = True
continue

if refresh_tokenizer_every:
# extra copy to prevent memory leaks
tokens = np.array(tokens, dtype=dtype)
yield TokenizerOutput.from_tokens(id=row.id, src=path, loc=i, tokens=tokens) # pyright: ignore

if (refresh_tokenizer_every > 0 and i % refresh_tokenizer_every == 0) or force_refresh:
# to prevent memory leaks, we refresh the tokenizer every so often
del tokenizer
gc.collect()
tokenizer = make_tokenizer(tokenizer_name_or_path, **tokenizer_kwargs)

# we reset the flag after refreshing the tokenizer
force_refresh = False

except Exception as ex:
logger.error("Error processing %s:%d", path, i, exc_info=ex)