We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 256e403 commit 20eec35Copy full SHA for 20eec35
thinc/shims/pytorch.py
@@ -111,7 +111,7 @@ def predict(self, inputs: ArgsKwargs) -> Any:
111
"""
112
self._model.eval()
113
with torch.no_grad():
114
- with torch.amp.autocast("cuda", self._mixed_precision):
+ with torch.amp.autocast("cuda", enabled=self._mixed_precision):
115
outputs = self._model(*inputs.args, **inputs.kwargs)
116
self._model.train()
117
return outputs
@@ -125,7 +125,7 @@ def begin_update(self, inputs: ArgsKwargs):
125
126
127
# Note: mixed-precision autocast must not be applied to backprop.
128
129
output = self._model(*inputs.args, **inputs.kwargs)
130
131
def backprop(grads):
0 commit comments