Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

InferenceObjects integration #464

Open
sethaxen opened this issue Feb 17, 2023 · 12 comments · May be fixed by #465
Open

InferenceObjects integration #464

sethaxen opened this issue Feb 17, 2023 · 12 comments · May be fixed by #465

Comments

@sethaxen
Copy link
Member

The plan is to allow AbstractMCMC.sample to return InferenceObjects.InferenceData as a chain_type and to move toward that being a default return type in Turing. There's a mostly functional proof of concept of this integration at https://github.com/sethaxen/DynamicPPLInferenceObjects.jl. @yebai suggested moving this code into DynamicPPL directly and adding InferenceObjects as a dependency, which would increase DynamicPPL load time by 20%. I've opened this issue to discuss whether we want to take this approach or a different one for this integration.

From DynamicPPLInferenceObjects, it seems the integration may be entirely implementable just by overloading methods from DynamicPPL and AbstractMCMC, so alternatively, on Julia v1.9 it could be implemented as an extension, and on early versions it could be loaded with Requires.

Related issues:

@devmotion
Copy link
Member

early versions it could be loaded with Requires.

Could we just make it a proper dependency on older Julia versions? It's still possible to make it a weak dependency on Julia 1.9 at the same time.

@devmotion
Copy link
Member

It seems you also implement chainstack and bundle samples? That should rather be an extension of AbstractMCMC, I assume?

@sethaxen
Copy link
Member Author

Could we just make it a proper dependency on older Julia versions? It's still possible to make it a weak dependency on Julia 1.9 at the same time.

Ah, really? Is that just making it a fully dependency on 1.9 but just only loading it in an extension? Or something more fancy?

It seems you also implement chainstack and bundle samples? That should rather be an extension of AbstractMCMC, I assume?

For chainstack, yes, but bundle_samples relies on some DynamicPPL functionality. I suppose we can restrict the model type to DynamicPPL.Model to avoid type piracy.

For Chains the corresponding methods live in Turing proper. That might be cleaner. Currently it relies on utility functions get_params and get_sample_stats, which would need to be added to the DynamicPPL API so that Turing could overload them for its samplers.

@devmotion
Copy link
Member

It's a weak dependency on Julia >= 1.9 if you declare it both as a strong dependency and a weak dependency. See https://pkgdocs.julialang.org/dev/creating-packages/#Transition-from-normal-dependency-to-extension.

@devmotion
Copy link
Member

For Chains the corresponding methods live in Turing proper. That might be cleaner.

There have been long discussions (and even issues IIRC, maybe even in DynamicPPL?) about how messy the current situation is - e.g., in many places in DynamicPPL we rely on functionality that is only implemented in MCMCChains but avoid having it as a dependency and instead allow AbstractChains. Similar, I think not all code in Turing should actually be there.

@sethaxen
Copy link
Member Author

For Chains the corresponding methods live in Turing proper. That might be cleaner.

There have been long discussions (and even issues IIRC, maybe even in DynamicPPL?) about how messy the current situation is - e.g., in many places in DynamicPPL we rely on functionality that is only implemented in MCMCChains but avoid having it as a dependency and instead allow AbstractChains. Similar, I think not all code in Turing should actually be there.

How would you recommend proceeding then for this integration?

@torfjelde
Copy link
Member

Is there a particular reason why we don't just add it to Turing for now? I agree a week dep might make sense, but it's a bit annoying to make it an explicit dependency pre-1.9, no? In Turing.jl I'm guessing the increased compilation time will be minor in comparison.

@sethaxen
Copy link
Member Author

Is there a particular reason why we don't just add it to Turing for now?

If the code lived in Turing it would entirely be type piracy. Other than that, I don't see a good reason.

@yebai
Copy link
Member

yebai commented Feb 17, 2023

A 20% increase in load time is probably not a big deal, I think.

@devmotion
Copy link
Member

devmotion commented Feb 17, 2023

How would you recommend proceeding then for this integration?

I guess short-term basically anything goes - it can just take a long time to move away and improve a suboptimal but somewhat working setup, in my experience.

In the longer term, I think a better solution would be

  • to implement chainstack etc. in the InferenceObjects package itself by depending on AbstractMCMC or, see below, a an even more lightweight chains package (similar to sampler packages)
  • to generalize the methods in DynamicPPL such as pointwise_loglikelihood, loglikelihood etc. in such a way that they can work with arbitrary AbstractChains as input, similar to how arrays with dimensions are supported by MCMCDiagnosticTools
  • to make Turing also AbstractChains/chain_type agnostic

