Skip to content

Commit 44132cc

Browse files
Revert "Add --timing flag, phase timing to @dynamo_timed (pytorch#92637)"
This reverts commit 773b513. Reverted pytorch#92637 on behalf of https://github.com/malfet due to Broke lint
1 parent 5ac2278 commit 44132cc

File tree

8 files changed

+19
-101
lines changed

8 files changed

+19
-101
lines changed

benchmarks/dynamo/common.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1425,10 +1425,6 @@ def run_one_model(
14251425
name, model, example_inputs, optimize_ctx, experiment
14261426
)
14271427
print(status)
1428-
if self.args.timing:
1429-
from torch._dynamo.utils import print_time_report
1430-
print_time_report()
1431-
14321428
end_calls_captured = torch._dynamo.utils.counters["stats"]["calls_captured"]
14331429
end_unique_graphs = torch._dynamo.utils.counters["stats"]["unique_graphs"]
14341430
if explain:
@@ -1687,9 +1683,6 @@ def get_example_inputs(self):
16871683
want to verify the numerical correctness of graidents. But that may
16881684
cause time measurement not accurate""",
16891685
)
1690-
parser.add_argument(
1691-
"--timing", action="store_true", help="Emits phase timing"
1692-
)
16931686

16941687
group_fuser = parser.add_mutually_exclusive_group()
16951688
# --nvfuser is now the default, keep the option to not break scripts

benchmarks/dynamo/run_all.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ if getent hosts fwdproxy; then
2626
fi
2727

2828
# Feel free to edit these, but we expect most users not to need to modify this
29-
BASE_FLAGS=( --accuracy --explain --timing)
29+
BASE_FLAGS=( --accuracy --explain )
3030
DATE="$(date)"
3131
WORK="$PWD"
3232

torch/_dynamo/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
skip,
1414
)
1515
from .external_utils import is_compiling
16-
from .utils import compilation_metrics, guard_failures, orig_code_map, reset_frame_count
16+
from .utils import compilation_metrics, guard_failures, orig_code_map
1717

1818
__all__ = [
1919
"allow_in_graph",
@@ -48,7 +48,6 @@ def reset():
4848
resume_execution.ContinueExecutionCache.cache.clear()
4949
eval_frame.most_recent_backend = None
5050
compilation_metrics.clear()
51-
reset_frame_count()
5251

5352

5453
def list_backends():

torch/_dynamo/convert_frame.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@
3535
format_bytecode,
3636
gen_record_file_name,
3737
guard_failures,
38-
increment_frame,
3938
init_logging,
4039
is_namedtuple,
4140
istype,
@@ -195,8 +194,8 @@ def convert_frame_assert(
195194
"""Fully convert a frame into an FX graph"""
196195
init_logging()
197196

197+
@dynamo_timed
198198
def _convert_frame_assert(frame: types.FrameType, cache_size: int, hooks: Hooks):
199-
increment_frame()
200199
code = frame.f_code
201200
input_codes.add(code)
202201
if code in output_codes:
@@ -274,7 +273,6 @@ def format_guard_failures(code):
274273
return wrap_convert_context(_convert_frame_assert)
275274

276275

277-
@dynamo_timed(phase_name="entire_frame_compile")
278276
def _compile(
279277
code: types.CodeType,
280278
globals: Dict[str, object],

torch/_dynamo/eval_frame.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
from .mutation_guard import install_generation_tagging_init
4242
from .output_graph import CompilerFn
4343
from .types import DynamoCallback
44-
from .utils import compile_times, dynamo_timed
44+
from .utils import compile_times
4545

4646
log = logging.getLogger(__name__)
4747

torch/_dynamo/output_graph.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,6 @@
4848
clone_inputs,
4949
count_calls,
5050
counters,
51-
dynamo_timed,
5251
format_graph_tabular,
5352
same,
5453
)
@@ -625,7 +624,6 @@ def compile_and_call_fx_graph(self, tx, rv, root):
625624
cg.make_call_generated_code(name)
626625
return cg.get_instructions()
627626

628-
@dynamo_timed(phase_name="backend_compile")
629627
def call_user_compiler(self, gm: fx.GraphModule) -> CompiledFn:
630628
try:
631629
name = (

torch/_dynamo/symbolic_convert.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050
GlobalWeakRefSource,
5151
LocalSource,
5252
)
53-
from .utils import counters, graph_break_dup_warning_checker, istype, proxy_args_kwargs, dynamo_timed
53+
from .utils import counters, graph_break_dup_warning_checker, istype, proxy_args_kwargs
5454
from .variables.base import MutableLocal, typestr, VariableTracker
5555
from .variables.builder import VariableBuilder, wrap_fx_proxy
5656
from .variables.builtin import BuiltinVariable

torch/_dynamo/utils.py

Lines changed: 14 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -86,90 +86,20 @@ def profile_wrapper(*args, **kwargs):
8686
return profile_wrapper
8787

8888

89-
frame_phase_timing = collections.OrderedDict()
90-
91-
curr_frame = 0
92-
93-
# Note: Called for you by dynamo - you almost never ever want to invoke this yourself.
94-
def increment_frame():
95-
global curr_frame
96-
curr_frame = curr_frame + 1
97-
98-
99-
# Note: Called for you by dynamo - you almost never ever want to invoke this yourself.
100-
def reset_frame_count():
101-
global curr_frame
102-
frame_phase_timing.clear()
103-
curr_frame = 0
104-
105-
106-
# Print a report of time spent so far
107-
# Ex:
108-
# TIMING:
109-
# entire_frame_compile:8.574629999999999
110-
# backend_compile:5.26806
111-
def print_time_report():
112-
total = 0
113-
total_by_key = {}
114-
print("TIMING:")
115-
for frame, timings in frame_phase_timing.items():
116-
for key, timing in timings.items():
117-
total += timing
118-
if key not in total_by_key:
119-
total_by_key[key] = timing
120-
else:
121-
total_by_key[key] += timing
122-
123-
out = ""
124-
for key, value in total_by_key.items():
125-
out = f"{out} {key}:{round(value, 5)}"
126-
127-
print(out)
128-
129-
130-
# dynamo_timed API works as a function decorator
131-
# By wrapping a function in dynamo_timed, we can store a record in compilation_metrics
132-
# where the key is the functions name.
133-
# For example:
134-
#
135-
# @dynamo_timed
136-
# def _foo(...):
137-
#
138-
# Would show up as an entry in our timing dict:
139-
# OrderedDict([('bar.<locals>._foo', [0.083690, 0.23949, 3.1425e-05])])
140-
# This is extremely useful for granular debugging.
141-
#
142-
# For a higher-level mode, pass a phase_name into dynamo_timed
143-
# phase_names record an extra record into a separate compilation timing structure,
144-
# one keyed on frame+name rather than function.
145-
# The frame is incremented outside of this function, in def increment_frame() above.
146-
def dynamo_timed(original_function=None, phase_name=None):
147-
def dynamo_timed_inner(func):
148-
@wraps(func)
149-
def time_wrapper(*args, **kwargs):
150-
key = func.__qualname__
151-
if key not in compilation_metrics:
152-
compilation_metrics[key] = []
153-
t0 = time.time()
154-
r = func(*args, **kwargs)
155-
time_spent = time.time() - t0
156-
# print(f"Dynamo timer: key={key}, latency={latency:.2f} sec")
157-
compilation_metrics[key].append(time_spent)
158-
if phase_name:
159-
frame_key = str(curr_frame)
160-
if frame_key not in frame_phase_timing:
161-
frame_phase_timing[frame_key] = {}
162-
assert (
163-
phase_name not in frame_phase_timing[frame_key]
164-
), f"Duplicate phase name {phase_name} for frame {frame_key}"
165-
frame_phase_timing[frame_key][phase_name] = time_spent
166-
return r
167-
168-
return time_wrapper
169-
170-
if original_function:
171-
return dynamo_timed_inner(original_function)
172-
return dynamo_timed_inner
89+
def dynamo_timed(func):
90+
@wraps(func)
91+
def time_wrapper(*args, **kwargs):
92+
key = func.__qualname__
93+
if key not in compilation_metrics:
94+
compilation_metrics[key] = []
95+
t0 = time.time()
96+
r = func(*args, **kwargs)
97+
latency = time.time() - t0
98+
# print(f"Dynamo timer: key={key}, latency={latency:.2f} sec")
99+
compilation_metrics[key].append(latency)
100+
return r
101+
102+
return time_wrapper
173103

174104

175105
def compile_times(repr="str", aggregate=False):

0 commit comments

Comments
 (0)