Skip to content

Commit 20eec35

Browse files
authored
Fix PyTorch autocast/mixed-precision (#952)
1 parent 256e403 commit 20eec35

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

thinc/shims/pytorch.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ def predict(self, inputs: ArgsKwargs) -> Any:
111111
"""
112112
self._model.eval()
113113
with torch.no_grad():
114-
with torch.amp.autocast("cuda", self._mixed_precision):
114+
with torch.amp.autocast("cuda", enabled=self._mixed_precision):
115115
outputs = self._model(*inputs.args, **inputs.kwargs)
116116
self._model.train()
117117
return outputs
@@ -125,7 +125,7 @@ def begin_update(self, inputs: ArgsKwargs):
125125
self._model.train()
126126

127127
# Note: mixed-precision autocast must not be applied to backprop.
128-
with torch.amp.autocast("cuda", self._mixed_precision):
128+
with torch.amp.autocast("cuda", enabled=self._mixed_precision):
129129
output = self._model(*inputs.args, **inputs.kwargs)
130130

131131
def backprop(grads):

0 commit comments

Comments
 (0)