Skip to content

documentation for Enzyme Type Trees #2385

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/SUMMARY.md
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@
- [Installation](./autodiff/installation.md)
- [How to debug](./autodiff/debugging.md)
- [Autodiff flags](./autodiff/flags.md)
- [Type Trees](./autodiff/type-trees.md)
- [Current limitations](./autodiff/limitations.md)

# Source Code Representation
Expand Down
118 changes: 118 additions & 0 deletions src/autodiff/type-trees.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
# Type Trees in Enzyme

This document describes type trees as used by Enzyme for automatic differentiation.

## What are Type Trees?

Type trees in Enzyme are a way to represent the types of variables, including their activity (e.g., whether they are active, duplicated, or contain duplicated data) for automatic differentiation. They provide a structured way for Enzyme to understand how to handle different data types during the differentiation process.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are there any docs or code which shows them including activity?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

my bad i saw this and got confused they both are in the same, later when i was going through enzyme codebase i realised they are different, fixing this
image


## Representing Rust Types as Type Trees

Enzyme needs to understand the structure and properties of Rust types to perform automatic differentiation correctly. This is where type trees come in. They provide a detailed map of a type, including pointer indirections and the underlying concrete data types.

The `-enzyme-rust-type` flag in Enzyme helps in interpreting types more accurately in the context of Rust's memory layout and type system.
Copy link
Member

@ZuseZ4 ZuseZ4 May 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The flag just tells enzyme to parse (rust) debug "dwarf" information.
A lot of type information is not encoded in such debug metadata, and the flag hasn't been re-evaluated or used in years. It's good to mention it here (with the corrected description), but I wouldn't mention it in the following sections. We should generate typetrees based on Rust types even without debug metadata. But how to use debug metadata is something we'll also discuss in one of the meetings with oli.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

got this, the flag was there in most of test files, which were related to rust


### Primitive Types

#### Floating-Point Types (`f32`, `f64`)

Consider a Rust reference to a 32-bit floating-point number, `&f32`.

In LLVM IR, this might be represented, for instance, as an `i8*` (a generic byte pointer) that is then `bitcast` to a `float*`. Consider the following LLVM IR function:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, you took that from rustmutpointer.ll I guess? Unfortunately they are too outdated, I just realized.
Typed ptr were removed (see my comment about this flag being very outdated), so i8* isn't a thing anymore.
Instead, we now have ptr ("opaque pointers") in LLVM.

You can look for the PRs which introduced these tests a few years ago. They should have instructions on how to reproduce them, so you can re-generate newer tests once you have a working setup.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes,

found the PR, this will be a great reference and understanding of how things were done

EnzymeAD/Enzyme#307


```llvm
define internal void @callee(i8* %x) {
start:
%x.dbg.spill = bitcast i8* %x to float*
; ...
ret void
}
```

When Enzyme analyzes this function (with appropriate flags like `-enzyme-rust-type`), it might produce the following type information for the argument `%x` and the result of the bitcast:

```llvm
i8* %x: {[-1]:Pointer, [-1,0]:Float@float}
%x.dbg.spill = bitcast i8* %x to float*: {[-1]:Pointer, [-1,0]:Float@float}
```

**Understanding the Type Tree: `{[-1]:Pointer, [-1,0]:Float@float}`**

This string is the type tree representation. Let's break it down:

* **`{ ... }`**: This encloses the set of type information for the variable.
* **`[-1]:Pointer`**:
* `[-1]` is an index or path. In this context, `-1` often refers to the base memory location or the immediate value pointed to.
Copy link
Member

@ZuseZ4 ZuseZ4 May 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

