Skip to content

Commit

Permalink
[Optimization] Support context_parallel_spliter for cp
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangyuqin1998 committed Feb 19, 2025
1 parent 08183d7 commit b1fb8b8
Showing 1 changed file with 10 additions and 1 deletion.
11 changes: 10 additions & 1 deletion paddlenlp/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,13 +460,17 @@ def fn(layer):
else ["labels"]
)
self.label_names = default_label_names if self.args.label_names is None else self.args.label_names
self.context_parallel_spliter = None

self.control = self.callback_handler.on_init_end(self.args, self.state, self.control)
self.print_config()

# very last
self._memory_tracker.stop_and_update_metrics()

def set_context_parallel_spliter(self, context_parallel_spliter):
self.context_parallel_spliter = context_parallel_spliter

Check warning on line 472 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L472

Added line #L472 was not covered by tests

def _wrap_amp_model(self, args, model):
logger.info("Using half precision")
self.enable_autocast_context_manager = True
Expand Down Expand Up @@ -1020,7 +1024,12 @@ def _inner_training_loop(
if self.args.use_hybrid_parallel and self.args.sep_parallel_degree > 1:
inputs = split_inputs_sequence_dim(inputs)
if self.args.use_hybrid_parallel and self.args.context_parallel_degree > 1:
inputs = split_inputs_sequence_dim_load_balance(inputs)
context_parallel_spliter = (

Check warning on line 1027 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L1027

Added line #L1027 was not covered by tests
split_inputs_sequence_dim_load_balance
if self.context_parallel_spliter is None
else self.context_parallel_spliter
)
inputs = context_parallel_spliter(inputs)

Check warning on line 1032 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L1032

Added line #L1032 was not covered by tests
if self.args.ignore_data_skip:
self.timers and self.timers("read-data").stop()

Expand Down

0 comments on commit b1fb8b8

Please sign in to comment.