Skip to content

Commit 982da2e

Browse files
committed
Functioning FastAPI implementation
FastAPI for BitNet inference framework functions * Benchmark BitNet models * Calculate BitNet model perplexity * Run BitNet inference framework Includes Dockerfile for running the FastAPI in a contained environment.
1 parent 7bcd27d commit 982da2e

File tree

10 files changed

+539
-1
lines changed

10 files changed

+539
-1
lines changed

.gitignore

+4
Original file line numberDiff line numberDiff line change
@@ -160,3 +160,7 @@ cython_debug/
160160
# and can be added to the global gitignore or merged into this file. For a more nuclear
161161
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
162162
#.idea/
163+
app/models/*/*
164+
165+
# Allow all files in app/lib/
166+
!app/lib

Dockerfile

+45
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
FROM python:3.9
2+
3+
WORKDIR /code
4+
5+
COPY ./app /code
6+
7+
RUN if [ -z "$(ls -A /code/models)" ]; then \
8+
echo "Error: No models found in /code/models" && exit 1; \
9+
fi
10+
11+
RUN apt-get update && apt-get install -y \
12+
wget \
13+
lsb-release \
14+
software-properties-common \
15+
gnupg \
16+
cmake && \
17+
bash -c "$(wget -O - https://apt.llvm.org/llvm.sh)" && \
18+
apt-get clean && \
19+
rm -rf /var/lib/apt/lists/*
20+
21+
RUN git clone --recursive https://github.com/microsoft/BitNet.git /tmp/BitNet && \
22+
cp -r /tmp/BitNet/* /code && \
23+
rm -rf /tmp/BitNet
24+
25+
RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt && \
26+
pip install "fastapi[standard]" "uvicorn[standard]"
27+
28+
RUN if [ -d "/code/models/Llama3-8B-1.58-100B-tokens" ]; then \
29+
python /code/setup_env.py -md /code/models/Llama3-8B-1.58-100B-tokens -q i2_s --use-pretuned && \
30+
find /code/models/Llama3-8B-1.58-100B-tokens -type f -name "*f32*.gguf" -delete; \
31+
fi
32+
33+
RUN if [ -d "/code/models/bitnet_b1_58-large" ]; then \
34+
python /code/setup_env.py -md /code/models/bitnet_b1_58-large -q i2_s --use-pretuned && \
35+
find /code/models/bitnet_b1_58-large -type f -name "*f32*.gguf" -delete; \
36+
fi
37+
38+
RUN if [ -d "/code/models/bitnet_b1_58-3B" ]; then \
39+
python /code/setup_env.py -md /code/models/bitnet_b1_58-3B -q i2_s --use-pretuned && \
40+
find /code/models/bitnet_b1_58-3B -type f -name "*f32*.gguf" -delete; \
41+
fi
42+
43+
EXPOSE 8080
44+
45+
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8080"]

README.md

+44-1
Original file line numberDiff line numberDiff line change
@@ -1 +1,44 @@
1-
# FastAPI-BitNet
1+
# FastAPI-BitNet
2+
3+
Install Conda: https://anaconda.org/anaconda/conda
4+
5+
Initialize the python environment:
6+
```
7+
conda init
8+
conda create -n bitnet python=3.9
9+
conda activate bitnet
10+
```
11+
12+
Install the Huggingface-CLI tool to download the models:
13+
```
14+
pip install -U "huggingface_hub[cli]"
15+
```
16+
17+
Download one/many of the 1-bit models from Huggingface below:
18+
```
19+
huggingface-cli download 1bitLLM/bitnet_b1_58-large --local-dir app/models/bitnet_b1_58-large
20+
huggingface-cli download 1bitLLM/bitnet_b1_58-3B --local-dir app/models/bitnet_b1_58-3B
21+
huggingface-cli download HF1BitLLM/Llama3-8B-1.58-100B-tokens --local-dir app/models/Llama3-8B-1.58-100B-tokens
22+
```
23+
24+
Build the docker image:
25+
```
26+
docker build -t fastapi_bitnet .
27+
```
28+
29+
Run the docker image:
30+
```
31+
docker run -d --name ai_container -p 8080:8080 fastapi_bitnet
32+
```
33+
34+
Once it's running navigate to http://127.0.0.1:8080/docs
35+
36+
---
37+
38+
Note:
39+
40+
If seeking to use this in production, make sure to extend the docker image with additional [authentication security](https://github.com/mjhea0/awesome-fastapi?tab=readme-ov-file#auth) steps. In its current state it's intended for use locally.
41+
42+
Building the docker file image requires upwards of 40GB RAM for `Llama3-8B-1.58-100B-tokens`, if you have less than 64GB RAM you will probably run into issues.
43+
44+
The Dockerfile deletes the larger f32 files, so as to reduce the time to build the docker image file, you'll need to comment out the `find /code/models/....` lines if you want the larger f32 files included.

app/__init__.py

Whitespace-only changes.

app/lib/__init__.py

Whitespace-only changes.

app/lib/endpoints.py

+153
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
from fastapi import FastAPI, HTTPException, Query, Depends
2+
from .models import ModelEnum, BenchmarkRequest, PerplexityRequest, InferenceRequest
3+
from .utils import run_command, parse_benchmark_data, parse_perplexity_data
4+
import os
5+
import subprocess
6+
7+
async def run_benchmark(
8+
model: ModelEnum,
9+
n_token: int = Query(128, gt=0),
10+
threads: int = Query(2, gt=0, le=os.cpu_count()),
11+
n_prompt: int = Query(32, gt=0)
12+
):
13+
"""Run benchmark on specified model"""
14+
request = BenchmarkRequest(model=model, n_token=n_token, threads=threads, n_prompt=n_prompt)
15+
16+
build_dir = os.getenv("BUILD_DIR", "build")
17+
bench_path = os.path.join(build_dir, "bin", "llama-bench")
18+
19+
if not os.path.exists(bench_path):
20+
raise HTTPException(status_code=500, detail="Benchmark binary not found")
21+
22+
command = [
23+
bench_path,
24+
'-m', request.model.value,
25+
'-n', str(request.n_token),
26+
'-ngl', '0',
27+
'-b', '1',
28+
'-t', str(request.threads),
29+
'-p', str(request.n_prompt),
30+
'-r', '5'
31+
]
32+
33+
try:
34+
result = subprocess.run(command, capture_output=True, text=True, check=True)
35+
parsed_data = parse_benchmark_data(result.stdout)
36+
return parsed_data
37+
except subprocess.CalledProcessError as e:
38+
raise HTTPException(status_code=500, detail=f"Benchmark failed: {str(e)}")
39+
40+
def validate_prompt_length(prompt: str = Query(..., description="Input text for perplexity calculation"), ctx_size: int = Query(10, gt=0)) -> str:
41+
token_count = len(prompt.split())
42+
min_tokens = 2 * ctx_size
43+
44+
if token_count < min_tokens:
45+
raise HTTPException(
46+
status_code=400,
47+
detail=f"Prompt too short. Needs at least {min_tokens} tokens, got {token_count}"
48+
)
49+
return prompt
50+
51+
async def run_perplexity(
52+
model: ModelEnum,
53+
prompt: str = Depends(validate_prompt_length),
54+
threads: int = Query(2, gt=0, le=os.cpu_count()),
55+
ctx_size: int = Query(10, gt=3),
56+
ppl_stride: int = Query(0, ge=0)
57+
):
58+
"""Calculate perplexity for given text and model"""
59+
try:
60+
request = PerplexityRequest(
61+
model=model,
62+
prompt=prompt,
63+
threads=threads,
64+
ctx_size=ctx_size,
65+
ppl_stride=ppl_stride
66+
)
67+
except ValueError as e:
68+
raise HTTPException(status_code=400, detail=str(e))
69+
70+
build_dir = os.getenv("BUILD_DIR", "build")
71+
ppl_path = os.path.join(build_dir, "bin", "llama-perplexity")
72+
73+
if not os.path.exists(ppl_path):
74+
raise HTTPException(status_code=500, detail="Perplexity binary not found")
75+
76+
command = [
77+
ppl_path,
78+
'--model', request.model.value,
79+
'--prompt', request.prompt,
80+
'--threads', str(request.threads),
81+
'--ctx-size', str(request.ctx_size),
82+
'--perplexity',
83+
'--ppl-stride', str(request.ppl_stride)
84+
]
85+
86+
try:
87+
result = subprocess.run(command, capture_output=True, text=True, check=True)
88+
parsed_data = parse_perplexity_data(result.stderr)
89+
return parsed_data
90+
except subprocess.CalledProcessError as e:
91+
raise HTTPException(status_code=500, detail=str(e))
92+
93+
def get_model_sizes():
94+
"""Endpoint to get the file sizes of supported .gguf models."""
95+
model_sizes = {}
96+
models_dir = "models"
97+
for subdir in os.listdir(models_dir):
98+
subdir_path = os.path.join(models_dir, subdir)
99+
if os.path.isdir(subdir_path):
100+
for file in os.listdir(subdir_path):
101+
if file.endswith(".gguf"):
102+
file_path = os.path.join(subdir_path, file)
103+
file_size_bytes = os.path.getsize(file_path)
104+
file_size_mb = round(file_size_bytes / (1024 * 1024), 3)
105+
file_size_gb = round(file_size_bytes / (1024 * 1024 * 1024), 3)
106+
model_sizes[file] = {
107+
"bytes": file_size_bytes,
108+
"MB": file_size_mb,
109+
"GB": file_size_gb
110+
}
111+
return model_sizes
112+
113+
async def run_inference_endpoint(
114+
model: ModelEnum,
115+
n_predict: int = Query(128, gt=0, le=100000),
116+
prompt: str = "",
117+
threads: int = Query(2, gt=0, le=os.cpu_count()),
118+
ctx_size: int = Query(2048, gt=0),
119+
temperature: float = Query(0.8, gt=0.0, le=2.0)
120+
):
121+
"""Endpoint to run inference with the given parameters."""
122+
request = InferenceRequest(
123+
model=model,
124+
n_predict=n_predict,
125+
prompt=prompt,
126+
threads=threads,
127+
ctx_size=ctx_size,
128+
temperature=temperature
129+
)
130+
output = run_inference(request)
131+
return {"result": output}
132+
133+
def run_inference(args: InferenceRequest) -> str:
134+
"""Run the inference command with the given arguments."""
135+
build_dir = os.getenv("BUILD_DIR", "build")
136+
main_path = os.path.join(build_dir, "bin", "llama-cli")
137+
138+
if not os.path.exists(main_path):
139+
raise HTTPException(status_code=500, detail="Inference binary not found")
140+
141+
command = [
142+
main_path,
143+
'-m', args.model.value,
144+
'-n', str(args.n_predict),
145+
'-t', str(args.threads),
146+
'-p', args.prompt,
147+
'-ngl', '0',
148+
'-c', str(args.ctx_size),
149+
'--temp', str(args.temperature),
150+
"-b", "1"
151+
]
152+
output = run_command(command)
153+
return output

app/lib/models.py

+103
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
from typing import Dict, Any
2+
from pydantic import BaseModel, validator, root_validator
3+
from enum import Enum
4+
import os
5+
6+
def create_model_enum(directory: str):
7+
"""Dynamically create an Enum for models based on files in the directory."""
8+
models = {}
9+
for subdir in os.listdir(directory):
10+
subdir_path = os.path.join(directory, subdir)
11+
if os.path.isdir(subdir_path):
12+
for file in os.listdir(subdir_path):
13+
if file.endswith(".gguf"):
14+
model_name = f"{subdir}_{file.replace('-', '_').replace('.', '_')}"
15+
models[model_name] = os.path.join(subdir_path, file)
16+
return Enum("ModelEnum", models)
17+
18+
# Create the ModelEnum based on the models directory
19+
ModelEnum = create_model_enum("models")
20+
21+
max_n_predict = 100000
22+
23+
class BenchmarkRequest(BaseModel):
24+
model: ModelEnum
25+
n_token: int = 128
26+
threads: int = 2
27+
n_prompt: int = 32
28+
29+
@validator('threads')
30+
def validate_threads(cls, v):
31+
max_threads = os.cpu_count()
32+
if v > max_threads:
33+
raise ValueError(f"Number of threads cannot exceed {max_threads}")
34+
return v
35+
36+
@validator('n_token', 'n_prompt', 'threads')
37+
def validate_positive(cls, v):
38+
if v <= 0:
39+
raise ValueError("Value must be positive")
40+
return v
41+
42+
class PerplexityRequest(BaseModel):
43+
model: ModelEnum
44+
prompt: str
45+
threads: int = 2
46+
ctx_size: int = 3
47+
ppl_stride: int = 0
48+
49+
@validator('threads')
50+
def validate_threads(cls, v):
51+
max_threads = os.cpu_count()
52+
if v > max_threads:
53+
raise ValueError(f"Number of threads cannot exceed {max_threads}")
54+
elif v <= 0:
55+
raise ValueError("Value must be positive")
56+
return v
57+
58+
@validator('ctx_size')
59+
def validate_positive(cls, v):
60+
if v < 3:
61+
raise ValueError("Value must be greater than 3")
62+
return v
63+
64+
@root_validator(pre=True)
65+
def validate_prompt_length(cls, values: Dict[str, Any]) -> Dict[str, Any]:
66+
prompt = values.get('prompt')
67+
ctx_size = values.get('ctx_size')
68+
69+
if prompt and ctx_size:
70+
token_count = len(prompt.split())
71+
min_tokens = 2 * ctx_size
72+
73+
if token_count < min_tokens:
74+
raise ValueError(f"Prompt too short. Needs at least {min_tokens} tokens, got {token_count}")
75+
76+
return values
77+
78+
class InferenceRequest(BaseModel):
79+
model: ModelEnum
80+
n_predict: int = 128
81+
prompt: str
82+
threads: int = 2
83+
ctx_size: int = 2048
84+
temperature: float = 0.8
85+
86+
@validator('threads')
87+
def validate_threads(cls, v):
88+
max_threads = os.cpu_count()
89+
if v > max_threads:
90+
raise ValueError(f"Number of threads cannot exceed {max_threads}")
91+
return v
92+
93+
@validator('n_predict')
94+
def validate_n_predict(cls, v):
95+
if v > max_n_predict:
96+
raise ValueError(f"Number of predictions cannot exceed {max_n_predict}")
97+
return v
98+
99+
@validator('threads', 'ctx_size', 'temperature', 'n_predict')
100+
def validate_positive(cls, v):
101+
if v <= 0:
102+
raise ValueError("Value must be positive")
103+
return v

0 commit comments

Comments
 (0)