Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 57 additions & 16 deletions open_lm/datapreprocess/ray/tokenize_shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,17 @@

import logging


import yaml
import pathlib

# Initialize an empty dictionary for sampling frequencies

DIR = pathlib.Path(__file__).parent.absolute()

DEFAULT_SPECIAL_TOKENS = {
"EleutherAI/gpt-neox-20b": {"eos_token_id": 0, "pad_token_id": 1},
}


def load_from_yaml(filename):
SAMPLING_FREQUENCIES = {}
Expand Down Expand Up @@ -257,10 +260,9 @@ def preprocess(
sources: enum.Enum = None,
source_counter: GlobalCounter = None,
):
tokenizer_fn, vocab_size = tokenizer
tokenizer_fn, EOS, PAD = tokenizer
rng = random.Random(hash(key) + seed)
EOT = SpecialTokens.END_OF_TEXT.value % (vocab_size + len(SpecialTokens))
PAD = SpecialTokens.PAD.value % (vocab_size + len(SpecialTokens))

if do_sample:
assert sources is not None
sample_freq = sources.get_sampling_frequency(key)
Expand All @@ -274,7 +276,7 @@ def preprocess(
pbar.set_description(key)
for string in pbar:
tokens = tokenizer_fn(string)
tokens.append(EOT)
tokens.append(EOS)
buffer += tokens
idx = 0
while idx < len(buffer) - seqlen:
Expand Down Expand Up @@ -352,12 +354,6 @@ def process_keys(data, tokenizer, seqlen, seed, content_key, do_sample, sources=
fh.close()


class SpecialTokens(Enum):
END_OF_TEXT = 0
PAD = -1
END_OF_DOCUMENT = -2


def parse_s3_path(s3_path):
"""
Extract the bucket and key from an S3 path.
Expand Down Expand Up @@ -442,7 +438,7 @@ def write_to_location(folder, tar_name, bio):
assert False, f"error is {path} and {e}"


def load_tokenizer(tokenizer):
def load_tokenizer(tokenizer, eos_overwrite=None, pad_overwrite=None):
enc = None
if pathlib.Path(tokenizer).exists() and pathlib.Path(tokenizer).is_file():
enc = PreTrainedTokenizerFast(tokenizer_file=tokenizer)
Expand All @@ -453,7 +449,44 @@ def load_tokenizer(tokenizer):
print(str(e))
raise ValueError(f"Unknown Tokenizer: {tokenizer}")

return (lambda x: enc(x).input_ids, enc.vocab_size)
eos_token_id, pad_token_id = enc.eos_token_id, enc.pad_token_id
if tokenizer in DEFAULT_SPECIAL_TOKENS:
eos_token_id = DEFAULT_SPECIAL_TOKENS[tokenizer]["eos_token_id"]
pad_token_id = DEFAULT_SPECIAL_TOKENS[tokenizer]["pad_token_id"]

if eos_overwrite is not None:
assert eos_overwrite < len(
enc.vocab
), f"eos_overwrite {eos_overwrite} is greater than vocab size {len(enc.vocab)}"
if eos_token_id is not None and eos_overwrite != eos_token_id:
logger.warning(
f"Default EOS id for {tokenizer} is {eos_token_id} and you are overriding it to be {eos_overwrite}."
)
eos_token_id = eos_overwrite

if pad_overwrite is not None:
assert pad_overwrite < len(
enc.vocab
), f"pad_overwrite {pad_overwrite} is greater than vocab size {len(enc.vocab)}"
if pad_token_id is not None and pad_overwrite != pad_token_id:
logger.warning(
f"Default PAD id for {tokenizer} is {pad_token_id} and you are overriding it to be {pad_overwrite}."
)
pad_token_id = pad_overwrite

if eos_token_id is None:
raise ValueError(
"Tokenizer does not have a specified EOS token id. Please manually pass one in via --eos_overwrite"
)
if pad_token_id is None:
raise ValueError(
"Tokenizer does not have a specified PAD token id. Please manually pass one in via --pad_overwrite"
)

logger.info(f'EOS token id has been set to {eos_token_id} which decodes to "{enc.decode([eos_token_id])}".')
logger.info(f'PAD token id has been set to {pad_token_id} which decodes to "{enc.decode([pad_token_id])}".')

return (lambda x: enc(x).input_ids, eos_token_id, pad_token_id)


def glob_files(path, suffixes):
Expand Down Expand Up @@ -557,7 +590,7 @@ def main(args):
parser.add_argument("--content_key", type=str, default="text")
parser.add_argument("--seqlen", type=int, default=2048)
parser.add_argument("--tokenizer", type=str, default="EleutherAI/gpt-neox-20b")
parser.add_argument("--vocab_size", type=int, default=None) # for pre-tokenized data, don't load tokenizer
parser.add_argument("--pretokenized", action="store_true") # For pre-tokenized data, don't load tokenizer
parser.add_argument("--wds_chunk_size", type=int, default=8192)
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--subset", type=int, default=None)
Expand All @@ -576,6 +609,8 @@ def main(args):
parser.add_argument("--suffixes", nargs="+", default=[".json", ".jsonl", ".zst", ".zstd", ".tar", ".gz"])
parser.add_argument("--presort", action="store_true")
parser.add_argument("--allow_imbalanced_write", action="store_true")
parser.add_argument("--eos_overwrite", type=int, default=None)
parser.add_argument("--pad_overwrite", type=int, default=None)

args = parser.parse_args(args)
if args.do_sample:
Expand Down Expand Up @@ -608,6 +643,7 @@ def main(args):
dashboard_host=args.ray_dashboard_host,
)
num_nodes = len(ray.nodes())

input_folders = args.input.split(",")
input_paths = []
for inp_folder in input_folders:
Expand Down Expand Up @@ -637,7 +673,12 @@ def main(args):
ctx.execution_options.resource_limits.object_store_memory = float("inf")
ray.data.DataContext.get_current().execution_options.verbose_progress = True
start_time = time.time()
tokenizer = load_tokenizer(args.tokenizer) if args.vocab_size is None else (lambda x: x, args.vocab_size)

if args.pretokenized:
tokenizer = (lambda x: x, args.eos_overwrite, args.pad_overwrite)
else:
tokenizer = load_tokenizer(args.tokenizer, args.eos_overwrite, args.pad_overwrite)

logger.info(f"Total number of keys = {len(input_paths)}")
df = pd.DataFrame(input_paths, columns=["path"])
ds = ray.data.from_pandas(pd.DataFrame(input_paths, columns=["path"])).repartition(parallelism)
Expand Down Expand Up @@ -687,7 +728,7 @@ def main(args):
ds = ds.repartition(1)
ds = ds.sort(key="shard")
jsonl_lines = ds.take_all()
token_count_from_manifest = sum([x["num_sequences"][0] for x in jsonl_lines] * seqlen)
token_count_from_manifest = sum([x["num_sequences"] for x in jsonl_lines] * seqlen)
write_manifest(jsonl_lines, args)
else:
write_status = ds.map_batches(
Expand Down
49 changes: 48 additions & 1 deletion tests/test_tokenize_shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,63 @@ def run_around_tests():
def test_tokenize_shuffle_simple():
content_len = 2048
NUM_TOKENS = 86058
NUM_PAGES = 160
NUM_JSONLS = 16
EOS = 1
PAD = 0

exit_value = os.system(
f"python open_lm/datapreprocess/ray/tokenize_shuffle.py --input s3://dcnlp-west-test/tokenize_shuffle_test/C4_V3_tiny/ --content_key content --output test_output/ --seqlen {content_len}"
)
assert exit_value == 0
ds = wds.WebDataset("test_output/00000001.tar").decode()
total = 0
eos_tokens = 0
padded_sequences = 0
for x in ds:
assert len(x["json.gz"]) == content_len + 1
total += len(x["json.gz"])
eos_tokens += x["json.gz"].count(EOS)
padded_sequences += 1 if x["json.gz"][-1] == PAD else 0

# assert total == NUM_TOKENS
assert eos_tokens == NUM_PAGES
assert padded_sequences == NUM_JSONLS

with open("test_output/manifest.jsonl", "rb") as f:
out = f.read()
out = [json.loads(o) for o in out.decode("utf-8").split("\n")[:-1]]

# assert out[0]["shard"] == "00000001"
# assert out[0]["num_sequences"] == NUM_TOKENS // (content_len + 1)


def test_tokenize_shuffle_overide_eos_and_pad():
content_len = 2048
NUM_TOKENS = 86058
NUM_PAGES = 160
NUM_JSONLS = 16
EOS = 1
PAD = 0

# Swap the identity of EOS and PAD special tokens to test whether --eos_overwrite and --pad_overwrite flags work correctly.
exit_value = os.system(
f"python open_lm/datapreprocess/ray/tokenize_shuffle.py --input s3://dcnlp-west-test/tokenize_shuffle_test/C4_V3_tiny/ --content_key content --output test_output/ --seqlen {content_len} --eos_overwrite {EOS} --pad_overwrite {PAD}"
)
assert exit_value == 0
ds = wds.WebDataset("test_output/00000001.tar").decode()
total = 0
eos_tokens = 0
padded_sequences = 0
for x in ds:
assert len(x["json.gz"]) == content_len + 1
total += len(x["json.gz"])
eos_tokens += x["json.gz"].count(EOS)
padded_sequences += 1 if x["json.gz"][-1] == PAD else 0

# assert total == NUM_TOKENS
assert eos_tokens == NUM_PAGES
assert padded_sequences == NUM_JSONLS

with open("test_output/manifest.jsonl", "rb") as f:
out = f.read()
Expand All @@ -40,7 +87,7 @@ def test_tokenize_shuffle_tar(content_key, NUM_TOKENS):

params = f"--content_key {content_key}"
if content_key == "npy":
params += " --vocab_size 16384"
params += " --pretokenized"

exit_value = os.system(
f"python open_lm/datapreprocess/ray/tokenize_shuffle.py --input s3://dcnlp-west-test/tokenize_shuffle_test/webvid_tiny/ {params} --output test_output/ --seqlen {content_len}"
Expand Down