|
3 | 3 | """
|
4 | 4 | torch.export Tutorial
|
5 | 5 | ===================================================
|
6 |
| -**Author:** William Wen, Zhengxu Chen, Angela Yi |
| 6 | +**Author:** William Wen, Zhengxu Chen, Angela Yi, Pian Pawakapan |
7 | 7 | """
|
8 | 8 |
|
9 | 9 | ######################################################################
|
10 | 10 | #
|
11 | 11 | # .. warning::
|
12 | 12 | #
|
13 | 13 | # ``torch.export`` and its related features are in prototype status and are subject to backwards compatibility
|
14 |
| -# breaking changes. This tutorial provides a snapshot of ``torch.export`` usage as of PyTorch 2.3. |
| 14 | +# breaking changes. This tutorial provides a snapshot of ``torch.export`` usage as of PyTorch 2.5. |
15 | 15 | #
|
16 | 16 | # :func:`torch.export` is the PyTorch 2.X way to export PyTorch models into
|
17 | 17 | # standardized model representations, intended
|
@@ -304,237 +304,330 @@ def false_fn(x):
|
304 | 304 | # Constraints/Dynamic Shapes
|
305 | 305 | # --------------------------
|
306 | 306 | #
|
307 |
| -# Ops can have different specializations/behaviors for different tensor shapes, so by default, |
308 |
| -# ``torch.export`` requires inputs to ``ExportedProgram`` to have the same shape as the respective |
309 |
| -# example inputs given to the initial ``torch.export.export()`` call. |
310 |
| -# If we try to run the ``ExportedProgram`` in the example below with a tensor |
311 |
| -# with a different shape, we get an error: |
| 307 | +# This section covers dynamic behavior and representation of exported programs. Dynamic behavior is |
| 308 | +# subjective to the particular model being exported, so for the most part of this tutorial, we'll focus |
| 309 | +# on this particular toy model (with the resulting tensor shapes annotated): |
312 | 310 |
|
313 |
| -class MyModule2(torch.nn.Module): |
| 311 | +class DynamicModel(torch.nn.Module): |
314 | 312 | def __init__(self):
|
315 | 313 | super().__init__()
|
316 |
| - self.lin = torch.nn.Linear(100, 10) |
| 314 | + self.l = torch.nn.Linear(5, 3) |
| 315 | + |
| 316 | + def forward( |
| 317 | + self, |
| 318 | + w: torch.Tensor, # [6, 5] |
| 319 | + x: torch.Tensor, # [4] |
| 320 | + y: torch.Tensor, # [8, 4] |
| 321 | + z: torch.Tensor, # [32] |
| 322 | + ): |
| 323 | + x0 = x + y # [8, 4] |
| 324 | + x1 = self.l(w) # [6, 3] |
| 325 | + x2 = x0.flatten() # [32] |
| 326 | + x3 = x2 + z # [32] |
| 327 | + return x1, x3 |
| 328 | + |
| 329 | +###################################################################### |
| 330 | +# By default, ``torch.export`` produces a static program. One consequence of this is that at runtime, |
| 331 | +# the program won't work on inputs with different shapes, even if they're valid in eager mode. |
| 332 | + |
| 333 | +w = torch.randn(6, 5) |
| 334 | +x = torch.randn(4) |
| 335 | +y = torch.randn(8, 4) |
| 336 | +z = torch.randn(32) |
| 337 | +model = DynamicModel() |
| 338 | +ep = export(model, (w, x, y, z)) |
| 339 | +model(w, x, torch.randn(3, 4), torch.randn(12)) |
| 340 | +ep.module()(w, x, torch.randn(3, 4), torch.randn(12)) |
| 341 | + |
| 342 | +###################################################################### |
| 343 | +# Basic concepts: symbols and guards |
| 344 | +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ |
| 345 | +# |
| 346 | +# To enable dynamism, ``export()`` provides a ``dynamic_shapes`` argument. The easiest way to work with |
| 347 | +# dynamic shapes is using ``Dim.AUTO`` and looking at the program that's returned. Dynamic behavior is specified |
| 348 | +# at a input dimension-level; for each input we can specify a tuple of values: |
| 349 | + |
| 350 | +from torch.export.dynamic_shapes import Dim |
| 351 | + |
| 352 | +dynamic_shapes = { |
| 353 | + "w": (Dim.AUTO, Dim.AUTO), |
| 354 | + "x": (Dim.AUTO,), |
| 355 | + "y": (Dim.AUTO, Dim.AUTO), |
| 356 | + "z": (Dim.AUTO,), |
| 357 | +} |
| 358 | +ep = export(model, (w, x, y, z), dynamic_shapes=dynamic_shapes) |
317 | 359 |
|
318 |
| - def forward(self, x, y): |
319 |
| - return torch.nn.functional.relu(self.lin(x + y), inplace=True) |
| 360 | +###################################################################### |
| 361 | +# Before we look at the program that's produced, let's understand what specifying ``dynamic_shapes`` entails, |
| 362 | +# and how that interacts with export. For every input dimension where a ``Dim`` object is specified, a symbol is |
| 363 | +# `allocated <https://pytorch.org/docs/main/export.programming_model.html#basics-of-symbolic-shapes>`_, |
| 364 | +# taking on a range of ``[2, inf]`` (why not ``[0, inf]`` or ``[1, inf]``? we'll explain later in the |
| 365 | +# 0/1 specialization section). |
| 366 | +# |
| 367 | +# Export then runs model tracing, looking at each operation that's performed by the model. Each individual operation can emit |
| 368 | +# what's called "guards"; basically boolean condition that are required to be true for the program to be valid. |
| 369 | +# When guards involve symbols allocated for input dimensions, the program contains restrictions on what input shapes are valid; |
| 370 | +# i.e. the program's dynamic behavior. The symbolic shapes subsystem is the part responsible for taking in all the emitted guards |
| 371 | +# and producing a final program representation that adheres to all of these guards. Before we see this "final representation" in |
| 372 | +# an ``ExportedProgram``, let's look at the guards emitted by the toy model we're tracing. |
| 373 | +# |
| 374 | +# Here, each forward input tensor is annotated with the symbol allocated at the start of tracing: |
320 | 375 |
|
321 |
| -mod2 = MyModule2() |
322 |
| -exported_mod2 = export(mod2, (torch.randn(8, 100), torch.randn(8, 100))) |
| 376 | +class DynamicModel(torch.nn.Module): |
| 377 | + def __init__(self): |
| 378 | + super().__init__() |
| 379 | + self.l = torch.nn.Linear(5, 3) |
323 | 380 |
|
324 |
| -try: |
325 |
| - exported_mod2.module()(torch.randn(10, 100), torch.randn(10, 100)) |
326 |
| -except Exception: |
327 |
| - tb.print_exc() |
| 381 | + def forward( |
| 382 | + self, |
| 383 | + w: torch.Tensor, # [s0, s1] |
| 384 | + x: torch.Tensor, # [s2] |
| 385 | + y: torch.Tensor, # [s3, s4] |
| 386 | + z: torch.Tensor, # [s5] |
| 387 | + ): |
| 388 | + x0 = x + y # guard: s2 == s4 |
| 389 | + x1 = self.l(w) # guard: s1 == 5 |
| 390 | + x2 = x0.flatten() # no guard added here |
| 391 | + x3 = x2 + z # guard: s3 * s4 == s5 |
| 392 | + return x1, x3 |
328 | 393 |
|
329 | 394 | ######################################################################
|
330 |
| -# We can relax this constraint using the ``dynamic_shapes`` argument of |
331 |
| -# ``torch.export.export()``, which allows us to specify, using ``torch.export.Dim`` |
332 |
| -# (`documentation <https://pytorch.org/docs/main/export.html#torch.export.Dim>`__), |
333 |
| -# which dimensions of the input tensors are dynamic. |
| 395 | +# Let's understand each of the operations and the emitted guards: |
334 | 396 | #
|
335 |
| -# For each tensor argument of the input callable, we can specify a mapping from the dimension |
336 |
| -# to a ``torch.export.Dim``. |
337 |
| -# A ``torch.export.Dim`` is essentially a named symbolic integer with optional |
338 |
| -# minimum and maximum bounds. |
| 397 | +# - ``x0 = x + y``: This is an element-wise add with broadcasting, since ``x`` is a 1-d tensor and ``y`` a 2-d tensor. ``x`` is broadcasted along the last dimension of ``y``, emitting the guard ``s2 == s4``. |
| 398 | +# - ``x1 = self.l(w)``: Calling ``nn.Linear()`` performs a matrix multiplication with model parameters. In export, parameters, buffers, and constants are considered program state, which is considered static, and so this is a matmul between a dynamic input (``w: [s0, s1]``), and a statically-shaped tensor. This emits the guard ``s1 == 5``. |
| 399 | +# - ``x2 = x0.flatten()``: This call actually doesn't emit any guards! (at least none relevant to input shapes) |
| 400 | +# - ``x3 = x2 + z``: ``x2`` has shape ``[s3*s4]`` after flattening, and this element-wise add emits ``s3 * s4 == s5``. |
339 | 401 | #
|
340 |
| -# Then, the format of ``torch.export.export()``'s ``dynamic_shapes`` argument is a mapping |
341 |
| -# from the input callable's tensor argument names, to dimension --> dim mappings as described above. |
342 |
| -# If there is no ``torch.export.Dim`` given to a tensor argument's dimension, then that dimension is |
343 |
| -# assumed to be static. |
| 402 | +# Writing all of these guards down and summarizing is almost like a mathematical proof, which is what the symbolic shapes |
| 403 | +# subsystem tries to do! In summary, we can conclude that the program must have the following input shapes to be valid: |
344 | 404 | #
|
345 |
| -# The first argument of ``torch.export.Dim`` is the name for the symbolic integer, used for debugging. |
346 |
| -# Then we can specify an optional minimum and maximum bound (inclusive). Below, we show a usage example. |
| 405 | +# - ``w: [s0, 5]`` |
| 406 | +# - ``x: [s2]`` |
| 407 | +# - ``y: [s3, s2]`` |
| 408 | +# - ``z: [s2*s3]`` |
347 | 409 | #
|
348 |
| -# In the example below, our input |
349 |
| -# ``inp1`` has an unconstrained first dimension, but the size of the second |
350 |
| -# dimension must be in the interval [4, 18]. |
351 |
| - |
352 |
| -from torch.export import Dim |
353 |
| - |
354 |
| -inp1 = torch.randn(10, 10, 2) |
| 410 | +# And when we do finally print out the exported program to see our result, those shapes are what we see annotated on the |
| 411 | +# corresponding inputs: |
355 | 412 |
|
356 |
| -class DynamicShapesExample1(torch.nn.Module): |
357 |
| - def forward(self, x): |
358 |
| - x = x[:, 2:] |
359 |
| - return torch.relu(x) |
| 413 | +print(ep) |
360 | 414 |
|
361 |
| -inp1_dim0 = Dim("inp1_dim0") |
362 |
| -inp1_dim1 = Dim("inp1_dim1", min=4, max=18) |
363 |
| -dynamic_shapes1 = { |
364 |
| - "x": {0: inp1_dim0, 1: inp1_dim1}, |
365 |
| -} |
| 415 | +###################################################################### |
| 416 | +# Another feature to notice is the range_constraints field above, which contains a valid range for each symbol. This isn't |
| 417 | +# so interesting currently, since this export call doesn't emit any guards related to symbol bounds and each base symbol has |
| 418 | +# a generic bound, but this will come up later. |
| 419 | +# |
| 420 | +# So far, because we've been exporting this toy model, this experience has not been representative of how hard |
| 421 | +# it typically is to debug dynamic shapes guards & issues. In most cases it isn't obvious what guards are being emitted, |
| 422 | +# and which operations and parts of user code are responsible. For this toy model we pinpoint the exact lines, and the guards |
| 423 | +# are rather intuitive. |
| 424 | +# |
| 425 | +# In more complicated cases, a helpful first step is always to enable verbose logging. This can be done either with the environment |
| 426 | +# variable ``TORCH_LOGS="+dynamic"``, or interactively with ``torch._logging.set_logs(dynamic=10)``: |
366 | 427 |
|
367 |
| -exported_dynamic_shapes_example1 = export(DynamicShapesExample1(), (inp1,), dynamic_shapes=dynamic_shapes1) |
| 428 | +torch._logging.set_logs(dynamic=10) |
| 429 | +ep = export(model, (w, x, y, z), dynamic_shapes=dynamic_shapes) |
368 | 430 |
|
369 |
| -print(exported_dynamic_shapes_example1.module()(torch.randn(5, 5, 2))) |
| 431 | +###################################################################### |
| 432 | +# This spits out quite a handful, even with this simple toy model. The log lines here have been cut short at front and end |
| 433 | +# to ignore unnecessary info, but looking through the logs we can see the lines relevant to what we described above; |
| 434 | +# e.g. the allocation of symbols: |
370 | 435 |
|
371 |
| -try: |
372 |
| - exported_dynamic_shapes_example1.module()(torch.randn(8, 1, 2)) |
373 |
| -except Exception: |
374 |
| - tb.print_exc() |
| 436 | +""" |
| 437 | +create_symbol s0 = 6 for L['w'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2841 in <lambda>) |
| 438 | +create_symbol s1 = 5 for L['w'].size()[1] [2, int_oo] (_dynamo/variables/builder.py:2841 in <lambda>) |
| 439 | +runtime_assert True == True [statically known] |
| 440 | +create_symbol s2 = 4 for L['x'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2841 in <lambda>) |
| 441 | +create_symbol s3 = 8 for L['y'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2841 in <lambda>) |
| 442 | +create_symbol s4 = 4 for L['y'].size()[1] [2, int_oo] (_dynamo/variables/builder.py:2841 in <lambda>) |
| 443 | +create_symbol s5 = 32 for L['z'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2841 in <lambda>) |
| 444 | +""" |
375 | 445 |
|
376 |
| -try: |
377 |
| - exported_dynamic_shapes_example1.module()(torch.randn(8, 20, 2)) |
378 |
| -except Exception: |
379 |
| - tb.print_exc() |
| 446 | +###################################################################### |
| 447 | +# The lines with `create_symbol` show when a new symbol has been allocated, and the logs also identify the tensor variable names |
| 448 | +# and dimensions they've been allocated for. In other lines we can also see the guards emitted: |
380 | 449 |
|
381 |
| -try: |
382 |
| - exported_dynamic_shapes_example1.module()(torch.randn(8, 8, 3)) |
383 |
| -except Exception: |
384 |
| - tb.print_exc() |
| 450 | +""" |
| 451 | +runtime_assert Eq(s2, s4) [guard added] x0 = x + y # output shape: [8, 4] # dynamic_shapes_tutorial.py:16 in forward (_subclasses/fake_impls.py:845 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s2, s4)" |
| 452 | +runtime_assert Eq(s1, 5) [guard added] x1 = self.l(w) # [6, 3] # dynamic_shapes_tutorial.py:17 in forward (_meta_registrations.py:2127 in meta_mm), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s1, 5)" |
| 453 | +runtime_assert Eq(s2*s3, s5) [guard added] x3 = x2 + z # [32] # dynamic_shapes_tutorial.py:19 in forward (_subclasses/fake_impls.py:845 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s2*s3, s5)" |
| 454 | +""" |
385 | 455 |
|
386 | 456 | ######################################################################
|
387 |
| -# Note that if our example inputs to ``torch.export`` do not satisfy the constraints |
388 |
| -# given by ``dynamic_shapes``, then we get an error. |
| 457 | +# Next to the ``[guard added]`` messages, we also see the responsible user lines of code - luckily here the model is simple enough. |
| 458 | +# In many real-world cases it's not so straightforward: high-level torch operations can have complicated fake-kernel implementations |
| 459 | +# or operator decompositions that complicate where and what guards are emitted. In such cases the best way to dig deeper and investigate |
| 460 | +# is to follow the logs' suggestion, and re-run with environment variable ``TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="..."``, to further |
| 461 | +# attribute the guard of interest. |
| 462 | +# |
| 463 | +# ``Dim.AUTO`` is just one of the available options for interacting with ``dynamic_shapes``; as of writing this 2 other options are available: |
| 464 | +# ``Dim.DYNAMIC``, and ``Dim.STATIC``. ``Dim.STATIC`` simply marks a dimension static, while ``Dim.DYNAMIC`` is similar to ``Dim.AUTO`` in all |
| 465 | +# ways except one: it raises an error when specializing to a constant; this is designed to maintain dynamism. See for example what happens when a |
| 466 | +# static guard is emitted on a dynamically-marked dimension: |
389 | 467 |
|
390 |
| -inp1_dim1_bad = Dim("inp1_dim1_bad", min=11, max=18) |
391 |
| -dynamic_shapes1_bad = { |
392 |
| - "x": {0: inp1_dim0, 1: inp1_dim1_bad}, |
393 |
| -} |
394 |
| - |
395 |
| -try: |
396 |
| - export(DynamicShapesExample1(), (inp1,), dynamic_shapes=dynamic_shapes1_bad) |
397 |
| -except Exception: |
398 |
| - tb.print_exc() |
| 468 | +dynamic_shapes["w"] = (Dim.AUTO, Dim.DYNAMIC) |
| 469 | +export(model, (w, x, y, z), dynamic_shapes=dynamic_shapes) |
399 | 470 |
|
400 | 471 | ######################################################################
|
401 |
| -# We can enforce that equalities between dimensions of different tensors |
402 |
| -# by using the same ``torch.export.Dim`` object, for example, in matrix multiplication: |
403 |
| - |
404 |
| -inp2 = torch.randn(4, 8) |
405 |
| -inp3 = torch.randn(8, 2) |
| 472 | +# Static guards also aren't always inherent to the model; they can also come from user specifications. In fact, a common pitfall leading to shape |
| 473 | +# specializations is when the user specifies conflicting markers for equivalent dimensions; one dynamic and another static. The same error type is |
| 474 | +# raised when this is the case for ``x.shape[0]`` and ``y.shape[1]``: |
406 | 475 |
|
407 |
| -class DynamicShapesExample2(torch.nn.Module): |
408 |
| - def forward(self, x, y): |
409 |
| - return x @ y |
| 476 | +dynamic_shapes["w"] = (Dim.AUTO, Dim.AUTO) |
| 477 | +dynamic_shapes["x"] = (Dim.STATIC,) |
| 478 | +dynamic_shapes["y"] = (Dim.AUTO, Dim.DYNAMIC) |
| 479 | +export(model, (w, x, y, z), dynamic_shapes=dynamic_shapes) |
410 | 480 |
|
411 |
| -inp2_dim0 = Dim("inp2_dim0") |
412 |
| -inner_dim = Dim("inner_dim") |
413 |
| -inp3_dim1 = Dim("inp3_dim1") |
| 481 | +###################################################################### |
| 482 | +# Here you might ask why export "specializes", i.e. why we resolve this static/dynamic conflict by going with the static route. The answer is because |
| 483 | +# of the symbolic shapes system described above, of symbols and guards. When ``x.shape[0]`` is marked static, we don't allocate a symbol, and compile |
| 484 | +# treating this shape as a concrete integer 4. A symbol is allocated for ``y.shape[1]``, and so we finally emit the guard ``s3 == 4``, leading to |
| 485 | +# specialization. |
| 486 | +# |
| 487 | +# One feature of export is that during tracing, statements like asserts, ``torch._check()``, and ``if/else`` conditions will also emit guards. |
| 488 | +# See what happens when we augment the existing model with such statements: |
414 | 489 |
|
415 |
| -dynamic_shapes2 = { |
416 |
| - "x": {0: inp2_dim0, 1: inner_dim}, |
417 |
| - "y": {0: inner_dim, 1: inp3_dim1}, |
| 490 | +class DynamicModel(torch.nn.Module): |
| 491 | + def __init__(self): |
| 492 | + super().__init__() |
| 493 | + self.l = torch.nn.Linear(5, 3) |
| 494 | + |
| 495 | + def forward(self, w, x, y, z): |
| 496 | + assert w.shape[0] <= 512 |
| 497 | + torch._check(x.shape[0] >= 16) |
| 498 | + if w.shape[0] == x.shape[0] + 2: |
| 499 | + x0 = x + y |
| 500 | + x1 = self.l(w) |
| 501 | + x2 = x0.flatten() |
| 502 | + x3 = x2 + z |
| 503 | + return x1, x3 |
| 504 | + else: |
| 505 | + return w |
| 506 | + |
| 507 | +dynamic_shapes = { |
| 508 | + "w": (Dim.AUTO, Dim.AUTO), |
| 509 | + "x": (Dim.AUTO,), |
| 510 | + "y": (Dim.AUTO, Dim.AUTO), |
| 511 | + "z": (Dim.AUTO,), |
418 | 512 | }
|
419 |
| - |
420 |
| -exported_dynamic_shapes_example2 = export(DynamicShapesExample2(), (inp2, inp3), dynamic_shapes=dynamic_shapes2) |
421 |
| - |
422 |
| -print(exported_dynamic_shapes_example2.module()(torch.randn(2, 16), torch.randn(16, 4))) |
423 |
| - |
424 |
| -try: |
425 |
| - exported_dynamic_shapes_example2.module()(torch.randn(4, 8), torch.randn(4, 2)) |
426 |
| -except Exception: |
427 |
| - tb.print_exc() |
| 513 | +ep = export(DynamicModel(), (w, x, y, z), dynamic_shapes=dynamic_shapes) |
| 514 | +print(ep) |
428 | 515 |
|
429 | 516 | ######################################################################
|
430 |
| -# We can also describe one dimension in terms of other. There are some |
431 |
| -# restrictions to how detailed we can specify one dimension in terms of another, |
432 |
| -# but generally, those in the form of ``A * Dim + B`` should work. |
433 |
| - |
434 |
| -class DerivedDimExample1(torch.nn.Module): |
435 |
| - def forward(self, x, y): |
436 |
| - return x + y[1:] |
437 |
| - |
438 |
| -foo = DerivedDimExample1() |
439 |
| - |
440 |
| -x, y = torch.randn(5), torch.randn(6) |
441 |
| -dimx = torch.export.Dim("dimx", min=3, max=6) |
442 |
| -dimy = dimx + 1 |
443 |
| -derived_dynamic_shapes1 = ({0: dimx}, {0: dimy}) |
444 |
| - |
445 |
| -derived_dim_example1 = export(foo, (x, y), dynamic_shapes=derived_dynamic_shapes1) |
446 |
| - |
447 |
| -print(derived_dim_example1.module()(torch.randn(4), torch.randn(5))) |
448 |
| - |
449 |
| -try: |
450 |
| - derived_dim_example1.module()(torch.randn(4), torch.randn(6)) |
451 |
| -except Exception: |
452 |
| - tb.print_exc() |
453 |
| - |
454 |
| - |
455 |
| -class DerivedDimExample2(torch.nn.Module): |
456 |
| - def forward(self, z, y): |
457 |
| - return z[1:] + y[1::3] |
458 |
| - |
459 |
| -foo = DerivedDimExample2() |
460 |
| - |
461 |
| -z, y = torch.randn(4), torch.randn(10) |
462 |
| -dx = torch.export.Dim("dx", min=3, max=6) |
463 |
| -dz = dx + 1 |
464 |
| -dy = dx * 3 + 1 |
465 |
| -derived_dynamic_shapes2 = ({0: dz}, {0: dy}) |
466 |
| - |
467 |
| -derived_dim_example2 = export(foo, (z, y), dynamic_shapes=derived_dynamic_shapes2) |
468 |
| -print(derived_dim_example2.module()(torch.randn(7), torch.randn(19))) |
| 517 | +# Each of these statements emits an additional guard, and the exported program shows the changes; ``s0`` is eliminated in favor of ``s2 + 2``, |
| 518 | +# and ``s2`` now contains lower and upper bounds, reflected in ``range_constraints``. |
| 519 | +# |
| 520 | +# For the if/else condition, you might ask why the True branch was taken, and why it wasn't the ``w.shape[0] != x.shape[0] + 2`` guard that |
| 521 | +# got emitted from tracing. The answer is that export is guided by the sample inputs provided by tracing, and specializes on the branches taken. |
| 522 | +# If different sample input shapes were provided that fail the ``if`` condition, export would trace and emit guards corresponding to the ``else`` branch. |
| 523 | +# Additionally, you might ask why we traced only the ``if`` branch, and if it's possible to maintain control-flow in your program and keep both branches |
| 524 | +# alive. For that, refer to rewriting your model code following the ``Control Flow Ops`` section above. |
469 | 525 |
|
470 | 526 | ######################################################################
|
471 |
| -# We can actually use ``torch.export`` to guide us as to which ``dynamic_shapes`` constraints |
472 |
| -# are necessary. We can do this by relaxing all constraints (recall that if we |
473 |
| -# do not provide constraints for a dimension, the default behavior is to constrain |
474 |
| -# to the exact shape value of the example input) and letting ``torch.export`` |
475 |
| -# error out. |
476 |
| - |
477 |
| -inp4 = torch.randn(8, 16) |
478 |
| -inp5 = torch.randn(16, 32) |
479 |
| - |
480 |
| -class DynamicShapesExample3(torch.nn.Module): |
481 |
| - def forward(self, x, y): |
482 |
| - if x.shape[0] <= 16: |
483 |
| - return x @ y[:, :16] |
484 |
| - return y |
485 |
| - |
486 |
| -dynamic_shapes3 = { |
487 |
| - "x": {i: Dim(f"inp4_dim{i}") for i in range(inp4.dim())}, |
488 |
| - "y": {i: Dim(f"inp5_dim{i}") for i in range(inp5.dim())}, |
489 |
| -} |
| 527 | +# 0/1 specialization |
| 528 | +# ^^^^^^^^^^^^^^^^^^ |
| 529 | +# |
| 530 | +# Since we're talking about guards and specializations, it's a good time to talk about the 0/1 specialization issue we brought up earlier. |
| 531 | +# The bottom line is that export will specialize on sample input dimensions with value 0 or 1, because these shapes have trace-time properties that |
| 532 | +# don't generalize to other shapes. For example, size 1 tensors can broadcast while other sizes fail; and size 0 ... . This just means that you should |
| 533 | +# specify 0/1 sample inputs when you'd like your program to hardcode them, and non-0/1 sample inputs when dynamic behavior is desirable. See what happens |
| 534 | +# at runtime when we export this linear layer: |
490 | 535 |
|
491 |
| -try: |
492 |
| - export(DynamicShapesExample3(), (inp4, inp5), dynamic_shapes=dynamic_shapes3) |
493 |
| -except Exception: |
494 |
| - tb.print_exc() |
| 536 | +ep = export( |
| 537 | + torch.nn.Linear(4, 3), |
| 538 | + (torch.randn(1, 4),), |
| 539 | + dynamic_shapes={ |
| 540 | + "input": (Dim.AUTO, Dim.STATIC), |
| 541 | + }, |
| 542 | +) |
| 543 | +ep.module()(torch.randn(2, 4)) |
495 | 544 |
|
496 | 545 | ######################################################################
|
497 |
| -# We can see that the error message gives us suggested fixes to our |
498 |
| -# dynamic shape constraints. Let us follow those suggestions (exact |
499 |
| -# suggestions may differ slightly): |
500 |
| - |
501 |
| -def suggested_fixes(): |
502 |
| - inp4_dim1 = Dim('shared_dim') |
503 |
| - # suggested fixes below |
504 |
| - inp4_dim0 = Dim('inp4_dim0', max=16) |
505 |
| - inp5_dim1 = Dim('inp5_dim1', min=17) |
506 |
| - inp5_dim0 = inp4_dim1 |
507 |
| - # end of suggested fixes |
508 |
| - return { |
509 |
| - "x": {0: inp4_dim0, 1: inp4_dim1}, |
510 |
| - "y": {0: inp5_dim0, 1: inp5_dim1}, |
511 |
| - } |
| 546 | +# Named Dims |
| 547 | +# ^^^^^^^^^^ |
| 548 | +# |
| 549 | +# So far we've only been talking about 3 ways to specify dynamic shapes: ``Dim.AUTO``, ``Dim.DYNAMIC``, and ``Dim.STATIC``. The attraction of these is the |
| 550 | +# low-friction user experience; all the guards emitted during model tracing are adhered to, and dynamic behavior like min/max ranges, relations, and static/dynamic |
| 551 | +# dimensions are automatically figured out underneath export. The dynamic shapes subsystem essentially acts as a "discovery" process, summarizing these guards |
| 552 | +# and presenting what export believes is the overall dynamic behavior of the program. The drawback of this design appears once the user has stronger expectations or |
| 553 | +# beliefs about the dynamic behavior of these models - maybe there is a strong desire on dynamism and specializations on particular dimensions are to be avoided at |
| 554 | +# all costs, or maybe we just want to catch changes in dynamic behavior with changes to the original model code, or possibly underlying decompositions or meta-kernels. |
| 555 | +# These changes won't be detected and the ``export()`` call will most likely succeed, unless tests are in place that check the resulting ``ExportedProgram`` representation. |
| 556 | +# |
| 557 | +# For such cases, our stance is to recommend the "traditional" way of specifying dynamic shapes, which longer-term users of export might be familiar with: named ``Dims``: |
512 | 558 |
|
513 |
| -dynamic_shapes3_fixed = suggested_fixes() |
514 |
| -exported_dynamic_shapes_example3 = export(DynamicShapesExample3(), (inp4, inp5), dynamic_shapes=dynamic_shapes3_fixed) |
515 |
| -print(exported_dynamic_shapes_example3.module()(torch.randn(4, 32), torch.randn(32, 64))) |
| 559 | +dx = Dim("dx", min=4, max=256) |
| 560 | +dh = Dim("dh", max=512) |
| 561 | +dynamic_shapes = { |
| 562 | + "x": (dx, None), |
| 563 | + "y": (2 * dx, dh), |
| 564 | +} |
516 | 565 |
|
517 | 566 | ######################################################################
|
518 |
| -# Note that in the example above, because we constrained the value of ``x.shape[0]`` in |
519 |
| -# ``dynamic_shapes_example3``, the exported program is sound even though there is a |
520 |
| -# raw ``if`` statement. |
| 567 | +# This style of dynamic shapes allows the user to specify what symbols are allocated for input dimensions, min/max bounds on those symbols, and places restrictions on the |
| 568 | +# dynamic behavior of the ``ExportedProgram`` produced; ``ConstraintViolation`` errors will be raised if model tracing emits guards that conflict with the relations or static/dynamic |
| 569 | +# specifications given. For example, in the above specification, the following is asserted: |
521 | 570 | #
|
522 |
| -# If you want to see why ``torch.export`` generated these constraints, you can |
523 |
| -# re-run the script with the environment variable ``TORCH_LOGS=dynamic,dynamo``, |
524 |
| -# or use ``torch._logging.set_logs``. |
525 |
| - |
526 |
| -import logging |
527 |
| -torch._logging.set_logs(dynamic=logging.INFO, dynamo=logging.INFO) |
528 |
| -exported_dynamic_shapes_example3 = export(DynamicShapesExample3(), (inp4, inp5), dynamic_shapes=dynamic_shapes3_fixed) |
| 571 | +# - ``x.shape[0]`` is to have range ``[4, 256]``, and related to ``y.shape[0]`` by ``y.shape[0] == 2 * x.shape[0]``. |
| 572 | +# - ``x.shape[1]`` is static. |
| 573 | +# - ``y.shape[1]`` has range ``[2, 512]``, and is unrelated to any other dimension. |
| 574 | +# |
| 575 | +# In this design, we allow relations between dimensions to be specified with univariate linear expressions: ``A * dim + B`` can be specified for any dimension. This allows users |
| 576 | +# to specify more complex constraints like integer divisibility for dynamic dimensions: |
529 | 577 |
|
530 |
| -# reset to previous values |
531 |
| -torch._logging.set_logs(dynamic=logging.WARNING, dynamo=logging.WARNING) |
| 578 | +dx = Dim("dx", min=4, max=512) |
| 579 | +dynamic_shapes = { |
| 580 | + "x": (4 * dx, None) # x.shape[0] has range [16, 2048], and is divisible by 4. |
| 581 | +} |
532 | 582 |
|
533 | 583 | ######################################################################
|
534 |
| -# We can view an ``ExportedProgram``'s symbolic shape ranges using the |
535 |
| -# ``range_constraints`` field. |
| 584 | +# Constraint violations, suggested fixes |
| 585 | +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ |
| 586 | +# |
| 587 | +# One common issue with this specification style (before ``Dim.AUTO`` was introduced), is that the specification would often be mismatched with what was produced by model tracing. |
| 588 | +# That would lead to ``ConstraintViolation`` errors and export suggested fixes - see for example with this model & specification, where the model inherently requires equality between |
| 589 | +# dimensions 0 of ``x`` and ``y``, and requires dimension 1 to be static. |
536 | 590 |
|
537 |
| -print(exported_dynamic_shapes_example3.range_constraints) |
| 591 | +class Foo(torch.nn.Module): |
| 592 | + def forward(self, x, y): |
| 593 | + w = x + y |
| 594 | + return w + torch.ones(4) |
| 595 | + |
| 596 | +dx, dy, d1 = torch.export.dims("dx", "dy", "d1") |
| 597 | +ep = export( |
| 598 | + Foo(), |
| 599 | + (torch.randn(6, 4), torch.randn(6, 4)), |
| 600 | + dynamic_shapes={ |
| 601 | + "x": (dx, d1), |
| 602 | + "y": (dy, d1), |
| 603 | + }, |
| 604 | +) |
| 605 | + |
| 606 | +###################################################################### |
| 607 | +# The expectation with suggested fixes is that the user can interactively copy-paste the changes into their dynamic shapes specification, and successfully export afterwards. |
| 608 | +# |
| 609 | +# Lastly, there's couple nice-to-knows about the options for specification: |
| 610 | +# |
| 611 | +# - ``None`` is a good option for static behavior: |
| 612 | +# - ``dynamic_shapes=None`` (default) exports with the entire model being static. |
| 613 | +# - specifying ``None`` at an input-level exports with all tensor dimensions static, and is also required for non-tensor inputs. |
| 614 | +# - specifying ``None`` at a dimension-level specializes that dimension, though this is deprecated in favor of ``Dim.STATIC``. |
| 615 | +# - specifying per-dimension integer values also produces static behavior, and will additionally check that the provided sample input matches the specification. |
| 616 | +# |
| 617 | +# These options are combined in the inputs & dynamic shapes spec below: |
| 618 | + |
| 619 | +inputs = ( |
| 620 | + torch.randn(4, 4), |
| 621 | + torch.randn(3, 3), |
| 622 | + 16, |
| 623 | + False, |
| 624 | +) |
| 625 | +dynamic_shapes = { |
| 626 | + "tensor_0": (Dim.AUTO, None), |
| 627 | + "tensor_1": None, |
| 628 | + "int_val": None, |
| 629 | + "bool_val": None, |
| 630 | +} |
538 | 631 |
|
539 | 632 | ######################################################################
|
540 | 633 | # Custom Ops
|
|
0 commit comments