Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 82f449a

Browse files
pianpwksvekars
andauthoredJan 9, 2025··
[export] update dynamic shapes section (#3214)
* Update torch_export_tutorial.py --------- Co-authored-by: Svetlana Karslioglu <[email protected]>
1 parent 33a52a5 commit 82f449a

File tree

1 file changed

+283
-190
lines changed

1 file changed

+283
-190
lines changed
 

‎intermediate_source/torch_export_tutorial.py

Lines changed: 283 additions & 190 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,15 @@
33
"""
44
torch.export Tutorial
55
===================================================
6-
**Author:** William Wen, Zhengxu Chen, Angela Yi
6+
**Author:** William Wen, Zhengxu Chen, Angela Yi, Pian Pawakapan
77
"""
88

99
######################################################################
1010
#
1111
# .. warning::
1212
#
1313
# ``torch.export`` and its related features are in prototype status and are subject to backwards compatibility
14-
# breaking changes. This tutorial provides a snapshot of ``torch.export`` usage as of PyTorch 2.3.
14+
# breaking changes. This tutorial provides a snapshot of ``torch.export`` usage as of PyTorch 2.5.
1515
#
1616
# :func:`torch.export` is the PyTorch 2.X way to export PyTorch models into
1717
# standardized model representations, intended
@@ -304,237 +304,330 @@ def false_fn(x):
304304
# Constraints/Dynamic Shapes
305305
# --------------------------
306306
#
307-
# Ops can have different specializations/behaviors for different tensor shapes, so by default,
308-
# ``torch.export`` requires inputs to ``ExportedProgram`` to have the same shape as the respective
309-
# example inputs given to the initial ``torch.export.export()`` call.
310-
# If we try to run the ``ExportedProgram`` in the example below with a tensor
311-
# with a different shape, we get an error:
307+
# This section covers dynamic behavior and representation of exported programs. Dynamic behavior is
308+
# subjective to the particular model being exported, so for the most part of this tutorial, we'll focus
309+
# on this particular toy model (with the resulting tensor shapes annotated):
312310

313-
class MyModule2(torch.nn.Module):
311+
class DynamicModel(torch.nn.Module):
314312
def __init__(self):
315313
super().__init__()
316-
self.lin = torch.nn.Linear(100, 10)
314+
self.l = torch.nn.Linear(5, 3)
315+
316+
def forward(
317+
self,
318+
w: torch.Tensor, # [6, 5]
319+
x: torch.Tensor, # [4]
320+
y: torch.Tensor, # [8, 4]
321+
z: torch.Tensor, # [32]
322+
):
323+
x0 = x + y # [8, 4]
324+
x1 = self.l(w) # [6, 3]
325+
x2 = x0.flatten() # [32]
326+
x3 = x2 + z # [32]
327+
return x1, x3
328+
329+
######################################################################
330+
# By default, ``torch.export`` produces a static program. One consequence of this is that at runtime,
331+
# the program won't work on inputs with different shapes, even if they're valid in eager mode.
332+
333+
w = torch.randn(6, 5)
334+
x = torch.randn(4)
335+
y = torch.randn(8, 4)
336+
z = torch.randn(32)
337+
model = DynamicModel()
338+
ep = export(model, (w, x, y, z))
339+
model(w, x, torch.randn(3, 4), torch.randn(12))
340+
ep.module()(w, x, torch.randn(3, 4), torch.randn(12))
341+
342+
######################################################################
343+
# Basic concepts: symbols and guards
344+
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
345+
#
346+
# To enable dynamism, ``export()`` provides a ``dynamic_shapes`` argument. The easiest way to work with
347+
# dynamic shapes is using ``Dim.AUTO`` and looking at the program that's returned. Dynamic behavior is specified
348+
# at a input dimension-level; for each input we can specify a tuple of values:
349+
350+
from torch.export.dynamic_shapes import Dim
351+
352+
dynamic_shapes = {
353+
"w": (Dim.AUTO, Dim.AUTO),
354+
"x": (Dim.AUTO,),
355+
"y": (Dim.AUTO, Dim.AUTO),
356+
"z": (Dim.AUTO,),
357+
}
358+
ep = export(model, (w, x, y, z), dynamic_shapes=dynamic_shapes)
317359

318-
def forward(self, x, y):
319-
return torch.nn.functional.relu(self.lin(x + y), inplace=True)
360+
######################################################################
361+
# Before we look at the program that's produced, let's understand what specifying ``dynamic_shapes`` entails,
362+
# and how that interacts with export. For every input dimension where a ``Dim`` object is specified, a symbol is
363+
# `allocated <https://pytorch.org/docs/main/export.programming_model.html#basics-of-symbolic-shapes>`_,
364+
# taking on a range of ``[2, inf]`` (why not ``[0, inf]`` or ``[1, inf]``? we'll explain later in the
365+
# 0/1 specialization section).
366+
#
367+
# Export then runs model tracing, looking at each operation that's performed by the model. Each individual operation can emit
368+
# what's called "guards"; basically boolean condition that are required to be true for the program to be valid.
369+
# When guards involve symbols allocated for input dimensions, the program contains restrictions on what input shapes are valid;
370+
# i.e. the program's dynamic behavior. The symbolic shapes subsystem is the part responsible for taking in all the emitted guards
371+
# and producing a final program representation that adheres to all of these guards. Before we see this "final representation" in
372+
# an ``ExportedProgram``, let's look at the guards emitted by the toy model we're tracing.
373+
#
374+
# Here, each forward input tensor is annotated with the symbol allocated at the start of tracing:
320375

321-
mod2 = MyModule2()
322-
exported_mod2 = export(mod2, (torch.randn(8, 100), torch.randn(8, 100)))
376+
class DynamicModel(torch.nn.Module):
377+
def __init__(self):
378+
super().__init__()
379+
self.l = torch.nn.Linear(5, 3)
323380

324-
try:
325-
exported_mod2.module()(torch.randn(10, 100), torch.randn(10, 100))
326-
except Exception:
327-
tb.print_exc()
381+
def forward(
382+
self,
383+
w: torch.Tensor, # [s0, s1]
384+
x: torch.Tensor, # [s2]
385+
y: torch.Tensor, # [s3, s4]
386+
z: torch.Tensor, # [s5]
387+
):
388+
x0 = x + y # guard: s2 == s4
389+
x1 = self.l(w) # guard: s1 == 5
390+
x2 = x0.flatten() # no guard added here
391+
x3 = x2 + z # guard: s3 * s4 == s5
392+
return x1, x3
328393

329394
######################################################################
330-
# We can relax this constraint using the ``dynamic_shapes`` argument of
331-
# ``torch.export.export()``, which allows us to specify, using ``torch.export.Dim``
332-
# (`documentation <https://pytorch.org/docs/main/export.html#torch.export.Dim>`__),
333-
# which dimensions of the input tensors are dynamic.
395+
# Let's understand each of the operations and the emitted guards:
334396
#
335-
# For each tensor argument of the input callable, we can specify a mapping from the dimension
336-
# to a ``torch.export.Dim``.
337-
# A ``torch.export.Dim`` is essentially a named symbolic integer with optional
338-
# minimum and maximum bounds.
397+
# - ``x0 = x + y``: This is an element-wise add with broadcasting, since ``x`` is a 1-d tensor and ``y`` a 2-d tensor. ``x`` is broadcasted along the last dimension of ``y``, emitting the guard ``s2 == s4``.
398+
# - ``x1 = self.l(w)``: Calling ``nn.Linear()`` performs a matrix multiplication with model parameters. In export, parameters, buffers, and constants are considered program state, which is considered static, and so this is a matmul between a dynamic input (``w: [s0, s1]``), and a statically-shaped tensor. This emits the guard ``s1 == 5``.
399+
# - ``x2 = x0.flatten()``: This call actually doesn't emit any guards! (at least none relevant to input shapes)
400+
# - ``x3 = x2 + z``: ``x2`` has shape ``[s3*s4]`` after flattening, and this element-wise add emits ``s3 * s4 == s5``.
339401
#
340-
# Then, the format of ``torch.export.export()``'s ``dynamic_shapes`` argument is a mapping
341-
# from the input callable's tensor argument names, to dimension --> dim mappings as described above.
342-
# If there is no ``torch.export.Dim`` given to a tensor argument's dimension, then that dimension is
343-
# assumed to be static.
402+
# Writing all of these guards down and summarizing is almost like a mathematical proof, which is what the symbolic shapes
403+
# subsystem tries to do! In summary, we can conclude that the program must have the following input shapes to be valid:
344404
#
345-
# The first argument of ``torch.export.Dim`` is the name for the symbolic integer, used for debugging.
346-
# Then we can specify an optional minimum and maximum bound (inclusive). Below, we show a usage example.
405+
# - ``w: [s0, 5]``
406+
# - ``x: [s2]``
407+
# - ``y: [s3, s2]``
408+
# - ``z: [s2*s3]``
347409
#
348-
# In the example below, our input
349-
# ``inp1`` has an unconstrained first dimension, but the size of the second
350-
# dimension must be in the interval [4, 18].
351-
352-
from torch.export import Dim
353-
354-
inp1 = torch.randn(10, 10, 2)
410+
# And when we do finally print out the exported program to see our result, those shapes are what we see annotated on the
411+
# corresponding inputs:
355412

356-
class DynamicShapesExample1(torch.nn.Module):
357-
def forward(self, x):
358-
x = x[:, 2:]
359-
return torch.relu(x)
413+
print(ep)
360414

361-
inp1_dim0 = Dim("inp1_dim0")
362-
inp1_dim1 = Dim("inp1_dim1", min=4, max=18)
363-
dynamic_shapes1 = {
364-
"x": {0: inp1_dim0, 1: inp1_dim1},
365-
}
415+
######################################################################
416+
# Another feature to notice is the range_constraints field above, which contains a valid range for each symbol. This isn't
417+
# so interesting currently, since this export call doesn't emit any guards related to symbol bounds and each base symbol has
418+
# a generic bound, but this will come up later.
419+
#
420+
# So far, because we've been exporting this toy model, this experience has not been representative of how hard
421+
# it typically is to debug dynamic shapes guards & issues. In most cases it isn't obvious what guards are being emitted,
422+
# and which operations and parts of user code are responsible. For this toy model we pinpoint the exact lines, and the guards
423+
# are rather intuitive.
424+
#
425+
# In more complicated cases, a helpful first step is always to enable verbose logging. This can be done either with the environment
426+
# variable ``TORCH_LOGS="+dynamic"``, or interactively with ``torch._logging.set_logs(dynamic=10)``:
366427

367-
exported_dynamic_shapes_example1 = export(DynamicShapesExample1(), (inp1,), dynamic_shapes=dynamic_shapes1)
428+
torch._logging.set_logs(dynamic=10)
429+
ep = export(model, (w, x, y, z), dynamic_shapes=dynamic_shapes)
368430

369-
print(exported_dynamic_shapes_example1.module()(torch.randn(5, 5, 2)))
431+
######################################################################
432+
# This spits out quite a handful, even with this simple toy model. The log lines here have been cut short at front and end
433+
# to ignore unnecessary info, but looking through the logs we can see the lines relevant to what we described above;
434+
# e.g. the allocation of symbols:
370435

371-
try:
372-
exported_dynamic_shapes_example1.module()(torch.randn(8, 1, 2))
373-
except Exception:
374-
tb.print_exc()
436+
"""
437+
create_symbol s0 = 6 for L['w'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2841 in <lambda>)
438+
create_symbol s1 = 5 for L['w'].size()[1] [2, int_oo] (_dynamo/variables/builder.py:2841 in <lambda>)
439+
runtime_assert True == True [statically known]
440+
create_symbol s2 = 4 for L['x'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2841 in <lambda>)
441+
create_symbol s3 = 8 for L['y'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2841 in <lambda>)
442+
create_symbol s4 = 4 for L['y'].size()[1] [2, int_oo] (_dynamo/variables/builder.py:2841 in <lambda>)
443+
create_symbol s5 = 32 for L['z'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2841 in <lambda>)
444+
"""
375445

376-
try:
377-
exported_dynamic_shapes_example1.module()(torch.randn(8, 20, 2))
378-
except Exception:
379-
tb.print_exc()
446+
######################################################################
447+
# The lines with `create_symbol` show when a new symbol has been allocated, and the logs also identify the tensor variable names
448+
# and dimensions they've been allocated for. In other lines we can also see the guards emitted:
380449

381-
try:
382-
exported_dynamic_shapes_example1.module()(torch.randn(8, 8, 3))
383-
except Exception:
384-
tb.print_exc()
450+
"""
451+
runtime_assert Eq(s2, s4) [guard added] x0 = x + y # output shape: [8, 4] # dynamic_shapes_tutorial.py:16 in forward (_subclasses/fake_impls.py:845 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s2, s4)"
452+
runtime_assert Eq(s1, 5) [guard added] x1 = self.l(w) # [6, 3] # dynamic_shapes_tutorial.py:17 in forward (_meta_registrations.py:2127 in meta_mm), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s1, 5)"
453+
runtime_assert Eq(s2*s3, s5) [guard added] x3 = x2 + z # [32] # dynamic_shapes_tutorial.py:19 in forward (_subclasses/fake_impls.py:845 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s2*s3, s5)"
454+
"""
385455

386456
######################################################################
387-
# Note that if our example inputs to ``torch.export`` do not satisfy the constraints
388-
# given by ``dynamic_shapes``, then we get an error.
457+
# Next to the ``[guard added]`` messages, we also see the responsible user lines of code - luckily here the model is simple enough.
458+
# In many real-world cases it's not so straightforward: high-level torch operations can have complicated fake-kernel implementations
459+
# or operator decompositions that complicate where and what guards are emitted. In such cases the best way to dig deeper and investigate
460+
# is to follow the logs' suggestion, and re-run with environment variable ``TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="..."``, to further
461+
# attribute the guard of interest.
462+
#
463+
# ``Dim.AUTO`` is just one of the available options for interacting with ``dynamic_shapes``; as of writing this 2 other options are available:
464+
# ``Dim.DYNAMIC``, and ``Dim.STATIC``. ``Dim.STATIC`` simply marks a dimension static, while ``Dim.DYNAMIC`` is similar to ``Dim.AUTO`` in all
465+
# ways except one: it raises an error when specializing to a constant; this is designed to maintain dynamism. See for example what happens when a
466+
# static guard is emitted on a dynamically-marked dimension:
389467

390-
inp1_dim1_bad = Dim("inp1_dim1_bad", min=11, max=18)
391-
dynamic_shapes1_bad = {
392-
"x": {0: inp1_dim0, 1: inp1_dim1_bad},
393-
}
394-
395-
try:
396-
export(DynamicShapesExample1(), (inp1,), dynamic_shapes=dynamic_shapes1_bad)
397-
except Exception:
398-
tb.print_exc()
468+
dynamic_shapes["w"] = (Dim.AUTO, Dim.DYNAMIC)
469+
export(model, (w, x, y, z), dynamic_shapes=dynamic_shapes)
399470

400471
######################################################################
401-
# We can enforce that equalities between dimensions of different tensors
402-
# by using the same ``torch.export.Dim`` object, for example, in matrix multiplication:
403-
404-
inp2 = torch.randn(4, 8)
405-
inp3 = torch.randn(8, 2)
472+
# Static guards also aren't always inherent to the model; they can also come from user specifications. In fact, a common pitfall leading to shape
473+
# specializations is when the user specifies conflicting markers for equivalent dimensions; one dynamic and another static. The same error type is
474+
# raised when this is the case for ``x.shape[0]`` and ``y.shape[1]``:
406475

407-
class DynamicShapesExample2(torch.nn.Module):
408-
def forward(self, x, y):
409-
return x @ y
476+
dynamic_shapes["w"] = (Dim.AUTO, Dim.AUTO)
477+
dynamic_shapes["x"] = (Dim.STATIC,)
478+
dynamic_shapes["y"] = (Dim.AUTO, Dim.DYNAMIC)
479+
export(model, (w, x, y, z), dynamic_shapes=dynamic_shapes)
410480

411-
inp2_dim0 = Dim("inp2_dim0")
412-
inner_dim = Dim("inner_dim")
413-
inp3_dim1 = Dim("inp3_dim1")
481+
######################################################################
482+
# Here you might ask why export "specializes", i.e. why we resolve this static/dynamic conflict by going with the static route. The answer is because
483+
# of the symbolic shapes system described above, of symbols and guards. When ``x.shape[0]`` is marked static, we don't allocate a symbol, and compile
484+
# treating this shape as a concrete integer 4. A symbol is allocated for ``y.shape[1]``, and so we finally emit the guard ``s3 == 4``, leading to
485+
# specialization.
486+
#
487+
# One feature of export is that during tracing, statements like asserts, ``torch._check()``, and ``if/else`` conditions will also emit guards.
488+
# See what happens when we augment the existing model with such statements:
414489

415-
dynamic_shapes2 = {
416-
"x": {0: inp2_dim0, 1: inner_dim},
417-
"y": {0: inner_dim, 1: inp3_dim1},
490+
class DynamicModel(torch.nn.Module):
491+
def __init__(self):
492+
super().__init__()
493+
self.l = torch.nn.Linear(5, 3)
494+
495+
def forward(self, w, x, y, z):
496+
assert w.shape[0] <= 512
497+
torch._check(x.shape[0] >= 16)
498+
if w.shape[0] == x.shape[0] + 2:
499+
x0 = x + y
500+
x1 = self.l(w)
501+
x2 = x0.flatten()
502+
x3 = x2 + z
503+
return x1, x3
504+
else:
505+
return w
506+
507+
dynamic_shapes = {
508+
"w": (Dim.AUTO, Dim.AUTO),
509+
"x": (Dim.AUTO,),
510+
"y": (Dim.AUTO, Dim.AUTO),
511+
"z": (Dim.AUTO,),
418512
}
419-
420-
exported_dynamic_shapes_example2 = export(DynamicShapesExample2(), (inp2, inp3), dynamic_shapes=dynamic_shapes2)
421-
422-
print(exported_dynamic_shapes_example2.module()(torch.randn(2, 16), torch.randn(16, 4)))
423-
424-
try:
425-
exported_dynamic_shapes_example2.module()(torch.randn(4, 8), torch.randn(4, 2))
426-
except Exception:
427-
tb.print_exc()
513+
ep = export(DynamicModel(), (w, x, y, z), dynamic_shapes=dynamic_shapes)
514+
print(ep)
428515

429516
######################################################################
430-
# We can also describe one dimension in terms of other. There are some
431-
# restrictions to how detailed we can specify one dimension in terms of another,
432-
# but generally, those in the form of ``A * Dim + B`` should work.
433-
434-
class DerivedDimExample1(torch.nn.Module):
435-
def forward(self, x, y):
436-
return x + y[1:]
437-
438-
foo = DerivedDimExample1()
439-
440-
x, y = torch.randn(5), torch.randn(6)
441-
dimx = torch.export.Dim("dimx", min=3, max=6)
442-
dimy = dimx + 1
443-
derived_dynamic_shapes1 = ({0: dimx}, {0: dimy})
444-
445-
derived_dim_example1 = export(foo, (x, y), dynamic_shapes=derived_dynamic_shapes1)
446-
447-
print(derived_dim_example1.module()(torch.randn(4), torch.randn(5)))
448-
449-
try:
450-
derived_dim_example1.module()(torch.randn(4), torch.randn(6))
451-
except Exception:
452-
tb.print_exc()
453-
454-
455-
class DerivedDimExample2(torch.nn.Module):
456-
def forward(self, z, y):
457-
return z[1:] + y[1::3]
458-
459-
foo = DerivedDimExample2()
460-
461-
z, y = torch.randn(4), torch.randn(10)
462-
dx = torch.export.Dim("dx", min=3, max=6)
463-
dz = dx + 1
464-
dy = dx * 3 + 1
465-
derived_dynamic_shapes2 = ({0: dz}, {0: dy})
466-
467-
derived_dim_example2 = export(foo, (z, y), dynamic_shapes=derived_dynamic_shapes2)
468-
print(derived_dim_example2.module()(torch.randn(7), torch.randn(19)))
517+
# Each of these statements emits an additional guard, and the exported program shows the changes; ``s0`` is eliminated in favor of ``s2 + 2``,
518+
# and ``s2`` now contains lower and upper bounds, reflected in ``range_constraints``.
519+
#
520+
# For the if/else condition, you might ask why the True branch was taken, and why it wasn't the ``w.shape[0] != x.shape[0] + 2`` guard that
521+
# got emitted from tracing. The answer is that export is guided by the sample inputs provided by tracing, and specializes on the branches taken.
522+
# If different sample input shapes were provided that fail the ``if`` condition, export would trace and emit guards corresponding to the ``else`` branch.
523+
# Additionally, you might ask why we traced only the ``if`` branch, and if it's possible to maintain control-flow in your program and keep both branches
524+
# alive. For that, refer to rewriting your model code following the ``Control Flow Ops`` section above.
469525

470526
######################################################################
471-
# We can actually use ``torch.export`` to guide us as to which ``dynamic_shapes`` constraints
472-
# are necessary. We can do this by relaxing all constraints (recall that if we
473-
# do not provide constraints for a dimension, the default behavior is to constrain
474-
# to the exact shape value of the example input) and letting ``torch.export``
475-
# error out.
476-
477-
inp4 = torch.randn(8, 16)
478-
inp5 = torch.randn(16, 32)
479-
480-
class DynamicShapesExample3(torch.nn.Module):
481-
def forward(self, x, y):
482-
if x.shape[0] <= 16:
483-
return x @ y[:, :16]
484-
return y
485-
486-
dynamic_shapes3 = {
487-
"x": {i: Dim(f"inp4_dim{i}") for i in range(inp4.dim())},
488-
"y": {i: Dim(f"inp5_dim{i}") for i in range(inp5.dim())},
489-
}
527+
# 0/1 specialization
528+
# ^^^^^^^^^^^^^^^^^^
529+
#
530+
# Since we're talking about guards and specializations, it's a good time to talk about the 0/1 specialization issue we brought up earlier.
531+
# The bottom line is that export will specialize on sample input dimensions with value 0 or 1, because these shapes have trace-time properties that
532+
# don't generalize to other shapes. For example, size 1 tensors can broadcast while other sizes fail; and size 0 ... . This just means that you should
533+
# specify 0/1 sample inputs when you'd like your program to hardcode them, and non-0/1 sample inputs when dynamic behavior is desirable. See what happens
534+
# at runtime when we export this linear layer:
490535

491-
try:
492-
export(DynamicShapesExample3(), (inp4, inp5), dynamic_shapes=dynamic_shapes3)
493-
except Exception:
494-
tb.print_exc()
536+
ep = export(
537+
torch.nn.Linear(4, 3),
538+
(torch.randn(1, 4),),
539+
dynamic_shapes={
540+
"input": (Dim.AUTO, Dim.STATIC),
541+
},
542+
)
543+
ep.module()(torch.randn(2, 4))
495544

496545
######################################################################
497-
# We can see that the error message gives us suggested fixes to our
498-
# dynamic shape constraints. Let us follow those suggestions (exact
499-
# suggestions may differ slightly):
500-
501-
def suggested_fixes():
502-
inp4_dim1 = Dim('shared_dim')
503-
# suggested fixes below
504-
inp4_dim0 = Dim('inp4_dim0', max=16)
505-
inp5_dim1 = Dim('inp5_dim1', min=17)
506-
inp5_dim0 = inp4_dim1
507-
# end of suggested fixes
508-
return {
509-
"x": {0: inp4_dim0, 1: inp4_dim1},
510-
"y": {0: inp5_dim0, 1: inp5_dim1},
511-
}
546+
# Named Dims
547+
# ^^^^^^^^^^
548+
#
549+
# So far we've only been talking about 3 ways to specify dynamic shapes: ``Dim.AUTO``, ``Dim.DYNAMIC``, and ``Dim.STATIC``. The attraction of these is the
550+
# low-friction user experience; all the guards emitted during model tracing are adhered to, and dynamic behavior like min/max ranges, relations, and static/dynamic
551+
# dimensions are automatically figured out underneath export. The dynamic shapes subsystem essentially acts as a "discovery" process, summarizing these guards
552+
# and presenting what export believes is the overall dynamic behavior of the program. The drawback of this design appears once the user has stronger expectations or
553+
# beliefs about the dynamic behavior of these models - maybe there is a strong desire on dynamism and specializations on particular dimensions are to be avoided at
554+
# all costs, or maybe we just want to catch changes in dynamic behavior with changes to the original model code, or possibly underlying decompositions or meta-kernels.
555+
# These changes won't be detected and the ``export()`` call will most likely succeed, unless tests are in place that check the resulting ``ExportedProgram`` representation.
556+
#
557+
# For such cases, our stance is to recommend the "traditional" way of specifying dynamic shapes, which longer-term users of export might be familiar with: named ``Dims``:
512558

513-
dynamic_shapes3_fixed = suggested_fixes()
514-
exported_dynamic_shapes_example3 = export(DynamicShapesExample3(), (inp4, inp5), dynamic_shapes=dynamic_shapes3_fixed)
515-
print(exported_dynamic_shapes_example3.module()(torch.randn(4, 32), torch.randn(32, 64)))
559+
dx = Dim("dx", min=4, max=256)
560+
dh = Dim("dh", max=512)
561+
dynamic_shapes = {
562+
"x": (dx, None),
563+
"y": (2 * dx, dh),
564+
}
516565

517566
######################################################################
518-
# Note that in the example above, because we constrained the value of ``x.shape[0]`` in
519-
# ``dynamic_shapes_example3``, the exported program is sound even though there is a
520-
# raw ``if`` statement.
567+
# This style of dynamic shapes allows the user to specify what symbols are allocated for input dimensions, min/max bounds on those symbols, and places restrictions on the
568+
# dynamic behavior of the ``ExportedProgram`` produced; ``ConstraintViolation`` errors will be raised if model tracing emits guards that conflict with the relations or static/dynamic
569+
# specifications given. For example, in the above specification, the following is asserted:
521570
#
522-
# If you want to see why ``torch.export`` generated these constraints, you can
523-
# re-run the script with the environment variable ``TORCH_LOGS=dynamic,dynamo``,
524-
# or use ``torch._logging.set_logs``.
525-
526-
import logging
527-
torch._logging.set_logs(dynamic=logging.INFO, dynamo=logging.INFO)
528-
exported_dynamic_shapes_example3 = export(DynamicShapesExample3(), (inp4, inp5), dynamic_shapes=dynamic_shapes3_fixed)
571+
# - ``x.shape[0]`` is to have range ``[4, 256]``, and related to ``y.shape[0]`` by ``y.shape[0] == 2 * x.shape[0]``.
572+
# - ``x.shape[1]`` is static.
573+
# - ``y.shape[1]`` has range ``[2, 512]``, and is unrelated to any other dimension.
574+
#
575+
# In this design, we allow relations between dimensions to be specified with univariate linear expressions: ``A * dim + B`` can be specified for any dimension. This allows users
576+
# to specify more complex constraints like integer divisibility for dynamic dimensions:
529577

530-
# reset to previous values
531-
torch._logging.set_logs(dynamic=logging.WARNING, dynamo=logging.WARNING)
578+
dx = Dim("dx", min=4, max=512)
579+
dynamic_shapes = {
580+
"x": (4 * dx, None) # x.shape[0] has range [16, 2048], and is divisible by 4.
581+
}
532582

533583
######################################################################
534-
# We can view an ``ExportedProgram``'s symbolic shape ranges using the
535-
# ``range_constraints`` field.
584+
# Constraint violations, suggested fixes
585+
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
586+
#
587+
# One common issue with this specification style (before ``Dim.AUTO`` was introduced), is that the specification would often be mismatched with what was produced by model tracing.
588+
# That would lead to ``ConstraintViolation`` errors and export suggested fixes - see for example with this model & specification, where the model inherently requires equality between
589+
# dimensions 0 of ``x`` and ``y``, and requires dimension 1 to be static.
536590

537-
print(exported_dynamic_shapes_example3.range_constraints)
591+
class Foo(torch.nn.Module):
592+
def forward(self, x, y):
593+
w = x + y
594+
return w + torch.ones(4)
595+
596+
dx, dy, d1 = torch.export.dims("dx", "dy", "d1")
597+
ep = export(
598+
Foo(),
599+
(torch.randn(6, 4), torch.randn(6, 4)),
600+
dynamic_shapes={
601+
"x": (dx, d1),
602+
"y": (dy, d1),
603+
},
604+
)
605+
606+
######################################################################
607+
# The expectation with suggested fixes is that the user can interactively copy-paste the changes into their dynamic shapes specification, and successfully export afterwards.
608+
#
609+
# Lastly, there's couple nice-to-knows about the options for specification:
610+
#
611+
# - ``None`` is a good option for static behavior:
612+
# - ``dynamic_shapes=None`` (default) exports with the entire model being static.
613+
# - specifying ``None`` at an input-level exports with all tensor dimensions static, and is also required for non-tensor inputs.
614+
# - specifying ``None`` at a dimension-level specializes that dimension, though this is deprecated in favor of ``Dim.STATIC``.
615+
# - specifying per-dimension integer values also produces static behavior, and will additionally check that the provided sample input matches the specification.
616+
#
617+
# These options are combined in the inputs & dynamic shapes spec below:
618+
619+
inputs = (
620+
torch.randn(4, 4),
621+
torch.randn(3, 3),
622+
16,
623+
False,
624+
)
625+
dynamic_shapes = {
626+
"tensor_0": (Dim.AUTO, None),
627+
"tensor_1": None,
628+
"int_val": None,
629+
"bool_val": None,
630+
}
538631

539632
######################################################################
540633
# Custom Ops

0 commit comments

Comments
 (0)
Please sign in to comment.