Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions docs/cpp2/metafunctions.md
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,36 @@ main: () = {

### For computational and functional types

#### `autodiff`

An `autodiff` type is extended so that derivatives can be computed. The metafunction adds for each function and member function a differentiated version. **This is a proof of concept implementation. Expect it to break.**
A simple hello diff example is:
```
ad: @autodiff type = {
func: (x: double) -> (r: double) = {
r = x * x;
}
}

main: (args) = {
x := 3.0;
x_d := 1.0;

r := ad::func_d(x, x_d);

std::cout << "Derivative of 'x*x' at (x)$ is (r.r_d)$" << std::endl;
}
```

The `@autodiff` metafunction mostly supports the forward mode of algorithmic differentiation. The reverse mode is only partly implemented and not yet well tested.
See [Supported autodiff features](../notes/autodiff_status.md) for a list of supported language features.

Options can be given by text template arguments, e.g. `@autodiff<"reverse">` enables the reverse mode.
| Option | Description |
| `"reverse"` | Reverse mode algorithmic differentiation. Default suffix `_b`. |
| `"order=<n>"` | Higher order derivatives. `<n>` can be arbitrary. See `regression-tests/pure2-autodiff-higher-order.cpp2` for examples. |
| `"suffix=<s>"` | Change the forward mode suffix. Can be used to apply autodiff multiple times. E.g. `@autodiff @autodiff<"suffix=_d2">`. |
| `"rws_suffix=<s>"` | Change the reverse mode suffix. |

#### `regex`

Expand Down
34 changes: 34 additions & 0 deletions docs/notes/autodiff_status.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Supported algorithmic differentiation (autodiff) features

The listings might be incomplete. If something is missing, it is not supported. Algorithmic differentiation is applied via the [`autodiff` metafunction](../cpp2/metafunctions.md#autodiff). Maybe the planned features are added in 2026. Do not wait for them. The autodif feature is a proof of concept implementation.

** Reverse mode algorithmic differentiation is very experimental. Expect it to break. **

## Currently supported or planned features

| Description | Status forward | Status reverse |
| --- | --- | --- |
| Type definitions (structures) | <span style="color:green">Supported</span> | <span style="color:green">Supported</span> |
| Member values | <span style="color:green">Supported</span> | <span style="color:gray">Planned</span> |
| Member functions | <span style="color:green">Supported</span> | <span style="color:green">Supported</span> |
| Function arguments | <span style="color:green">Supported</span> | <span style="color:green">Supported</span> |
| Function return arguments | <span style="color:green">Supported</span> | <span style="color:green">Supported</span> |
| Addition and multiplication | <span style="color:green">Supported</span> | <span style="color:green">Supported</span> |
| Prefix addition and subtraction | <span style="color:green">Supported</span> | <span style="color:gray">Planned</span> |
| Static member function calls | <span style="color:green">Supported</span> | <span style="color:green">Supported</span> |
| Member function calls | <span style="color:green">Supported</span> | <span style="color:gray">Planned</span> |
| Function calls | <span style="color:green">Supported</span> | <span style="color:green">Supported</span> |
| Math functions (sin, cos, exp, sqrt) | <span style="color:green">Supported</span> | <span style="color:green">Supported</span> |
| If else | <span style="color:green">Supported</span> | <span style="color:gray">Planned</span> |
| Return statement | <span style="color:green">Supported</span> | <span style="color:gray">Planned</span> |
| Intermediate variables | <span style="color:green">Supported</span> | <span style="color:green">Supported</span> |
| Passive variables | <span style="color:green">Supported</span> | <span style="color:green">Supported</span> |
| While loop | <span style="color:green">Supported</span> | <span style="color:gray">Planned</span> |
| Do while loop | <span style="color:green">Supported</span> | <span style="color:gray">Planned</span> |
| For loop | <span style="color:green">Supported</span> | <span style="color:green">Supported</span> |
| Template arguments | <span style="color:gray">Planned</span> | <span style="color:gray">Planned</span> |
| Lambda functions | <span style="color:gray">Planned</span> | <span style="color:gray">Planned</span> |




Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
x + x = 4.000000
x + x diff order 1 = 2.000000
x + x diff order 2 = 0.000000
x + x diff order 3 = 0.000000
x + x diff order 4 = 0.000000
x + x diff order 5 = 0.000000
x + x diff order 6 = 0.000000
0 - x = -2.000000
0 - x diff order 1 = -1.000000
0 - x diff order 2 = 0.000000
0 - x diff order 3 = 0.000000
0 - x diff order 4 = 0.000000
0 - x diff order 5 = 0.000000
0 - x diff order 6 = 0.000000
x^7 = 128.000000
x^7 diff order 1 = 448.000000
x^7 diff order 2 = 1344.000000
x^7 diff order 3 = 3360.000000
x^7 diff order 4 = 6720.000000
x^7 diff order 5 = 10080.000000
x^7 diff order 6 = 10080.000000
1/x = 0.500000
1/x diff order 1 = -0.250000
1/x diff order 2 = 0.250000
1/x diff order 3 = -0.375000
1/x diff order 4 = 0.750000
1/x diff order 5 = -1.875000
1/x diff order 6 = 5.625000
sqrt(x) = 1.414214
sqrt(x) diff order 1 = 0.353553
sqrt(x) diff order 2 = -0.088388
sqrt(x) diff order 3 = 0.066291
sqrt(x) diff order 4 = -0.082864
sqrt(x) diff order 5 = 0.145012
sqrt(x) diff order 6 = -0.326277
log(x) = 0.693147
log(x) diff order 1 = 0.500000
log(x) diff order 2 = -0.250000
log(x) diff order 3 = 0.250000
log(x) diff order 4 = -0.375000
log(x) diff order 5 = 0.750000
log(x) diff order 6 = -1.875000
exp(x) = 7.389056
exp(x) diff order 1 = 7.389056
exp(x) diff order 2 = 7.389056
exp(x) diff order 3 = 7.389056
exp(x) diff order 4 = 7.389056
exp(x) diff order 5 = 7.389056
exp(x) diff order 6 = 7.389056
sin(x) = 0.909297
sin(x) diff order 1 = -0.416147
sin(x) diff order 2 = -0.909297
sin(x) diff order 3 = 0.416147
sin(x) diff order 4 = 0.909297
sin(x) diff order 5 = -0.416147
sin(x) diff order 6 = -0.909297
cos(x) = -0.416147
cos(x) diff order 1 = -0.909297
cos(x) diff order 2 = 0.416147
cos(x) diff order 3 = 0.909297
cos(x) diff order 4 = -0.416147
cos(x) diff order 5 = -0.909297
cos(x) diff order 6 = 0.416147
Original file line number Diff line number Diff line change
@@ -0,0 +1,240 @@
diff(x + y) at (x = 2.000000, x_d = ( 1.000000 0.000000 0.000000 0.000000 0.000000 0.000000 ), y = 3.000000, y_d = ( 2.000000 0.000000 0.000000 0.000000 0.000000 0.000000 )):
r = 5.000000
d1 = 3.000000
d2 = 0.000000
d3 = 0.000000
d4 = 0.000000
d5 = 0.000000
d6 = 0.000000
diff(x + y + x) at (x = 2.000000, x_d = ( 1.000000 0.000000 0.000000 0.000000 0.000000 0.000000 ), y = 3.000000, y_d = ( 2.000000 0.000000 0.000000 0.000000 0.000000 0.000000 )):
r = 7.000000
d1 = 4.000000
d2 = 0.000000
d3 = 0.000000
d4 = 0.000000
d5 = 0.000000
d6 = 0.000000
diff(x - y) at (x = 2.000000, x_d = ( 1.000000 0.000000 0.000000 0.000000 0.000000 0.000000 ), y = 3.000000, y_d = ( 2.000000 0.000000 0.000000 0.000000 0.000000 0.000000 )):
r = -1.000000
d1 = -1.000000
d2 = 0.000000
d3 = 0.000000
d4 = 0.000000
d5 = 0.000000
d6 = 0.000000
diff(x - y - x) at (x = 2.000000, x_d = ( 1.000000 0.000000 0.000000 0.000000 0.000000 0.000000 ), y = 3.000000, y_d = ( 2.000000 0.000000 0.000000 0.000000 0.000000 0.000000 )):
r = -3.000000
d1 = -2.000000
d2 = 0.000000
d3 = 0.000000
d4 = 0.000000
d5 = 0.000000
d6 = 0.000000
diff(x + y - x) at (x = 2.000000, x_d = ( 1.000000 0.000000 0.000000 0.000000 0.000000 0.000000 ), y = 3.000000, y_d = ( 2.000000 0.000000 0.000000 0.000000 0.000000 0.000000 )):
r = 3.000000
d1 = 2.000000
d2 = 0.000000
d3 = 0.000000
d4 = 0.000000
d5 = 0.000000
d6 = 0.000000
diff(x * y) at (x = 2.000000, x_d = ( 1.000000 0.000000 0.000000 0.000000 0.000000 0.000000 ), y = 3.000000, y_d = ( 2.000000 0.000000 0.000000 0.000000 0.000000 0.000000 )):
r = 6.000000
d1 = 7.000000
d2 = 4.000000
d3 = 0.000000
d4 = 0.000000
d5 = 0.000000
d6 = 0.000000
diff(x * y * x) at (x = 2.000000, x_d = ( 1.000000 0.000000 0.000000 0.000000 0.000000 0.000000 ), y = 3.000000, y_d = ( 2.000000 0.000000 0.000000 0.000000 0.000000 0.000000 )):
r = 12.000000
d1 = 20.000000
d2 = 22.000000
d3 = 12.000000
d4 = 0.000000
d5 = 0.000000
d6 = 0.000000
diff(x / y) at (x = 2.000000, x_d = ( 1.000000 0.000000 0.000000 0.000000 0.000000 0.000000 ), y = 3.000000, y_d = ( 2.000000 0.000000 0.000000 0.000000 0.000000 0.000000 )):
r = 0.666667
d1 = -0.111111
d2 = 0.148148
d3 = -0.296296
d4 = 0.790123
d5 = -2.633745
d6 = 10.534979
diff(x / y / y) at (x = 2.000000, x_d = ( 1.000000 0.000000 0.000000 0.000000 0.000000 0.000000 ), y = 3.000000, y_d = ( 2.000000 0.000000 0.000000 0.000000 0.000000 0.000000 )):
r = 0.222222
d1 = -0.185185
d2 = 0.296296
d3 = -0.691358
d4 = 2.106996
d5 = -7.901235
d6 = 35.116598
diff(x * y / x) at (x = 2.000000, x_d = ( 1.000000 0.000000 0.000000 0.000000 0.000000 0.000000 ), y = 3.000000, y_d = ( 2.000000 0.000000 0.000000 0.000000 0.000000 0.000000 )):
r = 3.000000
d1 = 2.000000
d2 = 0.000000
d3 = 0.000000
d4 = 0.000000
d5 = 0.000000
d6 = 0.000000
diff(x * (x + y)) at (x = 2.000000, x_d = ( 1.000000 0.000000 0.000000 0.000000 0.000000 0.000000 ), y = 3.000000, y_d = ( 2.000000 0.000000 0.000000 0.000000 0.000000 0.000000 )):
r = 10.000000
d1 = 11.000000
d2 = 6.000000
d3 = 0.000000
d4 = 0.000000
d5 = 0.000000
d6 = 0.000000
diff(x + x * y) at (x = 2.000000, x_d = ( 1.000000 0.000000 0.000000 0.000000 0.000000 0.000000 ), y = 3.000000, y_d = ( 2.000000 0.000000 0.000000 0.000000 0.000000 0.000000 )):
r = 8.000000
d1 = 8.000000
d2 = 4.000000
d3 = 0.000000
d4 = 0.000000
d5 = 0.000000
d6 = 0.000000
diff(+x + y)) at (x = 2.000000, x_d = ( 1.000000 0.000000 0.000000 0.000000 0.000000 0.000000 ), y = 3.000000, y_d = ( 2.000000 0.000000 0.000000 0.000000 0.000000 0.000000 )):
r = 5.000000
d1 = 3.000000
d2 = 0.000000
d3 = 0.000000
d4 = 0.000000
d5 = 0.000000
d6 = 0.000000
diff(-x + y)) at (x = 2.000000, x_d = ( 1.000000 0.000000 0.000000 0.000000 0.000000 0.000000 ), y = 3.000000, y_d = ( 2.000000 0.000000 0.000000 0.000000 0.000000 0.000000 )):
r = 1.000000
d1 = 1.000000
d2 = 0.000000
d3 = 0.000000
d4 = 0.000000
d5 = 0.000000
d6 = 0.000000
diff(x * func(x, y)) at (x = 2.000000, x_d = ( 1.000000 0.000000 0.000000 0.000000 0.000000 0.000000 ), y = 3.000000, y_d = ( 2.000000 0.000000 0.000000 0.000000 0.000000 0.000000 )):
r = 10.000000
d1 = 11.000000
d2 = 6.000000
d3 = 0.000000
d4 = 0.000000
d5 = 0.000000
d6 = 0.000000
diff(x * func_outer(x, y)) at (x = 2.000000, x_d = ( 1.000000 0.000000 0.000000 0.000000 0.000000 0.000000 ), y = 3.000000, y_d = ( 2.000000 0.000000 0.000000 0.000000 0.000000 0.000000 )):
r = 10.000000
d1 = 11.000000
d2 = 6.000000
d3 = 0.000000
d4 = 0.000000
d5 = 0.000000
d6 = 0.000000
diff(sin(x - y)) at (x = 2.000000, x_d = ( 1.000000 0.000000 0.000000 0.000000 0.000000 0.000000 ), y = 3.000000, y_d = ( 2.000000 0.000000 0.000000 0.000000 0.000000 0.000000 )):
r = -0.841471
d1 = -0.540302
d2 = 0.841471
d3 = 0.540302
d4 = -0.841471
d5 = -0.540302
d6 = 0.841471
diff(if branch) at (x = 2.000000, x_d = ( 1.000000 0.000000 0.000000 0.000000 0.000000 0.000000 ), y = 3.000000, y_d = ( 2.000000 0.000000 0.000000 0.000000 0.000000 0.000000 )):
r = 2.000000
d1 = 1.000000
d2 = 0.000000
d3 = 0.000000
d4 = 0.000000
d5 = 0.000000
d6 = 0.000000
diff(if else branch) at (x = 2.000000, x_d = ( 1.000000 0.000000 0.000000 0.000000 0.000000 0.000000 ), y = 3.000000, y_d = ( 2.000000 0.000000 0.000000 0.000000 0.000000 0.000000 )):
r = 2.000000
d1 = 1.000000
d2 = 0.000000
d3 = 0.000000
d4 = 0.000000
d5 = 0.000000
d6 = 0.000000
diff(direct return) at (x = 2.000000, x_d = ( 1.000000 0.000000 0.000000 0.000000 0.000000 0.000000 ), y = 3.000000, y_d = ( 2.000000 0.000000 0.000000 0.000000 0.000000 0.000000 )):
r = 5.000000
d1 = 3.000000
d2 = 0.000000
d3 = 0.000000
d4 = 0.000000
d5 = 0.000000
d6 = 0.000000
diff(intermediate var) at (x = 2.000000, x_d = ( 1.000000 0.000000 0.000000 0.000000 0.000000 0.000000 ), y = 3.000000, y_d = ( 2.000000 0.000000 0.000000 0.000000 0.000000 0.000000 )):
r = 5.000000
d1 = 3.000000
d2 = 0.000000
d3 = 0.000000
d4 = 0.000000
d5 = 0.000000
d6 = 0.000000
diff(intermediate passive var) at (x = 2.000000, x_d = ( 1.000000 0.000000 0.000000 0.000000 0.000000 0.000000 ), y = 3.000000, y_d = ( 2.000000 0.000000 0.000000 0.000000 0.000000 0.000000 )):
r = 5.000000
d1 = 3.000000
d2 = 0.000000
d3 = 0.000000
d4 = 0.000000
d5 = 0.000000
d6 = 0.000000
diff(intermediate untyped) at (x = 2.000000, x_d = ( 1.000000 0.000000 0.000000 0.000000 0.000000 0.000000 ), y = 3.000000, y_d = ( 2.000000 0.000000 0.000000 0.000000 0.000000 0.000000 )):
r = 5.000000
d1 = 3.000000
d2 = 0.000000
d3 = 0.000000
d4 = 0.000000
d5 = 0.000000
d6 = 0.000000
diff(intermediate default init) at (x = 2.000000, x_d = ( 1.000000 0.000000 0.000000 0.000000 0.000000 0.000000 ), y = 3.000000, y_d = ( 2.000000 0.000000 0.000000 0.000000 0.000000 0.000000 )):
r = 5.000000
d1 = 3.000000
d2 = 0.000000
d3 = 0.000000
d4 = 0.000000
d5 = 0.000000
d6 = 0.000000
diff(intermediate no init) at (x = 2.000000, x_d = ( 1.000000 0.000000 0.000000 0.000000 0.000000 0.000000 ), y = 3.000000, y_d = ( 2.000000 0.000000 0.000000 0.000000 0.000000 0.000000 )):
r = 5.000000
d1 = 3.000000
d2 = 0.000000
d3 = 0.000000
d4 = 0.000000
d5 = 0.000000
d6 = 0.000000
diff(while loop) at (x = 2.000000, x_d = ( 1.000000 0.000000 0.000000 0.000000 0.000000 0.000000 ), y = 3.000000, y_d = ( 2.000000 0.000000 0.000000 0.000000 0.000000 0.000000 )):
r = 8.000000
d1 = 5.000000
d2 = 0.000000
d3 = 0.000000
d4 = 0.000000
d5 = 0.000000
d6 = 0.000000
diff(do while loop) at (x = 2.000000, x_d = ( 1.000000 0.000000 0.000000 0.000000 0.000000 0.000000 ), y = 3.000000, y_d = ( 2.000000 0.000000 0.000000 0.000000 0.000000 0.000000 )):
r = 8.000000
d1 = 5.000000
d2 = 0.000000
d3 = 0.000000
d4 = 0.000000
d5 = 0.000000
d6 = 0.000000
diff(for loop) at (x = 2.000000, x_d = ( 1.000000 0.000000 0.000000 0.000000 0.000000 0.000000 ), y = 3.000000, y_d = ( 2.000000 0.000000 0.000000 0.000000 0.000000 0.000000 )):
r = 5.000000
d1 = 3.000000
d2 = 0.000000
d3 = 0.000000
d4 = 0.000000
d5 = 0.000000
d6 = 0.000000
diff(tye_outer.a + y) at (x = 2.000000, x_d = ( 1.000000 0.000000 0.000000 0.000000 0.000000 0.000000 ), y = 3.000000, y_d = ( 2.000000 0.000000 0.000000 0.000000 0.000000 0.000000 )):
r = 5.000000
d1 = 3.000000
d2 = 0.000000
d3 = 0.000000
d4 = 0.000000
d5 = 0.000000
d6 = 0.000000
diff(type_outer.add(y)) at (x = 2.000000, x_d = ( 1.000000 0.000000 0.000000 0.000000 0.000000 0.000000 ), y = 3.000000, y_d = ( 2.000000 0.000000 0.000000 0.000000 0.000000 0.000000 )):
r = 5.000000
d1 = 3.000000
d2 = 0.000000
d3 = 0.000000
d4 = 0.000000
d5 = 0.000000
d6 = 0.000000
Loading
Loading