|
8 | 8 | from stopes.utils.sharding.hf_shards import HFInputConfig, HFShard |
9 | 9 |
|
10 | 10 | # TODO: Hard code this to test if there are changes in HF datasets API |
11 | | -first_item_id = 7 |
| 11 | +expected_first_four = [1, 0, 1, 0] # contemmcm/rotten_tomatoes first 4 reviewState values |
12 | 12 |
|
13 | 13 |
|
14 | 14 | def test_shard_iteration(): |
15 | 15 | shard = HFShard( |
16 | 16 | filter=None, |
17 | | - path_or_name="Fraser/mnist-text-small", |
18 | | - split="test", |
| 17 | + path_or_name="contemmcm/rotten_tomatoes", |
| 18 | + split="complete", |
19 | 19 | index=0, |
20 | 20 | num_shards=50, |
21 | 21 | ) |
22 | 22 | with shard: |
23 | 23 | item = next(iter(shard)) |
24 | 24 | assert isinstance(item, dict) |
25 | | - assert "label" in item |
26 | | - assert item["label"] == first_item_id |
| 25 | + assert "reviewState" in item |
| 26 | + assert item["reviewState"] == expected_first_four[0] |
27 | 27 |
|
28 | 28 | with shard as progress: |
29 | 29 | batch_iter = progress.to_batches(batch_size=4) |
30 | | - item = next(batch_iter) |
31 | | - assert item["label"][0].as_py() == first_item_id # type: ignore |
| 30 | + batch = next(batch_iter) |
| 31 | + # Verify first 4 items match expected pattern [1,0,1,0] |
| 32 | + for i in range(4): |
| 33 | + assert batch["reviewState"][i].as_py() == expected_first_four[i] # type: ignore |
32 | 34 |
|
33 | 35 |
|
34 | 36 | def test_input_config(): |
35 | 37 | input_config = HFInputConfig( |
36 | | - input_file="Fraser/mnist-text-small", |
37 | | - split="test", |
| 38 | + input_file="contemmcm/rotten_tomatoes", |
| 39 | + split="complete", |
38 | 40 | num_shards=50, |
39 | 41 | ) |
40 | 42 | shards = input_config.make_shards() |
41 | 43 | first_shard = shards[0] |
42 | 44 | with first_shard: |
43 | 45 | item = next(iter(first_shard)) |
44 | | - assert item["label"] == first_item_id |
| 46 | + assert item["reviewState"] == expected_first_four[0] |
0 commit comments