diff --git a/test/test_datasets.py b/test/test_datasets.py index 48d08b846de..2c74332a833 100644 --- a/test/test_datasets.py +++ b/test/test_datasets.py @@ -2746,15 +2746,15 @@ def inject_fake_data(self, tmpdir, config): datasets_utils.create_image_folder( root=split_dir, name="image_2", - file_name_fn=lambda i: f"{i:06d}_10.png", - num_examples=num_examples, + file_name_fn=lambda i: f"{i // 2:06d}_1{i % 2}.png", + num_examples=num_examples * 2, size=(3, 100, 200), ) datasets_utils.create_image_folder( root=split_dir, name="image_3", - file_name_fn=lambda i: f"{i:06d}_10.png", - num_examples=num_examples, + file_name_fn=lambda i: f"{i // 2:06d}_1{i % 2}.png", + num_examples=num_examples * 2, size=(3, 100, 200), ) @@ -2762,7 +2762,7 @@ def inject_fake_data(self, tmpdir, config): datasets_utils.create_image_folder( root=split_dir, name="disp_occ_0", - file_name_fn=lambda i: f"{i:06d}.png", + file_name_fn=lambda i: f"{i:06d}_10.png", num_examples=num_examples, # Kitti2015 uses a single channel image for disparities size=(1, 100, 200), @@ -2771,7 +2771,7 @@ def inject_fake_data(self, tmpdir, config): datasets_utils.create_image_folder( root=split_dir, name="disp_occ_1", - file_name_fn=lambda i: f"{i:06d}.png", + file_name_fn=lambda i: f"{i:06d}_10.png", num_examples=num_examples, # Kitti2015 uses a single channel image for disparities size=(1, 100, 200), diff --git a/torchvision/datasets/_stereo_matching.py b/torchvision/datasets/_stereo_matching.py index b07161d277c..feed695565c 100644 --- a/torchvision/datasets/_stereo_matching.py +++ b/torchvision/datasets/_stereo_matching.py @@ -334,13 +334,13 @@ def __init__(self, root: str, split: str = "train", transforms: Optional[Callabl verify_str_arg(split, "split", valid_values=("train", "test")) root = Path(root) / "Kitti2015" / (split + "ing") - left_img_pattern = str(root / "image_2" / "*.png") - right_img_pattern = str(root / "image_3" / "*.png") + left_img_pattern = str(root / "image_2" / "*_10.png") + right_img_pattern = str(root / "image_3" / "*_10.png") self._images = self._scan_pairs(left_img_pattern, right_img_pattern) if split == "train": - left_disparity_pattern = str(root / "disp_occ_0" / "*.png") - right_disparity_pattern = str(root / "disp_occ_1" / "*.png") + left_disparity_pattern = str(root / "disp_occ_0" / "*_10.png") + right_disparity_pattern = str(root / "disp_occ_1" / "*_10.png") self._disparities = self._scan_pairs(left_disparity_pattern, right_disparity_pattern) else: self._disparities = list((None, None) for _ in self._images)