File tree Expand file tree Collapse file tree 3 files changed +19
-1
lines changed Expand file tree Collapse file tree 3 files changed +19
-1
lines changed Original file line number Diff line number Diff line change @@ -152,7 +152,7 @@ def _compute_attention(
152
152
attention_mask = ops .expand_dims (attention_mask , axis = 1 )
153
153
attention_mask = ops .cast (attention_mask , dtype = "bool" )
154
154
# Only pass soft cap if needed as not all keras versions support.
155
- if self .logit_soft_cap :
155
+ if self .logit_soft_cap is not None :
156
156
kwargs = {"attn_logits_soft_cap" : self .logit_soft_cap }
157
157
else :
158
158
kwargs = {}
Original file line number Diff line number Diff line change @@ -67,6 +67,7 @@ def __init__(
67
67
self .rope_scaling_factor = rope_scaling_factor
68
68
self .use_sliding_window_attention = use_sliding_window_attention
69
69
self .sliding_window_size = sliding_window_size
70
+ self .logit_soft_cap = None
70
71
71
72
def build (self , inputs_shape ):
72
73
# Einsum variables:
Original file line number Diff line number Diff line change @@ -71,6 +71,23 @@ def fused_attention_op_available():
71
71
)
72
72
return False
73
73
return True
74
+ elif (
75
+ hasattr (keras .config , "is_flash_attention_enabled" )
76
+ and keras .config .backend () == "torch"
77
+ ):
78
+ try :
79
+ from torch .backends .cuda import SDPAParams as SDPAParams
80
+ from torch .backends .cuda import (
81
+ can_use_flash_attention as can_use_flash_attention ,
82
+ )
83
+ except ImportError :
84
+ logging .warning (
85
+ "Flash attention is not supported in your current PyTorch "
86
+ "version. Please update it by following the official guide: "
87
+ "https://pytorch.org/get-started/locally/"
88
+ )
89
+ return False
90
+ return True
74
91
else :
75
92
return False
76
93
You can’t perform that action at this time.
0 commit comments