diff --git a/python/dolma/cli/deduper.py b/python/dolma/cli/deduper.py index de6a43d5..d263d4ca 100644 --- a/python/dolma/cli/deduper.py +++ b/python/dolma/cli/deduper.py @@ -192,7 +192,6 @@ def run(cls, parsed_config: DeduperConfig): # perform some path validation to make sure we don't call the mixer with invalid config total_matching_documents = 0 for document in parsed_config.documents: - if not any( fnmatch.fnmatch(dict_config["dedupe"]["document_dir"], part) for part in document.split(os.sep) ): diff --git a/python/dolma/cli/mixer.py b/python/dolma/cli/mixer.py index 2ac6c5c5..41f632c6 100644 --- a/python/dolma/cli/mixer.py +++ b/python/dolma/cli/mixer.py @@ -66,6 +66,9 @@ class StreamConfig: "from the file extension." ), ) + document_dir: str = field( + default="documents", help="Folder in source path to replace with 'attributes' when looking for attributes" + ) @dataclass @@ -145,7 +148,6 @@ def run(cls, parsed_config: MixerConfig): # perform some path validation to make sure we don't call the mixer with invalid config total_matching_documents = 0 for document in stream_config.documents: - current_matching_documents = sum(1 for _ in glob_path(document)) if current_matching_documents == 0: # only raise a warning if no documents are found for a single path @@ -159,6 +161,7 @@ def run(cls, parsed_config: MixerConfig): # populate the stream config dict stream_config_dict["name"] = stream_config.name stream_config_dict["documents"] = [str(d) for d in stream_config.documents] + stream_config_dict["document_dir"] = stream_config.document_dir stream_config_dict["attributes"] = [str(a) for a in list(stream_config.attributes)] stream_config_dict["output"] = { "path": str(stream_config.output.path), diff --git a/python/dolma/cli/tagger.py b/python/dolma/cli/tagger.py index 9982ec05..9d29eafe 100644 --- a/python/dolma/cli/tagger.py +++ b/python/dolma/cli/tagger.py @@ -91,6 +91,10 @@ class TaggerConfig: default=False, help="If true, only print the configuration and exit without running the taggers.", ) + document_dir: str = field( + default="documents", + help="The folder in source paths to replace with 'attributes' to store results, if not 'documents'", + ) class TaggerCli(BaseCli): @@ -140,6 +144,7 @@ def run(cls, parsed_config: TaggerConfig): profile_output=parsed_config.profile.output, profile_steps=parsed_config.profile.steps, profile_sort_key=parsed_config.profile.sort_key, + document_dir=parsed_config.document_dir, ) diff --git a/python/dolma/core/runtime.py b/python/dolma/core/runtime.py index ac5e2a23..8ebfa0d3 100644 --- a/python/dolma/core/runtime.py +++ b/python/dolma/core/runtime.py @@ -392,6 +392,7 @@ def create_and_run_tagger( profile_steps: Optional[int] = None, profile_sort_key: str = "tottime", profile_lines: int = 100, + document_dir: str = "documents", ): """This function creates a tagger and runs it on a list of documents. @@ -444,7 +445,7 @@ def create_and_run_tagger( if destination is None: try: - destination = _make_paths_from_substitution(documents, "documents", f"attributes/{experiment}") + destination = _make_paths_from_substitution(documents, document_dir, f"attributes/{experiment}") except Exception as exp: raise RuntimeError("Could not make destination paths from documents paths") from exp elif isinstance(destination, str): diff --git a/python/dolma/warc/linearizers.py b/python/dolma/warc/linearizers.py index a99c0775..c4d588cd 100644 --- a/python/dolma/warc/linearizers.py +++ b/python/dolma/warc/linearizers.py @@ -143,3 +143,9 @@ def linearize(self, content: Union[str, bytes]) -> str: ) self._flush() return output or "" + + +@LinearizerRegistry.add("no-op") +class NoOpLinearizer(BaseLinearizer): + def linearize(self, content: Union[str, bytes]) -> str: + return str(content) diff --git a/python/dolma/warc/processor.py b/python/dolma/warc/processor.py index 474c6ca9..a3d949d1 100644 --- a/python/dolma/warc/processor.py +++ b/python/dolma/warc/processor.py @@ -247,6 +247,7 @@ def create_and_run_warc_pipeline( store_html_in_metadata: bool = False, skip_no_pre_taggers: bool = False, skip_no_post_taggers: bool = False, + skip_linearization: bool = False, ): with ExitStack() as stack: if metadata is None: diff --git a/src/shard.rs b/src/shard.rs index 226ba194..c66d6511 100644 --- a/src/shard.rs +++ b/src/shard.rs @@ -40,6 +40,10 @@ impl Shard { pub fn split_streams(streams: &Vec) -> Result, IoError> { let mut shards: Vec = Vec::new(); for stream_config in streams { + let document_dir = format!( + "/{}/", + stream_config.document_dir.as_deref().unwrap_or("documents") + ); let mut stream_shard_count = 0; log::info!("Computing shards for stream {}...", stream_config.name); let stream_inputs = find_objects_matching_patterns(&stream_config.documents)?; @@ -50,7 +54,7 @@ impl Shard { let mut attr_paths = Vec::new(); for prefix in stream_config.attributes.iter() { let attr_prefix = format!("/attributes/{}/", prefix); - let attr_path = input.replace("/documents/", &attr_prefix); + let attr_path = input.replace(&document_dir, &attr_prefix); attr_paths.push(attr_path); } ( @@ -135,13 +139,17 @@ impl Shard { // dataset is a strict subset of the original and is intended to be unshuffled and unsharded. let mut shards: Vec = Vec::new(); for stream_config in streams { + let document_dir = format!( + "/{}/", + stream_config.document_dir.as_deref().unwrap_or("documents") + ); let stream_inputs = find_objects_matching_patterns(&stream_config.documents)?; let input_count = stream_inputs.len(); let inputs = stream_inputs.into_iter().map(|input| { let mut attr_paths = Vec::new(); for prefix in stream_config.attributes.iter() { let attr_prefix = format!("/attributes/{}/", prefix); - let attr_path = input.replace("/documents/", &attr_prefix); + let attr_path = input.replace(&document_dir, &attr_prefix); attr_paths.push(attr_path); } DocumentPaths { @@ -152,10 +160,11 @@ impl Shard { for input in inputs { let doc_path_clone = input.doc_path.clone(); - let output_suffix = doc_path_clone.split("/documents/").last().unwrap(); + let output_suffix = doc_path_clone.split(&document_dir).last().unwrap(); let output = format!( - "{}/documents/{}", + "{}{}{}", stream_config.output.path.clone(), + document_dir, output_suffix ); log::info!("Creating shard for {}", output); @@ -543,6 +552,7 @@ pub mod shard_config { pub span_replacement: Option>, pub output: StreamOutputConfig, pub compression: Option, + pub document_dir: Option, } #[derive(Serialize, Deserialize, Clone)] diff --git a/tests/config/alt-path-mixer.json b/tests/config/alt-path-mixer.json new file mode 100644 index 00000000..cdcbe596 --- /dev/null +++ b/tests/config/alt-path-mixer.json @@ -0,0 +1,34 @@ +{ + "streams": [ + { + "name": "mixer-test", + "documents": [ + "tests/data/provided/alternative_term/*.gz" + ], + "document_dir":"alternative_term", + "output": { + "path": "tests/work/output/mixer", + "max_size_in_bytes": 100000 + }, + "attributes": [ + "pii", + "toxicity" + ], + "filter": { + "include": [ + "$.metadata[?(@.length < 10000)]" + ], + "exclude": [ + "$.metadata[?(@.length < 500)]", + "$.attributes[?(@.pii.too_much_pii == true)]", + "$.attributes[?(@.toxicity > 0.8)]" + ] + } + } + ], + "work_dir": { + "input": "tests/work/temp/mixer/input", + "output": "tests/work/temp/mixer/output" + }, + "processes": 1 +} diff --git a/tests/data/provided/alternative_term/000.json.gz b/tests/data/provided/alternative_term/000.json.gz new file mode 100644 index 00000000..f5419508 Binary files /dev/null and b/tests/data/provided/alternative_term/000.json.gz differ diff --git a/tests/python/test_mixer.py b/tests/python/test_mixer.py index 68ea1721..5c6d4718 100644 --- a/tests/python/test_mixer.py +++ b/tests/python/test_mixer.py @@ -22,6 +22,8 @@ EMAIL_SPANS_JQ = Path(__file__).parent.parent / "config/email-spans-jq.yaml" FILTER_BY_SPANS = Path(__file__).parent.parent / "config/filter-by-spans.json" MIXER = Path(__file__).parent.parent / "config/mixer.json" +ALT_DOC_PATH_MIXER = Path(__file__).parent.parent / "config/alt-path-mixer.json" + PARAGRAPH_SPANS = Path(__file__).parent.parent / "config/paragraph-spans.json" @@ -150,6 +152,35 @@ def test_remote_input_remote_output(self): provided = self.checkAndRemoveProvenance(provided) self.assertEqual(expected, provided) + def test_alt_doc_path_mixer(self): + if self.remote_test_prefix is None: + return self.skipTest("Skipping AWS tests") + + with open(ALT_DOC_PATH_MIXER, mode="r", encoding="utf8") as f: + config = json.load(f) + + # keep track of local output path + local_input = config["streams"][0]["documents"][0] + local_output = config["streams"][0]["output"]["path"] + + # replace results path with s3 path + config["streams"][0]["output"]["path"] = f"{self.remote_test_prefix}/{local_output}" + + # upload local input to s3, replace local input with s3 path + config["streams"][0]["documents"][0] = f"{self.remote_test_prefix}/{local_input}" + + with NamedTemporaryFile("w") as f: + json.dump(config, f) + f.flush() + + main(argv=["-c", f.name, "mix"]) + + download_s3_prefix(f"{self.remote_test_prefix}/tests/work", "tests/work/remote") + expected = load_jsonl("tests/data/expected/mixer.json.gz") + provided = load_jsonl("tests/work/remote/output/mixer/mixer-test-0000.json.gz") + provided = self.checkAndRemoveProvenance(provided) + self.assertEqual(expected, provided) + def test_remote_input_local_output(self): if self.remote_test_prefix is None: return self.skipTest("Skipping AWS tests") diff --git a/tests/python/test_paths.py b/tests/python/test_paths.py index e920af74..df758e22 100644 --- a/tests/python/test_paths.py +++ b/tests/python/test_paths.py @@ -295,7 +295,6 @@ def test_split_glob(self): class TestSplitExt(TestCase): def test_file(self): - prot, parts, ext = split_ext("file.txt") self.assertEqual(prot, "") @@ -318,7 +317,6 @@ def test_file(self): self.assertEqual(ext, ".") def test_path(self): - prot, parts, ext = split_ext("path/to/file.txt") self.assertEqual(prot, "") diff --git a/tests/python/test_warc.py b/tests/python/test_warc.py index 04f0e9a7..bd70da91 100644 --- a/tests/python/test_warc.py +++ b/tests/python/test_warc.py @@ -103,3 +103,59 @@ def test_pretag_html(self): {"by_4_0", "by_3_0"}, ) self.assertIn("cc_re__cc_re__cc_by_4_0", sample1[2]["attributes"]) + + def test_skip_linearization(self): + """Test that when skip_linearization is True, the raw HTML content is preserved.""" + outputs = self._run_pipeline_with_skip_linearization() + self.assertEqual(len(outputs), 2) + self.assertIn("sample-0000.jsonl.gz", outputs) + self.assertIn("sample-0001.jsonl.gz", outputs) + + sample0 = outputs["sample-0000.jsonl.gz"] + sample1 = outputs["sample-0001.jsonl.gz"] + + # Check that we got some documents + self.assertGreater(len(sample0), 0) + self.assertGreater(len(sample1), 0) + + # For all documents, verify they contain raw HTML instead of linearized text + for sample in chain(sample0, sample1): + # HTML content should be in the text field + self.assertIn("<", sample["text"]) + self.assertIn(">", sample["text"]) + + # Common HTML tags that should be present in raw HTML + html_indicators = [" Dict[str, List[dict]]: + """Helper method to run pipeline with skip_linearization=True.""" + create_and_run_warc_pipeline( + documents=[f"{DATA_PATH}/*.warc.gz"], + destination=[self.tempdir], + num_processes=1, + ignore_existing=False, + debug=True, + source_name="test", + skip_no_pre_taggers=False, + skip_no_post_taggers=False, + store_html_in_metadata=False, + linearizer_name="no-op", + pre_taggers=["cc_re"], + post_taggers=["lingua_1e2"], + ) + outputs: Dict[str, List[dict]] = {} + for fn in os.listdir(self.tempdir): + with smart_open.open(os.path.join(self.tempdir, fn), mode="rt", encoding="utf-8") as f: + for ln in f: + outputs.setdefault(fn, []).append(json.loads(ln)) + return outputs diff --git a/tests/python/utils.py b/tests/python/utils.py index 9813f2d3..a96c24df 100644 --- a/tests/python/utils.py +++ b/tests/python/utils.py @@ -70,9 +70,9 @@ def skip_aws_tests() -> bool: return (dolma_tests_skip or "false").lower() == "true" -def upload_test_documents(local_input: str, test_prefix: str) -> Tuple[str, str]: - remote_input = f"{test_prefix}/input/documents" - remote_output = f"{test_prefix}/output/documents" +def upload_test_documents(local_input: str, test_prefix: str, document_dir: str = "documents") -> Tuple[str, str]: + remote_input = f"{test_prefix}/input/{document_dir}" + remote_output = f"{test_prefix}/output/{document_dir}" for i, local_fp in enumerate(glob_path(local_input)): remote_fp = f"{remote_input}/{i:05d}.json.gz" @@ -127,6 +127,7 @@ def upload_s3_prefix(s3_prefix: str, local_prefix: str): bucket_name, prefix = parse_s3_path(s3_prefix) for local_fp in glob_path(local_prefix): + print(f"LOCAL_FP {local_fp}") name = local_fp.replace(local_prefix, "").lstrip("/") s3.upload_file(Bucket=bucket_name, Key=f"{prefix}/{name}", Filename=local_fp) @@ -167,9 +168,11 @@ def writeUnits( return [str(p) for p in file_paths] - def writeDocs(self, docs: List[str], partitions: int = 1, ext_dir: Optional[Path] = None) -> List[str]: + def writeDocs( + self, docs: List[str], partitions: int = 1, ext_dir: Optional[Path] = None, unit_type: str = "documents" + ) -> List[str]: encoded_docs = [{"id": str(i), "text": d, "source": __file__} for i, d in enumerate(docs)] - return self.writeUnits(units=encoded_docs, unit_type="documents", partitions=partitions, ext_dir=ext_dir) + return self.writeUnits(units=encoded_docs, unit_type=unit_type, partitions=partitions, ext_dir=ext_dir) def writeAttributes( self,