Skip to content

Conversation

@penelopeysm
Copy link
Member

@penelopeysm penelopeysm commented Dec 10, 2025

Prior to 'fast' LDF (i.e. in DynamicPPL <= 0.38) there were two different functions that were differentiated through when performing LogDensityFunctions.logdensity_and_gradient: one was a multi-argument function f(x, c1, c2, c3, ...) where x is the active argument, and one was a callable struct F(x) where F closed over the inactive arguments c1, c2, c3. See #806 and #922 for previous history, and the old version of src/logdensityfunction.jl here.

This feature was not ported over to FastLDF in #1139, which is probably an example of laziness on my part. The benefits of doing this is mainly improved performance, but for Enzyme it also allows us to avoid setting function_annotation (see also #1048).

Here are some benchmarks: clos means use the closure, func means the original function. As one can see most of the AD backends are faster with the function rather than the closure, hence _use_closure is mostly set to false. Part of me wonders if sticking an @inline on the callable struct might remove this difference, but my industriousness does have a limit.

# Run this once with `_use_closure` set to true and once with it set to false.

using DynamicPPL, Distributions, ForwardDiff, ReverseDiff, Mooncake, Enzyme, ADTypes, DataFrames
using DynamicPPL.TestUtils.AD: run_ad, ADResult

ADTYPES = Dict(
    "ForwardDiff" => AutoForwardDiff(),
    "ReverseDiff" => AutoReverseDiff(; compile=false),
    "ReverseDiffCompiled" => AutoReverseDiff(; compile=true),
    "MooncakeFwd" => AutoMooncakeForward(),
    "MooncakeRvs" => AutoMooncake(),
)
MODELS = DynamicPPL.TestUtils.DEMO_MODELS

RESULTS = []
for model in MODELS
    for (adtype_name, adtype) in ADTYPES
        result = run_ad(model, adtype; benchmark=true)
        push!(RESULTS, (model="$(model.f)", adtype=adtype_name, time=(result.grad_time / result.primal_time)))
    end
end
RESULT_DF = DataFrame(RESULTS)

using CSV
CSV.write("results.csv", RESULT_DF)

Here time refers to the time for gradient divided by time for primal.