-1 has a slightly different meaning, something along the lines of everything accessible from here, without dereferencing a pointer.
So [f64;32] could be represented as [-1]:Float@double, or as [0]:Float@double, [8]:Float@double, ...
Afaik we usually prefer -1 in such cases since it's shorter, but IIRC there were some gotchas.
@wsmoses can you share the private youtube video with him (if you prefer in a zulip dm), such that he has more information on how to write these docs?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or e.g. x: *const [f64;32] could be represented as
[[-1]:Pointer, [-1:-1]Float@double], or as [[0]:Pointer, [0:0]:Float@double, [0:8]:Float@double, ...
or as [[0]:Pointer, [0:-1]:Float@double or as [[-1]:Pointer, [-1:0]:Float@double, [-1:8]:Float@double, ...
(I think)

That also all is under the assumption that we won't access e.g. x[34] which is out of bounds of the original array.
I am not 100% sure if there are cases where this could be valid, I think with raw pointers it might be valid to access other elements. E.g. struct { x: [f64;32], y: i32 }. I you derive a raw pointer to x, you might be able to use it to access y (legally), in which case -1:Float@double would be wrong. I'd need a refresher on pointer provenance and other things, hopefully Oli will know more about it. You can add this as an open question at the end.

* `Pointer` indicates that the variable `%x` itself is treated as a pointer.
* **`[-1,0]:Float@float`**:
* `[-1,0]` is a path. It means: start with the base item `[-1]` (the pointer), and then look at offset `0` from the memory location it points to.
* `Float` is the `CConcreteType` (from `enzyme_ffi.rs`, corresponding to `DT_Float`). It signifies that the data at this location is a floating-point number.
* `@float` is a subtype or specific variant of `Float`. In this case, it specifies a single-precision float (like Rust's `f32`).

A reference to an `f64` (e.g., `&f64`) is handled very similarly. The LLVM IR might cast to `double*`:
```llvm
define internal void @callee(i8* %x) {
start:
%x.dbg.spill = bitcast i8* %x to double*
; ...
ret void
}
```

And the type tree would be:

```llvm
i8* %x: {[-1]:Pointer, [-1,0]:Float@double}
```
The key difference is `@double`, indicating a double-precision float.

This level of detail allows Enzyme to know, for example, that if `x` is an active variable in differentiation, the floating-point value it points to needs to be handled according to AD rules for its specific precision.

### Compound Types

#### Structs

Consider a Rust struct `T` with two `f32` fields (e.g., a reference `&T`):

```rust
struct T {
x: f32,
y: f32,
}

// And a function taking a reference to it:
// fn callee(t: &T) { /* ... */ }
```

In LLVM IR, a pointer to this struct might be initially represented as `i8*` and then cast to the specific struct type, like `{ float, float }*`:

```llvm
define internal void @callee(i8* %t) {
start:
%t.dbg.spill = bitcast i8* %t to { float, float }*
; ...
ret void
}
```

The Enzyme type analysis output for `%t` would be:

```llvm
i8* %t: {[-1]:Pointer, [-1,0]:Float@float, [-1,4]:Float@float}
```

**Understanding the Struct Type Tree: `{[-1]:Pointer, [-1,0]:Float@float, [-1,4]:Float@float}`**

* **`[-1]:Pointer`**: As before, this indicates that `%t` is a pointer.
* **`[-1,0]:Float@float`**:
* This describes the first field of the struct (`x`).
* `[-1,0]` means: from the memory location pointed to by `%t` (`-1`), at offset `0` bytes.
* `Float@float` indicates this field is an `f32`.
* **`[-1,4]:Float@float`**:
* This describes the second field of the struct (`y`).
* `[-1,4]` means: from the memory location pointed to by `%t` (`-1`), at offset `4` bytes.
* `Float@float` indicates this field is also an `f32`.

The offset `4` comes from the size of the first field (`f32` is 4 bytes). If the first field were, for example, an `f64` (8 bytes), the second field might be at offset `[-1,8]`. Enzyme uses these offsets to pinpoint the exact memory location of each field within the struct.

This detailed mapping is crucial for Enzyme to correctly track the activity of individual struct fields during automatic differentiation.