This repository contains the official implementation of the paper:
Prompt4Trust: A Reinforcement Learning Prompt Augmentation Framework for Clinically-Aligned Confidence Calibration in Multimodal Large Language Models
Anita Kriz*, Elizabeth Laura Janes*, Xing Shen*, Tal Arbel
*Equal contribution
IEEE/CVF International Conference on Computer Vision (ICCV) Workshops, 2025
Paper (arXiv preprint), BibTeX
Multimodal large language models (MLLMs) show great potential for healthcare applications, but their clinical deployment is challenged by prompt sensitivity and overconfident incorrect responses. To improve trustworthiness in safety-critical settings, we introduce Prompt4Trust, the first reinforcement learning framework for prompt augmentation focused on confidence calibration in MLLMs. A lightweight LLM is trained to generate context-aware auxiliary prompts that guide a downstream MLLM to produce predictions with confidence scores that better reflect true accuracy. By prioritizing clinically meaningful calibration, Prompt4Trust enhances both reliability and task performance, achieving state-of-the-art results on the PMC-VQA benchmark while enabling efficient zero-shot generalization to larger MLLMs.
Make sure you have at least 4 NVIDIA GPUs with adequate memory (memory requirement depends on the scale of the LLM/MLLM you want to use) if you wish to use open-source downstream task MLLMs.
We recommend to download the open-source LLMs/MLLMs using huggingface-cli before you start (make sure you obtained relevant permissions/agreement to download the models from Hugging Face):
huggingface-cli login
huggingface-cli download {REPO_NAME} --local-dir {SAVE_FOLDER} --local-dir-use-symlinks FalseFor example, the {REPO_NAME} can be Qwen/Qwen2.5-1.5B-Instruct and {SAVE_FOLDER} can be /usr/local/data/Qwen2.5-1.5B-Instruct. The downloaded model will be saved in the specified folder {SAVE_FOLDER}.
You can skip the training by downloading and using our trained CGP Generator.
It is recommended to use a virtual environment (e.g., venv) to avoid package conflicts. Here we assume you are using venv as your virtual environment. If you are using conda, please adjust the commands accordingly.
git clone https://github.com/xingbpshen/prompt4trust.git
cd prompt4trust/
pip install -r requirements.txtCreate a data/ folder under the project root prompt4trust/:
mkdir data/Download the dataset from the PMC-VQA repository here. Put all files in the newly created data/ folder.
Our training split can be generated using dataset/gen_train.py (with modification to data path) or can be downloaded here (coming soon).
The config files are located at config/. You can modify the parameters according to your needs. The default config file pmcvqa.yml is for the PMC-VQA dataset.
Here are some important parameters you may want to modify:
resources.cache_dirThis is where vLLM and other python packages will be cached. Make sure you have enough space.resources.policy_cudaThis is a string of CUDA devices (e.g.,"3,4"or"3") used for the policy update/training. Make sure you have enough memory on these devices.resources.action_cudaThis is a string of CUDA devices used for the TRL with vLLM serving to sample "actions" (in the context of reinforcement learning). Make sure you have enough memory on these devices.resources.downstream_cudaThis is a string of CUDA devices used for the downstream MLLM (to obtain reward). Make sure you have enough memory on these devices.model.policyThis is the model name. You can use any repository name supported by Hugging Face or a path to a local model (e.g.,"Qwen/Qwen2.5-1.5B-Instruct"or"/usr/local/data/Qwen2.5-1.5B-Instruct").model.downstreamThis is the model name. You can use any repository name supported by Hugging Face or a path to a local model.
Please note that resources.policy_cuda, resources.action_cuda, and resources.downstream_cuda must not include any overlapping device to avoid CUDA initialization error.
You can skip this step if you have already downloaded our trained CGP Generator from Hugging Face 🤗.
To enable TRL with vLLM serving, we need to start 2 servers: one for the policy model (to sample action) and one for the downstream LLM to calculate reward.
These servers will be started automatically so you do not need to do anything now.
By default, the policy model will be trained with GRPO using TRL support. Run the following command to start training:
python main.py --config {DATASET}.yml --log_folder {LOG_FOLDER} --trial_name {TRIAL_NAME} --train --niRunning the above command once will start:
- 2 detached subprocesses for vLLMs, each corresponding to one of the servers. You can observe the GPU memory usage increasing in the terminal. You can use
nvidia-smito check the GPU memory usage for your specified CUDA devicesresources.action_cudaandresources.downstream_cuda. - 1 foreground engine subprocess for TRL, which will be responsible for the training of the policy model. You can observe the GPU memory usage (on your specified CUDA devices
resources.policy_cuda) increasing in the terminal.
Runtime related logs will be saved in {LOG_FOLDER}/{TRIAL_NAME}/ folder.
Run the following command to evaluate the trained policy model:
python main.py --config {DATASET}.yml --log_folder {LOG_FOLDER} --trial_name {TRIAL_NAME} --test --ni This work was supported in part by the Natural Sciences and Engineering Research Council of Canada, in part by the Canadian Institute for Advanced Research (CIFAR) Artificial Intelligence Chairs Program, in part by the Mila - Quebec Artificial Intelligence Institute, in part by the compute resources provided by Mila (mila.quebec), in part by the Mila-Google Research Grant, in part by the Fonds de recherche du Québec, in part by the Canada First Research Excellence Fund, awarded to the Healthy Brains, Healthy Lives initiative at McGill University, and in part by the Department of Electrical and Computer Engineering at McGill University.
If you find this repository useful for your research, please consider citing our paper:
@inproceedings{kriz2025prompt4trust,
title={Prompt4Trust: A Reinforcement Learning Prompt Augmentation Framework for Clinically-Aligned Confidence Calibration in Multimodal Large Language Models},
author={Kriz, Anita and Janes, Elizabeth Laura and Shen, Xing and Arbel, Tal},
booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV) Workshops},
pages={1320--1329},
year={2025}
}Please raise a GitHub issue or email us at xing.shen@mail.mcgill.ca (with the email subject starting with "[Prompt4Trust]") if you have any question or encounter any issue.
