Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

first #240

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft

first #240

Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion python/dolma/cli/deduper.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,6 @@ def run(cls, parsed_config: DeduperConfig):
# perform some path validation to make sure we don't call the mixer with invalid config
total_matching_documents = 0
for document in parsed_config.documents:

if not any(
fnmatch.fnmatch(dict_config["dedupe"]["document_dir"], part) for part in document.split(os.sep)
):
Expand Down
5 changes: 4 additions & 1 deletion python/dolma/cli/mixer.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,9 @@ class StreamConfig:
"from the file extension."
),
)
document_dir: str = field(
default="documents", help="Folder in source path to replace with 'attributes' when looking for attributes"
)


@dataclass
Expand Down Expand Up @@ -145,7 +148,6 @@ def run(cls, parsed_config: MixerConfig):
# perform some path validation to make sure we don't call the mixer with invalid config
total_matching_documents = 0
for document in stream_config.documents:

current_matching_documents = sum(1 for _ in glob_path(document))
if current_matching_documents == 0:
# only raise a warning if no documents are found for a single path
Expand All @@ -159,6 +161,7 @@ def run(cls, parsed_config: MixerConfig):
# populate the stream config dict
stream_config_dict["name"] = stream_config.name
stream_config_dict["documents"] = [str(d) for d in stream_config.documents]
stream_config_dict["document_dir"] = stream_config.document_dir
stream_config_dict["attributes"] = [str(a) for a in list(stream_config.attributes)]
stream_config_dict["output"] = {
"path": str(stream_config.output.path),
Expand Down
5 changes: 5 additions & 0 deletions python/dolma/cli/tagger.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,10 @@ class TaggerConfig:
default=False,
help="If true, only print the configuration and exit without running the taggers.",
)
document_dir: Optional[str] = field(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we parametrize attributes too?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could for the sake of symmetry but it has little utility that I can see

default="documents",
help="The folder in source paths to replace with 'attributes' to store results, if not 'documents'",
)


class TaggerCli(BaseCli):
Expand Down Expand Up @@ -140,6 +144,7 @@ def run(cls, parsed_config: TaggerConfig):
profile_output=parsed_config.profile.output,
profile_steps=parsed_config.profile.steps,
profile_sort_key=parsed_config.profile.sort_key,
document_dir=parsed_config.document_dir,
)


