Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Implement a denormalize custom Jaxpr operator simplifying MCX logpdfs #71

Open
wants to merge 46 commits into
base: master
Choose a base branch
from
Open
Changes from 1 commit
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
b4d63bf
Implement a `denormalize` custom Jaxpr operator simplifying MCX logpd…
balancap Jan 26, 2021
6398e8f
wip
balancap Jan 31, 2021
04cc11f
wip
balancap Feb 6, 2021
ffb7329
wip
balancap Feb 6, 2021
c2c53a4
wip
balancap Feb 6, 2021
5d8b681
wip
balancap Feb 6, 2021
098ca4f
wip
balancap Feb 6, 2021
2f125da
wip
balancap Feb 6, 2021
5369cc7
wip
balancap Feb 6, 2021
3652b60
wip
balancap Feb 7, 2021
cf60bf3
wip
balancap Feb 7, 2021
ba4b560
wip
balancap Feb 7, 2021
3d80f5c
wip
balancap Feb 7, 2021
f39bce6
wip
balancap Feb 7, 2021
29435e7
wip
balancap Feb 7, 2021
9b9f5fb
wip
balancap Feb 8, 2021
89d76d2
wip
balancap Feb 8, 2021
da25e8e
wip
balancap Feb 8, 2021
8113a51
wip
balancap Feb 9, 2021
f1c310e
wip
balancap Feb 9, 2021
d6e452b
wip
balancap Feb 9, 2021
51c7ed5
wip
balancap Feb 10, 2021
c370c66
wip
balancap Feb 11, 2021
6c635e2
wip
balancap Feb 13, 2021
9396129
wip
balancap Feb 13, 2021
c3096c7
wip
balancap Feb 13, 2021
0286705
wip
balancap Feb 13, 2021
7f895d5
wip
balancap Feb 13, 2021
542177d
wip
balancap Feb 13, 2021
1c784e9
wip
balancap Feb 13, 2021
85cec94
wip
balancap Feb 14, 2021
2853d37
wip
balancap Feb 14, 2021
ec44124
wip
balancap Feb 14, 2021
15a5ff9
wip
balancap Feb 14, 2021
facafc2
wip
balancap Feb 14, 2021
9717abe
wip
balancap Feb 14, 2021
756781e
wip
balancap Feb 14, 2021
05c1be4
wip
balancap Feb 14, 2021
1504f11
wip
balancap Feb 15, 2021
9ee2b77
wip
balancap Feb 15, 2021
e1713ae
wip
balancap Feb 15, 2021
825c79a
wip
balancap Feb 17, 2021
be7f8e0
wip
balancap Feb 17, 2021
c594ff6
wip
balancap Feb 19, 2021
66133f5
wip
balancap Feb 19, 2021
d8d1180
wip
balancap Feb 19, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
wip
balancap committed Feb 9, 2021
commit 8113a51a49aac50ea185fafb029afb5db5f5459b
62 changes: 18 additions & 44 deletions mcx/core/jaxpr_ops.py
Original file line number Diff line number Diff line change
@@ -147,6 +147,18 @@ class ConstVarStatus(enum.Enum):
"""


def get_variable_const_status(v: Any, state: ConstVarState) -> ConstVarStatus:
"""Get the constant status on a variable (or literal)."""
if type(v) is jax.core.Literal:
# Non-finite if all entries are non-finite.
return (
ConstVarStatus.Unknown
if np.any(np.isfinite(v.val))
else ConstVarStatus.NonFinite
)
return state.get(v, None)


def jaxpr_find_constvars_visitor_fn(
eqn: jax.core.JaxprEqn,
state: ConstVarState,
@@ -166,27 +178,18 @@ def jaxpr_find_constvars_visitor_fn(
Updated constant variables collection with outputs of the Jaxpr equation.
"""

def get_var_status(v) -> ConstVarStatus:
if type(v) is jax.core.Literal:
# Non-finite if all entries are non-finite.
return (
ConstVarStatus.Unknown
if np.any(np.isfinite(v.val))
else ConstVarStatus.NonFinite
)
return state.get(v, ConstVarStatus.Unknown)

# Common ops logic: are inputs literal or const variables?
# NOTE: Jax literal are not hashable!
is_const_invars = [type(v) is jax.core.Literal or v in state for v in eqn.invars]
status_invars = [get_var_status(v) for v in eqn.invars]
status_invars = [get_variable_const_status(v, state) for v in eqn.invars]
if all(is_const_invars):
# Using a form of heuristic here: outputs are non-finite if one the input is. Should
# refine this logic per op.
any_non_finite_invar = any(
[s == ConstVarStatus.NonFinite for s in status_invars]
)
outvar_status = (
ConstVarStatus.NonFinite
if any([s == ConstVarStatus.NonFinite for s in status_invars])
else ConstVarStatus.Unknown
ConstVarStatus.NonFinite if any_non_finite_invar else ConstVarStatus.Unknown
)
state.update({v: outvar_status for v in eqn.outvars})
return state
@@ -222,11 +225,7 @@ def jaxpr_find_constvars_map_sub_states_fn(
sub_init_state[sub_invar] = state[eqn_invar]
elif type(sub_invar) is jax.core.Literal:
# Literal argument: check the value fo the status.
sub_init_state[sub_invar] = (
ConstVarStatus.Unknown
if np.any(np.isfinite(sub_invar.val))
else ConstVarStatus.NonFinite
)
sub_init_state[sub_invar] = get_variable_const_status(sub_invar, None)
return [sub_init_state]
else:
# TODO: support other high primitives. No constants passed at the moment.
@@ -295,31 +294,6 @@ def jaxpr_find_constvars(
return const_state


def jaxpr_find_constvars_old(
jaxpr: jax.core.Jaxpr, consts: List[jax.core.Var]
) -> List[jax.core.Var]:
"""Find all intermediates variables in a JAX expression which are expected to be constants.

Parameters
----------
jaxpr: JAX expression.
consts: List of known constant variables in the JAX expression.

Returns
-------
List of all intermediate constant variables.
"""
constvars_dict = {str(v): v for v in consts}
for eqn in jaxpr.eqns:
# Are inputs literal or const variables?
is_const_invars = [
str(v) in constvars_dict or type(v) is jax.core.Literal for v in eqn.invars
]
if all(is_const_invars):
constvars_dict.update({str(v): v for v in eqn.outvars})
return list(constvars_dict.values())


def jaxpr_find_denormalize_mapping(
jaxpr: jax.core.Jaxpr, consts: List[jax.core.Var]
) -> Dict[jax.core.Var, Tuple[jax.core.Primitive, jax.core.Var]]:
2 changes: 1 addition & 1 deletion tests/core/jaxpr_ops_test.py
Original file line number Diff line number Diff line change
@@ -23,7 +23,7 @@
},
# Simple inf constant propagation.
{"fn": lambda x: x + np.ones((2,)) + np.inf, "status": ConstVarStatus.NonFinite},
# Handle properly jax.jit sub-jaxpr.
# Handle properly jax.jit sub-jaxpr + inf constant.
{
"fn": lambda x: jax.jit(lambda y: y + np.full((2,), np.inf))(x) + np.exp(2.0),
"status": ConstVarStatus.NonFinite,