Skip to content

Commit

Permalink
do not overwrite tagger outputs with the same output path, fixes #113 (
Browse files Browse the repository at this point in the history
…#114)

* do not overwrite tagger outputs with the same output path

* added test for failure

* removed unused import

* caught error

---------

Co-authored-by: Luca Soldaini <[email protected]>
  • Loading branch information
peterbjorgensen and soldni authored Feb 7, 2024
1 parent 4b696f9 commit 9968464
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 1 deletion.
3 changes: 2 additions & 1 deletion python/dolma/core/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,8 @@ def _write_sample_to_streams(

# if not set; it will potentially not write to the output stream
# in case a tagger emits no spans
attributes_by_stream[tagger_output.path] = {}
if tagger_output.path not in attributes_by_stream:
attributes_by_stream[tagger_output.path] = {}

for tagger_key, tagger_value in tagger_data.items():
tagger_key = f"{tagger_output.exp}__{tagger_output.name}__{make_variable_name(tagger_key)}"
Expand Down
72 changes: 72 additions & 0 deletions tests/python/test_runtime.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import json
import os
import shutil
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import List, Optional
from unittest import TestCase

import smart_open
Expand Down Expand Up @@ -156,6 +158,76 @@ def test_alt_src(self):
self.assertEqual(attributes_full_name, attributes_star_in_path)
self.assertEqual(attributes_full_name, attributes_only_dir)

def test_multiple_taggers(self, experiment_name: Optional[str] = None):
documents_dir = Path(f"{LOCAL_DATA}/provided/documents")
taggers = ["c4_v1", "gopher_v1"]

with TemporaryDirectory() as temp_dir:
temp_path = Path(temp_dir)
(temp_path / "documents").mkdir(exist_ok=True)

for path in documents_dir.iterdir():
shutil.copy(path, temp_path / "documents" / path.name)

create_and_run_tagger(
documents=[os.path.join(temp_dir, "documents") + "/*"],
taggers=taggers,
experiment=experiment_name,
debug=True,
)

if experiment_name is None:
all_attribute_dirs = [temp_path / "attributes" / t for t in taggers]
else:
all_attribute_dirs = [temp_path / "attributes" / experiment_name]

for d in all_attribute_dirs:
# check if a folder for each tagger was created
self.assertTrue(os.path.exists(d))

# collect all attributes for all documents here
attributes = []

for fn in documents_dir.iterdir():
# collect all attributes for the current document here
current_attrs: List[dict] = []

for attr_path in all_attribute_dirs:
# check if attribute to corresponding document was created
attr_fp = attr_path / fn.name
self.assertTrue(attr_fp.exists())

if len(current_attrs) == 0:
with smart_open.open(attr_fp, "rt") as f:
# no attributes for this file name loaded in yet
current_attrs = [json.loads(ln) for ln in f]
else:
with smart_open.open(attr_fp, "rt") as f:
for i, attr_doc in enumerate(json.loads(ln) for ln in f):
# check if attributes are aligned
self.assertTrue(attr_doc["id"] == current_attrs[i]["id"])
current_attrs[i]["attributes"].update(attr_doc["attributes"])

attributes.extend(current_attrs)

for row in attributes:
# check if name of attribute files is correct
attribute_files_names = set(k.split("__")[0] for k in row["attributes"].keys())

if experiment_name is None:
self.assertEqual(attribute_files_names, set(taggers))
else:
self.assertEqual(attribute_files_names, {experiment_name})

# check if name of taggers is correct
tagger_names = set(k.split("__")[1] for k in row["attributes"].keys())
self.assertEqual(tagger_names, set(taggers))

def test_multiple_with_exp_name(self):
# same as test_multiple_taggers, but provide an experiment name
# this is to test failure reported here: https://github.com/allenai/dolma/pull/113
self.test_multiple_taggers(experiment_name="experiment_name")

def test_alt_exp(self):
documents_path = f"{LOCAL_DATA}/provided/documents/000.json.gz"
taggers = ["c4_v1"]
Expand Down

0 comments on commit 9968464

Please sign in to comment.