Skip to content

This is the official code for the paper "Booster: Tackling Harmful Fine-tuning for Large Language Models via Attenuating Harmful Perturbation" (ICLR2025 Oral).

License

Notifications You must be signed in to change notification settings

git-disl/Booster

Repository files navigation

Booster: Tackling Harmful Fine-tuning for Large Language Models via Attenuating Harmful Perturbation

[📕 Paper] [Homepage] [🤗 Alignment dataset] [🤗 Harmful dataset] [Slide] [Poster]

Fine-tuning-as-a-service

Fine-tuning-as-a-service allows users to upload data to service provider (e.g., OpenAI) for fine-tuning the base model. The mode The fine-tuend model is then deployed in the server and serve customized user need. Such a procedure usually contains two sequential stages: i) safety alignment stage-- the model is safety aligned with safety data. ii) fine-tuning stage-- the aligned model produced by the first stage is fine-tuned on user provided data.

Harmful fine-tuning Attack

However, such scenario expose serious safety issue, because the users might intentionally/unintentionally upload harmful data to break down the safety alignment of the victim LLMs. Specifically, the model suffers from harmful fine-tuning attack, the customized LLM forget the alignment knowledge and exhbit harmful behavior after fine-tuning on partial harmful data. See the following figure for an illustration.

Harmful fine-tuning Defense

Booster is the proposed alignment stage defense against harmful fine-tuning attack. Booster strenghten the aligned model's robustness by sufficiently exploiting alignment/harmful dataset. The high level idea is to simulate the harmful perturbation at the alignment stage, and attenuate its impact on the aligned model. The algorithm of Booster is as follows.

Main code logistic

We implement a cusomized trainer (BoosterAlignmentTrainer) on top of the original HuggingFace Trainer. To achieve Booster, we append several forward/backdward passes according to the psedo-agorithm.
Specifically, in trainer_step(), we use the following logistic:

# first backward gradient for harmful dataset    
with self.compute_loss_context_manager():
    loss =  self.compute_loss(model, harmful_inputs)
if self.use_apex:
    with amp.scale_loss(loss, self.optimizer) as scaled_loss:
        scaled_loss.backward()
else:
    self.accelerator.backward(loss)
stored_grads = {name: param.grad.data.clone() for name, param in model.named_parameters() if param.requires_grad}
# Take step with the harmful perturbation
with torch.no_grad():
    grad_norm = self._grad_norm(stored_grads)+ 1e-7
    # perturb the weights
    for name, param in model.named_parameters():
        if param.requires_grad:
            param.data -= self.args.alpha*stored_grads[name]/grad_norm

# backward the harmful gradient after harmful perturbation
with self.compute_loss_context_manager():
    loss2 =  self.compute_loss(model, harmful_inputs)
if self.use_apex:
    with amp.scale_loss(loss2, self.optimizer) as scaled_loss:
        scaled_loss.backward()
else:
    self.accelerator.backward(loss2)
perturb_grads = {name: param.grad.clone() for name, param in model.named_parameters() if param.requires_grad}
# calculate the alignment grad
with self.compute_loss_context_manager():
    loss3 =  self.compute_loss(model, inputs)
if self.use_apex:
    with amp.scale_loss(loss3, self.optimizer) as scaled_loss:
        scaled_loss.backward()
else:
    self.accelerator.backward(loss3)
# Finally, sum the grad
for name, param in model.named_parameters():
    if param.requires_grad:
        param.grad.data=param.grad.data  + (self.args.lamb)*stored_grads[name] -self.args.lamb* perturb_grads[name]

Of note, we strictly follow the psedo-algorithm without adding any extra tricks in the code. Just copy paste the code in the BoosterAlignmentTrainer. It will be suffcient if you want to merge Booster into your testbed. Please leave an issue if you encounter any issues for reproducing.

Package requirement

The package requirement is listed in booster.yml and booster_pip.txt. Run the following code to install the packages with anaconda and pip.

conda env create -f booster.yml
pip install -r booster_pip.txt

Data preparation

For safety alignment, please download the safety alignment dataset from this link, and put the json file under \data directory.

For finetuning task, we first need to run the following scripts to prepare the sueprvised finetuning data.

cd sst2
python build_dataset.py
cd ../gsm8k
python build_dataset.py
cd ../ag_news
python build_dataset.py
cd ..

Huggingface Llama2 access

Llama2-7B is a gated repo, which need a formal request to get access to the model. Check out https://huggingface.co/meta-llama/Llama-2-7b-hf. After applying permission from meta, you should be able to access the model, but you first need to enter your token in the file huggingface_token.txt.

