Skip to content

Commit

Permalink
Fix remaining method ambiguities (#2304)
Browse files Browse the repository at this point in the history
* Enabling aqua ambiguity testing for Turing


We test ambiguities only for Turing and not its dependencies.

* Format

* Fix bundle_samples method ambiguity

Concretely:

1. Creating an `AbstractTransition` type which all the Transitions in
   Turing subtype.
2. Modifying the type signature of bundle_samples to take a
   Vector{<:Union{AbstractTransition,AbstractVarInfo}} as the first
   argument. The AbstractVarInfo case occurs when sampling with Prior(),
   so the type signature of this argument mirrors that of the Sampler in
   the same function.

* Fix get() ambiguities

Done by:

1. Constraining the type parameter to AbstractVector{Symbol}
2. Modifying the method below it to use a vector instead of a tuple

* Bump to 0.34.0

---------

Co-authored-by: Abhinav Singh <[email protected]>
Co-authored-by: Xianda Sun <[email protected]>
  • Loading branch information
3 people committed Aug 30, 2024
1 parent 5b5da11 commit a26ce11
Show file tree
Hide file tree
Showing 6 changed files with 16 additions and 13 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "Turing"
uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
version = "0.33.3"
version = "0.34.0"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
8 changes: 5 additions & 3 deletions src/mcmc/Inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,9 @@ end
# Extended in contrib/inference/abstractmcmc.jl
getstats(t) = nothing

struct Transition{T, F<:AbstractFloat, S<:Union{NamedTuple, Nothing}}
abstract type AbstractTransition end

struct Transition{T, F<:AbstractFloat, S<:Union{NamedTuple, Nothing}} <: AbstractTransition
θ :: T
lp :: F # TODO: merge `lp` with `stat`
stat :: S
Expand Down Expand Up @@ -409,7 +411,7 @@ getlogevidence(transitions, sampler, state) = missing
# Default MCMCChains.Chains constructor.
# This is type piracy (at least for SampleFromPrior).
function AbstractMCMC.bundle_samples(
ts::Vector,
ts::Vector{<:Union{AbstractTransition,AbstractVarInfo}},
model::AbstractModel,
spl::Union{Sampler{<:InferenceAlgorithm},SampleFromPrior},
state,
Expand Down Expand Up @@ -472,7 +474,7 @@ end

# This is type piracy (for SampleFromPrior).
function AbstractMCMC.bundle_samples(
ts::Vector,
ts::Vector{<:Union{AbstractTransition,AbstractVarInfo}},
model::AbstractModel,
spl::Union{Sampler{<:InferenceAlgorithm},SampleFromPrior},
state,
Expand Down
4 changes: 2 additions & 2 deletions src/mcmc/particle_mcmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ end
SMC(space::Symbol...) = SMC(space)
SMC(space::Tuple) = SMC(AdvancedPS.ResampleWithESSThreshold(), space)

struct SMCTransition{T,F<:AbstractFloat}
struct SMCTransition{T,F<:AbstractFloat} <: AbstractTransition
"The parameters for any given sample."
θ::T
"The joint log probability of the sample (NOTE: does not work, always set to zero)."
Expand Down Expand Up @@ -222,7 +222,7 @@ end

const CSMC = PG # type alias of PG as Conditional SMC

struct PGTransition{T,F<:AbstractFloat}
struct PGTransition{T,F<:AbstractFloat} <: AbstractTransition
"The parameters for any given sample."
θ::T
"The joint log probability of the sample (NOTE: does not work, always set to zero)."
Expand Down
2 changes: 1 addition & 1 deletion src/mcmc/sghmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ function SGLD(
return SGLD{typeof(adtype),space,typeof(stepsize)}(stepsize, adtype)
end

struct SGLDTransition{T,F<:Real}
struct SGLDTransition{T,F<:Real} <: AbstractTransition
"The parameters for any given sample."
θ::T
"The joint log probability of the sample."
Expand Down
8 changes: 4 additions & 4 deletions src/optimisation/Optimisation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -277,13 +277,13 @@ StatsBase.loglikelihood(m::ModeResult) = m.lp

"""
Base.get(m::ModeResult, var_symbol::Symbol)
Base.get(m::ModeResult, var_symbols)
Base.get(m::ModeResult, var_symbols::AbstractVector{Symbol})
Return the values of all the variables with the symbol(s) `var_symbol` in the mode result
`m`. The return value is a `NamedTuple` with `var_symbols` as the key(s). The second
argument should be either a `Symbol` or an iterator of `Symbol`s.
argument should be either a `Symbol` or a vector of `Symbol`s.
"""
function Base.get(m::ModeResult, var_symbols)
function Base.get(m::ModeResult, var_symbols::AbstractVector{Symbol})
log_density = m.f
# Get all the variable names in the model. This is the same as the list of keys in
# m.values, but they are more convenient to filter when they are VarNames rather than
Expand All @@ -304,7 +304,7 @@ function Base.get(m::ModeResult, var_symbols)
return (; zip(var_symbols, value_vectors)...)
end

Base.get(m::ModeResult, var_symbol::Symbol) = get(m, (var_symbol,))
Base.get(m::ModeResult, var_symbol::Symbol) = get(m, [var_symbol])

"""
ModeResult(log_density::OptimLogDensity, solution::SciMLBase.OptimizationSolution)
Expand Down
5 changes: 3 additions & 2 deletions test/Aqua.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@ module AquaTests
using Aqua: Aqua
using Turing

# TODO(mhauru) We skip testing for method ambiguities because it catches a lot of problems
# in dependencies. Would like to check it for just Turing.jl itself though.
# We test ambiguities separately because it catches a lot of problems
# in dependencies but we test it for Turing.
Aqua.test_ambiguities([Turing])
Aqua.test_all(Turing; ambiguities=false)

end

2 comments on commit a26ce11

@penelopeysm
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/114210

Tip: Release Notes

Did you know you can add release notes too? Just add markdown formatted text underneath the comment after the text
"Release notes:" and it will be added to the registry PR, and if TagBot is installed it will also be added to the
release that TagBot creates. i.e.

@JuliaRegistrator register

Release notes:

## Breaking changes

- blah

To add them here just re-invoke and the PR will be updated.

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.34.0 -m "<description of version>" a26ce1198354cdb54b352f659369694b11bf489f
git push origin v0.34.0

Please sign in to comment.