|
| 1 | +<h1 align="center">vLLM + HAT</h1> |
| 2 | + |
| 3 | +<p align="center"> |
| 4 | +🤗 <a href="https://huggingface.co/Aleph-Alpha">Hugging Face</a>   |   📑 <a href="https://arxiv.org/abs/2501.10322">HAT ICLR25 Paper</a>    |    📑 Upcoming Research Paper |
| 5 | +</p> |
| 6 | + |
| 7 | +This branch provides a batched inference implementation of HAT (Hierarchical Autoregressive Transformer). This fork integrates HAT into vLLM v1 so you can run or serve HAT models with the same low-latency engine you know from vLLM. HAT wraps a standard Llama-style word-level transformer (referred to as the backbone) with two small byte-level modules: an encoder and a decoder. For a comprehensive architectural and training deep-dive, including a closer look at each component discussed below, an accompanying research paper will soon be released; which will also provide more information on the challenges behind batched inference for such a model. |
| 8 | + |
| 9 | +The encoder processes the input text as raw UTF-8 bytes, and produces a sequence of activations of the same length. The splitter is then in charge of splitting this text into words or semantically meaningful chunks. In the encoder connector layer, for each word, a learned latent vector attends to the encoder activations of the bytes which compose the word. The backbone then processes this word-level sequence to produce a sequence of word-level representations which guide the decoding process. Thus, to generate bytes auto-regressively, the decoder uses the encoder activations of the current word and the word-level representation of the previous word. |
| 10 | + |
| 11 | +Next Steps: |
| 12 | +- Currently, our CUDA graph implementation for HAT is still based on the vLLM v0 approach. When [PR 20059](https://github.com/vllm-project/vllm/pull/20059) gets merged, we will update our implementation and perform an upstream MR to vLLM. |
| 13 | + |
| 14 | +--- |
| 15 | +--- |
| 16 | +# Environment Setup |
| 17 | + |
| 18 | +### 1. Prerequisites |
| 19 | +* **GPU**: NVIDIA GPU |
| 20 | +* **Python**: 3.12. |
| 21 | +### 2. Clone and install |
| 22 | +```bash |
| 23 | +git clone <this-repository> vllm-hat |
| 24 | +cd vllm-hat |
| 25 | + |
| 26 | +# Create and activate a 3.12 virtual env |
| 27 | +uv venv -p 3.12 |
| 28 | +source .venv/bin/activate |
| 29 | + |
| 30 | +# Tell vLLM to skip local compilation and use prebuilt CUDA wheels |
| 31 | +export VLLM_USE_PRECOMPILED=1 |
| 32 | + |
| 33 | +# Finally, install in editable mode |
| 34 | +uv pip install -e . |
| 35 | +``` |
| 36 | + |
| 37 | +--- |
| 38 | +--- |
| 39 | +# Using HAT |
| 40 | + |
| 41 | +Points to keep in mind |
| 42 | +- If you want to test out the 70B model, please make sure to specify tensor parallel size. If testing on GPUs with 80GB VRAM, we recommend setting tensor parallel size to 4. |
| 43 | +- Currently, HAT only works with Flash Attention 2. Thus, if testing this model on Hopper architecture or newer, please make sure to export the environment variable `VLLM_FLASH_ATTN_VERSION = 2`. |
| 44 | +- Additionally, running the 70B on H100 or newer currently does not work. |
| 45 | + |
| 46 | +The supported HAT models are the following: |
| 47 | +- `Aleph-Alpha/llama-3_1-8b-tfree-hat-dpo` |
| 48 | +- `Aleph-Alpha/llama-3_1-8b-tfree-hat-sft` |
| 49 | +- `Aleph-Alpha/llama-3_1-8b-tfree-hat-base` |
| 50 | +- `Aleph-Alpha/llama-3_1-70b-tfree-hat-sft` |
| 51 | +--- |
| 52 | +## Offline Inference |
| 53 | + |
| 54 | +We have included an example script to run offline inference. |
| 55 | + |
| 56 | +```bash |
| 57 | +python hat_scripts/hat_offline_inference.py [OPTIONS] |
| 58 | +``` |
| 59 | + |
| 60 | +**Optional Parameters:** |
| 61 | +- `--model` - Path to the HAT model (default: Aleph-Alpha/llama-3_1-8b-tfree-hat-dpo) |
| 62 | +- `--batch-size` - Batch size for inference (default: 16) |
| 63 | +- `--max-bytes-per-req` - Output bytes (default: 1000) |
| 64 | +- `--tensor-parallel-size` - Tensor parallelism size (default: 1) |
| 65 | + |
| 66 | +--- |
| 67 | +## Serving Scenario (OpenAI-compatible API) |
| 68 | + |
| 69 | +### Starting the server |
| 70 | + |
| 71 | +```bash |
| 72 | +vllm serve [MODEL] [OPTIONS] |
| 73 | +``` |
| 74 | + |
| 75 | +**Example:** |
| 76 | +```bash |
| 77 | +vllm serve "Aleph-Alpha/llama-3_1-8b-tfree-hat-dpo" \ |
| 78 | + --trust-remote-code \ |
| 79 | + --dtype bfloat16 \ |
| 80 | + --compilation-config '{"full_cuda_graph": true, "level": 0}' |
| 81 | + --max-num-batched-tokens 100000 \ |
| 82 | + --tensor-parallel-size 1 \ |
| 83 | + --gpu-memory-utilization 0.9 \ |
| 84 | +``` |
| 85 | + |
| 86 | +**Required Options:** |
| 87 | +- `--trust-remote-code` - Required for HAT models |
| 88 | +- `--dtype bfloat16` - Required data type for HAT models |
| 89 | +- `--compilation-config '{"full_cuda_graph": true, "level": 0}'` - Required compilation settings for HAT models |
| 90 | + |
| 91 | +**Optional Parameters:** |
| 92 | +- `--max-num-batched-tokens` - Maximum number of batched tokens (default: varies) |
| 93 | +- `--tensor-parallel-size` - Tensor parallelism size (default: 1) |
| 94 | +- `--gpu-memory-utilization` - GPU memory utilization fraction (default: 0.9) |
| 95 | + |
| 96 | +### Sending requests |
| 97 | + |
| 98 | +Any OpenAI-compatible client works (curl, python, etc.). For convenience, we include a script that asynchronously sends multiple requests to the server: |
| 99 | + |
| 100 | +```bash |
| 101 | +python hat_scripts/send_async_prompts.py [OPTIONS] |
| 102 | +``` |
| 103 | + |
| 104 | +**Optional Parameters:** |
| 105 | +- `--api-url` - URL of the OpenAI-compatible chat completions API endpoint (default: http://localhost:8000/v1/chat/completions) |
| 106 | +- `--num-concurrent-requests` - Number of concurrent requests to send (default: 16) |
| 107 | +- `--max-bytes-per-req` - Output bytes (default: 1000) |
| 108 | + |
| 109 | + |
| 110 | +--- |
| 111 | +--- |
| 112 | +--- |
| 113 | +--- |
| 114 | + |
1 | 115 | <p align="center"> |
2 | 116 | <picture> |
3 | 117 | <source media="(prefers-color-scheme: dark)" srcset="https://raw.githubusercontent.com/vllm-project/vllm/main/docs/assets/logos/vllm-logo-text-dark.png"> |
|
0 commit comments