diff --git a/examples/manipulation-demo-streamlit.py b/examples/manipulation-demo-streamlit.py index e226c1f88..4c438f93b 100644 --- a/examples/manipulation-demo-streamlit.py +++ b/examples/manipulation-demo-streamlit.py @@ -13,30 +13,118 @@ # limitations under the License. import importlib +import sys +from pathlib import Path +import rclpy import streamlit as st from langchain_core.messages import AIMessage, HumanMessage, ToolMessage +from launch import LaunchDescription +from launch.actions import ( + IncludeLaunchDescription, +) +from launch.launch_description_sources import PythonLaunchDescriptionSource +from launch_ros.actions import Node +from launch_ros.substitutions import FindPackageShare from rai.agents.integrations.streamlit import get_streamlit_cb, streamlit_invoke +from rai.communication.ros2.connectors.ros2_connector import ROS2Connector from rai.messages import HumanMultimodalMessage +from rai_bench.manipulation_o3de import get_scenarios +from rai_bench.manipulation_o3de.benchmark import Scenario +from rai_sim.o3de.o3de_bridge import ( + O3DEngineArmManipulationBridge, + O3DExROS2SimulationConfig, +) +from rai_sim.simulation_bridge import SceneConfig + manipulation_demo = importlib.import_module("manipulation-demo") +def launch_description(): + launch_moveit = IncludeLaunchDescription( + PythonLaunchDescriptionSource( + [ + "src/examples/rai-manipulation-demo/Project/Examples/panda_moveit_config_demo.launch.py", + ] + ) + ) + + launch_robotic_manipulation = Node( + package="robotic_manipulation", + executable="robotic_manipulation", + output="screen", + parameters=[ + {"use_sim_time": True}, + ], + ) + + launch_openset = IncludeLaunchDescription( + PythonLaunchDescriptionSource( + [ + FindPackageShare("rai_bringup"), + "/launch/openset.launch.py", + ] + ), + ) + + return LaunchDescription( + [ + launch_openset, + launch_moveit, + launch_robotic_manipulation, + ] + ) + + +@st.cache_resource +def init_ros(): + rclpy.init() + return "ros" + + @st.cache_resource def initialize_graph(): return manipulation_demo.create_agent() -def main(): +@st.cache_resource +def initialize_o3de(scenario_path: str, o3de_config_path: str): + simulation_config = O3DExROS2SimulationConfig.load_config( + config_path=Path(o3de_config_path) + ) + scene_config = SceneConfig.load_base_config(Path(scenario_path)) + scenario = Scenario( + task=None, + scene_config=scene_config, + scene_config_path=scenario_path, + ) + o3de = O3DEngineArmManipulationBridge(ROS2Connector()) + o3de.init_simulation(simulation_config=simulation_config) + o3de.launch_robotic_stack( + required_robotic_ros2_interfaces=simulation_config.required_robotic_ros2_interfaces, + launch_description=launch_description(), + ) + o3de.setup_scene(scenario.scene_config) + + +def main(scenario: Scenario, simulation_config: O3DExROS2SimulationConfig): st.set_page_config( page_title="RAI Manipulation Demo", page_icon=":robot:", ) st.title("RAI Manipulation Demo") st.markdown("---") - st.sidebar.header("Tool Calls History") + if "ros" not in st.session_state: + ros = init_ros() + st.session_state["ros"] = ros + + if "o3de" not in st.session_state: + o3de = initialize_o3de(scenario, simulation_config) + st.session_state["o3de"] = o3de + if "graph" not in st.session_state: graph = initialize_graph() st.session_state["graph"] = graph @@ -70,4 +158,26 @@ def main(): if __name__ == "__main__": - main() + levels = [ + "medium", + "hard", + "very_hard", + ] + scenarios: list[Scenario] = get_scenarios(levels=levels) + scenario_names = [Path(s.scene_config_path).stem for s in scenarios] + print(scenario_names) + + if len(sys.argv) > 1: + layout = sys.argv[1] + if layout not in scenario_names: + raise ValueError(f"Invalid layout: {layout}. Select from {scenario_names}") + else: + layout = "3carrots_1a_1t_2bc_2yc" + o3de_config_path = ( + "src/rai_bench/rai_bench/manipulation_o3de/predefined/configs/o3de_config.yaml" + ) + + scenario_idx = scenario_names.index(layout) + scenario = str(scenarios[scenario_idx].scene_config_path) + + main(scenario, o3de_config_path)