-
Notifications
You must be signed in to change notification settings - Fork 3k
[ckpt] feat: add kimi ckpt engine backend #4954
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a new checkpoint engine backend, kimi_ckpt_engine, designed to support both GPU and Huawei Ascend NPU environments. The implementation is comprehensive, including the core engine logic, integration with the existing checkpointing framework, and a new test suite. My review focuses on ensuring correctness and maintainability. I've identified a critical thread-safety issue in the weight sending logic that could lead to data corruption. Additionally, I've suggested a minor but important rename in the test suite to improve clarity.
| def offload_cpu(named_tensors: dict[str, torch.Tensor], name: str, tensor: torch.Tensor): | ||
| named_tensors[name] = tensor.to("cpu", non_blocking=True) | ||
|
|
||
| start_time = time.time() | ||
| named_tensors = {} | ||
| for named_tensors_gpu in ckpt_get_named_tensor_buckets( | ||
| weights, self.bucket_size, self.train_world_size, self.rank, self.rollout_dtype | ||
| ): | ||
| with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor: | ||
| futures = [ | ||
| executor.submit( | ||
| offload_cpu, | ||
| named_tensors, | ||
| name, | ||
| tensor, | ||
| ) | ||
| for name, tensor in named_tensors_gpu.items() | ||
| ] | ||
| for future in concurrent.futures.as_completed(futures): | ||
| future.result() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The use of ThreadPoolExecutor to concurrently update the named_tensors dictionary is not thread-safe. Dictionaries in Python are not safe for concurrent writes from multiple threads. Even if the keys are unique, internal operations like resizing can lead to race conditions, data corruption, or intermittent crashes.
To ensure thread safety, it's better to have each thread return its result independently and then collect and update the dictionary in the main thread.
| def offload_cpu(named_tensors: dict[str, torch.Tensor], name: str, tensor: torch.Tensor): | |
| named_tensors[name] = tensor.to("cpu", non_blocking=True) | |
| start_time = time.time() | |
| named_tensors = {} | |
| for named_tensors_gpu in ckpt_get_named_tensor_buckets( | |
| weights, self.bucket_size, self.train_world_size, self.rank, self.rollout_dtype | |
| ): | |
| with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor: | |
| futures = [ | |
| executor.submit( | |
| offload_cpu, | |
| named_tensors, | |
| name, | |
| tensor, | |
| ) | |
| for name, tensor in named_tensors_gpu.items() | |
| ] | |
| for future in concurrent.futures.as_completed(futures): | |
| future.result() | |
| def offload_cpu(name: str, tensor: torch.Tensor) -> tuple[str, torch.Tensor]: | |
| return name, tensor.to("cpu", non_blocking=True) | |
| start_time = time.time() | |
| named_tensors = {} | |
| for named_tensors_gpu in ckpt_get_named_tensor_buckets( | |
| weights, self.bucket_size, self.train_world_size, self.rank, self.rollout_dtype | |
| ): | |
| with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor: | |
| futures = [ | |
| executor.submit( | |
| offload_cpu, | |
| name, | |
| tensor, | |
| ) | |
| for name, tensor in named_tensors_gpu.items() | |
| ] | |
| for future in concurrent.futures.as_completed(futures): | |
| name, tensor_cpu = future.result() | |
| named_tensors[name] = tensor_cpu |
What does this PR do?
Based on ckpt engine abstraction add checkpoint-engine abstraction, in this PR, we add kimi_ckpt_engine backend to support both GPU and huawei Ascend NPU.
Since establishing communication domains across trainer and rollout workers is required, this PR also depends on the newly added communication domain support in kimi_ckpt_engine.
TODO:
Checklist Before Starting
[{modules}] {type}: {description}(This will be checked by the CI){modules}includefsdp,megatron,veomni,sglang,vllm,rollout,trainer,ci,training_utils,recipe,hardware,deployment,ray,worker,single_controller,misc,perf,model,algo,env,tool,ckpt,doc,data,cfg,reward,like[megatron, fsdp, doc]{type}is infeat,fix,refactor,chore,test[BREAKING]to the beginning of the title.[BREAKING][fsdp, megatron] feat: dynamic batchingTest
We have verified the functionality on both GPU and NPU. Performance benchmarks on a 32 NPU environment show promising results; however, due to a lack of available GPU resources, performance data for GPU is still pending.
Checklist Before Submitting
Important
Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review.
pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=alwaysci-requestchannel in theverlSlack workspace. (If not accessible, please try the Feishu group (飞书群).)recipesubmodule, please also update the reference to the submodule commit viagit submodule update --remoteorcd recipe && git pull origin main.