|
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | 14 |
|
| 15 | +import time |
| 16 | + |
15 | 17 | import pytest |
16 | 18 |
|
17 | 19 | try: |
@@ -149,70 +151,42 @@ def test_point_cloud_filter(): |
149 | 151 |
|
150 | 152 |
|
151 | 153 | def test_get_gripping_point_tool_timeout(): |
152 | | - """Test GetGrippingPointTool timeout behavior.""" |
153 | | - # Mock the connector and components |
| 154 | + # Complete mock setup |
154 | 155 | mock_connector = Mock(spec=ROS2Connector) |
155 | | - |
156 | | - # Create mock components that will simulate timeout |
157 | 156 | mock_pcl_gen = Mock(spec=PointCloudFromSegmentation) |
158 | | - mock_pcl_gen.run.side_effect = lambda x: [] # Return empty to simulate no detection |
159 | | - |
160 | 157 | mock_filter = Mock(spec=PointCloudFilter) |
161 | | - mock_filter.run.return_value = [] |
162 | | - |
163 | 158 | mock_estimator = Mock(spec=GrippingPointEstimator) |
| 159 | + |
| 160 | + # Test 1: No timeout (fast execution) |
| 161 | + mock_pcl_gen.run.return_value = [] |
| 162 | + mock_filter.run.return_value = [] |
164 | 163 | mock_estimator.run.return_value = [] |
165 | 164 |
|
166 | | - # Create tool with short timeout |
167 | 165 | tool = GetGrippingPointTool( |
168 | 166 | connector=mock_connector, |
169 | | - point_cloud_from_segmentation=mock_pcl_gen, |
170 | | - point_cloud_filter=mock_filter, |
| 167 | + target_frame="base", |
| 168 | + source_frame="camera", |
| 169 | + camera_topic="/image", |
| 170 | + depth_topic="/depth", |
| 171 | + camera_info_topic="/info", |
171 | 172 | gripping_point_estimator=mock_estimator, |
172 | | - timeout_sec=0.1, |
| 173 | + point_cloud_filter=mock_filter, |
| 174 | + timeout_sec=5.0, |
173 | 175 | ) |
| 176 | + tool.point_cloud_from_segmentation = mock_pcl_gen # Connect the mock |
174 | 177 |
|
175 | | - # Test successful run with no gripping points found |
176 | | - result = tool._run("test_object") |
177 | | - assert "No gripping point found" in result |
178 | | - assert "test_object" in result |
179 | | - |
180 | | - # Test with mock that simulates found gripping points |
181 | | - mock_estimator.run.return_value = [np.array([1.0, 2.0, 3.0], dtype=np.float32)] |
182 | | - result = tool._run("test_object") |
183 | | - assert "gripping point of the object test_object is" in result |
184 | | - assert "[1. 2. 3.]" in result |
185 | | - |
186 | | - # Test with multiple gripping points |
187 | | - mock_estimator.run.return_value = [ |
188 | | - np.array([1.0, 2.0, 3.0], dtype=np.float32), |
189 | | - np.array([4.0, 5.0, 6.0], dtype=np.float32), |
190 | | - ] |
| 178 | + # Test fast execution - should complete without timeout |
191 | 179 | result = tool._run("test_object") |
192 | | - assert "Multiple gripping points found" in result |
193 | | - |
194 | | - |
195 | | -def test_get_gripping_point_tool_validation(): |
196 | | - """Test GetGrippingPointTool input validation.""" |
197 | | - mock_connector = Mock(spec=ROS2Connector) |
198 | | - mock_pcl_gen = Mock(spec=PointCloudFromSegmentation) |
199 | | - mock_filter = Mock(spec=PointCloudFilter) |
200 | | - mock_estimator = Mock(spec=GrippingPointEstimator) |
201 | | - |
202 | | - # Test tool creation |
203 | | - tool = GetGrippingPointTool( |
204 | | - connector=mock_connector, |
205 | | - point_cloud_from_segmentation=mock_pcl_gen, |
206 | | - point_cloud_filter=mock_filter, |
207 | | - gripping_point_estimator=mock_estimator, |
208 | | - ) |
| 180 | + assert "No test_objects detected" in result |
| 181 | + assert "timed out" not in result.lower() |
209 | 182 |
|
210 | | - # Verify tool properties |
211 | | - assert tool.name == "get_gripping_point" |
212 | | - assert "gripping points" in tool.description |
213 | | - assert tool.timeout_sec == 10.0 # default value |
| 183 | + # Test 2: Actual timeout behavior |
| 184 | + def slow_operation(obj_name): |
| 185 | + time.sleep(2.0) # Longer than timeout |
| 186 | + return [] |
214 | 187 |
|
215 | | - # Test args schema |
216 | | - from rai.tools.ros2.detection.tools import GetGrippingPointToolInput |
| 188 | + mock_pcl_gen.run.side_effect = slow_operation |
| 189 | + tool.timeout_sec = 1.0 # Short timeout |
217 | 190 |
|
218 | | - assert tool.args_schema == GetGrippingPointToolInput |
| 191 | + result = tool._run("test") |
| 192 | + assert "timed out" in result.lower() or "timeout" in result.lower() |
0 commit comments