Expand Down
3 changes: 2 additions & 1 deletion python/dolma/core/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,7 @@ def create_and_run_tagger(
profile_steps: Optional[int] = None,
profile_sort_key: str = "tottime",
profile_lines: int = 100,
document_dir: Optional[str] = "documents",
):
"""This function creates a tagger and runs it on a list of documents.

Expand Down Expand Up @@ -444,7 +445,7 @@ def create_and_run_tagger(

if destination is None:
try:
destination = _make_paths_from_substitution(documents, "documents", f"attributes/{experiment}")
destination = _make_paths_from_substitution(documents, document_dir, f"attributes/{experiment}")
except Exception as exp:
raise RuntimeError("Could not make destination paths from documents paths") from exp
elif isinstance(destination, str):
Expand Down
15 changes: 11 additions & 4 deletions python/dolma/warc/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ def process_single(
pre_taggers_names: List[str] = kwargs.get("pre_taggers") or []
pre_taggers = {make_variable_name(name): TaggerRegistry.get(name)() for name in pre_taggers_names}


# create the html extractor
linearizer_name: str = kwargs.get("linearizer_name") or "resiliparse"
linearizer = LinearizerRegistry.get(linearizer_name)()
Expand All @@ -127,6 +128,7 @@ def process_single(
# whether to skip this document if post-taggers find nothing
skip_no_post_taggers: bool = kwargs.get("skip_no_post_taggers") or False

skip_linearization: bool = kwargs.get("skip_linearization") or False
# derive the destination path if it is not provided by splitting out all the
# extensions, removing gz and warc, and adding jsonl.gz
if not destination_path.endswith(".jsonl.gz"):
Expand Down Expand Up @@ -192,12 +194,15 @@ def process_single(
continue

# extract text
doc.text = linearizer.linearize(content=decoded_content)
if skip_linearization:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we have a no-op linearizer instead of a boolean flag?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will do

doc.text = decoded_content
else:
doc.text = linearizer.linearize(content=decoded_content)

# these are the properties extracted from the HTML content
post_attributes = {name: tagger.tag(doc) for name, tagger in post_taggers.items()}
if skip_no_post_taggers and not sum(map(len, post_attributes.values())):
continue
# post_attributes = {name: tagger.tag(doc) for name, tagger in post_taggers.items()}
# if skip_no_post_taggers and not sum(map(len, post_attributes.values())):
# continue

doc.attributes = {
f"{t_name}__{t_name}__{make_variable_name(a_name)}": attr_values
Expand Down Expand Up @@ -247,6 +252,7 @@ def create_and_run_warc_pipeline(
store_html_in_metadata: bool = False,
skip_no_pre_taggers: bool = False,
skip_no_post_taggers: bool = False,
skip_linearization: bool = False,
):
with ExitStack() as stack:
if metadata is None:
Expand Down Expand Up @@ -302,4 +308,5 @@ def create_and_run_warc_pipeline(
skip_no_pre_taggers=skip_no_pre_taggers,
skip_no_post_taggers=skip_no_post_taggers,
source_name=source_name,
skip_linearization=skip_linearization
)
18 changes: 14 additions & 4 deletions src/shard.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@ impl Shard {
pub fn split_streams(streams: &Vec<StreamConfig>) -> Result<Vec<Shard>, IoError> {
let mut shards: Vec<Shard> = Vec::new();
for stream_config in streams {
let document_dir = format!(
"/{}/",
stream_config.document_dir.as_deref().unwrap_or("documents")
);
let mut stream_shard_count = 0;
log::info!("Computing shards for stream {}...", stream_config.name);
let stream_inputs = find_objects_matching_patterns(&stream_config.documents)?;
Expand All @@ -50,7 +54,7 @@ impl Shard {
let mut attr_paths = Vec::new();
for prefix in stream_config.attributes.iter() {
let attr_prefix = format!("/attributes/{}/", prefix);
let attr_path = input.replace("/documents/", &attr_prefix);
let attr_path = input.replace(&document_dir, &attr_prefix);
attr_paths.push(attr_path);
}
(
Expand Down Expand Up @@ -135,13 +139,17 @@ impl Shard {
// dataset is a strict subset of the original and is intended to be unshuffled and unsharded.
let mut shards: Vec<Shard> = Vec::new();
for stream_config in streams {
let document_dir = format!(
"/{}/",
stream_config.document_dir.as_deref().unwrap_or("documents")
);
let stream_inputs = find_objects_matching_patterns(&stream_config.documents)?;
let input_count = stream_inputs.len();
let inputs = stream_inputs.into_iter().map(|input| {
let mut attr_paths = Vec::new();
for prefix in stream_config.attributes.iter() {
let attr_prefix = format!("/attributes/{}/", prefix);
let attr_path = input.replace("/documents/", &attr_prefix);
let attr_path = input.replace(&document_dir, &attr_prefix);
attr_paths.push(attr_path);
}
DocumentPaths {
Expand All @@ -152,10 +160,11 @@ impl Shard {

for input in inputs {
let doc_path_clone = input.doc_path.clone();
let output_suffix = doc_path_clone.split("/documents/").last().unwrap();
let output_suffix = doc_path_clone.split(&document_dir).last().unwrap();
let output = format!(
"{}/documents/{}",
"{}{}{}",
stream_config.output.path.clone(),
document_dir,
output_suffix
);
log::info!("Creating shard for {}", output);
Expand Down Expand Up @@ -543,6 +552,7 @@ pub mod shard_config {
pub span_replacement: Option<Vec<SpanReplacementConfig>>,
pub output: StreamOutputConfig,
pub compression: Option<CompressionConfig>,
pub document_dir: Option<String>,
}

#[derive(Serialize, Deserialize, Clone)]
Expand Down
32 changes: 32 additions & 0 deletions tests/python/test_mixer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
EMAIL_SPANS_JQ = Path(__file__).parent.parent / "config/email-spans-jq.yaml"
FILTER_BY_SPANS = Path(__file__).parent.parent / "config/filter-by-spans.json"
MIXER = Path(__file__).parent.parent / "config/mixer.json"
ALT_DOC_PATH_MIXER = Path(__file__).parent.parent / "config/alt-path-mixer.json"

PARAGRAPH_SPANS = Path(__file__).parent.parent / "config/paragraph-spans.json"


Expand Down Expand Up @@ -150,6 +152,36 @@ def test_remote_input_remote_output(self):
provided = self.checkAndRemoveProvenance(provided)
self.assertEqual(expected, provided)

def test_alt_doc_path_mixer(self):
if self.remote_test_prefix is None:
return self.skipTest("Skipping AWS tests")

with open(ALT_DOC_PATH_MIXER, mode="r", encoding="utf8") as f:
config = json.load(f)

# keep track of local output path
local_input = config["streams"][0]["documents"][0]
local_output = config["streams"][0]["output"]["path"]

# replace results path with s3 path
config["streams"][0]["output"]["path"] = f"{self.remote_test_prefix}/{local_output}"

# upload local input to s3, replace local input with s3 path
config["streams"][0]["documents"][0] = f"{self.remote_test_prefix}/{local_input}"

with NamedTemporaryFile("w") as f:
json.dump(config, f)
f.flush()

main(argv=["-c", f.name, "mix"])

download_s3_prefix(f"{self.remote_test_prefix}/tests/work", "tests/work/remote")
expected = load_jsonl("tests/data/expected/mixer.json.gz")
provided = load_jsonl("tests/work/remote/output/mixer/mixer-test-0000.json.gz")
provided = self.checkAndRemoveProvenance(provided)
self.assertEqual(expected, provided)


def test_remote_input_local_output(self):
if self.remote_test_prefix is None:
return self.skipTest("Skipping AWS tests")
Expand Down
2 changes: 0 additions & 2 deletions tests/python/test_paths.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,6 @@ def test_split_glob(self):

class TestSplitExt(TestCase):
def test_file(self):

prot, parts, ext = split_ext("file.txt")

self.assertEqual(prot, "")
Expand All @@ -318,7 +317,6 @@ def test_file(self):
self.assertEqual(ext, ".")

def test_path(self):

prot, parts, ext = split_ext("path/to/file.txt")

self.assertEqual(prot, "")
Expand Down
51 changes: 51 additions & 0 deletions tests/python/test_warc.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,3 +103,54 @@ def test_pretag_html(self):
{"by_4_0", "by_3_0"},
)
self.assertIn("cc_re__cc_re__cc_by_4_0", sample1[2]["attributes"])

def test_skip_linearization(self):
"""Test that when skip_linearization is True, the raw HTML content is preserved."""
outputs = self._run_pipeline_with_skip_linearization()
self.assertEqual(len(outputs), 2)
self.assertIn("sample-0000.jsonl.gz", outputs)
self.assertIn("sample-0001.jsonl.gz", outputs)

sample0 = outputs["sample-0000.jsonl.gz"]
sample1 = outputs["sample-0001.jsonl.gz"]

# Check that we got some documents
self.assertGreater(len(sample0), 0)
self.assertGreater(len(sample1), 0)

# For all documents, verify they contain raw HTML instead of linearized text
for sample in chain(sample0, sample1):
# HTML content should be in the text field
self.assertIn("<", sample["text"])
self.assertIn(">", sample["text"])

# Common HTML tags that should be present in raw HTML
html_indicators = ["<html", "<body", "<div", "<p"]
self.assertTrue(any(indicator in sample["text"].lower() for indicator in html_indicators))

# Basic metadata should still be present
self.assertEqual(sample["version"], "v0")
self.assertEqual(sample["source"], "test")
self.assertIn("warc_url", sample["metadata"])
self.assertIn("url", sample["metadata"])
self.assertIn("warc_date", sample["metadata"])
self.assertIn("warc_filename", sample["metadata"])
self.assertIn("content_type", sample["metadata"])

def _run_pipeline_with_skip_linearization(self) -> Dict[str, List[dict]]:
"""Helper method to run pipeline with skip_linearization=True."""
create_and_run_warc_pipeline(
documents=[f"{DATA_PATH}/*.warc.gz"],
destination=[self.tempdir],
num_processes=1,
ignore_existing=False,
debug=True,
source_name="test",
skip_no_pre_taggers=False,
skip_no_post_taggers=False,
store_html_in_metadata=False,
linearizer_name="resiliparse",
skip_linearization=True,
pre_taggers=["cc_re"],
post_taggers=["lingua_1e2"],
)
55 changes: 28 additions & 27 deletions tests/python/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,33 +70,33 @@ def skip_aws_tests() -> bool:
return (dolma_tests_skip or "false").lower() == "true"


def upload_test_documents(local_input: str, test_prefix: str) -> Tuple[str, str]:
remote_input = f"{test_prefix}/input/documents"
remote_output = f"{test_prefix}/output/documents"
# def upload_test_documents(local_input: str, test_prefix: str, document_dir: str = "documents") -> Tuple[str, str]:
# remote_input = f"{test_prefix}/input/{document_dir}"
# remote_output = f"{test_prefix}/output/{document_dir}"

for i, local_fp in enumerate(glob_path(local_input)):
remote_fp = f"{remote_input}/{i:05d}.json.gz"
# for i, local_fp in enumerate(glob_path(local_input)):
# remote_fp = f"{remote_input}/{i:05d}.json.gz"

with open(local_fp, "rb") as f, open(remote_fp, "wb") as g:
g.write(f.read())
# with open(local_fp, "rb") as f, open(remote_fp, "wb") as g:
# g.write(f.read())

return remote_input, remote_output
# return remote_input, remote_output


def upload_test_attributes(local_attributes: str, test_prefix: str):
remote_attributes = f"{test_prefix}/input/attributes"
# def upload_test_attributes(local_attributes: str, test_prefix: str):
# remote_attributes = f"{test_prefix}/input/attributes"

for i, local_fp in enumerate(glob_path(local_attributes)):
matched = re.match(r"^(attributes|duplicate)-(\w+)", local_fp)
if not matched:
raise RuntimeError(f"Unexpected filename: {local_fp}")
# for i, local_fp in enumerate(glob_path(local_attributes)):
# matched = re.match(r"^(attributes|duplicate)-(\w+)", local_fp)
# if not matched:
# raise RuntimeError(f"Unexpected filename: {local_fp}")

_, name = matched.groups()
# _, name = matched.groups()

remote_fp = f"{remote_attributes}/{name}/{i:05d}.json.gz"
# remote_fp = f"{remote_attributes}/{name}/{i:05d}.json.gz"

with open(local_fp, "rb") as f, open(remote_fp, "wb") as g:
g.write(f.read())
# with open(local_fp, "rb") as f, open(remote_fp, "wb") as g:
# g.write(f.read())


def clean_test_data(test_prefix: str):
Expand Down Expand Up @@ -127,6 +127,7 @@ def upload_s3_prefix(s3_prefix: str, local_prefix: str):
bucket_name, prefix = parse_s3_path(s3_prefix)

for local_fp in glob_path(local_prefix):
print(f"LOCAL_FP {local_fp}")
name = local_fp.replace(local_prefix, "").lstrip("/")
s3.upload_file(Bucket=bucket_name, Key=f"{prefix}/{name}", Filename=local_fp)

Expand Down Expand Up @@ -167,9 +168,9 @@ def writeUnits(

return [str(p) for p in file_paths]

def writeDocs(self, docs: List[str], partitions: int = 1, ext_dir: Optional[Path] = None) -> List[str]:
def writeDocs(self, docs: List[str], partitions: int = 1, ext_dir: Optional[Path] = None,unit_type: str = "documents") -> List[str]:
encoded_docs = [{"id": str(i), "text": d, "source": __file__} for i, d in enumerate(docs)]
return self.writeUnits(units=encoded_docs, unit_type="documents", partitions=partitions, ext_dir=ext_dir)
return self.writeUnits(units=encoded_docs, unit_type=unit_type, partitions=partitions, ext_dir=ext_dir)

def writeAttributes(
self,
Expand Down Expand Up @@ -199,10 +200,10 @@ def writeConfig(self, config: dict, ext_dir: Optional[Path] = None) -> str:
def combineIntoDoc(self, *lines: str, join: str = "\n") -> str:
return join.join(lines)

def makeDocsCopy(self, path: Union[str, Path]) -> str:
path = Path(path)
dest = Path(self.makeUniquePath()) / "documents"
dest.mkdir(parents=True)
for fp in path.iterdir():
shutil.copy(fp, dest / fp.name)
return str(dest)
# def makeDocsCopy(self, path: Union[str, Path]) -> str:
# path = Path(path)
# dest = Path(self.makeUniquePath()) / "documents"
# dest.mkdir(parents=True)
# for fp in path.iterdir():
# shutil.copy(fp, dest / fp.name)
# return str(dest)
Loading