Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 25 additions & 77 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,96 +1,44 @@
# Loki
# Loki (Naive Sparsity)

This repository contains the code related to the experiments in the paper [Loki: Low-Rank Keys for Efficient Sparse Attention](https://arxiv.org/abs/2406.02542).
We provide the code to compute the PCA of the keys for various models, baseline method implementations and kernels for Loki used in the paper, along with scripts to evaluate the methods on perplexity evaluation and downstream tasks.

## Installation
You need to install the requirements as follows:
## Attention Query-Key Sparsity Mode

```
pip install -r requirements.txt
```

Note: The code requires specific versions of the huggingface transformers library present in the requirements.txt file. It will not work with other versions.

## Usage

#### Compute PCA of keys for a model
Say you want to compute the PCA transform for the keys of Llama-2-7b model. You can do this by following the steps below:

- Run perplexity evaluation on the model on a target dataset to save the keys, queries and values tensors.
```bash
# The --use-axonn flag is optional and is used to shard the model over multiple GPUs using AxoNN
When running with `sparsity_type = "attention-query-key"` (enabled via `--run-attention-query-key-sparsity` flag), the implementation:

python -u evaluate_tasks.py --sequence-length 4096 --model-id meta-llama/Llama-2-7b-hf --model-type llama --dataset wikitext-valid --save-tensors --tensor-dir <Directory to save tensors> --use-topk --top-k 1 [--use-axonn]
```
List of possible datasets - wikitext-valid, bookcorpus, c4
1. Applies magnitude-based pruning to keep only the top 50% of weights by magnitude in linear projection layers
2. Optimizes the attention mechanism's query-key operations using torch's native sparse matrix multiplication (more below)
3. The sparse representations are computed once and reused across forward passes

- Compute the PCA of the generated keys: In the `pca_analysis` directory, run the following command:

```bash
python pca.py key <NUM_LAYERS> <Path to saved key tensors> <Path to output the PCA transforms>
```
Verify that the PCA transform are saved in the output directory. Do not modify the subdirectory structure of the output directory as it is used by the downstream tasks evaluation code.
## Modified Files
- `test_attention_benchmark_fixed.py`: Main benchmark script
- `./methods/pca_topk/attention_benchmark_apex.py`: Implementation of attention mechanisms
- `./methods/pca_topk/sparsity_utils.py`: Utilities for sparse operations

#### Running the ML evaluations
Once the PCA transform is computed, we can run the ML evaluations using Loki. The following command runs the evaluation on the downstream tasks using the PCA transform computed in the previous step:

```bash
python -u evaluate_tasks.py \
--sequence-length 4096 \
--model-id meta-llama/Llama-2-7b-hf \
--model-type llama
--use-pca-topk
--top-r <16/32/64>
--top-k <0.125/0.25/0.5> \
--rotary-type <prerotary/postrotary> \
--dataset <Dataset to compute perplexity on, Default: wikitext-test> \
--transform-dataset <Dataset used to compute PCA: wikitext/bookcorpus/c4, Default:wikitext> \
[--lm-harness-eval] \ # Flag to evaluate the model on the LM Harness Tasks
[--use-wandb] \ # Optional flag to log the results to wandb
[--use-axonn] # Optional flag to shard the model over multiple GPUs using AxoNN
```


#### Running compute evaluation
To run the compute evaluation, you can use the following command:
## Usage Example

Benchmark for vanilla vs loki vs sparse naive
```bash
python evaluate_compute.py

```
This will run the attention benchmark with Loki and vanilla attention assuming a Llama2-13B type model and save the results in a `compute_files` directory.

<!---
#### Reproducing the results
We have provided slurm scripts to evaluate the baseline methods and Loki on the downstream tasks. You can run the scripts as follows:

- Generating the keys for the models:
First, you need to modify the template saver script based on your machine slurm configuration. You also need to modify the OUT_TENSOR_DATA_PATH in the script to save the keys to the desired location. Then you can run the script as follows:

```
# Generate batch scripts for the "saver" experiment
<>

# Run the batch scripts for the particular model
<>
Run the benchmark with sparse attention query-key optimization:

```

- Compute the PCA transform for the generated keys:
```

```bash
python test_attention_benchmark_fixed.py --orig-pca-dir /cmlscratch/sukriti5/pca_stuff/pca_components --cache-seq-len 3500 --num-gen-steps 10 --top-d 32 --num-heads 16 --run-loki-without-sparsity --run-attention-query-key-sparsity --output-csv ./attention_query_key_results.csv
```

### With AxoNN's tensor parallelism

Additionally, you can add `--use-axonn` flag to shard a large model like llama-13b over multiple GPUs.
For this you will need to launch the code using mpirun


```
mpirun -np 2 python -u evaluate_tasks.py --sequence-length 4096 --model-id meta-llama/Llama-2-13b-hf --model-type llama --use-h2o --heavy-ratio 0.1 --use-axonn
```
--->
### Sparse Matrix Multiplication
- For 2D tensors: Using `torch.sparse.mm` for efficient sparse matrix multiplication
- For 3D tensors (batched attention operations): Using `torch.bmm` for batched multiplication
- Precision handling: Converting half-precision tensors to float32 temporarily for sparse operations, as PyTorch's sparse operations don't directly support half-precision

### Sparsity Configuration
- Default sparsity level: 50% (ie keeping top 50% of weights by magnitude for now)
- Implementation: Magnitude-based pruning applied to linear projection layers
- Format: Using `torch.sparse_coo_tensor` to create COO format sparse representations
- One-time sparse conversion: In the `SparseLinear` class, sparse representation is created only once via the `make_sparse()` method. For caching, rhe sparse weight (`self._sparse_weight`) is stored as a class attribute and reused across forward passes

I've followed the same benchmark/attributes as in yours and Connor's benchmark files.
125 changes: 109 additions & 16 deletions evaluate_compute.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,119 @@
from methods.pca_topk.attention_benchmark import benchmark_attention
from methods.pca_topk.attention_benchmark_apex import benchmark_attention
import json
import torch
import os
import gc
import sys
import traceback


# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False

os.environ["OMP_NUM_THREADS"] = "1" # Set OpenMP threads
os.environ["MKL_NUM_THREADS"] = "1" # Set MKL threads
os.environ["NUMEXPR_NUM_THREADS"] = "1" # Set numexpr threads
os.environ["OPENBLAS_NUM_THREADS"] = "1" # Set OpenBLAS threads

def free_gpu_memory():
gc.collect()
torch.cuda.empty_cache()
print(f"GPU memory after cleanup: {torch.cuda.memory_allocated() / (1024**3):.2f} GB used, "
f"{torch.cuda.memory_reserved() / (1024**3):.2f} GB reserved")

if __name__ == "__main__":
os.makedirs("compute_files", exist_ok=True)
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"GPU count: {torch.cuda.device_count()}")
print(f"GPU name: {torch.cuda.get_device_name(0)}")
print(f"GPU memory: {torch.cuda.get_device_properties(0).total_memory / (1024**3):.2f} GB")

with torch.no_grad():
prompt_length = 3500
for num_gen_steps in [512]:
for topk in [2, 4, 8]:
for topr in [2, 4, 8]:
print(f"prompt length = {prompt_length}, gen length = {num_gen_steps}, batch_size={16}, topk={topk} and topr={topr}")
times_pca_topk, _ = benchmark_attention(prompt_length=prompt_length, num_gen_steps=num_gen_steps, batch_size=16, topk=prompt_length // topk, topr=128 // topr, vanilla=False)
#with open(f"prompt_{prompt_length}_gen_{num_gen_steps}_pca_topk_opt_first_matmul.json", "w") as f:
with open(f"compute_files/prompt_{prompt_length}_gen_{num_gen_steps}_topk_{topk}_topr_{topr}.json", "w") as f:
json.dump(times_pca_topk, f, indent=2)
for prompt_length in [512, 1024, 2048]:
for num_gen_steps in [64, 128, 256]:
# 1. Vanilla Attention (without Loki)
print(f"\nRunning Vanilla Attention benchmark...")
print(f"prompt length = {prompt_length}, gen length = {num_gen_steps}, batch_size=16")
try:
free_gpu_memory()
_, times_vanilla = benchmark_attention(
prompt_length=prompt_length,
num_gen_steps=num_gen_steps,
batch_size=16,
vanilla=True,
pcatopk=False,
dtype=torch.float16
)
vanilla_filename = f"compute_files/vanilla_prompt_{prompt_length}_gen_{num_gen_steps}.json"
print(f"Saving to {vanilla_filename}")
with open(vanilla_filename, "w") as f:
json.dump(times_vanilla, f, indent=2)
print("Vanilla Attention Times:")
for key, value in times_vanilla.items():
print(f" {key}: {value:.6f} s")
print(f"Net time (minus cache updates): {times_vanilla.get('total', 0) - times_vanilla.get('cache-update', 0):.6f} s")
except Exception as e:
print(f"Error in Vanilla benchmark: {str(e)}")
traceback.print_exc()

for topk in [4, 8]:
for topr in [4]:
# 2. Loki without sparsity
print(f"\nRunning Loki without Sparsity benchmark...")
print(f"prompt length = {prompt_length}, gen length = {num_gen_steps}, batch_size=16, topk={topk} and topr={topr}")
try:
free_gpu_memory()
times_loki, _ = benchmark_attention(
prompt_length=prompt_length,
num_gen_steps=num_gen_steps,
batch_size=16,
topk=prompt_length // topk,
topr=128 // topr,
vanilla=False,
pcatopk=True,
# No sparsity_type parameter
dtype=torch.float16
)
loki_filename = f"compute_files/loki_prompt_{prompt_length}_gen_{num_gen_steps}_topk_{topk}_topr_{topr}.json"
print(f"Saving to {loki_filename}")
with open(loki_filename, "w") as f:
json.dump(times_loki, f, indent=2)
print("Loki without Sparsity Times:")
for key, value in times_loki.items():
print(f" {key}: {value:.6f} s")
print(f"Net time (minus cache updates): {times_loki.get('total', 0) - times_loki.get('cache-update', 0):.6f} s")
except Exception as e:
print(f"Error in Loki benchmark: {str(e)}")
traceback.print_exc()

# 3. Loki with attention-query-key sparsity
print(f"\nRunning Loki with Attention-Query-Key Sparsity benchmark...")
print(f"prompt length = {prompt_length}, gen length = {num_gen_steps}, batch_size=16, topk={topk} and topr={topr}")
try:
free_gpu_memory()
times_sparsity, _ = benchmark_attention(
prompt_length=prompt_length,
num_gen_steps=num_gen_steps,
batch_size=16,
topk=prompt_length // topk,
topr=128 // topr,
vanilla=False,
pcatopk=True,
sparsity_type="attention-query-key",
dtype=torch.float16
)
sparsity_filename = f"compute_files/attention_query_key_prompt_{prompt_length}_gen_{num_gen_steps}_topk_{topk}_topr_{topr}.json"
print(f"Saving to {sparsity_filename}")
with open(sparsity_filename, "w") as f:
json.dump(times_sparsity, f, indent=2)
print("Loki with Attention-Query-Key Sparsity Times:")
for key, value in times_sparsity.items():
print(f" {key}: {value:.6f} s")
print(f"Net time (minus cache updates): {times_sparsity.get('total', 0) - times_sparsity.get('cache-update', 0):.6f} s")
except Exception as e:
print(f"Error in Attention-Query-Key Sparsity benchmark: {str(e)}")
traceback.print_exc()

print("\nAll benchmarks completed!")
free_gpu_memory()

_, times_vanilla = benchmark_attention(prompt_length=prompt_length, num_gen_steps=num_gen_steps, batch_size=16, topk=prompt_length // topk, topr=128 // topr, pcatopk=False)
with open(f"compute_files/prompt_{prompt_length}_gen_{num_gen_steps}_vanilla.json", "w") as f:
json.dump(times_vanilla, f, indent=2)


Loading