Skip to content

Commit

Permalink
Upgrade MLX framework, and refactor create_generator to use stream_ge…
Browse files Browse the repository at this point in the history
…nerate (#44)

* Upgrade MLX framework, and refactor create_generator to use mlx_lm.utils.stream_generate

* Fix vision kit argument

* working

* small fixes

* Improve comments and rename next_y

* Stop strings with stream generate (#45)

* Update vision_model_wrapper.py image resize defaults

---------

Co-authored-by: Matt Clayton <[email protected]>
  • Loading branch information
neilmehta24 and mattjcly authored Nov 26, 2024
1 parent 9acb8c2 commit 1ccce42
Show file tree
Hide file tree
Showing 12 changed files with 2,711 additions and 802 deletions.
7 changes: 7 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -80,3 +80,10 @@ Currently supported vision models and download links:
- [mlx-community/Qwen2-VL-7B-Instruct-4bit](https://model.lmstudio.ai/download/mlx-community/Qwen2-VL-7B-Instruct-4bit) - 4.68 GB
- Llava-v1.6
- [mlx-community/llava-v1.6-mistral-7b-4bit](https://model.lmstudio.ai/download/mlx-community/llava-v1.6-mistral-7b-4bit) - 4.26 GB

## Testing

To run tests, run the following command from the root of this repo:
```
python -m unittest discover tests
```
7 changes: 3 additions & 4 deletions demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,10 +82,9 @@ def image_to_base64(image_path):
generator = create_generator(
model_kit,
prompt_tokens,
None,
images_base64,
args.stop_strings,
{"max_tokens": 1024},
images_b64=images_base64,
stop_strings=args.stop_strings,
max_tokens=1024,
top_logprobs=args.top_logprobs,
)
for generation_result in generator:
Expand Down
330 changes: 243 additions & 87 deletions mlx_engine/generate.py

Large diffs are not rendered by default.

17 changes: 9 additions & 8 deletions mlx_engine/model_kit.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,14 @@ class ModelKit:
tokenizer: TokenizerWrapper = None
detokenizer: StreamingDetokenizer = None
cache_wrapper: Optional[CacheWrapper] = None
max_kv_size: int = None

def __init__(self, model_path: Path, max_kv_size: int):
self.model_path = model_path
self.model, self.tokenizer = mlx_lm.utils.load(self.model_path)
self.detokenizer = self.tokenizer.detokenizer
self.cache_wrapper = CacheWrapper(self.model, max_kv_size)
self.max_kv_size = max_kv_size

def tokenize(self, prompt: str) -> List[int]:
ids = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(prompt))
Expand All @@ -32,7 +34,12 @@ def tokenize(self, prompt: str) -> List[int]:
return ids

def process_prompt(
self, prompt_tokens, img_b64, prompt_progress_callback, generate_args
self,
prompt_tokens,
img_b64,
prompt_progress_callback,
repetition_context_size,
generate_args,
) -> mx.array:
"""
This method processes the prompt, adding its tokens to the cache history
Expand All @@ -45,12 +52,6 @@ def process_prompt(
if len(prompt_tokens) == 0:
raise ValueError("Prompt tokens must be non-empty")

if "repetition_context_size" not in generate_args:
generate_args["repetition_context_size"] = (
20 # default value for mlx_lm.utils.generate_step
)
repetition_context_size = generate_args["repetition_context_size"]

# Check for common tokens with the previous cache and re-use the cache if possible
prompt_tokens = self.cache_wrapper.update_cache(
mx.array(prompt_tokens),
Expand All @@ -61,7 +62,7 @@ def process_prompt(

return prompt_tokens

def record_generated_token(self, token: int) -> None:
def update_cache_wrapper(self, token: int) -> None:
self.cache_wrapper.record_generated_token(token)

@property
Expand Down
159 changes: 0 additions & 159 deletions mlx_engine/stop_processor.py

This file was deleted.

Loading

0 comments on commit 1ccce42

Please sign in to comment.