We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 02b64d7 commit 08ef0e6Copy full SHA for 08ef0e6
scripts/run_sft.py
@@ -113,7 +113,7 @@ def main():
113
model_kwargs = dict(
114
revision=model_args.model_revision,
115
trust_remote_code=model_args.trust_remote_code,
116
- attn_implementation='flash_attention_2',
+ attn_implementation=model_args.attn_implementation,
117
torch_dtype=torch_dtype,
118
use_cache=False if training_args.gradient_checkpointing else True,
119
device_map=get_kbit_device_map() if quantization_config is not None else None,
@@ -230,4 +230,4 @@ def main():
230
231
232
if __name__ == "__main__":
233
- main()
+ main()
0 commit comments