Skip to content

Latest commit

 

History

History
executable file
·
156 lines (116 loc) · 4.87 KB

File metadata and controls

executable file
·
156 lines (116 loc) · 4.87 KB

ShiElding with Control barrier fUnctions in inverse REinforcement learning (SECURE)

Codebase for SECURE

Implementations with Garage and Tensorflow.

Please consider citing our work if you are using the codebase:

@inproceedings{secure_hri2024,
  title={Enhancing Safety in Learning from Demonstration Algorithms via Control Barrier Function Shielding},
  author={Yang*, Yue and Chen*, Letian and Zaidi*, Zulfiqar and van Waveren, Sanne and Krishna, Arjun and Gombolay, Matthew},
  booktitle={Proceedings of International Conference on Human-Robot Interaction (HRI)},
  year={2024}
}

The repo is currently open-sourced under the CC-BY-NC license.

Setup: Dependencies and Environment Preparation

The code is tested with Python 3.7 with Anaconda.

Required packages:

pip install numpy joblib==0.11 tensorflow==2.9.0 scipy path PyMC3 cached-property pyprind gym matplotlib dowel akro ray psutil setproctitle cma Box2D gymnasium==0.28.1 torch opencv-python

gym==0.14.0 and tensorflow-probability==0.14.0 does not like each other, so we need to separately install tensorflow-probability:

pip install tensorflow-probability==0.14.0

If you are directly running python scripts, you will need to add the project root into your PYTHONPATH:

export PYTHONPATH=\path\to\this\repo\src

To use the Panda arm push domain, you need to install the panda-gym first. Run the following commands:

pip install -e panda-gym

Running SECURE

Example - 1: Demolition derby

  1. Collect demonstrations and states: We have prepared the demonstrations and states for Demolition derby. Please find them in the following locations:
  • Demonstrations: src/demonstrations/demos_demolition_derby.pkl
  • States: src/states/states_demolition_derby.pkl
  1. Train CBF NN: The learned CBF NN is prepared at data/demolition_derby/cbf_model. You can also obtain it by:
python src/models/train_cbf_nn_demolition_derby.py
  1. Run the AIRL script on Demolition derby: The learned policy model is prepared at data/demolition_derby/airl_model. You can also obtain it by:
# AIRL
python scripts/train_airl_demolition_derby.py
  1. Evaluate AIRL on Demolition derby:
python scripts/evaluate_airl_demolition_derby.py

The results will be saved in data/demolition_derby/airl_model/share/eval_results_just_airl.txt.

  1. Evaluate SECURE on Demolition derby:
python scripts/secure_demolition_derby.py

The results will be saved in data/demolition_derby/airl_model/share/eval_results_secure.txt.

Example - 2: Panda arm push

  1. Collect demonstrations and states: We have prepared the demonstrations and states for Demolition derby. Please find them in the following locations:
  • Demonstrations: src/demonstrations/demos_panda_arm_push.pkl
  • States: src/states/states_panda_arm_push.pkl
  1. Train CBF NN: The learned CBF NN is prepared at data/panda_arm_push/cbf_model. You can also obtain it by:
python src/models/train_cbf_nn_panda_arm_push.py
  1. Run the AIRL script on Panda arm push: To learn a better AIRL policy for this more complex domain, we use BC for warmup. The learned policy model is prepared at data/panda_arm_push/bc. You can also obtain it by:
# AIRL
python scripts/bc_airl_panda_arm_push.py

The learned policy model is prepared at data/panda_arm_push/airl_model. You can also obtain it by:

# AIRL
python scripts/train_airl_panda_arm_push.py
  1. Evaluate AIRL on Panda arm push:
python scripts/evaluate_airl_panda_arm_push.py

The results will be saved in data/panda_arm_push/airl_model/share/eval_results_just_airl.txt.

  1. Evaluate SECURE script on Panda arm push:
python scripts/secure_panda_arm_push.py

The results will be saved in data/panda_arm_push/airl_model/share/eval_results_secure.txt.

Other domains

  1. Collect demonstrations to src/demonstrations.
  2. Collect states to src/states.
  3. Create a script modeled after scripts/secure_demolition_derby.py.
  4. Change the environment, location, and log prefix. Use circle() if it's 2D environment, or use fibonacci_sphere() in scripts/secure_panda_arm_push.py if it's a 3D environment.
  5. Run the script analagous to demolition derby.
python scripts\{your_script}.py

Code Structure

The SECURE code is adjusted from the original AIRL codebase and MACBF.