Skip to content

Commit 22bf48c

Browse files
author
The TensorFlow Datasets Authors
committed
Fix the multi_news dataset.
PiperOrigin-RevId: 794084822
1 parent 9666187 commit 22bf48c

File tree

8 files changed

+33
-27
lines changed

8 files changed

+33
-27
lines changed

tensorflow_datasets/summarization/multi_news.py

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,7 @@
1616
"""Multi-News dataset."""
1717

1818
import os
19-
20-
from tensorflow_datasets.core.utils.lazy_imports_utils import tensorflow as tf
19+
from etils import epath
2120
import tensorflow_datasets.public_api as tfds
2221

2322
_CITATION = """
@@ -42,7 +41,8 @@
4241
- summary: news summary.
4342
"""
4443

45-
_URL = "https://drive.google.com/uc?export=download&id=1vRY2wM6rlOZrf9exGTm5pXj5ExlVwJ0C"
44+
_URL_PATH = "https://huggingface.co/datasets/multi_news/resolve/main/data"
45+
4646

4747
_DOCUMENT = "document"
4848
_SUMMARY = "summary"
@@ -51,7 +51,7 @@
5151
class MultiNews(tfds.core.GeneratorBasedBuilder):
5252
"""Multi-News dataset."""
5353

54-
VERSION = tfds.core.Version("1.0.0")
54+
VERSION = tfds.core.Version("2.0.0")
5555

5656
def _info(self):
5757
return tfds.core.DatasetInfo(
@@ -67,35 +67,35 @@ def _info(self):
6767

6868
def _split_generators(self, dl_manager):
6969
"""Returns SplitGenerators."""
70-
extract_path = os.path.join(
71-
dl_manager.download_and_extract(_URL), "multi-news-original"
72-
)
73-
return [
74-
tfds.core.SplitGenerator(
75-
name=tfds.Split.TRAIN,
76-
gen_kwargs={"path": os.path.join(extract_path, "train")},
77-
),
78-
tfds.core.SplitGenerator(
79-
name=tfds.Split.VALIDATION,
80-
gen_kwargs={"path": os.path.join(extract_path, "val")},
70+
data_dict = {
71+
"train_src": _URL_PATH + "train.src.cleaned",
72+
"train_tgt": _URL_PATH + "train.tgt",
73+
"val_src": _URL_PATH + "val.src.cleaned",
74+
"val_tgt": _URL_PATH + "val.tgt",
75+
"test_src": _URL_PATH + "test.src.cleaned",
76+
"test_tgt": _URL_PATH + "test.tgt",
77+
}
78+
files = dl_manager.download_and_extract(data_dict)
79+
return {
80+
"train": self._generate_examples(
81+
files["train_src"], files["train_tgt"]
8182
),
82-
tfds.core.SplitGenerator(
83-
name=tfds.Split.TEST,
84-
gen_kwargs={"path": os.path.join(extract_path, "test")},
83+
"validation": self._generate_examples(
84+
files["val_src"], files["val_tgt"]
8585
),
86-
]
86+
"test": self._generate_examples(files["test_src"], files["test_tgt"]),
87+
}
8788

88-
def _generate_examples(self, path=None):
89+
def _generate_examples(self, src_file, tgt_file):
8990
"""Yields examples."""
90-
with tf.io.gfile.GFile(
91-
os.path.join(path + ".src")
92-
) as src_f, tf.io.gfile.GFile(os.path.join(path + ".tgt")) as tgt_f:
91+
with epath.Path(src_file).open() as src_f, epath.Path(
92+
tgt_file
93+
).open() as tgt_f:
9394
for i, (src_line, tgt_line) in enumerate(zip(src_f, tgt_f)):
9495
yield i, {
9596
# In original file, each line has one example and natural newline
9697
# tokens "\n" are being replaced with "NEWLINE_CHAR". Here restore
9798
# the natural newline token to avoid special vocab "NEWLINE_CHAR".
9899
_DOCUMENT: src_line.strip().replace("NEWLINE_CHAR", "\n"),
99-
# Remove the starting token "- " for every target sequence.
100-
_SUMMARY: tgt_line.strip().lstrip("- "),
100+
_SUMMARY: tgt_line.strip().lstrip(),
101101
}

tensorflow_datasets/summarization/multi_news_test.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,14 @@ class MultiNewsTest(testing.DatasetBuilderTestCase):
2626
"validation": 1, # Number of fake validation example
2727
"test": 1, # Number of fake test example
2828
}
29-
DL_EXTRACT_RESULT = ""
30-
29+
DL_EXTRACT_RESULT = {
30+
"train_src": "train.src.cleaned",
31+
"train_tgt": "train.tgt",
32+
"val_src": "val.src.cleaned",
33+
"val_tgt": "val.tgt",
34+
"test_src": "test.src.cleaned",
35+
"test_tgt": "test.tgt",
36+
}
3137

3238
if __name__ == "__main__":
3339
testing.test_main()

0 commit comments

Comments
 (0)