Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion isaaclab_arena/embodiments/franka/franka.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def __init__(
self,
enable_cameras: bool = False,
initial_pose: Pose | None = None,
initial_joint_pose: list[float] | None = None,
concatenate_observation_terms: bool = False,
arm_mode: ArmMode | None = None,
):
Expand All @@ -55,6 +56,8 @@ def __init__(
self.observation_config = FrankaObservationsCfg()
self.observation_config.policy.concatenate_terms = self.concatenate_observation_terms
self.event_config = FrankaEventCfg()
if initial_joint_pose is not None:
self.set_initial_joint_pose(initial_joint_pose)
self.reward_config = FrankaRewardsCfg()
self.mimic_env = FrankaMimicEnv

Expand All @@ -68,6 +71,9 @@ def _update_scene_cfg_with_robot_initial_pose(self, scene_config: Any, pose: Pos
scene_config.stand.init_state.rot = pose.rotation_wxyz
return scene_config

def set_initial_joint_pose(self, initial_joint_pose: list[float]) -> None:
self.event_config.init_franka_arm_pose.params["default_pose"] = initial_joint_pose

def get_ee_frame_name(self, arm_mode: ArmMode) -> str:
return "ee_frame"

Expand Down Expand Up @@ -176,7 +182,7 @@ def __post_init__(self):

@configclass
class FrankaEventCfg:
"""Configuration for Franek."""
"""Configuration for Franka."""

init_franka_arm_pose = EventTerm(
func=franka_stack_events.set_default_joint_pose,
Expand Down
136 changes: 136 additions & 0 deletions isaaclab_arena/tasks/goal_pose_task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
# Copyright (c) 2025, The Isaac Lab Arena Project Developers (https://github.com/isaac-sim/IsaacLab-Arena/blob/main/CONTRIBUTORS.md).
# All rights reserved.
#
# SPDX-License-Identifier: Apache-2.0

import numpy as np
from dataclasses import MISSING

import isaaclab.envs.mdp as mdp_isaac_lab
from isaaclab.envs.common import ViewerCfg
from isaaclab.managers import EventTermCfg, SceneEntityCfg, TerminationTermCfg
from isaaclab.scene import InteractiveSceneCfg
from isaaclab.utils import configclass

from isaaclab_arena.assets.asset import Asset
from isaaclab_arena.metrics.metric_base import MetricBase
from isaaclab_arena.metrics.object_moved import ObjectMovedRateMetric
from isaaclab_arena.metrics.success_rate import SuccessRateMetric
from isaaclab_arena.tasks.task_base import TaskBase
from isaaclab_arena.tasks.terminations import goal_pose_task_termination
from isaaclab_arena.terms.events import set_object_pose
from isaaclab_arena.utils.cameras import get_viewer_cfg_look_at_object


class GoalPoseTask(TaskBase):
def __init__(
self,
object: Asset,
episode_length_s: float | None = None,
target_x_range: tuple[float, float] | None = None,
target_y_range: tuple[float, float] | None = None,
target_z_range: tuple[float, float] | None = None,
target_orientation_wxyz: tuple[float, float, float, float] | None = None,
target_orientation_tolerance_rad: float | None = None,
):
"""
Args:
object: The object asset for the goal pose task.
episode_length_s: Episode length in seconds.
target_x_range: Success zone x-range [min, max] in meters.
target_y_range: Success zone y-range [min, max] in meters.
target_z_range: Success zone z-range [min, max] in meters.
target_orientation_wxyz: Target quaternion [w, x, y, z].
target_orientation_tolerance_rad: Angular tolerance in radians (default: 0.1).
"""
super().__init__(episode_length_s=episode_length_s)
self.object = object
# this is needed to revise the default env_spacing in arena_env_builder: priority task > embodiment > scene > default
self.scene_config = InteractiveSceneCfg(num_envs=1, env_spacing=3.0, replicate_physics=False)
self.events_cfg = GoalPoseEventCfg(self.object)
self.termination_cfg = self.make_termination_cfg(
target_x_range=target_x_range,
target_y_range=target_y_range,
target_z_range=target_z_range,
target_orientation_wxyz=target_orientation_wxyz,
target_orientation_tolerance_rad=target_orientation_tolerance_rad,
)

def get_scene_cfg(self):
return self.scene_config

def get_termination_cfg(self):
return self.termination_cfg

def make_termination_cfg(
self,
target_x_range: tuple[float, float] | None = None,
target_y_range: tuple[float, float] | None = None,
target_z_range: tuple[float, float] | None = None,
target_orientation_wxyz: tuple[float, float, float, float] | None = None,
target_orientation_tolerance_rad: float | None = None,
):
params: dict = {"object_cfg": SceneEntityCfg(self.object.name)}
if target_x_range is not None:
params["target_x_range"] = target_x_range
if target_y_range is not None:
params["target_y_range"] = target_y_range
if target_z_range is not None:
params["target_z_range"] = target_z_range
if target_orientation_wxyz is not None:
params["target_orientation_wxyz"] = target_orientation_wxyz
if target_orientation_tolerance_rad is not None:
params["target_orientation_tolerance_rad"] = target_orientation_tolerance_rad

success = TerminationTermCfg(
func=goal_pose_task_termination,
params=params,
)
return TerminationsCfg(success=success)

def get_events_cfg(self):
return self.events_cfg

def get_prompt(self):
raise NotImplementedError("Function not implemented yet.")

def get_mimic_env_cfg(self, embodiment_name: str):
raise NotImplementedError("Function not implemented yet.")

def get_metrics(self) -> list[MetricBase]:
return [
SuccessRateMetric(),
ObjectMovedRateMetric(self.object),
]

def get_viewer_cfg(self) -> ViewerCfg:
return get_viewer_cfg_look_at_object(lookat_object=self.object, offset=np.array([1.5, 1.5, 1.5]))


@configclass
class TerminationsCfg:
"""Termination terms for the MDP."""

time_out: TerminationTermCfg = TerminationTermCfg(func=mdp_isaac_lab.time_out)
success: TerminationTermCfg = MISSING


@configclass
class GoalPoseEventCfg:
"""Configuration for goal pose."""

reset_object_pose: EventTermCfg = MISSING

def __init__(self, object: Asset):
initial_pose = object.get_initial_pose()
if initial_pose is not None:
self.reset_object_pose = EventTermCfg(
func=set_object_pose,
mode="reset",
params={
"pose": initial_pose,
"asset_cfg": SceneEntityCfg(object.name),
},
)
else:
raise ValueError(f"Initial pose is not set for the object {object.name}")
66 changes: 66 additions & 0 deletions isaaclab_arena/tasks/terminations.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#
# SPDX-License-Identifier: Apache-2.0

import math
import torch

from isaaclab.assets import RigidObject
Expand Down Expand Up @@ -77,6 +78,71 @@ def objects_in_proximity(
return done


def goal_pose_task_termination(
env: ManagerBasedRLEnv,
object_cfg: SceneEntityCfg = SceneEntityCfg("object"),
target_x_range: tuple[float, float] | None = None,
target_y_range: tuple[float, float] | None = None,
target_z_range: tuple[float, float] | None = None,
target_orientation_wxyz: tuple[float, float, float, float] | None = None,
target_orientation_tolerance_rad: float = 0.1,
) -> torch.Tensor:
"""Terminate when the object's pose is within the thresholds (BBox + Orientation).

Args:
env: The RL environment instance.
object_cfg: The configuration of the object to track.
target_x_range: Success zone x-range [min, max] in meters.
target_y_range: Success zone y-range [min, max] in meters.
target_z_range: Success zone z-range [min, max] in meters.
target_orientation_wxyz: Target quaternion [w, x, y, z].
target_orientation_tolerance_rad: Angular tolerance in radians (default: 0.1).

Returns:
A boolean tensor of shape (num_envs, )
"""
object_instance: RigidObject = env.scene[object_cfg.name]
object_root_pos_w = object_instance.data.root_pos_w
object_root_quat_w = object_instance.data.root_quat_w

device = env.device
num_envs = env.num_envs

has_any_threshold = any([
target_x_range is not None,
target_y_range is not None,
target_z_range is not None,
target_orientation_wxyz is not None,
])

if not has_any_threshold:
return torch.zeros(num_envs, dtype=torch.bool, device=device)

success = torch.ones(num_envs, dtype=torch.bool, device=device)

# Position range checks
ranges = [target_x_range, target_y_range, target_z_range]
for idx, range_val in enumerate(ranges):
if range_val is not None:
range_min, range_max = range_val
in_range = (object_root_pos_w[:, idx] >= range_min) & (object_root_pos_w[:, idx] <= range_max)
success &= in_range

# Orientation check
if target_orientation_wxyz is not None:
target_quat = torch.tensor(target_orientation_wxyz, device=device, dtype=torch.float32).unsqueeze(0)

# Formula: |<q1, q2>| > cos(tolerance / 2)
quat_dot = torch.sum(object_root_quat_w * target_quat, dim=-1)
abs_dot = torch.abs(quat_dot)
min_cos = math.cos(target_orientation_tolerance_rad / 2.0)

ori_success = abs_dot >= min_cos
success &= ori_success

return success


def object_above(
env: ManagerBasedRLEnv,
object_cfg: SceneEntityCfg,
Expand Down
Loading