@@ -106,7 +106,9 @@ class ModelArguments:
106
106
107
107
s2 : Optional [bool ] = field (default = False )
108
108
s2_scales : Optional [str ] = field (default = "336,672,1008" )
109
-
109
+
110
+ use_pos_skipping : Optional [bool ] = field (default = False )
111
+ pos_skipping_range : Optional [int ] = field (default = 4096 )
110
112
111
113
@dataclass
112
114
class DataArguments :
@@ -1222,11 +1224,24 @@ def get_model(model_args, training_args, bnb_model_from_pretrained_args):
1222
1224
1223
1225
customized_kwargs = dict ()
1224
1226
customized_kwargs .update (bnb_model_from_pretrained_args )
1225
-
1226
- overwrite_config = {}
1227
1227
cfg_pretrained = None
1228
- if model_args .rope_scaling_factor is not None and model_args .rope_scaling_type is not None :
1228
+
1229
+ overwrite_config = {}
1230
+ if any ([
1231
+ model_args .rope_scaling_factor is not None ,
1232
+ model_args .rope_scaling_type is not None ,
1233
+ model_args .mm_spatial_pool_stride is not None ,
1234
+ model_args .mm_spatial_pool_out_channels is not None ,
1235
+ model_args .mm_spatial_pool_mode is not None ,
1236
+ model_args .mm_resampler_type is not None
1237
+ ]):
1229
1238
cfg_pretrained = AutoConfig .from_pretrained (model_args .model_name_or_path )
1239
+
1240
+ if model_args .use_pos_skipping is not None and model_args .pos_skipping_range is not None :
1241
+ overwrite_config ["use_pos_skipping" ] = model_args .use_pos_skipping
1242
+ overwrite_config ["pos_skipping_range" ] = model_args .pos_skipping_range
1243
+
1244
+ if model_args .rope_scaling_factor is not None and model_args .rope_scaling_type is not None :
1230
1245
overwrite_config ["rope_scaling" ] = {
1231
1246
"factor" : model_args .rope_scaling_factor ,
1232
1247
"type" : model_args .rope_scaling_type ,
@@ -1247,8 +1262,7 @@ def get_model(model_args, training_args, bnb_model_from_pretrained_args):
1247
1262
overwrite_config ["mm_spatial_pool_mode" ] = model_args .mm_spatial_pool_mode
1248
1263
1249
1264
if overwrite_config :
1250
- if cfg_pretrained is None :
1251
- cfg_pretrained = AutoConfig .from_pretrained (model_args .model_name_or_path )
1265
+ assert cfg_pretrained is not None , "cfg_pretrained is None"
1252
1266
1253
1267
rank0_print (f"Overwriting config with { overwrite_config } " )
1254
1268
for k , v in overwrite_config .items ():
0 commit comments