|
2 | 2 | #
|
3 | 3 | # Each autograd function is represented by `DifferentiabilityInfo` containing
|
4 | 4 | # a list of `Derivative`. See `tools.codegen.api.autograd` for the data models.
|
5 |
| -from collections import defaultdict, Counter |
| 5 | +from collections import defaultdict |
6 | 6 | import re
|
7 |
| -from typing import Sequence, Any, Tuple, List, Set, Dict, Match, Optional |
| 7 | +from typing import Counter, Sequence, Any, Tuple, List, Set, Dict, Match, Optional |
8 | 8 | import yaml
|
9 | 9 |
|
10 | 10 | from tools.codegen.api.autograd import (Derivative, DifferentiabilityInfo,
|
@@ -43,32 +43,15 @@ def load_derivatives(derivatives_yaml_path: str, native_yaml_path: str) -> Seque
|
43 | 43 | assert str(function.func) not in functions_by_schema
|
44 | 44 | functions_by_schema[str(function.func)] = function
|
45 | 45 |
|
| 46 | + # Keep track of how many of which ops we've seen so we can |
| 47 | + # disambiguate them with a numeric suffix. |
| 48 | + op_counter = Counter[str]() |
| 49 | + |
46 | 50 | infos = [
|
47 |
| - create_differentiability_info(defn, functions_by_signature, functions_by_schema) |
| 51 | + create_differentiability_info(defn, functions_by_signature, functions_by_schema, op_counter) |
48 | 52 | for defn in definitions]
|
49 | 53 |
|
50 |
| - # To keep it byte-for-byte compatible with the old codegen, we assign op names as a separate |
51 |
| - # step. We only assign op names to those with differentiable args, and only append suffix to |
52 |
| - # duplicated op names. This can be simplified if the first of the duplicates can be named |
53 |
| - # 'XyzBackward' instead of 'XyzBackward0' or unconditionally append '0' to singletons. |
54 |
| - op_names = create_op_names(infos) |
55 |
| - res = [ |
56 |
| - DifferentiabilityInfo( |
57 |
| - name=info.name, |
58 |
| - func=info.func, |
59 |
| - op=op_name, |
60 |
| - derivatives=info.derivatives, |
61 |
| - forward_derivatives=info.forward_derivatives, |
62 |
| - all_saved_inputs=info.all_saved_inputs, |
63 |
| - all_saved_outputs=info.all_saved_outputs, |
64 |
| - args_with_derivatives=info.args_with_derivatives, |
65 |
| - non_differentiable_arg_names=info.non_differentiable_arg_names, |
66 |
| - output_differentiability=info.output_differentiability, |
67 |
| - output_differentiability_conditions=info.output_differentiability_conditions, |
68 |
| - ) |
69 |
| - for info, op_name in zip(infos, op_names)] |
70 |
| - |
71 |
| - _GLOBAL_LOAD_DERIVATIVE_CACHE[key] = res |
| 54 | + _GLOBAL_LOAD_DERIVATIVE_CACHE[key] = infos |
72 | 55 |
|
73 | 56 | return _GLOBAL_LOAD_DERIVATIVE_CACHE[key]
|
74 | 57 |
|
@@ -279,6 +262,7 @@ def create_differentiability_info(
|
279 | 262 | defn: Dict[Any, Any],
|
280 | 263 | functions_by_signature: Dict[FunctionSchema, List[NativeFunction]],
|
281 | 264 | functions_by_schema: Dict[str, NativeFunction],
|
| 265 | + op_counter: Counter[str], |
282 | 266 | ) -> DifferentiabilityInfo:
|
283 | 267 | """Processes a single entry `defn` in derivatives.yaml"""
|
284 | 268 |
|
@@ -424,10 +408,17 @@ def set_up_derivatives(f: NativeFunction) -> Tuple[
|
424 | 408 |
|
425 | 409 | derivatives, forward_derivatives, args_with_derivatives, non_differentiable_arg_names = set_up_derivatives(canonical)
|
426 | 410 |
|
| 411 | + # only assign an op name if we are actually going to calculate a derivative |
| 412 | + op = None |
| 413 | + if args_with_derivatives: |
| 414 | + op_prefix = _create_op_prefix(defn_name) |
| 415 | + op = f'{op_prefix}{op_counter[op_prefix]}' |
| 416 | + op_counter[op_prefix] += 1 |
| 417 | + |
427 | 418 | return DifferentiabilityInfo(
|
428 | 419 | name=defn_name,
|
429 | 420 | func=canonical,
|
430 |
| - op=None, |
| 421 | + op=op, |
431 | 422 | derivatives=derivatives,
|
432 | 423 | forward_derivatives=forward_derivatives,
|
433 | 424 | all_saved_inputs=dedup_vars([v for d in derivatives for v in d.saved_inputs]),
|
@@ -566,35 +557,22 @@ def repl(m: Match[str]) -> str:
|
566 | 557 |
|
567 | 558 | return formula, tuple(saved)
|
568 | 559 |
|
569 |
| -def create_op_name(info: DifferentiabilityInfo) -> Optional[str]: |
570 |
| - # only assign an op name if we are actually going to calculate a derivative |
571 |
| - if not info.args_with_derivatives: |
572 |
| - return None |
573 |
| - name = info.name |
| 560 | +def _create_op_prefix(name: str) -> str: |
| 561 | + """Takes a native function name converts to a op prefix name. |
| 562 | +
|
| 563 | + Note that the "name" parameter must be the native function name |
| 564 | + without the optional variant suffix, so "add" instead of |
| 565 | + "add.out". |
| 566 | +
|
| 567 | + OP names correspond to classes, hence the change to title case. |
| 568 | +
|
| 569 | + Example:: |
| 570 | + >>> _create_op_prefix('add') |
| 571 | + 'AddBackward' |
| 572 | + """ |
574 | 573 | camel_case = ''.join([p.title() for p in name.split('_')])
|
575 | 574 | return (camel_case + 'Backward').replace('ForwardBackward', 'Backward')
|
576 | 575 |
|
577 |
| -def create_op_names(infos: Sequence[DifferentiabilityInfo]) -> Sequence[Optional[str]]: |
578 |
| - names = list(map(create_op_name, infos)) |
579 |
| - dups = set(item for item, count in Counter(names).items() if count > 1) |
580 |
| - |
581 |
| - # de-duplicate operation names |
582 |
| - # you end up with something like: |
583 |
| - # AddBackward0 |
584 |
| - # AddBackward1 |
585 |
| - # one for each overload |
586 |
| - counter: Dict[str, int] = Counter() |
587 |
| - dedup: List[Optional[str]] = [] |
588 |
| - for name in names: |
589 |
| - if name is None: |
590 |
| - # Keep a placeholder |
591 |
| - dedup.append(None) |
592 |
| - elif name in dups: |
593 |
| - dedup.append(f'{name}{counter[name]}') |
594 |
| - counter[name] += 1 |
595 |
| - else: |
596 |
| - dedup.append(name) |
597 |
| - return dedup |
598 | 576 |
|
599 | 577 | def dedup_vars(vars: Sequence[SavedAttribute]) -> Sequence[SavedAttribute]:
|
600 | 578 | seen: Set[str] = set()
|
|
0 commit comments