@@ -102,15 +102,18 @@ def is_acceptable(tensor):
102
102
return True
103
103
104
104
105
- def set_flags (_enabled = None , _benchmark = None , _deterministic = None , _allow_tf32 = None ):
105
+ def set_flags (_enabled = None , _benchmark = None , _benchmark_limit = None , _deterministic = None , _allow_tf32 = None ):
106
106
orig_flags = (torch ._C ._get_cudnn_enabled (),
107
107
torch ._C ._get_cudnn_benchmark (),
108
+ None if not is_available () else torch ._C ._cuda_get_cudnn_benchmark_limit (),
108
109
torch ._C ._get_cudnn_deterministic (),
109
110
torch ._C ._get_cudnn_allow_tf32 ())
110
111
if _enabled is not None :
111
112
torch ._C ._set_cudnn_enabled (_enabled )
112
113
if _benchmark is not None :
113
114
torch ._C ._set_cudnn_benchmark (_benchmark )
115
+ if _benchmark_limit is not None and is_available ():
116
+ torch ._C ._cuda_set_cudnn_benchmark_limit (_benchmark_limit )
114
117
if _deterministic is not None :
115
118
torch ._C ._set_cudnn_deterministic (_deterministic )
116
119
if _allow_tf32 is not None :
@@ -119,9 +122,9 @@ def set_flags(_enabled=None, _benchmark=None, _deterministic=None, _allow_tf32=N
119
122
120
123
121
124
@contextmanager
122
- def flags (enabled = False , benchmark = False , deterministic = False , allow_tf32 = True ):
125
+ def flags (enabled = False , benchmark = False , benchmark_limit = 10 , deterministic = False , allow_tf32 = True ):
123
126
with __allow_nonbracketed_mutation ():
124
- orig_flags = set_flags (enabled , benchmark , deterministic , allow_tf32 )
127
+ orig_flags = set_flags (enabled , benchmark , benchmark_limit , deterministic , allow_tf32 )
125
128
try :
126
129
yield
127
130
finally :
@@ -141,6 +144,9 @@ def __init__(self, m, name):
141
144
enabled = ContextProp (torch ._C ._get_cudnn_enabled , torch ._C ._set_cudnn_enabled )
142
145
deterministic = ContextProp (torch ._C ._get_cudnn_deterministic , torch ._C ._set_cudnn_deterministic )
143
146
benchmark = ContextProp (torch ._C ._get_cudnn_benchmark , torch ._C ._set_cudnn_benchmark )
147
+ benchmark_limit = None
148
+ if is_available ():
149
+ benchmark_limit = ContextProp (torch ._C ._cuda_get_cudnn_benchmark_limit , torch ._C ._cuda_set_cudnn_benchmark_limit )
144
150
allow_tf32 = ContextProp (torch ._C ._get_cudnn_allow_tf32 , torch ._C ._set_cudnn_allow_tf32 )
145
151
146
152
# This is the sys.modules replacement trick, see
@@ -152,3 +158,4 @@ def __init__(self, m, name):
152
158
deterministic : bool
153
159
benchmark : bool
154
160
allow_tf32 : bool
161
+ benchmark_limit : int
0 commit comments