Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions lmdeploy/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,7 @@ class TurbomindEngineConfig:
max_prefill_token_num: int = 8192
num_tokens_per_iter: int = 0
max_prefill_iters: int = 1
devices: List[int] = field(default_factory=lambda: [0])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we specify the cuda devices by the env var CUDA_VISIBLE_DEVICES?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, trl will specify which GPU to load the model:
https://github.com/huggingface/trl/blob/main/trl/trainer/grpo_trainer.py#L416

Copy link
Collaborator

@lvhan028 lvhan028 Feb 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://github.com/vllm-project/vllm/blob/d0a7a2769d92619afdcdc3b91c78098eaa9e38c0/vllm/engine/arg_utils.py#L718
According to vllm's EngineArgs definition, the value of device can be one of the following:

DEVICE_OPTIONS = [
    "auto",
    "cuda",
    "neuron",
    "cpu",
    "openvino",
    "tpu",
    "xpu",
    "hpu",
]

I haven't found a case to build the vllm engine with specifying device ids
Could you please provide an example?


def __post_init__(self):
"""Check input validation."""
Expand Down Expand Up @@ -297,6 +298,7 @@ class PytorchEngineConfig:
download_dir: str = None
revision: str = None
quant_policy: Literal[0, 4, 8] = 0
devices: List[int] = field(default_factory=lambda: [0])

def __post_init__(self):
"""Check input validation."""
Expand Down
28 changes: 27 additions & 1 deletion lmdeploy/turbomind/deploy/loader.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
import json
import os
import os.path as osp
import re
from abc import ABC, abstractmethod
from collections import defaultdict
from collections import OrderedDict, defaultdict
from functools import partial
from glob import glob
from typing import Iterator, Tuple
Expand Down Expand Up @@ -145,6 +146,31 @@ def items(self):


def create_loader(model_path: str, pattern: str) -> BaseLoader:
if not isinstance(model_path, (str, os.PathLike)):

def generate():
generator = OrderedDict()
model_dict = {}
if not isinstance(model_path, dict):
for key, value in list(model_path):
model_dict[key] = value
else:
model_dict = model_path
for key, value in model_dict.items():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where is model_dict stored? Is it in CPU memory or GPU memory?

Copy link
Author

@tastelikefeet tastelikefeet Feb 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

match = re.findall(pattern, key)
if not match:
if -1 not in generator:
generator[-1] = {}
generator[-1][key] = value
else:
layer = int(match[0])
if layer not in generator:
generator[layer] = {}
generator[layer][key] = value
return generator

return generate()

args = (model_path, pattern)

if osp.exists(osp.join(model_path, SAFE_WEIGHT_INDEX_NAME)):
Expand Down
42 changes: 28 additions & 14 deletions lmdeploy/turbomind/turbomind.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ def __init__(self,
f' greater than 0, but got {_engine_config.max_batch_size}'

self.gpu_count = _engine_config.tp
self.gpu_list = _engine_config.devices

self.tokenizer = tokenizer
if model_source == ModelSource.WORKSPACE:
Expand All @@ -112,10 +113,13 @@ def __init__(self,

with ThreadPoolExecutor(max_workers=self.gpu_count) as e:
ranks = [self.node_id * self.gpu_count + device_id for device_id in range(self.gpu_count)]
for _ in e.map(self.model_comm.process_weight, range(self.gpu_count), ranks):
pass
# This is for load_state_dict
# process_weight will optimizer the kernel by col major matrix and pack_b
# This will result in the failure of get_params
# for _ in e.map(self.model_comm.process_weight, range(self.gpu_count), ranks):
# pass
# implicit synchronization
for _ in e.map(self.model_comm.create_engine, range(self.gpu_count), ranks, repeat(self.nccl_params)):
for _ in e.map(self.model_comm.create_engine, self.gpu_list, ranks, repeat(self.nccl_params)):
pass

self.session_len = self.config.session_len
Expand All @@ -130,30 +134,30 @@ def _create_weight(self, model_comm):
torch.cuda.synchronize()

# create weight
def _create_weight_func(device_id):
rank = self.node_id * self.gpu_count + device_id
def _create_weight_func(index, device_id):
rank = self.node_id * self.gpu_count + index
model_comm.create_shared_weights(device_id, rank)

with ThreadPoolExecutor(max_workers=self.gpu_count) as executor:
futures = []
for device_id in range(self.gpu_count):
futures.append(executor.submit(_create_weight_func, device_id))
for idx, device_id in enumerate(self.gpu_list):
futures.append(executor.submit(_create_weight_func, idx, device_id))
for future in futures:
future.result()

def _get_model_params(self, model_comm, tm_params):
"""Get turbomind model params when loading from hf."""

def _get_params(device_id, que):
rank = self.node_id * self.gpu_count + device_id
def _get_params(idx, device_id, que):
rank = self.node_id * self.gpu_count + idx
out = model_comm.get_params(device_id, rank)
que.put(out)

que = Queue()
with ThreadPoolExecutor(max_workers=self.gpu_count) as executor:
futures = []
for device_id in range(self.gpu_count):
futures.append(executor.submit(_get_params, device_id, que))
for idx, device_id in enumerate(self.gpu_list):
futures.append(executor.submit(_get_params, idx, device_id, que))
for future in futures:
future.result()

Expand Down Expand Up @@ -215,13 +219,23 @@ def _from_hf(self, model_source: ModelSource, model_path: str, engine_config: Tu
self._get_model_params(model_comm, tm_params)
logger.warning(f'get {len(tm_params)} model params')
tm_model.export()
self.tm_model = tm_model
# there should be no left turbomind params.
if len(tm_params) > 0:
uninitialized = list(tm_params.keys())
logger.warning('the model may not be loaded successfully '
f'with {len(tm_params)} uninitialized params:\n{uninitialized}')
return model_comm

def load_weights(self, state_dict):
tm_params = self.tm_model.tm_params
self._get_model_params(self.model_comm, tm_params)
input_model = self.tm_model.input_model
model_path = input_model.model_path
input_model.model_path = state_dict
self.tm_model.export()
input_model.model_path = model_path

def _from_workspace(self, model_path: str, engine_config: TurbomindEngineConfig):
"""Load model which is converted by `lmdeploy convert`"""
config_path = osp.join(model_path, 'triton_models', 'weights', 'config.yaml')
Expand Down Expand Up @@ -302,7 +316,7 @@ def create_instance(self, cuda_stream_id=0):
Returns:
TurboMindInstance: an instance of turbomind
"""
return TurboMindInstance(self, self.config, cuda_stream_id)
return TurboMindInstance(self, self.config, cuda_stream_id, self.gpu_list[0])


def _get_logits(outputs, offset: int):
Expand Down Expand Up @@ -396,7 +410,7 @@ class TurboMindInstance:
cuda_stream_id(int): identity of a cuda stream
"""

def __init__(self, tm_model: TurboMind, config: TurbomindModelConfig, cuda_stream_id: int = 0):
def __init__(self, tm_model: TurboMind, config: TurbomindModelConfig, cuda_stream_id: int = 0, device_id: int = 0):
self.tm_model = tm_model
self.cuda_stream_id = cuda_stream_id

Expand All @@ -408,7 +422,7 @@ def __init__(self, tm_model: TurboMind, config: TurbomindModelConfig, cuda_strea
self.nccl_params = tm_model.nccl_params

# create model instances
self.model_inst = self._create_model_instance(0)
self.model_inst = self._create_model_instance(device_id)

self.config = config
self.lock = None
Expand Down