diff --git a/open_lm/datapreprocess/ray/tokenize_shuffle.py b/open_lm/datapreprocess/ray/tokenize_shuffle.py index 665dbbb5..e0ba33fa 100644 --- a/open_lm/datapreprocess/ray/tokenize_shuffle.py +++ b/open_lm/datapreprocess/ray/tokenize_shuffle.py @@ -90,6 +90,7 @@ class RawFileType(enum.Enum): ZSTD_JSONL_COMPRESSED = 2 GZIP_JSONL_COMPRESSED = 3 TAR = 4 + TAR_PRETOK = 5 UNKNOWN = -1 @@ -118,13 +119,17 @@ def tar_reader(fh: BinaryIO, content_key: str): """ content_key: where in the tarfile to find the text/tokens. Options: "txt" - read text file as string + "json" - read json file "json:key" - read json[key] as string + "json.gz" - same as json, but also gzipped + "json.gz:key" - same as json.gz, but also gzipped "npy" - read numpy array as tokens """ + # TODO(gsmyrnis) - I think some of the modes (namely npy) are not clean on whether they are still useful - consider + # removing them in the future. content_ext = content_key.split(":")[0] buffer = io.BytesIO(fh.read()) with tarfile.open(fileobj=buffer, mode="r") as tar: - samples = [] for member in tar.getmembers(): if member.isfile() and member.name.endswith(f".{content_ext}"): with tar.extractfile(member) as fileobj: @@ -132,8 +137,20 @@ def tar_reader(fh: BinaryIO, content_key: str): if content_ext == "txt": content = fileobj.read().decode("utf-8") elif content_ext == "json": - json_dict, json_key = json.load(fileobj), content_key.split(":")[1] - content = json_dict[json_key] + json_data = json.load(fileobj) + if isinstance(json_data, dict): + json_key = content_key.split(":")[1] + content = json_data[json_key] + else: + content = json_data + elif content_ext == "json.gz": + with gzip.open(fileobj, "rb") as fileobj_unzip: + json_data = json.load(fileobj_unzip) + if isinstance(json_data, dict): + json_key = content_key.split(":")[1] + content = json_data[json_key] + else: + content = json_data elif content_ext == "npy": token_array = np.load(io.BytesIO(fileobj.read()), allow_pickle=True) content = token_array.reshape(-1).tolist() @@ -234,7 +251,7 @@ def _flush_buffer(self, folder, counter): tokens = [int(x) for x in self.buffer[i]["tokens"]] token_count += len(tokens) json_string = json.dumps(tokens) - uid = hashlib.md5(json_string.encode()).hexdigest() + uid = f"{tar_index_str}_{i:0{digits}}" sample = {"__key__": uid, "json.gz": json_string} sink.write(sample) bio.seek(0) @@ -256,6 +273,7 @@ def preprocess( do_sample: bool = False, sources: enum.Enum = None, source_counter: GlobalCounter = None, + pretok_tars: bool = False, ): tokenizer_fn, vocab_size = tokenizer rng = random.Random(hash(key) + seed) @@ -273,8 +291,11 @@ def preprocess( pbar = tqdm(file_reader(fh), mininterval=10) pbar.set_description(key) for string in pbar: - tokens = tokenizer_fn(string) - tokens.append(EOT) + if file_type == RawFileType.TAR and pretok_tars: + tokens = string + else: + tokens = tokenizer_fn(string) + tokens.append(EOT) buffer += tokens while len(buffer) >= seqlen: if do_sample: @@ -308,7 +329,9 @@ def preprocess( return [] -def process_keys(data, tokenizer, seqlen, seed, content_key, do_sample, sources=None, source_counters=None): +def process_keys( + data, tokenizer, seqlen, seed, content_key, do_sample, pretok_tars, sources=None, source_counters=None +): path = data["path"] if path.startswith("s3"): @@ -337,6 +360,7 @@ def process_keys(data, tokenizer, seqlen, seed, content_key, do_sample, sources= do_sample=do_sample, sources=sources, source_counter=source_counter, + pretok_tars=pretok_tars, ) # Ensure that all operations on the file handle are done within this block @@ -569,8 +593,14 @@ def main(args): "--ray_dashboard_host", type=str, default="127.0.0.1" ) # default is localhost; for slurm jobs do 0.0.0.0 parser.add_argument("--suffixes", nargs="+", default=[".json", ".jsonl", ".zst", ".zstd", ".tar", ".gz"]) + parser.add_argument("--pretok_tars", action="store_true", help="Assume tars contain pretokenized data.") args = parser.parse_args(args) + + assert not args.pretok_tars or args.suffixes == [ + ".tar" + ], "Currently mixing with tokenized and untokenized data at the same time is not supported." + if args.do_sample: Sources, SAMPLING_FREQUENCIES = load_from_yaml(args.default_dataset_yaml) logger.info(f"SOURCES:\n {Sources}") @@ -612,6 +642,7 @@ def main(args): input_paths = input_paths[: args.subset] if args.subfraction is not None: input_paths = input_paths[: int(args.subfraction * len(input_paths))] + print("Files considered: \n", input_paths) print(f"num files ={len(input_paths)}") num_files = len(input_paths) @@ -650,6 +681,7 @@ def main(args): seed=args.seed, content_key=content_key, do_sample=args.do_sample, + pretok_tars=args.pretok_tars, sources=Sources, source_counters=source_counters, ) diff --git a/tests/test_tokenize_shuffle.py b/tests/test_tokenize_shuffle.py index 97d73169..7e2240fc 100644 --- a/tests/test_tokenize_shuffle.py +++ b/tests/test_tokenize_shuffle.py @@ -115,3 +115,37 @@ def test_tokenize_shuffle_local_read_local_write(): total += len(x["json.gz"]) assert total == NUM_TOKENS assert exit_value == 0 + + +def test_tokenize_shuffle_with_pretokenized(): + content_len = 2048 + NUM_TOKENS = 24508089 + # download a small test json file and store at ./test_input + os.system("mkdir test_input") + os.system("mkdir test_output") + os.system( + "wget -O ./test_input/wikipedia_sample.jsonl https://huggingface.co/datasets/togethercomputer/RedPajama-Data-1T-Sample/resolve/main/wikipedia_sample.jsonl" + ) + # run tokenize script + exit_value_1 = os.system( + f"python open_lm/datapreprocess/ray/tokenize_shuffle.py --input ./test_input --content_key text --seqlen {content_len} --output ./test_output/" + ) + assert exit_value_1 == 0 + + os.system("cp -r ./test_output ./test_input/2a/") + os.system("cp -r ./test_output ./test_input/2b/") + + exit_value_2 = os.system( + f"python open_lm/datapreprocess/ray/tokenize_shuffle.py --input ./test_input/2a,./test_input/2b --content_key json.gz --seqlen {content_len} --output ./test_output/2 --pretok_tars --suffixes .tar" + ) + assert exit_value_2 == 0 + + tars = [os.path.join("test_output/2", fname) for fname in os.listdir("test_output/2") if fname.endswith(".tar")] + total = 0 + for tar in tars: + ds = wds.WebDataset(tar).decode() + for x in ds: + assert len(x["json.gz"]) == content_len + 1 + total += len(x["json.gz"]) + + assert total == 2 * NUM_TOKENS