Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/test/test_select.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -60,4 +60,4 @@ components_cfg_file: src/dataflex/configs/components.yaml
component_name: delta_loss # 选择组件名称,对应 components_cfg_file 中定义的组件
warmup_step: 4
update_step: 3
update_times: 4
update_times: 4 # Updates per Flex epoch; repeats by num_train_epochs
2 changes: 1 addition & 1 deletion examples/test/test_weight.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -59,4 +59,4 @@ train_type: dynamic_weight # 选择训练器类型。可选值包括:
components_cfg_file: src/dataflex/configs/components.yaml
component_name: custom # 选择组件名称,对应 components_cfg_file 中定义的组件
warmup_step: 1
train_step: 3 # 总训练步数(包括warm_up)
train_step: 3 # Total steps; overrides num_train_epochs
4 changes: 2 additions & 2 deletions examples/train_lora/selectors/custom.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ components_cfg_file: src/dataflex/configs/components.yaml
component_name: custom # 选择组件名称,对应 components_cfg_file 中定义的组件
warmup_step: 10
update_step: 10
update_times: 2
update_times: 2 # Updates per Flex epoch; repeats by num_train_epochs

## eval
# val_size: 0.001
Expand All @@ -71,4 +71,4 @@ eval_steps: 10
# early_stopping_steps: 3
# early_stopping_min_delta: 0.01

# FORCE_TORCHRUN=1 DISABLE_VERSION_CHECK=1 dataflex-cli train examples/train_lora/selectors/custom.yaml
# FORCE_TORCHRUN=1 DISABLE_VERSION_CHECK=1 dataflex-cli train examples/train_lora/selectors/custom.yaml
4 changes: 2 additions & 2 deletions examples/train_lora/selectors/delta_loss.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ components_cfg_file: src/dataflex/configs/components.yaml
component_name: delta_loss # 选择组件名称,对应 components_cfg_file 中定义的组件
warmup_step: 10
update_step: 10
update_times: 2
update_times: 2 # Updates per Flex epoch; repeats by num_train_epochs

## eval
# val_size: 0.001
Expand All @@ -71,4 +71,4 @@ eval_steps: 10
# early_stopping_steps: 3
# early_stopping_min_delta: 0.01

# FORCE_TORCHRUN=1 DISABLE_VERSION_CHECK=1 dataflex-cli train examples/train_lora/selectors/delta_loss.yaml
# FORCE_TORCHRUN=1 DISABLE_VERSION_CHECK=1 dataflex-cli train examples/train_lora/selectors/delta_loss.yaml
4 changes: 2 additions & 2 deletions examples/train_lora/selectors/less.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ components_cfg_file: src/dataflex/configs/components.yaml
component_name: less # 选择组件名称,对应 components_cfg_file 中定义的组件
warmup_step: 10
update_step: 10
update_times: 2
update_times: 2 # Updates per Flex epoch; repeats by num_train_epochs

## eval
# val_size: 0.001
Expand All @@ -71,4 +71,4 @@ eval_steps: 10
# early_stopping_steps: 3
# early_stopping_min_delta: 0.01

# FORCE_TORCHRUN=1 DISABLE_VERSION_CHECK=1 dataflex-cli train examples/train_lora/selectors/less.yaml
# FORCE_TORCHRUN=1 DISABLE_VERSION_CHECK=1 dataflex-cli train examples/train_lora/selectors/less.yaml
4 changes: 2 additions & 2 deletions examples/train_lora/selectors/loss.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ components_cfg_file: src/dataflex/configs/components.yaml
component_name: loss # 选择组件名称,对应 components_cfg_file 中定义的组件
warmup_step: 10
update_step: 10
update_times: 2
update_times: 2 # Updates per Flex epoch; repeats by num_train_epochs

## eval
# val_size: 0.001
Expand All @@ -71,4 +71,4 @@ eval_steps: 10
# early_stopping_steps: 3
# early_stopping_min_delta: 0.01

