Skip to content

Commit 08ef0e6

Browse files
authored
Update run_sft.py
1 parent 02b64d7 commit 08ef0e6

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

scripts/run_sft.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ def main():
113113
model_kwargs = dict(
114114
revision=model_args.model_revision,
115115
trust_remote_code=model_args.trust_remote_code,
116-
attn_implementation='flash_attention_2',
116+
attn_implementation=model_args.attn_implementation,
117117
torch_dtype=torch_dtype,
118118
use_cache=False if training_args.gradient_checkpointing else True,
119119
device_map=get_kbit_device_map() if quantization_config is not None else None,
@@ -230,4 +230,4 @@ def main():
230230

231231

232232
if __name__ == "__main__":
233-
main()
233+
main()

0 commit comments

Comments
 (0)