diff --git a/python/dolma/tokenizer/tokenizer.py b/python/dolma/tokenizer/tokenizer.py index 4a7d68c1..04017cdc 100644 --- a/python/dolma/tokenizer/tokenizer.py +++ b/python/dolma/tokenizer/tokenizer.py @@ -367,23 +367,39 @@ def tokenize_file( tokenizer = make_tokenizer(tokenizer_name_or_path, **tokenizer_kwargs) dtype = deepcopy(tokenizer.dtype) decoder = msgspec.json.Decoder(InputSpec) + force_refresh = False + with smart_open.open(path, mode="rt") as input_stream: for i, line in enumerate(input_stream, start=1): try: - row = decoder.decode(line) - if text := row.text.strip(): - # skip empty docs - tokens = tokenizer.encode(text, add_special_tokens=True) - if refresh_tokenizer_every: - # extra copy to prevent memory leaks - tokens = np.array(tokens, dtype=dtype) - yield TokenizerOutput.from_tokens(id=row.id, src=path, loc=i, tokens=tokens) # pyright: ignore + try: + row = decoder.decode(line) + if not (text := row.text.strip()): + # skip empty docs + continue - if refresh_tokenizer_every > 0 and i % refresh_tokenizer_every == 0: + # the actual tokenization happens here + tokens = tokenizer.encode(text, add_special_tokens=True) + except Exception: + # in case of failure, we log the error and continue + # We refresh the tokenizer to prevent memory leaks from affecting the rest of the processing + logger.warning("Error tokenizing %s:%d", path, i) + force_refresh = True + continue + + if refresh_tokenizer_every: + # extra copy to prevent memory leaks + tokens = np.array(tokens, dtype=dtype) + yield TokenizerOutput.from_tokens(id=row.id, src=path, loc=i, tokens=tokens) # pyright: ignore + + if (refresh_tokenizer_every > 0 and i % refresh_tokenizer_every == 0) or force_refresh: # to prevent memory leaks, we refresh the tokenizer every so often del tokenizer gc.collect() tokenizer = make_tokenizer(tokenizer_name_or_path, **tokenizer_kwargs) + # we reset the flag after refreshing the tokenizer + force_refresh = False + except Exception as ex: logger.error("Error processing %s:%d", path, i, exc_info=ex)