# FORCE_TORCHRUN=1 DISABLE_VERSION_CHECK=1 dataflex-cli train examples/train_lora/selectors/loss.yaml
# FORCE_TORCHRUN=1 DISABLE_VERSION_CHECK=1 dataflex-cli train examples/train_lora/selectors/loss.yaml
4 changes: 2 additions & 2 deletions examples/train_lora/selectors/near.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ components_cfg_file: src/dataflex/configs/components.yaml
component_name: near # 选择组件名称,对应 components_cfg_file 中定义的组件
warmup_step: 10
update_step: 10
update_times: 2
update_times: 2 # Updates per Flex epoch; repeats by num_train_epochs

## eval
# val_size: 0.001
Expand All @@ -71,4 +71,4 @@ eval_steps: 10
# early_stopping_steps: 3
# early_stopping_min_delta: 0.01

# FORCE_TORCHRUN=1 DISABLE_VERSION_CHECK=1 dataflex-cli train examples/train_lora/selectors/near.yaml
# FORCE_TORCHRUN=1 DISABLE_VERSION_CHECK=1 dataflex-cli train examples/train_lora/selectors/near.yaml
4 changes: 2 additions & 2 deletions examples/train_lora/selectors/nice.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ components_cfg_file: src/dataflex/configs/components.yaml
component_name: nice # 选择组件名称,对应 components_cfg_file 中定义的组件
warmup_step: 10
update_step: 10
update_times: 2
update_times: 2 # Updates per Flex epoch; repeats by num_train_epochs

## eval
# val_size: 0.001
Expand All @@ -71,4 +71,4 @@ eval_steps: 10
# early_stopping_steps: 3
# early_stopping_min_delta: 0.01

# FORCE_TORCHRUN=1 DISABLE_VERSION_CHECK=1 dataflex-cli train examples/train_lora/selectors/nice.yaml
# FORCE_TORCHRUN=1 DISABLE_VERSION_CHECK=1 dataflex-cli train examples/train_lora/selectors/nice.yaml
4 changes: 2 additions & 2 deletions examples/train_lora/selectors/random.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ components_cfg_file: src/dataflex/configs/components.yaml
component_name: random # 选择组件名称,对应 components_cfg_file 中定义的组件
warmup_step: 10
update_step: 10
update_times: 2
update_times: 2 # Updates per Flex epoch; repeats by num_train_epochs

## eval
# val_size: 0.001
Expand All @@ -71,4 +71,4 @@ eval_steps: 10
# early_stopping_steps: 3
# early_stopping_min_delta: 0.01

# FORCE_TORCHRUN=1 DISABLE_VERSION_CHECK=1 dataflex-cli train examples/train_lora/selectors/random.yaml
# FORCE_TORCHRUN=1 DISABLE_VERSION_CHECK=1 dataflex-cli train examples/train_lora/selectors/random.yaml
2 changes: 1 addition & 1 deletion examples/train_lora/selectors/tsds.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ components_cfg_file: src/dataflex/configs/components.yaml
component_name: tsds # 选择组件名称,对应 components_cfg_file 中定义的组件
warmup_step: 10
update_step: 10
update_times: 2
update_times: 2 # Updates per Flex epoch; repeats by num_train_epochs