model                                       adtype               time(clos)   time(func)   clos/func    mean(clos/func)  std(clos/func)
demo_dot_assume_observe                     ForwardDiff          1.525467599  1.274712187  1.196715317  1.194930597      0.142726317
demo_assume_index_observe                   ForwardDiff          1.44429541   1.367799257  1.055926448                   
demo_assume_multivariate_observe            ForwardDiff          1.679705404  1.458947368  1.151313228                   
demo_dot_assume_observe_index               ForwardDiff          2.003910968  1.5681335    1.277895644                   
demo_assume_dot_observe                     ForwardDiff          3.160890644  2.269975335  1.392477969                   
demo_assume_multivariate_observe_literal    ForwardDiff          1.526315789  1.463927463  1.042617089                   
demo_dot_assume_observe_index_literal       ForwardDiff          1.490406868  1.100337767  1.354499421                   
demo_assume_dot_observe_literal             ForwardDiff          2.172935945  2.269538142  0.957435306                   
demo_assume_observe_literal                 ForwardDiff          1.739594118  1.306940931  1.331042648                   
demo_assume_submodel_observe_index_literal  ForwardDiff          1.597721311  1.290655327  1.237914785                   
demo_dot_assume_observe_submodel            ForwardDiff          1.497499745  1.131985731  1.322896308                   
demo_dot_assume_observe_matrix_index        ForwardDiff          1.375612421  1.140178153  1.206489019                   
demo_assume_matrix_observe_matrix_index     ForwardDiff          1.403399928  1.393818007  1.006874585                   
demo_dot_assume_observe                     MooncakeFwd          13.71284271  12.28016878  1.116665655  1.117348426      0.138978196
demo_assume_index_observe                   MooncakeFwd          9.806866033  10.16455696  0.964809983                   
demo_assume_multivariate_observe            MooncakeFwd          14.32729477  13.44351464  1.065740258                   
demo_dot_assume_observe_index               MooncakeFwd          12.99918633  10.98645643  1.183201009                   
demo_assume_dot_observe                     MooncakeFwd          8.9453125    6.336938553  1.411614209                   
demo_assume_multivariate_observe_literal    MooncakeFwd          14.36744828  13.78410311  1.042320139                   
demo_dot_assume_observe_index_literal       MooncakeFwd          12.39047587  11.19519899  1.106766917                   
demo_assume_dot_observe_literal             MooncakeFwd          7.90678468   8.648456658  0.914242274                   
demo_assume_observe_literal                 MooncakeFwd          8.14370993   6.063852475  1.342992753                   
demo_assume_submodel_observe_index_literal  MooncakeFwd          12.37424316  11.59245604  1.0674393                     
demo_dot_assume_observe_submodel            MooncakeFwd          13.58812441  11.94641598  1.137422674                   
demo_dot_assume_observe_matrix_index        MooncakeFwd          14.33099908  12.31012     1.164164044                   
demo_assume_matrix_observe_matrix_index     MooncakeFwd          14.38441882  14.26812892  1.008150326                   
demo_dot_assume_observe                     MooncakeRvs          4.975314511  3.923043303  1.268228293  1.161985086      0.1449571
demo_assume_index_observe                   MooncakeRvs          3.276771572  3.252789532  1.007372761                   
demo_assume_multivariate_observe            MooncakeRvs          5.260503186  4.572033898  1.150582717                   
demo_dot_assume_observe_index               MooncakeRvs          4.566469869  3.630591631  1.257775683                   
demo_assume_dot_observe                     MooncakeRvs          5.616897585  4.911655364  1.143585445                   
demo_assume_multivariate_observe_literal    MooncakeRvs          4.942171437  4.519455416  1.093532513                   
demo_dot_assume_observe_index_literal       MooncakeRvs          4.740758124  3.331779452  1.422890738                   
demo_assume_dot_observe_literal             MooncakeRvs          4.985933104  5.047140831  0.987872792                   
demo_assume_observe_literal                 MooncakeRvs          4.938183576  4.422866152  1.1165121                     
demo_assume_submodel_observe_index_literal  MooncakeRvs          4.572583408  3.873515664  1.18047371                    
demo_dot_assume_observe_submodel            MooncakeRvs          4.389248431  3.8731542    1.133249079                   
demo_dot_assume_observe_matrix_index        MooncakeRvs          5.60766652   4.015870574  1.396376307                   
demo_assume_matrix_observe_matrix_index     MooncakeRvs          4.727606079  4.990326923  0.947353981                   
demo_dot_assume_observe                     ReverseDiff          67.84729114  52.74588235  1.28630498   1.178560687      0.130455839
demo_assume_index_observe                   ReverseDiff          58.17253605  54.37883925  1.069764211                   
demo_assume_multivariate_observe            ReverseDiff          66.15684785  59.30163603  1.115599034                   
demo_dot_assume_observe_index               ReverseDiff          71.56960413  63.23089475  1.131877137                   
demo_assume_dot_observe                     ReverseDiff          157.984494   110.0992542  1.434927922                   
demo_assume_multivariate_observe_literal    ReverseDiff          55.93370698  52.77155172  1.059921589                   
demo_dot_assume_observe_index_literal       ReverseDiff          72.94905858  55.06082648  1.324881286                   
demo_assume_dot_observe_literal             ReverseDiff          103.8512397  96.80268542  1.072813623                   
demo_assume_observe_literal                 ReverseDiff          125.2098441  104.3174019  1.200277632                   
demo_assume_submodel_observe_index_literal  ReverseDiff          67.71928065  57.96516872  1.1682754                     
demo_dot_assume_observe_submodel            ReverseDiff          61.50066487  52.23269152  1.177436258                   
demo_dot_assume_observe_matrix_index        ReverseDiff          60.39138037  46.0713928   1.310821677                   
demo_assume_matrix_observe_matrix_index     ReverseDiff          45.24193548  46.71880176  0.968388182                   
demo_dot_assume_observe                     ReverseDiffCompiled  20.90625     51.28954839  0.407612285  0.367591864      0.052660854
demo_assume_index_observe                   ReverseDiffCompiled  19.2457378   53.7984      0.357738107                   
demo_assume_multivariate_observe            ReverseDiffCompiled  18.49863146  58.35986883  0.316975206                   
demo_dot_assume_observe_index               ReverseDiffCompiled  22.21232631  59.22746781  0.375034205                   
demo_assume_dot_observe                     ReverseDiffCompiled  48.54647089  105.2648943  0.461183866                   
demo_assume_multivariate_observe_literal    ReverseDiffCompiled  15.95453127  53.94781274  0.295740095                   
demo_dot_assume_observe_index_literal       ReverseDiffCompiled  22.12678788  56.35138968  0.39265736                    
demo_assume_dot_observe_literal             ReverseDiffCompiled  36.52301255  96.41645676  0.378804758                   
demo_assume_observe_literal                 ReverseDiffCompiled  42.81327191  95.05131878  0.450422703                   
demo_assume_submodel_observe_index_literal  ReverseDiffCompiled  20.31091181  59.84733874  0.339378696                   
demo_dot_assume_observe_submodel            ReverseDiffCompiled  17.81327801  53.81377064  0.331017095                   
demo_dot_assume_observe_matrix_index        ReverseDiffCompiled  17.93991416  47.62331839  0.376704412                   
demo_assume_matrix_observe_matrix_index     ReverseDiffCompiled  13.98849123  47.35032558  0.29542545                    

