@@ -398,11 +398,15 @@ def delight_transformer_lm_wiki103(args):
398
398
scale_attn_drop = 0.1
399
399
scale_attn_drop_d_m = round (scale_attn_drop / delta_model_dimension , 2 )
400
400
scale_attn_drop_d_m = bound_function (0 , 0.1 , scale_attn_drop_d_m )
401
+
402
+ scale_delight_drop = 0.1
403
+ scale_delight_drop_d_m = round (scale_delight_drop / delta_model_dimension , 2 )
404
+ scale_delight_drop_d_m = bound_function (0 , 0.1 , scale_delight_drop_d_m )
401
405
402
406
args .dropout = getattr (args , "dropout" , scale_dropout_d_m )
403
407
args .delight_emb_dropout = getattr (args , "delight_emb_dropout" , 0.1 ) # We used a fixed value
404
408
args .attention_dropout = getattr (args , "attention_dropout" , scale_attn_drop_d_m )
405
- args .delight_dropout = getattr (args , "delight_dropout" , 0.0 )
409
+ args .delight_dropout = getattr (args , "delight_dropout" , scale_delight_drop_d_m )
406
410
args .pe_dropout = getattr (args , "pe_dropout" , 0.1 ) # We used a fixed value
407
411
args .activation_dropout = getattr (args , "activation_dropout" , 0.0 ) # we didn't use it
408
412
args .ffn_dropout = getattr (args , "ffn_dropout" , scale_dropout_d_m )
0 commit comments