## eval
# val_size: 0.001
Expand Down
2 changes: 1 addition & 1 deletion examples/train_lora/weighters/custom.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -57,4 +57,4 @@ train_type: dynamic_weight # 选择训练器类型。可选值包括:
components_cfg_file: src/dataflex/configs/components.yaml
component_name: custom # 选择组件名称,对应 components_cfg_file 中定义的组件
warmup_step: 100
train_step: 500 # 总训练步数(包括warm_up)
train_step: 500 # Total steps; overrides num_train_epochs
2 changes: 1 addition & 1 deletion examples/train_lora/weighters/loss.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -57,4 +57,4 @@ train_type: dynamic_weight # 选择训练器类型。可选值包括:
components_cfg_file: src/dataflex/configs/components.yaml
component_name: loss # 选择组件名称,对应 components_cfg_file 中定义的组件
warmup_step: 100
train_step: 500 # 总训练步数(包括warm_up)
train_step: 500 # Total steps; overrides num_train_epochs
43 changes: 38 additions & 5 deletions skills/how_to_use.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,39 @@ A DataFlex training config is a standard LlamaFactory YAML with additional DataF
| `component_name` | str | Which algorithm to use, matching a key in `components_cfg_file` |
| `warmup_step` | int | Number of warmup steps before dynamic behavior kicks in |
| `update_step` | int | Interval (in steps) between dynamic updates |
| `update_times` | int | Total number of dynamic updates. Use `-1` for continuous updates until training ends |
| `update_times` | int | For `dynamic_select`, number of dynamic updates per Flex epoch. Use `-1` for dataset-sized epochs with continuous updates |
| `static_mix` | bool | If `true` with `dynamic_mix`, use fixed proportions (no dynamic updates). Used in DoReMi Step 1 & 3 |
| `train_step` | int | Total training steps. Required for `dynamic_weight` and `dynamic_mix` with `static_mix: true` |
| `train_step` | int | Optional total training steps. If set to a positive value, it overrides `num_train_epochs`-derived steps |

### Step and Epoch Semantics

DataFlex dynamic trainers run a step-based training loop internally. Prefer `eval_strategy: "steps"` / `save_strategy: "steps"` or disable them with `"no"` when using dynamic training. Epoch-based evaluation or saving depends on the internal step-to-epoch bookkeeping and may not align with Flex epoch boundaries.

For `dynamic_select`, `num_train_epochs` repeats Flex epochs. One Flex epoch contains:

```text
warmup_step + update_step * update_times
```

For example:

```yaml
warmup_step: 10
update_step: 10
update_times: 2
num_train_epochs: 3
```

This runs `3 * (10 + 10 * 2) = 90` optimization steps. To keep the old single-Flex-epoch behavior, set:

```yaml
num_train_epochs: 1.0
```

If `train_step > 0`, DataFlex uses `train_step` as the exact total number of optimization steps and does not derive total steps from `num_train_epochs`.
For multi-epoch tests, make sure example configs do not leave a positive `train_step`; pass `train_step=0` on the CLI if you want `num_train_epochs` to control training length.

For `dynamic_weight`, `warmup_step` is a global step threshold. Reweighting starts when `global_step >= warmup_step` and does not reset at epoch boundaries.

### Data Mixture Fields (for `dynamic_mix` only)

Expand All @@ -83,12 +113,14 @@ component_name: less # choices: less, nice, loss, delta_loss, tsds, near,
warmup_step: 10
update_step: 10
update_times: 2
num_train_epochs: 1.0
```

**How it works:**
1. Warmup phase: train on randomly sampled data for `warmup_step` steps.
2. At `warmup_step` and every `update_step` steps: pause training, run the selector to pick new samples, rebuild the dataloader.
3. Total steps = `warmup_step + update_step * update_times`.
3. One Flex epoch has `warmup_step + update_step * update_times` steps.
4. Total steps are derived from `num_train_epochs` unless `train_step > 0`.

**Example:**
```bash
Expand Down Expand Up @@ -138,13 +170,14 @@ train_type: dynamic_weight
components_cfg_file: src/dataflex/configs/components.yaml
component_name: loss # choices: loss, custom
warmup_step: 100
train_step: 500
train_step: 500 # fixed-step example; set to 0 for num_train_epochs-based multi-epoch runs
```

**How it works:**
1. Standard training for `warmup_step` steps (no reweighting).
2. After warmup: each training step computes per-sample losses and applies the weighting strategy.
3. Total steps = `train_step`.
3. `warmup_step` is measured in global optimization steps and does not reset per epoch.
4. If `train_step > 0`, total steps = `train_step`; otherwise total steps follow the standard `num_train_epochs` calculation.

**Example:**
```bash
Expand Down
6 changes: 3 additions & 3 deletions src/dataflex/train/hparams/dynamic_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,15 +547,15 @@ class DynamicFinetuningArguments(
)
update_times: int = field(
default=1,
metadata={"help": "Total update times during the whole dynamic training progress for dynamic select or mix training"},
metadata={"help": "Update times per Flex epoch for dynamic selection. Use <= 0 for no fixed update count."},
)
static_mix: bool = field(
default=False,
metadata={"help": "Whether or not to fix the static mix ratio in dynamic mix training."},
)
train_step: int = field(
default=0,
metadata={"help": "Only used in dynamic weight trainer and mix trainer (static_mix=True). Total training steps (including warmup)."},
metadata={"help": "Optional total training steps. If set, overrides num_train_epochs."},
)
freeze_gate: bool = field(
default=False,
Expand Down Expand Up @@ -646,4 +646,4 @@ def __init__(self, early_stopping_patience: int):
early_stopping_threshold=min_delta,
)

