-
Notifications
You must be signed in to change notification settings - Fork 8
Afd demo #12
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: afd-dev-hsliu
Are you sure you want to change the base?
Afd demo #12
Conversation
Signed-off-by: czrz <[email protected]>
|
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run You ask your reviewers to trigger select CI tests on top of Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add If you have any questions, please reach out to us on Slack at https://slack.vllm.ai. 🚀 |
Signed-off-by: czrz <[email protected]>
Signed-off-by: czrz <[email protected]>
Signed-off-by: czrz <[email protected]>
Signed-off-by: czrz <[email protected]>
hsliuustc0106
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please check
| @@ -0,0 +1,16 @@ | |||
| from vllm import LLM, SamplingParams | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--temperature", type=float, default=0.8)
parser.add_argument("--top_p", type=float, default=0.95)
args = parser.parse_args()
sampling_params = SamplingParams(temperature=args.temperature, top_p=args.top_p)
examples/afd/offline_attn.py
Outdated
| ] | ||
|
|
||
| sampling_params = SamplingParams(temperature=0.8, top_p=0.95) | ||
| #llm = LLM(model="/data2/models/Qwen3-0.6B") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
del
| from vllm.distributed.afd.AFDconnector import AFDConnectorBase | ||
|
|
||
|
|
||
| class ncclconnector(AFDConnectorBase): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
XcclConnector
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NcclConnector
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
change to p2pconnector
| if weights_not_loaded: | ||
| raise ValueError("Following weights were not initialized from " | ||
| f"checkpoint: {weights_not_loaded}") | ||
| # if weights_not_loaded: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
del
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
| @@ -31,11 +31,12 @@ | |||
| from torch import nn | |||
| from transformers import DeepseekV2Config, DeepseekV3Config | |||
|
|
|||
| import vllm.distributed.parallel_state as ps | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个估计precommit过不了,import有前后顺序
| intermediate_tensors = ps._AFD_CONNECTOR.recv_attn_output() | ||
| hidden_states = intermediate_tensors["hidden_states"] | ||
|
|
||
| # ae_group = get_afd_group() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
apply to all, del all unnecessary comment
| expert_params_mapping = FusedMoE.make_expert_params_mapping( | ||
| ckpt_gate_proj_name="gate_proj", | ||
| ckpt_down_proj_name="down_proj", | ||
| ckpt_up_proj_name="up_proj", | ||
| num_experts=self.config.n_routed_experts, | ||
| num_redundant_experts=self.num_redundant_experts) | ||
| num_redundant_experts=vllm_config.parallel_config.num_redundant_experts) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why we need to change this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self.num_redundant_experts comes from an example_moe, which does not exist in attn part. we have to get num_redundant_experts from vllm config to avoid error
| return group_metadata | ||
|
|
||
|
|
||
| class FFNModelRunner(GPUModelRunner): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
class FFNModelRunner(GPUModelRunner):
def __init__(self, vllm_config: VllmConfig, device: torch.device):
super().__init__(vllm_config=vllm_config, device=device)
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
self._shutdown_event = threading.Event() # Add shutdown mechanism
def execute_model(self):
print('ffn forward begin')
try:
with set_forward_context(None, self.vllm_config):
while not self._shutdown_event.is_set(): # ✅ Add exit condition
try:
layers_num = len(self.model.model.layers)
for i in range(layers_num):
if self._shutdown_event.is_set(): # Check for shutdown
break
self.model.model.layers[i].forward_ffn()
# Add small delay to prevent busy waiting
time.sleep(0.001)
except Exception as e:
logger.error(f"Error in FFN execution: {e}")
break
except Exception as e:
logger.error(f"FFN Model Runner failed: {e}")
raise
def shutdown(self):
"""Gracefully shutdown the FFN model runner."""
self._shutdown_event.set()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
try-except doesn't align with vllm's model runner
| super().__init__(process_group) | ||
| self.process_group = process_group | ||
|
|
||
| def send_attn_output(self, intermediate_tensors: IntermediateTensors): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
class ncclconnector(AFDConnectorBase):
def send_attn_output(self, intermediate_tensors: IntermediateTensors):
"""Send attention output with proper error handling."""
try:
self.process_group.send_tensor_dict(
intermediate_tensors.tensors,
dst=0, # ✅ Uncomment and fix destination
all_gather_group=None,
timeout=timedelta(seconds=30) # ✅ Add timeout
)
except Exception as e:
logger.error(f"Failed to send attention output: {e}")
raise RuntimeError(f"Communication error: {e}")
def recv_attn_output(self) -> IntermediateTensors:
"""Receive attention output with proper error handling."""
try:
intermediate_tensors = self.process_group.recv_tensor_dict(
src=0, # ✅ Uncomment and fix source
all_gather_group=None,
timeout=timedelta(seconds=30) # ✅ Add timeout
)
return IntermediateTensors(intermediate_tensors)
except Exception as e:
logger.error(f"Failed to receive attention output: {e}")
raise RuntimeError(f"Communication error: {e}")
vllm/v1/worker/gpu_worker.py
Outdated
| role = self.vllm_config.additional_config.get("role", None) | ||
| logger.info("AFD worker building") | ||
|
|
||
| ffn_size = self.vllm_config.additional_config.get("ffn_size") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def create_worker(vllm_config, rank, distributed_init_method, is_driver_worker: bool = True):
# ✅ Add configuration validation
additional_config = vllm_config.additional_config
ffn_size = additional_config.get("ffn_size")
attn_size = additional_config.get("attn_size")
if ffn_size is None or attn_size is None:
raise ValueError("ffn_size and attn_size must be specified in additional_config")
if not isinstance(ffn_size, int) or not isinstance(attn_size, int):
raise ValueError("ffn_size and attn_size must be integers")
if ffn_size <= 0 or attn_size <= 0:
raise ValueError("ffn_size and attn_size must be positive integers")
Signed-off-by: czrz <[email protected]>
|
This pull request has been automatically marked as stale because it has not had any activity within 90 days. It will be automatically closed if no further activity occurs within 30 days. Leave a comment if you feel this pull request should remain open. Thank you! |
Purpose
Test Plan
Test Result
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.