The last two points probably require some additions to the AbstractChains interface (well, there isn't one yet), in AbstractMCMC or some other, even more lightweight, package.
For instance, I have thought for a while that something like eachsample, eachchain, etc. (similar to eachslice) could be useful and be used instead of the explicit 1:size(chain, 1) etc. in the current code in DynamicPPL.

@torfjelde
Copy link
Member

If the code lived in Turing it would entirely be type piracy. Other than that, I don't see a good reason.

That is basically the entirety of Turing.jl though haha 😅

I guess short-term basically anything goes - it can just take a long time to move away and improve a suboptimal but somewhat working setup, in my experience.

But this is a fair point 😕 We have quite a lot of examples of that.

@sethaxen
Copy link
Member Author

sethaxen commented Feb 17, 2023

I guess short-term basically anything goes - it can just take a long time to move away and improve a suboptimal but somewhat working setup, in my experience.

Okay, I think then the approach I will take in the short term is:

  1. adapt all the code from DynamicPPLInferenceObjects except for the code in bundle_samples.jl to be an extension module DynamicPPLInferenceObjectsExt.
  2. Add the code in bundle_samples.jl to Turing

In the longer term, I think a better solution would be

These all sound good, but at the moment I lack the bandwidth to tackle them.

For instance, I have thought for a while that something like eachsample, eachchain, etc. (similar to eachslice) could be useful and be used instead of the explicit 1:size(chain, 1) etc. in the current code in DynamicPPL.

Agreed! Actually, this would be automatically supported in InferenceObjects once DimensionalData adds eachslice support for AbstractDimStack (see rafaqz/DimensionalData.jl#418), but in the meantime this works:

julia> using InferenceObjects, DimensionalData

julia> function _eachslice(ds::InferenceObjects.Dataset; dims)
           concrete_dims = DimensionalData.dims(ds, dims)
           return (view(ds, d...) for d in DimensionalData.DimIndices(concrete_dims))
       end;

julia> eachchain(ds::InferenceObjects.Dataset) = _eachslice(ds; dims=DimensionalData.Dim{:chain});

julia> function eachsample(ds::InferenceObjects.Dataset)
           sample_dims = (DimensionalData.Dim{:chain}, DimensionalData.Dim{:draw})
           return _eachslice(ds; dims=sample_dims)
       end;

julia> using ArviZExampleData;

julia> ds = load_example_data("centered_eight").posterior
Dataset with dimensions: 
  Dim{:draw} Sampled{Int64} Int64[0, 1, , 498, 499] ForwardOrdered Irregular Points,
  Dim{:chain} Sampled{Int64} Int64[0, 1, 2, 3] ForwardOrdered Irregular Points,
  Dim{:school} Categorical{String} String[Choate, Deerfield, , St. Paul's, Mt. Hermon] Unordered
and 3 layers:
  :mu    Float64 dims: Dim{:draw}, Dim{:chain} (500×4)
  :theta Float64 dims: Dim{:school}, Dim{:draw}, Dim{:chain} (8×500×4)
  :tau   Float64 dims: Dim{:draw}, Dim{:chain} (500×4)

with metadata Dict{String, Any} with 6 entries:
  "created_at"                => "2022-10-13T14:37:37.315398"
  "inference_library_version" => "4.2.2"
  "sampling_time"             => 7.48011
  "tuning_steps"              => 1000
  "arviz_version"             => "0.13.0.dev0"
  "inference_library"         => "pymc"

julia> collect(eachchain(ds))[1]
Dataset with dimensions: 
  Dim{:draw} Sampled{Int64} Int64[0, 1, , 498, 499] ForwardOrdered Irregular Points,
  Dim{:school} Categorical{String} String[Choate, Deerfield, , St. Paul's, Mt. Hermon] Unordered
and 3 layers:
  :mu    Float64 dims: Dim{:draw} (500)
  :theta Float64 dims: Dim{:school}, Dim{:draw} (8×500)
  :tau   Float64 dims: Dim{:draw} (500)

with metadata Dict{String, Any} with 6 entries:
  "created_at"                => "2022-10-13T14:37:37.315398"
  "inference_library_version" => "4.2.2"
  "sampling_time"             => 7.48011
  "tuning_steps"              => 1000
  "arviz_version"             => "0.13.0.dev0"
  "inference_library"         => "pymc"

julia> collect(eachsample(ds))[1]
Dataset with dimensions: 
  Dim{:school} Categorical{String} String[Choate, Deerfield, , St. Paul's, Mt. Hermon] Unordered
and 3 layers:
  :mu    Float64 dims: 
  :theta Float64 dims: Dim{:school} (8)
  :tau   Float64 dims: 

with metadata Dict{String, Any} with 6 entries:
  "created_at"                => "2022-10-13T14:37:37.315398"
  "inference_library_version" => "4.2.2"
  "sampling_time"             => 7.48011
  "tuning_steps"              => 1000
  "arviz_version"             => "0.13.0.dev0"
  "inference_library"         => "pymc"

Edit: adapted for rafaqz/DimensionalData.jl#462

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants