Skip to content

Commit 3d619e7

Browse files
committed
replacing mnist-text for rotten-tomatoes in test-case due to hf scripts deprecation
1 parent 3985a37 commit 3d619e7

1 file changed

Lines changed: 12 additions & 10 deletions

File tree

stopes/utils/test_hf_shards.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,37 +8,39 @@
88
from stopes.utils.sharding.hf_shards import HFInputConfig, HFShard
99

1010
# TODO: Hard code this to test if there are changes in HF datasets API
11-
first_item_id = 7
11+
expected_first_four = [1, 0, 1, 0] # contemmcm/rotten_tomatoes first 4 reviewState values
1212

1313

1414
def test_shard_iteration():
1515
shard = HFShard(
1616
filter=None,
17-
path_or_name="Fraser/mnist-text-small",
18-
split="test",
17+
path_or_name="contemmcm/rotten_tomatoes",
18+
split="complete",
1919
index=0,
2020
num_shards=50,
2121
)
2222
with shard:
2323
item = next(iter(shard))
2424
assert isinstance(item, dict)
25-
assert "label" in item
26-
assert item["label"] == first_item_id
25+
assert "reviewState" in item
26+
assert item["reviewState"] == expected_first_four[0]
2727

2828
with shard as progress:
2929
batch_iter = progress.to_batches(batch_size=4)
30-
item = next(batch_iter)
31-
assert item["label"][0].as_py() == first_item_id # type: ignore
30+
batch = next(batch_iter)
31+
# Verify first 4 items match expected pattern [1,0,1,0]
32+
for i in range(4):
33+
assert batch["reviewState"][i].as_py() == expected_first_four[i] # type: ignore
3234

3335

3436
def test_input_config():
3537
input_config = HFInputConfig(
36-
input_file="Fraser/mnist-text-small",
37-
split="test",
38+
input_file="contemmcm/rotten_tomatoes",
39+
split="complete",
3840
num_shards=50,
3941
)
4042
shards = input_config.make_shards()
4143
first_shard = shards[0]
4244
with first_shard:
4345
item = next(iter(first_shard))
44-
assert item["label"] == first_item_id
46+
assert item["reviewState"] == expected_first_four[0]

0 commit comments

Comments
 (0)