Skip to content

Conversation

@Balladie
Copy link
Contributor

@Balladie Balladie commented Dec 6, 2025

This increases EXTRA_RESERVED_VRAM by 32MB if CudaMallocAsync is enabled.

The recent OOM issue in FLUX.2 (#10891) seems to have mixed reasons from CudaMallocAsync and a VRAM lookahead calculation. For the former, the allocator in CudaMallocAsync can reserve more memory than requested to possibly reduce fragmentation, but the policy is not known so the gap cannot be predicted without actual allocation afaik (if it's wrong please correct it). To find way to stick to the current implementation I have profiled the reserved memory and found a pattern that it usually reserves memory by multiply of 32MB, so here it suggests to reserve it in advance.

Tested on RTX 4090 but the number could be different in other circumstances, so it should be verified by multiples.

It is a temporary solution for sure, a better improvement would be possible. For example, the load function could keep track of this gap on the fly and adjust the offloaded list every time the weight does not fit into the VRAM. Another would be loosening current VRAM calculation with more accurate logic if it's empirical. But at least it seems to resolve some OOM issue for now and it's not a big cost so I leave the quick change possible here.

Trace of reserved and allocated memory on flux 2:

[DEBUG 0] reserved: 1280.00 MB (+ 1280.00 MB), allocated: 1280.00 MB, module: 1280.00 MB
[DEBUG 1] reserved: 1600.00 MB (+ 320.00 MB), allocated: 1600.00 MB, module: 320.00 MB
[DEBUG 2] reserved: 1920.00 MB (+ 320.00 MB), allocated: 1920.00 MB, module: 320.00 MB
[DEBUG 3] reserved: 2240.00 MB (+ 320.00 MB), allocated: 2240.00 MB, module: 320.00 MB
...
[DEBUG 62] reserved: 21120.00 MB (+ 320.00 MB), allocated: 21120.00 MB, module: 320.00 MB
[DEBUG 63] reserved: 21440.00 MB (+ 320.00 MB), allocated: 21440.00 MB, module: 320.00 MB
[DEBUG 64] reserved: 21504.00 MB (+ 64.00 MB), allocated: 21480.00 MB, module: 40.00 MB
[DEBUG 65] reserved: 21536.00 MB (+ 32.00 MB), allocated: 21520.00 MB, module: 40.00 MB
[DEBUG 66] reserved: 21568.00 MB (+ 32.00 MB), allocated: 21560.00 MB, module: 40.00 MB
[DEBUG 67] reserved: 21600.00 MB (+ 32.00 MB), allocated: 21600.00 MB, module: 40.00 MB
[DEBUG 68] reserved: 21632.00 MB (+ 32.00 MB), allocated: 21610.00 MB, module: 10.00 MB
[DEBUG 69] reserved: 21632.00 MB (+ 0.00 MB), allocated: 21610.01 MB, module: 0.01 MB
[DEBUG 70] reserved: 21632.00 MB (+ 0.00 MB), allocated: 21610.02 MB, module: 0.01 MB
[DEBUG 71] reserved: 21632.00 MB (+ 0.00 MB), allocated: 21610.03 MB, module: 0.01 MB
...

@Balladie
Copy link
Contributor Author

Balladie commented Dec 6, 2025

cc. @rattus128 as related to the recent fixes.

@rattus128
Copy link
Contributor

Hey. So 32MB is very small in the context of flux 2. Comfy defines minimum inference VRAM and headrooms on the order of 100s of MB. If this is making a difference we should dig on it doing math in the 100s first.

I fixed a big oom last night in #11144

This was on OOM on model load. In this one I could reserve 5GB extra VRAM and it would still oom. This has to be an allocator deallocator race condition in async and I wonder if the same race is hurting us on a smaller scale in the inference too.

I do know that the loader is not accounting the on-the-fly cost of dequantization while simultaneously over estimating the latent relative memory consumption. We only account the O(n) VRAM consumption while flux dequant has 600MB (at least) of O(1) to dequant those massive weights. This is the root cause of the weird behaviour where increasing image resolution or batch size saves you from OOM (bigger N). Are you observing this? Are you observing sensitivity to parameters of the job for the OOM?

If you have a fresh oom log and workflow for the ticket paste it so I can have a look. There's been too many fixes since the OPs report we need to look at any flux2 ooms fresh.

@Balladie
Copy link
Contributor Author

Balladie commented Dec 7, 2025

Hey. So 32MB is very small in the context of flux 2. Comfy defines minimum inference VRAM and headrooms on the order of 100s of MB. If this is making a difference we should dig on it doing math in the 100s first.

I fixed a big oom last night in #11144

This was on OOM on model load. In this one I could reserve 5GB extra VRAM and it would still oom. This has to be an allocator deallocator race condition in async and I wonder if the same race is hurting us on a smaller scale in the inference too.

I do know that the loader is not accounting the on-the-fly cost of dequantization while simultaneously over estimating the latent relative memory consumption. We only account the O(n) VRAM consumption while flux dequant has 600MB (at least) of O(1) to dequant those massive weights. This is the root cause of the weird behaviour where increasing image resolution or batch size saves you from OOM (bigger N). Are you observing this? Are you observing sensitivity to parameters of the job for the OOM?

If you have a fresh oom log and workflow for the ticket paste it so I can have a look. There's been too many fixes since the OPs report we need to look at any flux2 ooms fresh.

Thanks for the context as I don't know much about history of the logics. I have also noticed that the OOM happens when upcasting offloaded weights to match the fp32 dtype of input, but it is only the case in CudaMallocAsync while doing the same upcast so I looked up for the loading part.

I left the logs here, it is occurred during TE load and I checked that it does not do anything with the image resolution nor the batch size. Also referred to your changes in #11144 but unfortunately it did not help for this case. Please let me know if there’s anything more I can help with as I also wish this resolved quickly.

@rattus128
Copy link
Contributor

I think this is your root cause:

#11171

@comfyanonymous @Balladie if we want to do extra reservation for the sake of cudaMallocAsync we should probably do something a bit bigger than 32MB.

More generally, I'm confident the primary reason for it spiking its its async nature WRT to the CPUs VRAM accounting moreso than fragmentation an I think it can strike at any time with any weight size. I think @Balladie got a bit lucky in that --disable-cuda-malloc worked along with this 32MB extra reserve and the numbers back that up. The 330MB that was unaccounted for in the reproducer is pretty much all our headroom and would bring the VRAM very close to the ceiling and use all our extra reservation. So any small change in memory behaviour will make things different. I also observed intermittence on my 3060 reproducer fp8 where it only sometimes happened.

@Balladie
Copy link
Contributor Author

Balladie commented Dec 7, 2025

Thanks for looking in detail @rattus128. I have checked that work and can confirm that it also resolved the issue. Will close this if it's merged.

To leave notes for future related issue: I think the cause is not the one side but mixed, both the black-box allocator and the underestimation of upcasting weights that @rattus128 worked on, but regardless of that it happens and based on some inspection the offloading behavior with CUDA calls seems to be similar on both async and native backend. So I thought it was on the boundary between OOM and ceiling caused by over-reserved cache memory in CudaMallocAsync allocator. Seems both change would result in reserving more memory in advance but I agree that #11171 will give more explainable change. Still async behavior is hard to debug and it could possibly make another issue in the future so leaving it here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants