feat: ARA, but it's LoRA#332
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces Arbitrary-Rank Ablation with LoRA (ARA-LoRA), enabling compatibility with quantization and eliminating the need for full model reloads. Key changes include adding configuration settings (use_ara_lora and ara_lora_rank), updating model initialization and reset routines, and implementing the ara_lora_abliterate optimization loop. The reviewer feedback highlights critical improvements: enhancing numerical stability in ara_lora_abliterate by running LBFGS optimization on float32 clones, correcting a misleading print statement during model resetting, ensuring new settings are documented in config.default.toml, and resolving style guide violations regarding type annotations, comment formatting, and setting descriptions.
|
Hey, you're back! As you may have heard, I am currently battling legal threats from megacorporations, as well as negative portrayals in the global press, so it may take a few weeks before I can give this PR the attention it deserves. 😆 I had already planned to implement optional LoRA support for ARA myself, but by factoring the matrix after the fact (like we do with MPOA), which is inferior to your approach of optimizing the factors directly. During tests, I had observed effective ranks around 10-30 typically, with the corresponding singular values explaining 99+% of the variance with the models I tested. This is still interesting from a research PoV, because it hints at a highly non-linear (if not full-rank) refusal manifold. Given how many advantages direct LoRA optimization has, I think we should probably remove the option to do full-rank optimization entirely. It just adds complexity and I doubt there will be a difference in practice. As I said, I will get back to this soon. There is also still a stability issue with ARA, which I believe is caused by the fact that the gradient is not directionally continuous because only the nearest In the meantime, please keep backups of everything locally, in case this repository becomes unavailable. |
|
Sure, thank you for review! I'm going to address Gemini's comments later in the day and make more empirical tests of your suggestions. (Including high KL instabilities) Still, merging it before into the ARA branch as an experimental option would be helpful as many people are using that branch and the model reload / lack of quantization are adding high overhead. May I ping you after I make enough tests anyway? (Like, on multiple models and edge cases + Gemini comments) I saw the unfolding of the stuff on Reddit a few hours ago. I always advice to store as much as one can: repos, models and datasets. TBH, the HF datasets are probably even more endangered species (copyright, unfiltered NSFW, or both) and there have already been widely scandalous cases of them being removed. Sadly, most of the datasets are too large for local storage, I wish more people torrented them, because unlike heretic models (they can be made at home), there are almost no solutions of dataset restoration |
|
As for research, just before this pull request, I found absolutely no references to the ARA technique Heretic uses nowadays, having searched Semantic Scholar and Googled arXiv. Although there were some SOM papers, referencing my previous pull request :) About ARA I found no mentions, showing that it either has low explainability, or high explainability, or it's working too well and is not defeatable on the benchmarks |
|
You had referenced this in a thread I am subscribed to. I'm quite curious to see if we can finally and easily produce a good K2.5 or 2.6 model with less censorship (without using the method that ruins the models, apparently...) I'll look in to this further, and will watch this thread. Again, I'd be happy to help with funding or providing some compute if possible when it's needed. I can under limited circumstance, and it's def not technically out of my own pocket right now (I wish that were possible at the moment, I'd love to find that hah), but I'll try to see if I can make it work. @p-e-w regarding your mention of the mongoloids trying to sue you and drag you through the dirt, fuck them. Legally, you're sound, brother. If you need me to pass something by my lawyer or family within law, let me know. Keep it up, and if you need to "pass the torch" meaning grant another repository that "definitely isn't maintained or owned by you" ownership of the repo, never be afraid to make a second account and clear yourself of that burden ;) |
|
Also, second little bit of commentary regarding the unfortunate situation regarding archiving things of this nature: MEGA has a MADLY cheap S3 compatible storage offering. Definitely worth a look. I'm pretty sure it's the cheapest option out there. |
I mean, I invented ARA from scratch just a few months ago, haven't described it in detail yet, and it's only implemented in an unmerged pull request, not in mainline, so I wouldn't expect papers to reference it already. I'm subscribed to relevant SemanticScholar streams and some recent papers still use the bog standard all-layer directional ablation from 2 years ago. As for project resilience, we are going for IPFS for release distribution (already implemented), the upcoming Heretic Grimoire system for model preservation (more details soon), and decentralized infrastructure wherever possible. Please contribute ideas to #330. Also, unofficial public Git mirrors are always welcome, especially if they are located in non-US (and ideally non-Western) jurisdictions!
That's true in general, but fortunately not for the standard prompt datasets used by Heretic (especially not if it's only the first few hundred prompts like we use). We need backups of those that aren't on HF. |
|
I found one (general) issue. If there is a high KL, low refusal (e.g. KL > 1.0, refusal < 10) region, after the 60th turn (random testing's end), the optimizer might think the refusals reduction outweights the KL and get stuck there, with KL being > 1.0 on these trials. We need to penalize these regions even harder somehow or to exclude them. |
|
We need to understand under what conditions those trials with KLD = 15 or so happen. I'm sure it's some kind of runaway behavior where the gradient pushing away from the bad residuals outweighs the gradient holding the good residuals in place. This problem does not seem to occur with
So with worse precision we get better results? That's suspicious to say the least. |
|
I've been running the tests whole day right now, so I'm going to update this. My new suspicion is the high KL low refusal outweight logic, so for the next run I'm patching it to count KL > 1.0 as 100/100 refusals to cut-off that region completely. I will update here as the tests complete. If it works more or less in bf16, I can try fp32 again |
|
Results after a BF16 run. I'd say, pretty optimistic.
To prevent the optimizer stuck at KL > 1, I added a reset to put the refusals count to 100 in evaluator.py and it helped much if kl_divergence > 1.01:
refusals = len(self.bad_prompts)
print(" * KL OVERFLOW")
else:
print(" * Counting model refusals...")
refusals = self.count_refusals()For comparison, coder3101's gemma 31b it heretic has 15/100 refusals and 0.0434 KL. (we have 16/100 and 0.07 KL) We have greater KL, but on the other side we did it with only 4bit precision in ~36 GB VRAM fully at home. And I ran about 160 Heretic trials, with more optimization and patience it could hit better values. For subjective chatting tests, the uncensored aspect is fully here, so it's not just numbers on screen. If I'll have time, I'm also going to launch the FP32 version, maybe it can be even better (though I'm skeptical about it) |
Implement Arbitrary-Rank Ablation (ARA) and its LoRA variant as alternative abliteration methods alongside the existing directional approach. ARA LoRA optimises low-rank adapter factors directly rather than modifying full weight matrices, making it compatible with quantized models (bnb 4-bit, EXL3) and eliminating the need for full model reloads between trials. ARA LoRA implementation based on work by kabachuha (p-e-w#332). New config options: use_ara, use_ara_lora, ara_lora_rank.
|
Btw, in the current code the adapter model (PEFT LoRA format) is saved in the folder at the end for some reason, despite the "Merge LoRA into the model" option selected. I will look at the code to truely merge the LoRA. Meanwhile, the adapter can be merged with a simple "offline" script if you want to test it sooner |
This is tempting, but it's the wrong approach. TPE relies on two fundamental, intuitive assumptions, which follow from its formulation via Gaussian Mixture Models:
Your approach breaks both of these assumptions and will ON AVERAGE lead to worse convergence behavior. You're introducing an artificial cliff in the objective landscape that TPE cannot comprehend, because the returned refusal count is actually incorrect. Here's an incredibly important thing to understand: TPE is highly random. It's almost impossible to conclude anything about its behavior from empirical observations when doing only a few studies. You'd need hundreds of complete studies, with potentially hundreds of runs each, to do that. Running a study, changing something, then running the study again and comparing the Pareto fronts tells you basically nothing. I have fallen into this trap myself many, many times.
So what is the correct solution?Optuna has a built-in mechanism for telling the optimizer that a trial is "unacceptable": raise optuna.TrialPruned()This is a special exception designed for this purpose. You're telling Optuna that this was a bad trial without lying about the objective value. However, this is just a hack in this case. It's still not a proper solution. That's because the real problem is that those ultra-high-KLD parameters exist in the first place. The objective landscape is extremely unstable in some regions. Such landscapes are unsuitable for optimization with TPE. There is only one proper solution: We need to understand what causes this runaway behavior, and fix it in ARA. |
LoRA export has been fixed on master, but merging is not trivial because of the reproducibility system. |
|
I didn't really look into the internals of optuna or TPE, this is just for reference in the PR context. LoRA and how to prune properly are the future goals, I think. For now, I'd like to test multiple architectures like Qwen or previous Gemmas for more stats If/when you have time to test the PR, feel free to do it and write suggestions! Because currently with Gemma 4 31b I have ~8 hours to use for a full run (a whole day/night) |
|
I recommend using much smaller models for testing. 4B models are usually intelligent enough to have robust refusal behavior. Anything larger is just wasting time during development. |
|
Here I target large models specifically to test the 4bit bitsandbytes compatibility, LoRA and quantization effects. As for 4b models, I indeed want to test Qwen3 4b with LoRA, but no quant. |
|
FYI, the currently planned merge order is as follows:
|
|
Got it. I think, I'm going to be dumping the results/code corrections/observations in the PR anyway, if it would be helpful Btw, what about SOMA? Are we discarding it or will leave as an option / plugin? I read at least one paper citing my implementation of it in Heretic, so it may be helpful for researchers |
|
SOMA should also be rewritten as a modifier plugin once the plugin infrastructure is in place, and I'd be happy to include it as a default plugin in Heretic. The existing implementations (abliteration, SOMA, ARA) will inform the API of the modifier plugin system. |
|
Thanks, very interesting! Using a decision tree is a really good idea for analyzing this. Am I reading those trees right, and explosion can occur even when What exactly is the definition of the "explosion" class? |
|
KLD > 9 is definitely the correct criterion and KLD > 1.5 is not, because the latter also happens with traditional abliteration while the former doesn't. We need to concentrate on the unique conditions that only arise with ARA.
That doesn't follow from the decision tree though. The root node has the criterion |
|
I put the decision tree here as it was generated. Btw, can you quickly launch this Pull-request on your end? To see how ARA-LoRA works on your hardware and what results you are getting |
|
I'm not on a GPU system right now, but I'll try it out in the next couple of days. |








ARA, but it's LoRA
This replaces the inplace training of a layer's matrix with optimization of a product of two low rank matrices which make an additive to the weight.
It eliminates the need for model reload from disk, accelerating the process, and, crucially, 4 bit bitsandbytes compatibility (impossible for normal ARA), making it available for heretic to work upon Gemma 4 31b it at only ~36 GB VRAM, accessible for most duo-gpu prosumer setups.
Due to it being a preset-rank LoRA, it's technically not "arbitrary rank", but this can be simulated with the LoRA having a high enough rank. This is also compatible with row normalisation preservation because the weight is corrected with the norm of its dequantization.
During optimization, there can be high KLs, especially at the start, but they drift to adequate values later.
I think this is a must-have for heretic for three reasons:
Independence - you don't need to depend on the few HF abliterators who might disappear at any moment.
Locality - no need to rent high end hardware and download a huge model back and forth to a server.
Faster iteration - if you are testing or making a models merge, the obscure HF component models might not be abliterated and you need to rely just on yourself to do this.
3*. Speed of the process itself - the mode doesn't require re-reading the model from the storage, giving a massive speed-boost for slower file systems.
Just as shown with normal LoRA-based abliteration, the quantization/low-rank trade-off should be neiligible.
Here is the eighth trial of my experiment. (because I'll go to sleep soon, so it's a very quick demo)
Feel free to test the code, make critiques and suggestions. The mode is fully additive and doesn't conflict with the default LoRA and ARA switches.
cc @p-e-w