diff --git a/explorations/asic_quant_combinations.yaml b/explorations/asic_quant_combinations.yaml new file mode 100644 index 0000000000..ae29206990 --- /dev/null +++ b/explorations/asic_quant_combinations.yaml @@ -0,0 +1,110 @@ +# ── asic_quant_combinations.yaml ────────────────────────────────────────────── + +named_static_groups: + - named_group: "slerp" + named_group_settings: + mlp_residual_combination: ["slerp"] + attn_residual_combination: ["slerp"] + - named_group: "add" + named_group_settings: + mlp_residual_combination: ["add"] + attn_residual_combination: ["add"] + - named_group: "lerp" + named_group_settings: + mlp_residual_combination: ["lerp"] + attn_residual_combination: ["lerp"] + - named_group: "rmsnorm_no_wte" + named_group_settings: + norm_variant_attn: ["rmsnorm"] + norm_variant_output: ["rmsnorm"] + - named_group: "rmsnorm_with_wte" + named_group_settings: + norm_variant_wte: ["rmsnorm"] + norm_variant_attn: ["rmsnorm"] + norm_variant_output: ["rmsnorm"] + - named_group: "offchip_peri_ln" + named_group_settings: + use_offchip_peri_ln: [true] + - named_group: "onchip_peri_ln" + named_group_settings: + use_peri_ln: [true] + - named_group: "no_peri_ln" + named_group_settings: + use_peri_ln: [false] + +common_group: + use_edgellm_asic: [true] + mlp_variant: ["edgellm_asic_mlp"] + attention_variant: ["edgellm_asic_attn"] + + # --- Training & dataset --- + dataset: ["minipile"] + max_iters: [20000] + full_quant_iteration: [10000] + + # --- Model architecture --- + n_layer: [8] + n_head: [8] + n_embd: [512] + block_size: [256] + batch_size: [64] + bias: [false] + dtype: ["bfloat16"] + + # --- Quantization settings --- + quantization_warmup_iters: [0] + quant_scheduler: ["linear"] + linear_variant_attn: ["quantized_linear"] + linear_variant_mlp: ["quantized_linear"] + quantize_linear_method: ["symmetric_quant"] + activations_quant_method: ["symmetric_quant"] + + # --- ASIC quantization controls --- + quantize_attn_act: [true] + quantize_mlp_act: [true] + quantize_asic_prenorm: [true] + quantize_asic_attn_softmax_denom: [true] + quantize_asic_attn_softmax_denom_bits: [16] + quantize_asic_attn_softmax_numerator: [true] + quantize_asic_attn_softmax_numerator_bits: [8] + + # --- Normalization / regularization --- + use_flash_norm: [true] + use_pre_ln: [false] + dropout: [0.0] + + # --- Optimization --- + grad_clip: [1.0] + beta1: [0.95] + beta2: [0.95] + weight_decay: [0.05] + learning_rate: [0.00075] + + # --- Sampling / evaluation --- + max_sample_tokens: [100] + sample_each_eval: [true] + +named_variation_groups: + - named_group: "residual_combinations" + named_group_alternates: ["lerp", "add", "slerp"] + + - named_group: "wte_norms" + named_group_alternates: ["rmsnorm_no_wte", "rmsnorm_with_wte"] + + - named_group: "peri_ln_types" + named_group_alternates: ["offchip_peri_ln", "onchip_peri_ln", "no_peri_ln"] + +parameter_groups: + - out_dir: ["asic_quant_combinations"] + + named_group_variations: + - "residual_combinations" + - "wte_norms" + - "peri_ln_types" + + # --- Activation transition --- + use_gradual_activation: [true] + activation_start: ["gelu"] + activation_end: ["relu"] + activation_transition_start_iter: [0] + activation_transition_end_iter: [20000]