Skip to content

Commit b737629

Browse files
Michael Dagitsesfacebook-github-bot
Michael Dagitses
authored andcommitted
simplify op name determination into a single forward pass (pytorch#64261)
Summary: Pull Request resolved: pytorch#64261 Note that this does not preserve byte-for-byte compatibility with existing names. Test Plan: * Rely on CI to catch gross errors. * Merge after release cut to catch subtle issues. Reviewed By: albanD Differential Revision: D30700647 Pulled By: dagitses fbshipit-source-id: 7b02f34b8fae3041240cc78fbc6bcae498c3acd4
1 parent b2c7c1d commit b737629

File tree

6 files changed

+38
-60
lines changed

6 files changed

+38
-60
lines changed
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
CopyBackwards(None, AddBackward0(ExpandBackward(AccumulateGrad()), MulBackward0(ExpandBackward(AccumulateGrad()), AccumulateGrad())))
1+
CopyBackwards(None, AddBackward0(ExpandBackward0(AccumulateGrad()), MulBackward0(ExpandBackward0(AccumulateGrad()), AccumulateGrad())))
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
CopyBackwards(None, AddBackward0(MulBackward0(ExpandBackward(AccumulateGrad()), None), MulBackward0(ExpandBackward(AccumulateGrad()), AccumulateGrad())))
1+
CopyBackwards(None, AddBackward0(MulBackward0(ExpandBackward0(AccumulateGrad()), None), MulBackward0(ExpandBackward0(AccumulateGrad()), AccumulateGrad())))

test/test_autograd.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3450,7 +3450,7 @@ def test_inplace_on_view_backward(self):
34503450
gradient_penalty.backward()
34513451

34523452
fn = gradient_penalty.grad_fn.next_functions[0][0].next_functions[1][0]
3453-
self.assertEqual(fn.name(), "ThresholdBackwardBackward")
3453+
self.assertEqual(fn.name(), "ThresholdBackwardBackward0")
34543454

34553455
def test_inplace_on_view_weak_grad_fn(self):
34563456
# Issue 23502: Test that b's grad_fn is preserved.
@@ -4859,7 +4859,7 @@ def maybe_check_raise(fn, should_raise):
48594859
# The 3 elements are for view_as, first output of unbind and second output of unbind
48604860
run_test(grad_mode=True, requires_grad=False, is_view=True,
48614861
should_raise_tuple=(None, None, None))
4862-
inp_change_err = "Output {} of UnbindBackward is a view and is being modified inplace."
4862+
inp_change_err = "Output {} of UnbindBackward0 is a view and is being modified inplace."
48634863
run_test(grad_mode=True, requires_grad=True, is_view=True,
48644864
should_raise_tuple=(None, inp_change_err.format("0"), inp_change_err.format("1")))
48654865
leaf_grad_err = "A view was created in no_grad mode and is being modified inplace"

test/test_cuda.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3049,7 +3049,7 @@ def test_autocast_rnn(self):
30493049
# Autocast wrapper requires at::_cudnn_rnn is autograd-exposed. This check can't guarantee
30503050
# at::_cudnn_rnn is autograd-exposed, but if it fires, it indicates some funny business has
30513051
# occurred and we should double check that at::_cudnn_rnn remains autograd-exposed.
3052-
self.assertEqual(out.grad_fn.name(), "CudnnRnnBackward")
3052+
self.assertEqual(out.grad_fn.name(), "CudnnRnnBackward0")
30533053
out.sum().backward()
30543054
grads = [p.grad.clone() for p in rnn.parameters()]
30553055

tools/autograd/load_derivatives.py

Lines changed: 30 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22
#
33
# Each autograd function is represented by `DifferentiabilityInfo` containing
44
# a list of `Derivative`. See `tools.codegen.api.autograd` for the data models.
5-
from collections import defaultdict, Counter
5+
from collections import defaultdict
66
import re
7-
from typing import Sequence, Any, Tuple, List, Set, Dict, Match, Optional
7+
from typing import Counter, Sequence, Any, Tuple, List, Set, Dict, Match, Optional
88
import yaml
99

1010
from tools.codegen.api.autograd import (Derivative, DifferentiabilityInfo,
@@ -43,32 +43,15 @@ def load_derivatives(derivatives_yaml_path: str, native_yaml_path: str) -> Seque
4343
assert str(function.func) not in functions_by_schema
4444
functions_by_schema[str(function.func)] = function
4545

46+
# Keep track of how many of which ops we've seen so we can
47+
# disambiguate them with a numeric suffix.
48+
op_counter = Counter[str]()
49+
4650
infos = [
47-
create_differentiability_info(defn, functions_by_signature, functions_by_schema)
51+
create_differentiability_info(defn, functions_by_signature, functions_by_schema, op_counter)
4852
for defn in definitions]
4953

50-
# To keep it byte-for-byte compatible with the old codegen, we assign op names as a separate
51-
# step. We only assign op names to those with differentiable args, and only append suffix to
52-
# duplicated op names. This can be simplified if the first of the duplicates can be named
53-
# 'XyzBackward' instead of 'XyzBackward0' or unconditionally append '0' to singletons.
54-
op_names = create_op_names(infos)
55-
res = [
56-
DifferentiabilityInfo(
57-
name=info.name,
58-
func=info.func,
59-
op=op_name,
60-
derivatives=info.derivatives,
61-
forward_derivatives=info.forward_derivatives,
62-
all_saved_inputs=info.all_saved_inputs,
63-
all_saved_outputs=info.all_saved_outputs,
64-
args_with_derivatives=info.args_with_derivatives,
65-
non_differentiable_arg_names=info.non_differentiable_arg_names,
66-
output_differentiability=info.output_differentiability,
67-
output_differentiability_conditions=info.output_differentiability_conditions,
68-
)
69-
for info, op_name in zip(infos, op_names)]
70-
71-
_GLOBAL_LOAD_DERIVATIVE_CACHE[key] = res
54+
_GLOBAL_LOAD_DERIVATIVE_CACHE[key] = infos
7255

7356
return _GLOBAL_LOAD_DERIVATIVE_CACHE[key]
7457

@@ -279,6 +262,7 @@ def create_differentiability_info(
279262
defn: Dict[Any, Any],
280263
functions_by_signature: Dict[FunctionSchema, List[NativeFunction]],
281264
functions_by_schema: Dict[str, NativeFunction],
265+
op_counter: Counter[str],
282266
) -> DifferentiabilityInfo:
283267
"""Processes a single entry `defn` in derivatives.yaml"""
284268

@@ -424,10 +408,17 @@ def set_up_derivatives(f: NativeFunction) -> Tuple[
424408

425409
derivatives, forward_derivatives, args_with_derivatives, non_differentiable_arg_names = set_up_derivatives(canonical)
426410

411+
# only assign an op name if we are actually going to calculate a derivative
412+
op = None
413+
if args_with_derivatives:
414+
op_prefix = _create_op_prefix(defn_name)
415+
op = f'{op_prefix}{op_counter[op_prefix]}'
416+
op_counter[op_prefix] += 1
417+
427418
return DifferentiabilityInfo(
428419
name=defn_name,
429420
func=canonical,
430-
op=None,
421+
op=op,
431422
derivatives=derivatives,
432423
forward_derivatives=forward_derivatives,
433424
all_saved_inputs=dedup_vars([v for d in derivatives for v in d.saved_inputs]),
@@ -566,35 +557,22 @@ def repl(m: Match[str]) -> str:
566557

567558
return formula, tuple(saved)
568559

569-
def create_op_name(info: DifferentiabilityInfo) -> Optional[str]:
570-
# only assign an op name if we are actually going to calculate a derivative
571-
if not info.args_with_derivatives:
572-
return None
573-
name = info.name
560+
def _create_op_prefix(name: str) -> str:
561+
"""Takes a native function name converts to a op prefix name.
562+
563+
Note that the "name" parameter must be the native function name
564+
without the optional variant suffix, so "add" instead of
565+
"add.out".
566+
567+
OP names correspond to classes, hence the change to title case.
568+
569+
Example::
570+
>>> _create_op_prefix('add')
571+
'AddBackward'
572+
"""
574573
camel_case = ''.join([p.title() for p in name.split('_')])
575574
return (camel_case + 'Backward').replace('ForwardBackward', 'Backward')
576575

577-
def create_op_names(infos: Sequence[DifferentiabilityInfo]) -> Sequence[Optional[str]]:
578-
names = list(map(create_op_name, infos))
579-
dups = set(item for item, count in Counter(names).items() if count > 1)
580-
581-
# de-duplicate operation names
582-
# you end up with something like:
583-
# AddBackward0
584-
# AddBackward1
585-
# one for each overload
586-
counter: Dict[str, int] = Counter()
587-
dedup: List[Optional[str]] = []
588-
for name in names:
589-
if name is None:
590-
# Keep a placeholder
591-
dedup.append(None)
592-
elif name in dups:
593-
dedup.append(f'{name}{counter[name]}')
594-
counter[name] += 1
595-
else:
596-
dedup.append(name)
597-
return dedup
598576

599577
def dedup_vars(vars: Sequence[SavedAttribute]) -> Sequence[SavedAttribute]:
600578
seen: Set[str] = set()

torch/csrc/autograd/variable.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -551,10 +551,10 @@ const std::shared_ptr<torch::autograd::Node>& VariableHooks::grad_fn(const Tenso
551551
// self = view_op_n(view_n-1)
552552
// self = inplace_op(self)
553553
//
554-
// For CPU/CUDA backends, we employ one AsStridedBackward Node to represent the chain of
554+
// For CPU/CUDA backends, we employ one AsStridedBackward0 Node to represent the chain of
555555
// view backward ops for effienciency.
556556
//
557-
// However in XLA backend we don't have full support of AsStridedBackward, we instead run a full
557+
// However in XLA backend we don't have full support of AsStridedBackward0, we instead run a full
558558
// forward pass with a tensor that requires gradient to get proper grad_fn setup,
559559
// then save it to DifferentiableViewMeta for future use.
560560
// This is fairly cheap for XLA lazy tensor approach (but would be really expensive for CPU/CUDA).
@@ -572,7 +572,7 @@ const std::shared_ptr<torch::autograd::Node>& VariableHooks::grad_fn(const Tenso
572572
auto diff_view = view_fn(view_info.base_);
573573
diff_view_meta->grad_fn_ = diff_view.grad_fn();
574574
} else {
575-
auto fn = std::make_shared<torch::autograd::generated::AsStridedBackward>();
575+
auto fn = std::make_shared<torch::autograd::generated::AsStridedBackward0>();
576576
fn->self_geometry = at::TensorGeometry(view_info.base_);
577577
fn->size = self.sizes().vec();
578578
fn->stride = self.strides().vec();

0 commit comments

Comments
 (0)