diff --git a/docs/proposals/2401-llm-trainer-v2/README.md b/docs/proposals/2401-llm-trainer-v2/README.md index 811e3019a2..39e1b1d5ec 100644 --- a/docs/proposals/2401-llm-trainer-v2/README.md +++ b/docs/proposals/2401-llm-trainer-v2/README.md @@ -66,7 +66,7 @@ TrainingClient().train( ), trainer=Trainer( fine_tuning_config=FineTuningConfig( - backend="huggingface", + framework="huggingface", dataset_class="Instruction", peft_config=LoraConfig(r=4), sharding_config=FsdpConfig(...), @@ -184,6 +184,180 @@ class InstructionDataset(Dataset, InitMethod): ``` +### Fine-Tuning Config + +We will add the fine-tuning configurations in the `fine_tuning_config` field in `Trainer` dataclass. + +| Parameters | What is it? | +| - | - | +| framework | Framework for fine-tuning. | +| dataset_class | Dataset class adopted to fine-tune the LLM. | +| peft_config | Configuration for the PEFT(Parameter-Efficient Fine-Tuning), including Lora, AdapterPrompt, PrefixTuning, etc. | +| sharding_config | Configuration for sharding policy for distributed training, such as FSDP(Fully Shared Data Parallel) and ZeRO(Zero Redundancy Optimizer). | +| kwargs | Some other backend-specific and launch-CLI-specific parameters. | + +```python +# FineTuningConfig DataClass +@dataclass +class FineTuningConfig: + framework: str = "huggingface" + dataset_class: Union[str, Dataset] = "InstructionDataset" + peft_config: Optional[Union[LoraConfig, QLoraConfig, AdapterConfig, PrefixConfig]] = None + sharding_config: Optional[Union[FsdpConfig, ZeroConfig]] = None + kwargs: Optional[Dict[str, str]] = None + +``` + +The Python SDK will look like: + +```python +job_id = TrainingClient().train( + dataset_config=HuggingFaceDatasetConfig( + storage_uri="tatsu-lab/alpaca", + ), + trainer=Trainer( + fine_tuning_config=FineTuningConfig( + framework="huggingface", + dataset_class="InstructionDataset", + peft_config=LoraConfig(r=4), + sharding_config=FsdpConfig(...), + kwargs={}, + ), + num_nodes=5, + ), + runtime_ref=llm_runtime, +) + +``` + +#### LoRA Config + +The *LoraConfig* represents the config of LoRA we use to fine-tune the model. + +| Parameters | What is it? | +| - | - | +| r | The rank of the low rank decomposition. | +| lora_alpha | The scaling factor that adjusts the magnitude of the low-rank matrices’ output | +| lora_dropout | The probability of applying Dropout to the low rank updates | + +#### QLoRA Config + +The *QLoraConfig* represents the config of QLoRA we use to fine-tune the model. + +| Parameters | What is it? | +| - | - | +| r | The rank of the low rank decomposition. | +| lora_alpha | The scaling factor that adjusts the magnitude of the low-rank matrices’ output | +| lora_dropout | The probability of applying Dropout to the low rank updates | +| quant_type | The quantization type, supporting nf4 and fp4 | +| use_double_quant | Whether to enable double quantization | +| compute_dtype | Actual data type in the computing phase | +| quant_storage | Actual data type in the storage phase | + +```python +# QLoraConfig DataClass +@dataclass +class QLoraConfig: + r: Optional[int] = None + lora_alpha: Optional[int] = None + lora_dropout: Optional[float] = None + quant_type: str = "fp4" # fp4 or nf4 + use_double_quant: bool = False + compute_dtype: torch.dtype = torch.bfloat16 + quant_storage: torch.dtype = torch.bfloat16 + +``` + +#### AdapterPrompt Config(TBD) + +The *AdapterConfig* represents the config of AdapterPrompt we use to fine-tune the model. + +| Parameters | What is it? | +| - | - | +| adapter_len | The length of adapter | +| adapter_layers | The number of layers that we insert adapter | + +```python +# AdapterConfig DataClass +@dataclass +class AdapterConfig: + adapter_len: int = 10 + adapter_layers: int = 30 + +``` + +#### PrefixTuning Config + +The *PrefixConfig* represents the config of PrefixTuning we use to fine-tune the model. + +| Parameters | What is it? | +| - | - | +| num_virtual_tokens | The number of virtual tokens | + +```python +# PrefixConfig DataClass +@dataclass +class PrefixConfig: + num_virtual_tokens: int = 30 + +``` + +#### FSDP Config + +The *FsdpConfig* represents the config of FSDP we use to fine-tune the model. + +| Parameters | What is it? | +| - | - | +| mixed_precision | Whether to enable mixed precision training | +| use_fp16 | Whether to use FP16 during the mixed precision training | +| fsdp_cpu_offload | Whether to offload some weights and optimizer states to cpu | +| sharding_strategy | The sharding strategy for FSDP, e.g. FULL_SHARD (default), HYBRID_SHARD, SHARD_GRAD_OP, NO_SHARD. | +| hsdp | Whether to enable Hybrid Shard Data Parallel (HSDP) | +| sharding_group_size | Specify the GPU num in the sharding group when hsdp set to true | +| replica_group_size | The number of sharding groups | +| checkpoint_type | Specify the type of model checkpoints | +| fsdp_activation_checkpointing | Whether to enable Activation Checkpointing | + +```python +# FsdpConfig DataClass +@dataclass +class FsdpConfig: + mixed_precision: bool = True + use_fp16: bool = False + fsdp_cpu_offload: bool=False + sharding_strategy: ShardingStrategy = ShardingStrategy.FULL_SHARD + hsdp: bool = False + sharding_group_size: int = 0 # requires hsdp to be set. + replica_group_size: int = 0 #requires hsdp to be set. + checkpoint_type: StateDictType = StateDictType.SHARDED_STATE_DICT + fsdp_activation_checkpointing: bool = True + +``` + +#### ZeRO Config + +The *ZeroConfig* represents the config of DeepSeed ZeRO we use to fine-tune the model. + +| Parameters | What is it? | +| - | - | +| stage | The stage of DeepSeed ZeRO. | +| zero_cpu_offload | Whether to offload some weights and optimizer states to cpu | +| checkpoint_type | Specify the type of model checkpoints | +| mixed_precision | Whether to enable mixed precision training | +| use_fp16 | Whether to use FP16 during the mixed precision training | + +```python +# ZeroConfig DataClass +@dataclass +class ZeroConfig: + stage: int = 0 + zero_cpu_offload: bool = False + checkpoint_type: StateDictType = StateDictType.SHARDED_STATE_DICT + mixed_precision: bool = True + use_fp16 : bool = False + +``` + ## Implementation History - 2025-01-31: Create KEP-2401 doc