A companion library release to https://thinkingmachines.ai/blog/defeating-nondeterminism-in-llm-inference/. This library contains some batch-invariant kernels as well as an example of achieving deterministic vLLM inference.
This library primarily leverages torch.Library to sub out existing PyTorch kernels with "batch-invariant" ones. This allows many existing PyTorch models to use the batch-invariant ops with low overhead and non-intrusive code changes.
pip install -e .
import torch
from batch_invariant_ops import set_batch_invariant_mode
# Enable batch-invariant mode
with set_batch_invariant_mode():
# Your inference code here
model = YourModel()
output = model(input_tensor)
The following example shows how batch size can affect results in standard PyTorch:
import torch
from batch_invariant_ops import set_batch_invariant_mode
torch.set_default_device('cuda')
# Just to get the logging out of the way haha
with set_batch_invariant_mode(True):
pass
def test_batch_invariance():
B, D = 2048, 4096
a = torch.linspace(-100, 100, B*D).reshape(B, D)
b = torch.linspace(-100, 100, D*D).reshape(D, D)
# Method 1: Matrix-vector multiplication (batch size 1)
out1 = torch.mm(a[:1], b)
# Method 2: Matrix-matrix multiplication, then slice (full batch)
out2 = torch.mm(a, b)[:1]
# Check if results are identical
diff = (out1 - out2).abs().max()
print(f"Difference: {diff.item()}")
return diff.item() == 0
# Test with standard PyTorch (likely to show differences)
print("Standard PyTorch:")
with set_batch_invariant_mode(False):
is_deterministic = test_batch_invariance()
print(f"Deterministic: {is_deterministic}")
# Test with batch-invariant operations
print("\nBatch-Invariant Mode:")
with set_batch_invariant_mode(True):
is_deterministic = test_batch_invariance()
print(f"Deterministic: {is_deterministic}")
deterministic_vllm_inference.py
shows an proof of concept of validating that vLLM can be made deterministic with a minor upstream PR to use this library. Without the upstream PR, we see that out of 1000 random length 100 completions we see 18 unique samples. After the upstream PR, there is only one unique sample.
torch.mm()
- Matrix multiplicationtorch.addmm()
- Matrix multiplication with bias addition
torch.log_softmax()
- Log-softmax activation
torch.mean()
- Mean computation along specified dimensions