Skip to content

Commit

Permalink
doc: add fine-tuning config section in design details chapter.
Browse files Browse the repository at this point in the history
Signed-off-by: Electronic-Waste <[email protected]>
  • Loading branch information
Electronic-Waste committed Feb 1, 2025
1 parent c2f7307 commit 1cdf053
Showing 1 changed file with 175 additions and 1 deletion.
176 changes: 175 additions & 1 deletion docs/proposals/2401-llm-trainer-v2/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ TrainingClient().train(
),
trainer=Trainer(
fine_tuning_config=FineTuningConfig(
backend="huggingface",
framework="huggingface",
dataset_class="Instruction",
peft_config=LoraConfig(r=4),
sharding_config=FsdpConfig(...),
Expand Down Expand Up @@ -184,6 +184,180 @@ class InstructionDataset(Dataset, InitMethod):

```

### Fine-Tuning Config

We will add the fine-tuning configurations in the `fine_tuning_config` field in `Trainer` dataclass.

| Parameters | What is it? |
| - | - |
| framework | Framework for fine-tuning. |
| dataset_class | Dataset class adopted to fine-tune the LLM. |
| peft_config | Configuration for the PEFT(Parameter-Efficient Fine-Tuning), including Lora, AdapterPrompt, PrefixTuning, etc. |
| sharding_config | Configuration for sharding policy for distributed training, such as FSDP(Fully Shared Data Parallel) and ZeRO(Zero Redundancy Optimizer). |
| kwargs | Some other backend-specific and launch-CLI-specific parameters. |

```python
# FineTuningConfig DataClass
@dataclass
class FineTuningConfig:
framework: str = "huggingface"
dataset_class: Union[str, Dataset] = "InstructionDataset"
peft_config: Optional[Union[LoraConfig, QLoraConfig, AdapterConfig, PrefixConfig]] = None
sharding_config: Optional[Union[FsdpConfig, ZeroConfig]] = None
kwargs: Optional[Dict[str, str]] = None

```

The Python SDK will look like:

```python
job_id = TrainingClient().train(
dataset_config=HuggingFaceDatasetConfig(
storage_uri="tatsu-lab/alpaca",
),
trainer=Trainer(
fine_tuning_config=FineTuningConfig(
framework="huggingface",
dataset_class="InstructionDataset",
peft_config=LoraConfig(r=4),
sharding_config=FsdpConfig(...),
kwargs={},
),
num_nodes=5,
),
runtime_ref=llm_runtime,
)

```

#### LoRA Config

The *LoraConfig* represents the config of LoRA we use to fine-tune the model.

| Parameters | What is it? |
| - | - |
| r | The rank of the low rank decomposition. |
| lora_alpha | The scaling factor that adjusts the magnitude of the low-rank matrices’ output |
| lora_dropout | The probability of applying Dropout to the low rank updates |

#### QLoRA Config

The *QLoraConfig* represents the config of QLoRA we use to fine-tune the model.

| Parameters | What is it? |
| - | - |
| r | The rank of the low rank decomposition. |
| lora_alpha | The scaling factor that adjusts the magnitude of the low-rank matrices’ output |
| lora_dropout | The probability of applying Dropout to the low rank updates |
| quant_type | The quantization type, supporting nf4 and fp4 |
| use_double_quant | Whether to enable double quantization |
| compute_dtype | Actual data type in the computing phase |
| quant_storage | Actual data type in the storage phase |

```python
# QLoraConfig DataClass
@dataclass
class QLoraConfig:
r: Optional[int] = None
lora_alpha: Optional[int] = None
lora_dropout: Optional[float] = None
quant_type: str = "fp4" # fp4 or nf4
use_double_quant: bool = False
compute_dtype: torch.dtype = torch.bfloat16
quant_storage: torch.dtype = torch.bfloat16

```

#### AdapterPrompt Config(TBD)

The *AdapterConfig* represents the config of AdapterPrompt we use to fine-tune the model.

| Parameters | What is it? |
| - | - |
| adapter_len | The length of adapter |
| adapter_layers | The number of layers that we insert adapter |

```python
# AdapterConfig DataClass
@dataclass
class AdapterConfig:
adapter_len: int = 10
adapter_layers: int = 30

```

#### PrefixTuning Config

The *PrefixConfig* represents the config of PrefixTuning we use to fine-tune the model.

| Parameters | What is it? |
| - | - |
| num_virtual_tokens | The number of virtual tokens |

```python
# PrefixConfig DataClass
@dataclass
class PrefixConfig:
num_virtual_tokens: int = 30

```

#### FSDP Config

The *FsdpConfig* represents the config of FSDP we use to fine-tune the model.

| Parameters | What is it? |
| - | - |
| mixed_precision | Whether to enable mixed precision training |
| use_fp16 | Whether to use FP16 during the mixed precision training |
| fsdp_cpu_offload | Whether to offload some weights and optimizer states to cpu |
| sharding_strategy | The sharding strategy for FSDP, e.g. FULL_SHARD (default), HYBRID_SHARD, SHARD_GRAD_OP, NO_SHARD. |
| hsdp | Whether to enable Hybrid Shard Data Parallel (HSDP) |
| sharding_group_size | Specify the GPU num in the sharding group when hsdp set to true |
| replica_group_size | The number of sharding groups |
| checkpoint_type | Specify the type of model checkpoints |
| fsdp_activation_checkpointing | Whether to enable Activation Checkpointing |

```python
# FsdpConfig DataClass
@dataclass
class FsdpConfig:
mixed_precision: bool = True
use_fp16: bool = False
fsdp_cpu_offload: bool=False
sharding_strategy: ShardingStrategy = ShardingStrategy.FULL_SHARD
hsdp: bool = False
sharding_group_size: int = 0 # requires hsdp to be set.
replica_group_size: int = 0 #requires hsdp to be set.
checkpoint_type: StateDictType = StateDictType.SHARDED_STATE_DICT
fsdp_activation_checkpointing: bool = True

```

#### ZeRO Config

The *ZeroConfig* represents the config of DeepSeed ZeRO we use to fine-tune the model.

| Parameters | What is it? |
| - | - |
| stage | The stage of DeepSeed ZeRO. |
| zero_cpu_offload | Whether to offload some weights and optimizer states to cpu |
| checkpoint_type | Specify the type of model checkpoints |
| mixed_precision | Whether to enable mixed precision training |
| use_fp16 | Whether to use FP16 during the mixed precision training |

```python
# ZeroConfig DataClass
@dataclass
class ZeroConfig:
stage: int = 0
zero_cpu_offload: bool = False
checkpoint_type: StateDictType = StateDictType.SHARDED_STATE_DICT
mixed_precision: bool = True
use_fp16 : bool = False

```

## Implementation History

- 2025-01-31: Create KEP-2401 doc
Expand Down

0 comments on commit 1cdf053

Please sign in to comment.