Skip to content

Commit 41c2974

Browse files
committed
Expand ForwardDiff thing
1 parent 7dcbb90 commit 41c2974

File tree

2 files changed

+20
-10
lines changed

2 files changed

+20
-10
lines changed

core-functionality/index.qmd

+2-6
Original file line numberDiff line numberDiff line change
@@ -344,13 +344,9 @@ For example, let `c` be a `Chain`:
344344

345345
#### Variable Types and Type Parameters
346346

347-
The element type of a vector (or matrix) of random variables should match the `eltype` of its prior distribution, `<: Integer` for discrete distributions and `<: AbstractFloat` for continuous distributions. Moreover, if the continuous random variable is to be sampled using a Hamiltonian sampler, the vector's element type needs to either be:
347+
The element type of a vector (or matrix) of random variables should match the `eltype` of its prior distribution, `<: Integer` for discrete distributions and `<: AbstractFloat` for continuous distributions.
348348

349-
1. `Real` to enable auto-differentiation through the model which uses special number types that are sub-types of `Real`, or
350-
351-
2. Some type parameter `T` defined in the model header using the type parameter syntax, e.g. `function gdemo(x, ::Type{T} = Float64) where {T}`.
352-
353-
Similarly, when using a particle sampler, the Julia variable used should either be:
349+
Moreover, when using a particle sampler, the Julia variable used should either be:
354350

355351
1. An `Array`, or
356352

usage/automatic-differentiation/index.qmd

+18-4
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ A common error with ForwardDiff looks like this:
175175
x[1] = a
176176
b ~ MvNormal(x, I)
177177
end
178-
sample(forwarddiff_fail(), NUTS(; adtype=AutoForwardDiff()), 100)
178+
sample(forwarddiff_fail(), NUTS(; adtype=AutoForwardDiff()), 10)
179179
```
180180

181181
The problem here is the line `x[1] = a`.
@@ -187,16 +187,30 @@ In more depth: the basic premise of ForwardDiff is that functions have to accept
187187
Here, the line `x[1] = a` is equivalent to `setindex!(x, a, 1)`, and although the method `setindex!(::Vector{Float64}, ::Real, ...)` does exist, it attempts to convert the `Real` into a `Float64`, which is where it fails.
188188
:::
189189

190-
The way around this is to pass a type as a parameter to the model:
190+
There are two ways around this.
191+
192+
Firstly, you could broaden the type of the container:
193+
194+
```{julia}
195+
@model function forwarddiff_working1()
196+
x = Real[0.0, 1.0]
197+
a ~ Normal()
198+
x[1] = a
199+
b ~ MvNormal(x, I)
200+
end
201+
sample(forwarddiff_working1(), NUTS(; adtype=AutoForwardDiff()), 10)
202+
```
203+
204+
Or, you can pass a type as a parameter to the model:
191205

192206
```{julia}
193-
@model function forwarddiff_working(::Type{T}=Float64) where T
207+
@model function forwarddiff_working2(::Type{T}=Float64) where T
194208
x = T[0.0, 1.0]
195209
a ~ Normal()
196210
x[1] = a
197211
b ~ MvNormal(x, I)
198212
end
199-
sample(forwarddiff_working(), NUTS(), 100) # runs fine
213+
sample(forwarddiff_working2(), NUTS(; adtype=AutoForwardDiff()), 10)
200214
```
201215

202216
## For AD Backend Developers

0 commit comments

Comments
 (0)