@github-actions
Copy link
Contributor

github-actions bot commented Dec 10, 2025

Benchmark Report

  • this PR's head: 94e62b87abf3f3ebb20033cc2b3a77315fc4b3d3
  • base branch: e3fbd2d856d1f2cbc39e7ddc11ee21f68737776f

Computer Information

Julia Version 1.11.7
Commit f2b3dbda30a (2025-09-08 12:10 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 4 × AMD EPYC 7763 64-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver3)
Threads: 1 default, 0 interactive, 1 GC (on 4 virtual cores)

Benchmark Results

┌───────────────────────┬───────┬─────────────┬───────────────────┬────────┬────────────────────────────────┬────────────────────────────┬─────────────────────────────────┐
│                       │       │             │                   │        │        t(eval) / t(ref)        │     t(grad) / t(eval)      │        t(grad) / t(ref)         │
│                       │       │             │                   │        │ ─────────┬───────────┬──────── │ ───────┬─────────┬──────── │ ──────────┬───────────┬──────── │
│                 Model │   Dim │  AD Backend │           VarInfo │ Linked │     base │   this PR │ speedup │   base │ this PR │ speedup │      base │   this PR │ speedup │
├───────────────────────┼───────┼─────────────┼───────────────────┼────────┼──────────┼───────────┼─────────┼────────┼─────────┼─────────┼───────────┼───────────┼─────────┤
│               Dynamic │    10 │    mooncake │             typed │   true │   333.44 │    389.69 │    0.86 │  11.25 │    9.19 │    1.22 │   3751.03 │   3582.89 │    1.05 │
│                   LDA │    12 │ reversediff │             typed │   true │  2404.63 │   2602.07 │    0.92 │   4.96 │    5.06 │    0.98 │  11929.01 │  13155.88 │    0.91 │
│   Loop univariate 10k │ 10000 │    mooncake │             typed │   true │ 94196.88 │ 105415.60 │    0.89 │   3.85 │    3.96 │    0.97 │ 363084.13 │ 417270.29 │    0.87 │
├───────────────────────┼───────┼─────────────┼───────────────────┼────────┼──────────┼───────────┼─────────┼────────┼─────────┼─────────┼───────────┼───────────┼─────────┤
│    Loop univariate 1k │  1000 │    mooncake │             typed │   true │  7193.26 │   7909.19 │    0.91 │   4.57 │    4.83 │    0.95 │  32899.32 │  38187.41 │    0.86 │
│      Multivariate 10k │ 10000 │    mooncake │             typed │   true │ 28073.27 │  32195.10 │    0.87 │  11.11 │   10.29 │    1.08 │ 311961.11 │ 331259.84 │    0.94 │
│       Multivariate 1k │  1000 │    mooncake │             typed │   true │  3215.47 │   3668.46 │    0.88 │  12.80 │    9.18 │    1.40 │  41167.81 │  33661.58 │    1.22 │
├───────────────────────┼───────┼─────────────┼───────────────────┼────────┼──────────┼───────────┼─────────┼────────┼─────────┼─────────┼───────────┼───────────┼─────────┤
│ Simple assume observe │     1 │ forwarddiff │             typed │  false │     2.48 │      2.69 │    0.92 │   3.79 │    3.98 │    0.95 │      9.40 │     10.71 │    0.88 │
│           Smorgasbord │   201 │ forwarddiff │             typed │  false │  1067.20 │   1225.98 │    0.87 │  64.03 │  126.03 │    0.51 │  68332.23 │ 154506.93 │    0.44 │
│           Smorgasbord │   201 │ forwarddiff │       simple_dict │   true │      err │       err │     err │    err │     err │     err │       err │       err │     err │
├───────────────────────┼───────┼─────────────┼───────────────────┼────────┼──────────┼───────────┼─────────┼────────┼─────────┼─────────┼───────────┼───────────┼─────────┤
│           Smorgasbord │   201 │ forwarddiff │ simple_namedtuple │   true │      err │       err │     err │    err │     err │     err │       err │       err │     err │
│           Smorgasbord │   201 │      enzyme │             typed │   true │  1461.07 │   1664.75 │    0.88 │   5.82 │    6.11 │    0.95 │   8499.90 │  10165.74 │    0.84 │
│           Smorgasbord │   201 │    mooncake │             typed │   true │  1465.24 │   1671.79 │    0.88 │   5.44 │    5.44 │    1.00 │   7972.11 │   9100.85 │    0.88 │
├───────────────────────┼───────┼─────────────┼───────────────────┼────────┼──────────┼───────────┼─────────┼────────┼─────────┼─────────┼───────────┼───────────┼─────────┤
│           Smorgasbord │   201 │ reversediff │             typed │   true │  1458.92 │   1679.78 │    0.87 │  91.66 │   89.51 │    1.02 │ 133728.43 │ 150358.42 │    0.89 │
│           Smorgasbord │   201 │ forwarddiff │      typed_vector │   true │  1476.34 │   1671.08 │    0.88 │  58.46 │   56.46 │    1.04 │  86303.55 │  94356.82 │    0.91 │
│           Smorgasbord │   201 │ forwarddiff │           untyped │   true │  1471.48 │   1711.49 │    0.86 │ 119.95 │   55.75 │    2.15 │ 176499.74 │  95413.17 │    1.85 │
├───────────────────────┼───────┼─────────────┼───────────────────┼────────┼──────────┼───────────┼─────────┼────────┼─────────┼─────────┼───────────┼───────────┼─────────┤
│           Smorgasbord │   201 │ forwarddiff │    untyped_vector │   true │  1470.10 │   1672.58 │    0.88 │  57.21 │   56.07 │    1.02 │  84099.73 │  93779.42 │    0.90 │
│              Submodel │     1 │    mooncake │             typed │   true │     6.33 │      7.17 │    0.88 │   5.49 │    5.24 │    1.05 │     34.75 │     37.55 │    0.93 │
└───────────────────────┴───────┴─────────────┴───────────────────┴────────┴──────────┴───────────┴─────────┴────────┴─────────┴─────────┴───────────┴───────────┴─────────┘

@codecov
Copy link

codecov bot commented Dec 10, 2025

Codecov Report

❌ Patch coverage is 85.00000% with 3 lines in your changes missing coverage. Please review.
✅ Project coverage is 80.00%. Comparing base (e3fbd2d) to head (94e62b8).

Files with missing lines Patch % Lines
src/logdensityfunction.jl 85.00% 3 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1172      +/-   ##
==========================================
- Coverage   80.01%   80.00%   -0.02%     
==========================================
  Files          41       41              
  Lines        3877     3890      +13     
==========================================
+ Hits         3102     3112      +10     
- Misses        775      778       +3     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@penelopeysm penelopeysm requested a review from mhauru December 10, 2025 18:46
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants