Skip to content

Commit

Permalink
Infinigram counting script for paper
Browse files Browse the repository at this point in the history
  • Loading branch information
jakep-allenai committed Feb 18, 2025
1 parent 6020122 commit e4f9b19
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions scripts/infinigram_count.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,10 @@ def process_document(doc, tokenizer, ngram_size, num_samples, index="v4_rpj_llam

# Determine valid starting indices based on word-split boundaries.
valid_positions = []
for i in range(len(token_ids) - ngram_size + 1):
start_offset = offsets[i][0]
if start_offset == 0 or (start_offset > 0 and text[start_offset - 1] == " "):
valid_positions.append(i)
# for i in range(len(token_ids) - ngram_size + 1):
# start_offset = offsets[i][0]
# if start_offset == 0 or (start_offset > 0 and text[start_offset - 1] == " "):
# valid_positions.append(i)

if not valid_positions:
# Fallback: if no valid positions are found, use all possible positions.
Expand All @@ -91,7 +91,7 @@ def process_document(doc, tokenizer, ngram_size, num_samples, index="v4_rpj_llam
ngram_token_ids = token_ids[idx: idx+ngram_size]
ngram_str = tokenizer.decode(ngram_token_ids, clean_up_tokenization_spaces=True)
# Only accept n-grams that contain only allowed characters.
if ALLOWED_RE.fullmatch(ngram_str):
if ALLOWED_RE.fullmatch(ngram_str) and len(ngram_str.strip()) > ngram_size * 3:
count = query_infinigram(ngram_str, index=index)
flag = "YES" if count > 0 else "NO"
valid_ngram_details.append((flag, ngram_str))
Expand All @@ -107,7 +107,7 @@ def main():
)
parser.add_argument("N", type=int, help="Number of random .jsonl files to process")
parser.add_argument("s3_path", type=str, help="S3 path to a prefix containing .jsonl files (e.g., s3://my-bucket/my-prefix/)")
parser.add_argument("--index", type=str, default="v4_rpj_llama_s4", help="Infini-gram index to use (default: v4_rpj_llama_s4)")
parser.add_argument("--index", type=str, default="v4_dolma-v1_7_llama", help="Infini-gram index to use (default: v4_rpj_llama_s4)")
parser.add_argument("--ngram_size", type=int, default=10, help="Size of the n-gram to sample (default: 10)")
parser.add_argument("--num_ngrams", type=int, default=100, help="Number of random n-grams to sample from each document (default: 100)")
args = parser.parse_args()
Expand Down

0 comments on commit e4f9b19

Please sign in to comment.