tuner.EarlyStoppingCallback = DataFlexEarlyStoppingCallback
tuner.EarlyStoppingCallback = DataFlexEarlyStoppingCallback
79 changes: 66 additions & 13 deletions src/dataflex/train/trainer/select_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,23 @@ def _inner_training_loop(
len_dataloader, # 等于数据集长度/worldsize/micro_batchsize
max_steps,
) = self.set_initial_training_values(args, train_dataloader, total_train_batch_size)
max_steps = (self.finetuning_args.warmup_step + self.finetuning_args.update_step * self.finetuning_args.update_times)
# Issue #49: support num_train_epochs while keeping dynamic-step training
epoch_update_steps = None
if self.finetuning_args.update_times > 0:
epoch_update_steps = max(
1,
self.finetuning_args.warmup_step + self.finetuning_args.update_step * self.finetuning_args.update_times,
)
if self.finetuning_args.train_step > 0:
max_steps = self.finetuning_args.train_step
elif epoch_update_steps is not None:
max_steps = int(np.ceil(args.num_train_epochs * epoch_update_steps))
else:
steps_per_epoch = max(1, int(np.ceil(len(self.train_dataset) / total_train_batch_size)))
max_steps = int(np.ceil(args.num_train_epochs * steps_per_epoch))
epoch_update_steps = steps_per_epoch
if self.finetuning_args.train_step > 0 and args.num_train_epochs != 1:
logger.warning("[Dataflex] train_step is set; num_train_epochs will be ignored.")
epoch_based = False
logger.info(f"[Dataflex]Set max train steps to {max_steps}")
logger.info(f"[Dataflex]Set epoch_based = False")
Expand Down Expand Up @@ -569,7 +585,6 @@ def _inner_training_loop(
if args.eval_on_start:
self._evaluate(trial, ignore_keys_for_eval, skip_scheduler=True)

# 放弃epoch逻辑,相当于只训练一个epoch,通过step来训练
current_dataloader = train_dataloader
epoch = 0
if self.state.global_step < self.finetuning_args.warmup_step:
Expand All @@ -581,7 +596,11 @@ def _inner_training_loop(
if args.past_index >= 0:
self._past = None

steps_in_epoch = max_steps * args.gradient_accumulation_steps
total_training_batches = max_steps * args.gradient_accumulation_steps
if epoch_update_steps is not None:
steps_in_epoch = epoch_update_steps * args.gradient_accumulation_steps
else:
steps_in_epoch = len_dataloader if len_dataloader is not None else total_training_batches
self.control = self.callback_handler.on_epoch_begin(args, self.state, self.control)

if resume_from_checkpoint is not None and steps_trained_in_current_epoch == 0:
Expand All @@ -603,7 +622,7 @@ def _inner_training_loop(
remainder = args.gradient_accumulation_steps
update_step = -1
# 一个epoch中的模型总更新次数
total_updates = steps_in_epoch // args.gradient_accumulation_steps + 1
total_updates = total_training_batches // args.gradient_accumulation_steps + 1
if args.gradient_accumulation_steps == 1:
total_updates -= 1
for _ in range(total_updates):
Expand All @@ -612,12 +631,22 @@ def _inner_training_loop(
num_batches = args.gradient_accumulation_steps if update_step != (total_updates - 1) else remainder

batch_samples, num_items_in_batch = self.get_batch_samples(current_iterator, num_batches, args.device)
if len(batch_samples) == 0:
epoch += 1
if hasattr(current_dataloader, "set_epoch"):
current_dataloader.set_epoch(epoch)
current_iterator = iter(current_dataloader)
batch_samples, num_items_in_batch = self.get_batch_samples(current_iterator, num_batches, args.device)
if len(batch_samples) == 0:
self.control.should_training_stop = True
break
# 遍历当前批次的样本
for i, inputs in enumerate(batch_samples):
step += 1 # 每次迭代时增加全局步数

# 判断是否达到同步步数,或者是当前epoch的最后一个步数
do_sync_step = (step + 1) % args.gradient_accumulation_steps == 0 or (step + 1) == steps_in_epoch
is_epoch_end = (step + 1 + steps_skipped) % steps_in_epoch == 0
do_sync_step = (step + 1) % args.gradient_accumulation_steps == 0 or is_epoch_end

# 由于我们使用了预取(prefetching),我们需要手动设置同步梯度
self.accelerator.gradient_state._set_sync_gradients(do_sync_step)
Expand Down Expand Up @@ -750,26 +779,50 @@ def _inner_training_loop(
# learning_rate=learning_rate,
)

# 动态训练更新
step_in_epoch = (
self.state.global_step % epoch_update_steps
if epoch_update_steps is not None
else self.state.global_step
)
if (
step_in_epoch == 0
and epoch_update_steps is not None
and self.state.global_step > 0
and self.state.global_step < max_steps
):
self.accelerator.wait_for_everyone()
torch.cuda.empty_cache()
if dist.is_initialized():
dist.barrier()

warmup_indices = self.selector.warmup(total_warmup_samples, replacement=True)
current_dataloader = self.get_train_dataloader(warmup_indices)
current_iterator = iter(current_dataloader)
elif (
self.state.global_step < max_steps and (
self.state.global_step == self.finetuning_args.warmup_step or
(self.state.global_step > self.finetuning_args.warmup_step and
(self.state.global_step - self.finetuning_args.warmup_step) % self.finetuning_args.update_step == 0))
step_in_epoch == self.finetuning_args.warmup_step or
(step_in_epoch > self.finetuning_args.warmup_step and
(step_in_epoch - self.finetuning_args.warmup_step) % self.finetuning_args.update_step == 0))
):
self.accelerator.wait_for_everyone()
torch.cuda.empty_cache()
if dist.is_initialized():
dist.barrier()

current_update_times = (self.state.global_step - self.finetuning_args.warmup_step) // self.finetuning_args.update_step + 1
current_update_times = (step_in_epoch - self.finetuning_args.warmup_step) // self.finetuning_args.update_step + 1
effective_update_times = self.finetuning_args.update_times
if effective_update_times <= 0 and epoch_update_steps is not None:
effective_update_times = max(
1,
int(np.ceil(max(epoch_update_steps - self.finetuning_args.warmup_step, 0) / self.finetuning_args.update_step)),
)
logger.info(f"[Dataflex] Model training paused, starting the {current_update_times}th dynamic data selection...")
# 这里传一些特定的selector参数
extra_args = dict(
optimizer_state=self.optimizer.state,
scheduler_state=self.lr_scheduler.state_dict(),
current_update_times=current_update_times,
update_times=self.finetuning_args.update_times,
update_times=effective_update_times,
tokenizer=self.tokenizer,
)
new_indices = self.selector.select(
Expand All @@ -780,8 +833,8 @@ def _inner_training_loop(
)

# 每个进程根据 local_indices 构造 dataloader
train_loader = self.get_train_dataloader(indices=new_indices)
current_iterator = iter(train_loader)
current_dataloader = self.get_train_dataloader(indices=new_indices)
current_iterator = iter(current_dataloader)

if self.accelerator.is_main_process:
logger.info(f"[Dataflex] Updated dataloader at step {self.state.global_step}, {len(new_indices)} samples in total.")
Expand Down
Loading
Loading