Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Implement a denormalize custom Jaxpr operator simplifying MCX logpdfs #71

Open
wants to merge 46 commits into
base: master
Choose a base branch
from
Open
Changes from 1 commit
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
b4d63bf
Implement a `denormalize` custom Jaxpr operator simplifying MCX logpd…
balancap Jan 26, 2021
6398e8f
wip
balancap Jan 31, 2021
04cc11f
wip
balancap Feb 6, 2021
ffb7329
wip
balancap Feb 6, 2021
c2c53a4
wip
balancap Feb 6, 2021
5d8b681
wip
balancap Feb 6, 2021
098ca4f
wip
balancap Feb 6, 2021
2f125da
wip
balancap Feb 6, 2021
5369cc7
wip
balancap Feb 6, 2021
3652b60
wip
balancap Feb 7, 2021
cf60bf3
wip
balancap Feb 7, 2021
ba4b560
wip
balancap Feb 7, 2021
3d80f5c
wip
balancap Feb 7, 2021
f39bce6
wip
balancap Feb 7, 2021
29435e7
wip
balancap Feb 7, 2021
9b9f5fb
wip
balancap Feb 8, 2021
89d76d2
wip
balancap Feb 8, 2021
da25e8e
wip
balancap Feb 8, 2021
8113a51
wip
balancap Feb 9, 2021
f1c310e
wip
balancap Feb 9, 2021
d6e452b
wip
balancap Feb 9, 2021
51c7ed5
wip
balancap Feb 10, 2021
c370c66
wip
balancap Feb 11, 2021
6c635e2
wip
balancap Feb 13, 2021
9396129
wip
balancap Feb 13, 2021
c3096c7
wip
balancap Feb 13, 2021
0286705
wip
balancap Feb 13, 2021
7f895d5
wip
balancap Feb 13, 2021
542177d
wip
balancap Feb 13, 2021
1c784e9
wip
balancap Feb 13, 2021
85cec94
wip
balancap Feb 14, 2021
2853d37
wip
balancap Feb 14, 2021
ec44124
wip
balancap Feb 14, 2021
15a5ff9
wip
balancap Feb 14, 2021
facafc2
wip
balancap Feb 14, 2021
9717abe
wip
balancap Feb 14, 2021
756781e
wip
balancap Feb 14, 2021
05c1be4
wip
balancap Feb 14, 2021
1504f11
wip
balancap Feb 15, 2021
9ee2b77
wip
balancap Feb 15, 2021
e1713ae
wip
balancap Feb 15, 2021
825c79a
wip
balancap Feb 17, 2021
be7f8e0
wip
balancap Feb 17, 2021
c594ff6
wip
balancap Feb 19, 2021
66133f5
wip
balancap Feb 19, 2021
d8d1180
wip
balancap Feb 19, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
wip
balancap committed Feb 6, 2021
commit 04cc11f36fc93cf02bd02c0fed222e4e9e37111e
59 changes: 58 additions & 1 deletion mcx/core/jaxpr_ops.py
Original file line number Diff line number Diff line change
@@ -5,9 +5,66 @@
from jax.util import safe_map

from functools import wraps
from typing import List, Dict, Tuple, Any
from typing import List, Dict, Tuple, Any, Type, TypeVar, Callable

Array = Any
TState = TypeVar("TState")

jaxpr_high_order_primitives_to_subjaxprs = {
jax.lax.cond_p: lambda jxpr: (None,),
jax.lax.while_p: None,
jax.lax.scan_p: None,
jax.core.CallPrimitive: None, # xla_call, from jax.jit
jax.core.MapPrimitive: None,
}
"""Collection of high-order Jax primitives, with sub-Jaxprs.
"""


def jaxpr_visitor(
jaxpr: jax.core.Jaxpr,
initial_state: TState,
visitor_fn: Callable[[jax.core.JaxprEqn, TState, Any], TState],
init_sub_state_fn: Callable[[jax.core.JaxprEqn, TState], List[TState]],
reverse: bool = False,
) -> Tuple[TState, List[Any]]:
"""Visitor pattern on a Jaxpr, traversing equations and supporting higher-order primitives
with sub-Jaxprs.

Parameters
----------
initial_state: Initial state to feed to the visitor method.
visitor_fn: Visitor function, taking an input state and Jaxpr, outputting an updated state.
init_sub_state_fn: Initializing method for higher-order primitives sub-Jaxprs. Taking as input
the existing state, and outputting input states to respective sub-Jaxprs.
reverse: Traverse the Jaxpr equations in reverse order.
Returns
-------
Output state of the last iteration.
"""
state = initial_state
subjaxprs_visit = []

equations = jax.eqns if not reverse else jax.eqns[::-1]
for eqn in equations:
if eqn.primitive in jaxpr_high_order_primitives_to_subjaxprs:
sub_jaxprs = jaxpr_high_order_primitives_to_subjaxprs[eqn.primitive]
sub_states = init_sub_state_fn(eqn, state)
# Map visitor method to each sub-jaxpr.
res_sub_states = [
jaxpr_visitor(
sub_jaxpr, sub_state, visitor_fn, init_sub_state_fn, reverse
)
for sub_jaxpr, sub_state in zip(sub_jaxprs, sub_states)
]
# Reduce, to update the current state.
sate = visitor_fn(eqn, state, res_sub_states)
subjaxprs_visit.append(res_sub_states)
else:
# Common Jaxpr equation: apply the visitor and update state.
state = visitor_fn(eqn, state, None)
subjaxprs_visit.append(None)
return state, subjaxprs_visit


def jax_lax_identity(x: Array) -> Array: