From 015d3c1d59afc71e76f731449e47a00ee11fdbd9 Mon Sep 17 00:00:00 2001 From: Luca Soldaini Date: Mon, 24 Feb 2025 16:10:32 -0800 Subject: [PATCH 1/2] first draft of safe tokenization --- python/dolma/tokenizer/tokenizer.py | 29 ++++++++++++++++++++++------- 1 file changed, 22 insertions(+), 7 deletions(-) diff --git a/python/dolma/tokenizer/tokenizer.py b/python/dolma/tokenizer/tokenizer.py index 4a7d68c1..3b3e964b 100644 --- a/python/dolma/tokenizer/tokenizer.py +++ b/python/dolma/tokenizer/tokenizer.py @@ -367,23 +367,38 @@ def tokenize_file( tokenizer = make_tokenizer(tokenizer_name_or_path, **tokenizer_kwargs) dtype = deepcopy(tokenizer.dtype) decoder = msgspec.json.Decoder(InputSpec) + force_refresh = False + with smart_open.open(path, mode="rt") as input_stream: for i, line in enumerate(input_stream, start=1): try: row = decoder.decode(line) - if text := row.text.strip(): + if not (text := row.text.strip()): # skip empty docs - tokens = tokenizer.encode(text, add_special_tokens=True) - if refresh_tokenizer_every: - # extra copy to prevent memory leaks - tokens = np.array(tokens, dtype=dtype) - yield TokenizerOutput.from_tokens(id=row.id, src=path, loc=i, tokens=tokens) # pyright: ignore + continue - if refresh_tokenizer_every > 0 and i % refresh_tokenizer_every == 0: + try: + tokens = tokenizer.encode(text, add_special_tokens=True) + except Exception: + # in case of failure, we log the error and continue + # We refresh the tokenizer to prevent memory leaks from affecting the rest of the processing + logger.warning("Error tokenizing %s:%d", path, i) + force_refresh = True + continue + + if refresh_tokenizer_every: + # extra copy to prevent memory leaks + tokens = np.array(tokens, dtype=dtype) + yield TokenizerOutput.from_tokens(id=row.id, src=path, loc=i, tokens=tokens) # pyright: ignore + + if (refresh_tokenizer_every > 0 and i % refresh_tokenizer_every == 0) or force_refresh: # to prevent memory leaks, we refresh the tokenizer every so often del tokenizer gc.collect() tokenizer = make_tokenizer(tokenizer_name_or_path, **tokenizer_kwargs) + # we reset the flag after refreshing the tokenizer + force_refresh = False + except Exception as ex: logger.error("Error processing %s:%d", path, i, exc_info=ex) From ef6fa4c35a5e6bba29751af39975af18f246ba81 Mon Sep 17 00:00:00 2001 From: Luca Soldaini Date: Mon, 24 Feb 2025 16:11:35 -0800 Subject: [PATCH 2/2] safe decoding --- python/dolma/tokenizer/tokenizer.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/python/dolma/tokenizer/tokenizer.py b/python/dolma/tokenizer/tokenizer.py index 3b3e964b..04017cdc 100644 --- a/python/dolma/tokenizer/tokenizer.py +++ b/python/dolma/tokenizer/tokenizer.py @@ -372,12 +372,13 @@ def tokenize_file( with smart_open.open(path, mode="rt") as input_stream: for i, line in enumerate(input_stream, start=1): try: - row = decoder.decode(line) - if not (text := row.text.strip()): - # skip empty docs - continue - try: + row = decoder.decode(line) + if not (text := row.text.strip()): + # skip empty docs + continue + + # the actual tokenization happens here tokens = tokenizer.encode(text, add_special_tokens=True) except Exception: # in case of failure, we log the error and continue