diff --git a/src/norecompile.jl b/src/norecompile.jl index 8446c8bcc..d9288547d 100644 --- a/src/norecompile.jl +++ b/src/norecompile.jl @@ -57,14 +57,14 @@ function wrapfun_iip(ff, dualT = dualgen(T) dualT1 = ArrayInterface.promote_eltype(T1, dualT) dualT2 = ArrayInterface.promote_eltype(T2, dualT) - dualT4 = dualgen(promote_type(T, T4)) + dualT4 = promote_dual(dualgen(T4), dualT) - iip_arglists = (Tuple{T1, T2, T3, T4}, - Tuple{dualT1, dualT2, T3, T4}, - Tuple{dualT1, T2, T3, dualT4}, - Tuple{dualT1, dualT2, T3, dualT4}) + iip_arglists = (Tuple{T1, T2, T3, T4}, # primal + Tuple{dualT1, dualT2, T3, T4}, # vjp + Tuple{dualT1, T2, T3, dualT4}, # tgrad + ) - iip_returnlists = ntuple(x -> Nothing, 4) + iip_returnlists = ntuple(x -> Nothing, length(iip_arglists)) fwt = map(iip_arglists, iip_returnlists) do A, R FunctionWrappersWrappers.FunctionWrappers.FunctionWrapper{R, A}(Void(ff))