Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 21 additions & 7 deletions strands_robots/policies/lerobot_local/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -886,17 +886,31 @@ def _build_batch_from_strands_format(
state_values.append(float(value))

if state_values:
# Validate state dimension matches what the model expects.
# Dimension mismatches cause shape errors in the forward pass.
# Auto-adapt state dimension to match what the model expects.
# Robots may expose more joints than the policy was trained on
# (e.g. aloha has 16 joints but ACT expects 14). Truncate excess
# or zero-pad if fewer, rather than raising an error.
state_feature = self._input_features.get("observation.state")
if state_feature:
expected_dim = state_feature.shape[0] if hasattr(state_feature, "shape") else len(state_values)
if len(state_values) != expected_dim:
raise ValueError(
f"State dimension mismatch: got {len(state_values)} values from "
f"robot_state_keys but model expects {expected_dim}. "
f"Check that robot_state_keys matches your robot's actual joint count."
if len(state_values) > expected_dim:
logger.warning(
"State dim %d > model expects %d — truncating to first %d values. "
"Check that robot_state_keys matches your robot's actual joint count.",
len(state_values),
expected_dim,
expected_dim,
)
state_values = state_values[:expected_dim]
elif len(state_values) < expected_dim:
logger.warning(
"State dim %d < model expects %d — zero-padding with %d zeros. "
"Check that robot_state_keys matches your robot's actual joint count.",
len(state_values),
expected_dim,
expected_dim - len(state_values),
)
state_values.extend([0.0] * (expected_dim - len(state_values)))
batch["observation.state"] = torch.tensor(state_values, dtype=torch.float32).unsqueeze(0).to(self._device)

# Map camera images to model's image input features.
Expand Down
10 changes: 7 additions & 3 deletions tests/test_lerobot_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -575,12 +575,16 @@ def test_numpy_floating_state(self):
np.testing.assert_allclose(batch["observation.state"][0].numpy(), [1.5, 2.5], atol=1e-5)

def test_state_padded_to_expected_dim(self):
"""State dimension mismatch should raise ValueError (fail-fast)."""
"""State dimension mismatch should auto-pad (not raise)."""
policy = _make_loaded_policy(state_dim=4, include_images=False)
policy.set_robot_state_keys(["a", "b"])
observation = {"a": 1.0, "b": 2.0}
with pytest.raises(ValueError, match="State dimension mismatch"):
policy._build_batch_from_strands_format(observation, {})
# After bug fix: auto-pads with zeros instead of raising
batch = policy._build_batch_from_strands_format(observation, {})
state = batch["observation.state"][0].numpy()
assert len(state) == 4
np.testing.assert_allclose(state[:2], [1.0, 2.0], atol=1e-5)
np.testing.assert_allclose(state[2:], [0.0, 0.0], atol=1e-5)

def test_empty_state_keys_raises(self):
"""Empty robot_state_keys should raise ValueError."""
Expand Down
Loading