Example command to run

We prepare scripts for re-producing all the experiments in the paper (check out the script directory). We recommend to use Slurm to reproduce the results as the logging file will be automatically organized into the script directory (if you don't use Slurm, just replace sbatch with bash in our example).

We first run SFT to produce the aligned model.

cd script/alignment
sbatch  smooth_align.sh

Then we finetune the model using 10% of harmful data with a total number of 1000 samples from SST2 dataset.

cd ../finetune
sbatch  smooth_poison_ratio.sh 0.1

A line of attack/defense designs

We are commited to design attacks and defenses from different angles in the topic of harmful fine-tuning. The currently avaialble work built in the disl group include:

We always welcome different forms of collaboration. If you are interested, please reach out Tiansheng Huang ([email protected]) for discussion.

Papers of harmful fine-tuning attacks/defense in ICLR2025

Of note, along with Booster, there are 13 papers on harmful fine-tuning attacks/defense being accepted by ICLR2025. Please consider to check them out if interested.

  • Tamper-Resistant Safeguards for Open-Weight LLMs
  • Booster: Tackling harmful fine-tuning for large language models via attenuating harmful perturbation
  • Identifying and Tuning Safety Neurons in Large Language Models
  • Safety alignment should be made more than just a few tokens deep
  • Do as I do (Safely): Mitigating Task-Specific Fine-tuning Risks in Large Language Models
  • Bi-Factorial Preference Optimization: Balancing Safety-Helpfulness in Language Models
  • Safety Layers in Aligned Large Language Models: The Key to LLM Security
  • SEAL: Safety-enhanced Aligned LLM Fine-tuning via Bilevel Data Selection
  • SaLoRA: Safety-Alignment Preserved Low-Rank Adaptation
  • Towards Secure Tuning: Mitigating Security Risks Arising from Benign Instruction Fine-Tuning
  • Probe before You Talk: Towards Black-box Defense against Backdoor Unalignment for Large Language Models
  • On Evaluating the Durability of Safeguards for Open-Weight LLMs
  • Emerging Safety Attack and Defense in Federated Instruction Tuning of Large Language Models

Citation

If you find our research interesting, you may cite the following papers.

@article{huang2024booster,
  title={Booster: Tackling Harmful Fine-tuning for Large Language Models via Attenuating Harmful Perturbation},
  author={Huang, Tiansheng and Hu, Sihao and Ilhan, Fatih and Tekin, Selim Furkan and Liu, Ling},
  journal={arXiv preprint arXiv:2409.01586},
  year={2024}
}

@article{huang2025virus,
  title={Virus: Harmful Fine-tuning Attack for Large Language Models Bypassing Guardrail Moderation},
  author={Huang, Tiansheng and Hu, Sihao and Ilhan, Fatih and Tekin, Selim Furkan and Liu, Ling},
  journal={arXiv preprint arXiv:2501.17433},
  year={2025}
}

@article{huang2024harmful,
  title={Harmful fine-tuning attacks and defenses for large language models: A survey},
  author={Huang, Tiansheng and Hu, Sihao and Ilhan, Fatih and Tekin, Selim Furkan and Liu, Ling},
  journal={arXiv preprint arXiv:2409.18169},
  year={2024}
}

@article{huang2024antidote,
  title={Antidote: Post-fine-tuning Safety Alignment for Large Language Models against Harmful Fine-tuning},
  author={Huang, Tiansheng and Bhattacharya, Gautam and Joshi, Pratik and Kimball, Josh and Liu, Ling},
  journal={arXiv preprint arXiv:2408.09600},
  year={2024}
}

@inproceedings{huanglisa,
  title={Lisa: Lazy Safety Alignment for Large Language Models against Harmful Fine-tuning Attack},
  author={Huang, Tiansheng and Hu, Sihao and Ilhan, Fatih and Tekin, Selim Furkan and Liu, Ling},
  booktitle={The Thirty-eighth Annual Conference on Neural Information Processing Systems}
}

@inproceedings{huangvaccine,
  title={Vaccine: Perturbation-aware Alignment for Large Language Models against Harmful Fine-tuning Attack},
  author={Huang, Tiansheng and Hu, Sihao and Liu, Ling},
  booktitle={The Thirty-eighth Annual Conference on Neural Information Processing Systems}
}

About

This is the official code for the paper "Booster: Tackling Harmful Fine-tuning for Large Language Models via Attenuating Harmful Perturbation" (ICLR2025 Oral).

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published