Skip to content

Commit 2e523ed

Browse files
davidberard98facebook-github-bot
authored andcommitted
[JIT] additional support for CallMethod with autocasting (pytorch#67925)
Summary: Pull Request resolved: pytorch#67925 Previously, the following would always fail, because autocasting would not be enabled in the called method: ``` torch.jit.script def fn(x, y): with autocast(): # CallMethod() to some method fn(x, y) ``` This allows the above, if autocasting is globally enabled, e.g. ``` torch.jit.script def fn(x, y): with autocast(): # CallMethod() to some method with autocast(): fn(x, y) # now ``` ghstack-source-id: 142667351 Test Plan: added test in test_jit_autocast.py Reviewed By: navahgar Differential Revision: D32214439 fbshipit-source-id: bb7db054e25e18f5e3d2fdb449c35b5942ab303e
1 parent f57c630 commit 2e523ed

File tree

2 files changed

+58
-1
lines changed

2 files changed

+58
-1
lines changed

test/test_jit_autocast.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -623,5 +623,42 @@ def t(t0, t1):
623623
self.assertEqual(t0.grad.dtype, ref_t0.grad.dtype)
624624
self.assertEqual(t1.grad.dtype, ref_t1.grad.dtype)
625625

626-
if __name__ == '__main__':
626+
@unittest.skipIf(not TEST_CUDA, "No cuda")
627+
def test_jit_call_method_under_autocast(self):
628+
@torch.jit.interface
629+
class Iface(torch.nn.Module):
630+
def forward(self, x, y) -> torch.Tensor:
631+
pass
632+
633+
class Impl(Iface):
634+
def forward(self, x, y):
635+
return torch.mm(x, y)
636+
637+
class Thing1(torch.nn.Module):
638+
impl: Iface
639+
640+
def forward(self, x, y):
641+
with torch.cuda.amp.autocast():
642+
a = torch.mm(x, y)
643+
b = self.impl.forward(a, x)
644+
return b
645+
646+
scripted_impl = torch.jit.script(Impl())
647+
thing1 = Thing1()
648+
thing1.impl = scripted_impl
649+
scripted_thing1 = torch.jit.script(thing1)
650+
x = torch.rand([2, 2])
651+
y = torch.rand([2, 2])
652+
653+
# make sure this doesn't throw an error
654+
with torch.cuda.amp.autocast():
655+
ans = scripted_thing1.forward(x, y)
656+
self.assertEqual(torch.mm(torch.mm(x, y), x), ans)
657+
658+
# sanity check: this isn't supported currently when global autocasting
659+
# isn't enabled
660+
self.assertRaises(RuntimeError, lambda: scripted_thing1.forward(x, y))
661+
662+
663+
if __name__ == "__main__":
627664
run_tests()

torch/csrc/jit/passes/autocast.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,16 @@ void handleBlock(Block* block, AutocastContext initial_state) {
242242
switch (node->kind()) {
243243
case prim::CallFunction:
244244
// TODO: limit it only to amp related node;
245+
if (current_state() == initial_state) {
246+
// if the current autocasting state is the same as the global state,
247+
// then autocasting will be done correctly on subsequent method and
248+
// function calls
249+
if (current_state()) {
250+
castTensorInputs(
251+
node, aten::_autocast_to_full_precision, current_state());
252+
}
253+
break;
254+
}
245255
TORCH_INTERNAL_ASSERT(
246256
!incompatible_amp.has_value() || incompatible_amp.value(),
247257
"Calls are not expected with AMP & JIT");
@@ -250,6 +260,16 @@ void handleBlock(Block* block, AutocastContext initial_state) {
250260

251261
case prim::CallMethod:
252262
// TODO: limit it only to amp related node;
263+
if (current_state() == initial_state) {
264+
// if the current autocasting state is the same as the global state,
265+
// then autocasting will be done correctly on subsequent method and
266+
// function calls
267+
if (current_state()) {
268+
castTensorInputs(
269+
node, aten::_autocast_to_full_precision, current_state());
270+
}
271+
break;
272+
}
253273
if (auto class_type = node->input(0)->type()->cast<ClassType>()) {
254274
const auto& name = node->s(attr::name);
255275
const auto& function = class_type->getMethod(name);

0 commit comments

Comments
 (0)