Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Implement a denormalize custom Jaxpr operator simplifying MCX logpdfs #71

Open
wants to merge 46 commits into
base: master
Choose a base branch
from
Open
Changes from 1 commit
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
b4d63bf
Implement a `denormalize` custom Jaxpr operator simplifying MCX logpd…
balancap Jan 26, 2021
6398e8f
wip
balancap Jan 31, 2021
04cc11f
wip
balancap Feb 6, 2021
ffb7329
wip
balancap Feb 6, 2021
c2c53a4
wip
balancap Feb 6, 2021
5d8b681
wip
balancap Feb 6, 2021
098ca4f
wip
balancap Feb 6, 2021
2f125da
wip
balancap Feb 6, 2021
5369cc7
wip
balancap Feb 6, 2021
3652b60
wip
balancap Feb 7, 2021
cf60bf3
wip
balancap Feb 7, 2021
ba4b560
wip
balancap Feb 7, 2021
3d80f5c
wip
balancap Feb 7, 2021
f39bce6
wip
balancap Feb 7, 2021
29435e7
wip
balancap Feb 7, 2021
9b9f5fb
wip
balancap Feb 8, 2021
89d76d2
wip
balancap Feb 8, 2021
da25e8e
wip
balancap Feb 8, 2021
8113a51
wip
balancap Feb 9, 2021
f1c310e
wip
balancap Feb 9, 2021
d6e452b
wip
balancap Feb 9, 2021
51c7ed5
wip
balancap Feb 10, 2021
c370c66
wip
balancap Feb 11, 2021
6c635e2
wip
balancap Feb 13, 2021
9396129
wip
balancap Feb 13, 2021
c3096c7
wip
balancap Feb 13, 2021
0286705
wip
balancap Feb 13, 2021
7f895d5
wip
balancap Feb 13, 2021
542177d
wip
balancap Feb 13, 2021
1c784e9
wip
balancap Feb 13, 2021
85cec94
wip
balancap Feb 14, 2021
2853d37
wip
balancap Feb 14, 2021
ec44124
wip
balancap Feb 14, 2021
15a5ff9
wip
balancap Feb 14, 2021
facafc2
wip
balancap Feb 14, 2021
9717abe
wip
balancap Feb 14, 2021
756781e
wip
balancap Feb 14, 2021
05c1be4
wip
balancap Feb 14, 2021
1504f11
wip
balancap Feb 15, 2021
9ee2b77
wip
balancap Feb 15, 2021
e1713ae
wip
balancap Feb 15, 2021
825c79a
wip
balancap Feb 17, 2021
be7f8e0
wip
balancap Feb 17, 2021
c594ff6
wip
balancap Feb 19, 2021
66133f5
wip
balancap Feb 19, 2021
d8d1180
wip
balancap Feb 19, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
wip
balancap committed Feb 13, 2021
commit 6c635e28bca1f0c2c353d96ada819b0b26714afb
133 changes: 127 additions & 6 deletions mcx/core/jaxpr_ops.py
Original file line number Diff line number Diff line change
@@ -10,15 +10,18 @@

from dataclasses import dataclass
from functools import wraps
from typing import List, Dict, Tuple, Any, Type, TypeVar, Callable
from typing import List, Dict, Optional, Set, Tuple, Any, Type, TypeVar, Callable

Array = Any
"""Generic Array type.
"""
TState = TypeVar("TState")
"""Generic Jaxpr visitor state.
"""

TRecState = Tuple[TState, List[Optional[List["TRecState"]]]]
"""Full recursive state, representing the visitor state of the Jaxpr as well as
the sub-states of all sub-jaxprs.
"""

jaxpr_high_order_primitives_to_subjaxprs = {
jax.lax.cond_p: lambda jxpr: jxpr.params["branches"],
@@ -86,7 +89,7 @@ def jaxpr_visitor(
map_sub_states_fn: Callable[[jax.core.JaxprEqn, TState], List[TState]],
reduce_sub_states_fn: Callable[[jax.core.JaxprEqn, TState, List[TState]], TState],
reverse: bool = False,
) -> Tuple[TState, List[Any]]:
) -> TRecState:
"""Visitor pattern on a Jaxpr, traversing equations and supporting higher-order primitives
with sub-Jaxprs.

@@ -153,6 +156,10 @@ class ConstVarInfo:
ConstVarState = Dict[jax.core.Var, ConstVarInfo]
"""Const variables visitor state: dictionary associating const variables with their info.
"""
ConstVarRecState = Tuple[ConstVarState, List[Optional[List["ConstVarRecState"]]]]


# Garthoks = Union[Garthok, Iterable['Garthoks']]


def get_variable_const_info(v: Any, state: ConstVarState) -> ConstVarInfo:
@@ -287,7 +294,7 @@ def jaxpr_find_constvars_reduce_sub_states_fn(

def jaxpr_find_constvars(
jaxpr: jax.core.Jaxpr, constvars: Dict[jax.core.Var, ConstVarInfo]
) -> Dict[jax.core.Var, ConstVarInfo]:
) -> ConstVarRecState:
"""Find all intermediates variables in a JAX expression which are expected to be constants.

Parameters
@@ -302,18 +309,132 @@ def jaxpr_find_constvars(
# Start with the collection of input constants.
const_state = copy.copy(constvars)
print(jaxpr)
const_state, _ = jaxpr_visitor(
const_rec_state = jaxpr_visitor(
jaxpr,
const_state,
jaxpr_find_constvars_visitor_fn,
jaxpr_find_constvars_map_sub_states_fn,
jaxpr_find_constvars_reduce_sub_states_fn,
reverse=False,
)
return const_state
return const_rec_state


DenormalizeState = Tuple[
Dict[jax.core.Var, Tuple[Any, jax.core.Var]], Set[jax.core.Var], ConstVarRecState
]
"""Denormalization state, combination of:
- dictionary of variable mapping, corresponding to `add` or `sub` ops which can be simplified;
- set of variables which can be traverse backward for denormalization;
- full recursive const variable state of the Jaxpr.
"""
DenormalizeRecState = Tuple[
DenormalizeState, List[Optional[List["DenormalizeRecState"]]]
]

denorm_supported_linear_ops = [
jax.lax.broadcast_in_dim_p,
jax.lax.broadcast_p,
jax.lax.neg_p,
jax.lax.reshape_p,
jax.lax.squeeze_p,
jax.lax.reduce_sum_p,
]


def jaxpr_denorm_mapping_visitor_fn(
eqn: jax.core.JaxprEqn,
state: DenormalizeState,
) -> DenormalizeState:
"""pass
fdsafas
"""
# Un-stack input complex input state!
denorm_map_dict, denorm_valid_vars, constvar_full_state = state
constvar_state, _ = constvar_full_state

def is_var_constant(v: Any) -> bool:
return type(v) is jax.core.Literal or v in constvar_state

if eqn.primitive in denorm_supported_linear_ops:
# Can continue denormalizing inputs if all outputs are in the linear vars collection.
if all([o in denorm_valid_vars for o in eqn.outvars]):
denorm_valid_vars |= set(eqn.invars)
elif eqn.primitive == jax.lax.add_p and eqn.outvars[0] in denorm_valid_vars:
lhs_invar, rhs_invar = eqn.invars[0], eqn.invars[1]
# Mapping the output to the non-const input.
if is_var_constant(lhs_invar):
denorm_valid_vars.add(rhs_invar)
denorm_map_dict[eqn.outvars[0]] = (jax_lax_identity, rhs_invar)
elif is_var_constant(rhs_invar):
denorm_valid_vars.add(lhs_invar)
denorm_map_dict[eqn.outvars[0]] = (jax_lax_identity, lhs_invar)
elif eqn.primitive == jax.lax.sub_p and eqn.outvars[0] in denorm_valid_vars:
lhs_invar, rhs_invar = eqn.invars[0], eqn.invars[1]
# Mapping the output to the non-const input (or the negative).
if is_var_constant(lhs_invar):
denorm_valid_vars.add(rhs_invar)
denorm_map_dict[eqn.outvars[0]] = (jax.lax.neg, rhs_invar)
elif is_var_constant(rhs_invar):
denorm_valid_vars.add(lhs_invar)
denorm_map_dict[eqn.outvars[0]] = (jax_lax_identity, lhs_invar)

# Re-construct updated state.
return (denorm_map_dict, denorm_valid_vars, constvar_full_state)


def jaxpr_denorm_mapping_map_sub_states_fn(
eqn: jax.core.JaxprEqn, state: DenormalizeState
) -> List[DenormalizeState]:
""""""
denorm_map_dict, denorm_valid_vars, constvar_full_state = state
sub_jaxprs = jaxpr_find_sub_jaxprs(eqn)
# TODO: fix properly.
return [state for _ in sub_jaxprs]


def jaxpr_denorm_mapping_reduce_sub_states_fn(
eqn: jax.core.JaxprEqn, state: DenormalizeState, sub_states: List[DenormalizeState]
) -> DenormalizeState:
""""""
# TODO: fix properly.
return state


def jaxpr_find_denormalize_mapping(
jaxpr: jax.core.Jaxpr, constvar_state: ConstVarRecState
) -> DenormalizeRecState:
"""Find all assignment simplifications in a JAX expression when denormalizing.

More specifically, this method is looking to simplify `add` and `sub` operations, with output linear
with respect to the Jaxpr outputs, and where one of the input is constant. It returns the simplified mapping
between input and output of `add`/`sub` ops which can be removed.

Parameters
----------
jaxpr: JAX expression.
consts: List of known constant variables in the JAX expression.

Returns
-------
Simplified mapping between `add` output and input (with the proper assignment lax op `identity` or `neg`).
"""
# Initialize the denormalize state, starting from the ouput variables.
denormalize_mapping = {}
denorm_valid_vars = set(jaxpr.outvars)
denorm_state = (denormalize_mapping, denorm_valid_vars, constvar_state)
denorm_rec_state = jaxpr_visitor(
jaxpr,
denorm_state,
jaxpr_denorm_mapping_visitor_fn,
jaxpr_denorm_mapping_map_sub_states_fn,
jaxpr_denorm_mapping_reduce_sub_states_fn,
reverse=True,
)
return denorm_rec_state


def jaxpr_find_denormalize_mapping_old(
jaxpr: jax.core.Jaxpr, consts: List[jax.core.Var]
) -> Dict[jax.core.Var, Tuple[jax.core.Primitive, jax.core.Var]]:
"""Find all assignment simplifications in a JAX expression when denormalizing.
21 changes: 12 additions & 9 deletions tests/core/jaxpr_ops_test.py
Original file line number Diff line number Diff line change
@@ -56,7 +56,7 @@ def test__jaxpr_find_constvars__propagate_constants(case):
{v: ConstVarInfo(False, True) for v in typed_jaxpr.jaxpr.constvars}
)

constvars = jaxpr_find_constvars(typed_jaxpr.jaxpr, constvars)
constvars, _ = jaxpr_find_constvars(typed_jaxpr.jaxpr, constvars)
for outvar in typed_jaxpr.jaxpr.outvars:
assert outvar in constvars
assert constvars[outvar] == expected_const_info
@@ -73,9 +73,12 @@ def test__jaxpr_find_constvars__propagate_constants(case):
@pytest.mark.parametrize("case", denorm_expected_add_mapping_op)
def test__jaxpr_find_denormalize_mapping__add_sub__proper_mapping(case):
typed_jaxpr = jax.make_jaxpr(case["fn"])(1.0)
denorm_map = jaxpr_find_denormalize_mapping(
typed_jaxpr.jaxpr, typed_jaxpr.jaxpr.constvars
)
constvars = {v: ConstVarInfo(False, True) for v in typed_jaxpr.jaxpr.constvars}
constvar_state = jaxpr_find_constvars(typed_jaxpr.jaxpr, constvars)

denorm_rec_state = jaxpr_find_denormalize_mapping(typed_jaxpr.jaxpr, constvar_state)
denorm_map = denorm_rec_state[0][0]

invar = typed_jaxpr.jaxpr.invars[0]
outvar = typed_jaxpr.jaxpr.outvars[0]

@@ -103,13 +106,13 @@ def test__jaxpr_find_denormalize_mapping__add_sub__proper_mapping(case):
@pytest.mark.parametrize("case", denorm_linear_op_propagating)
def test__jaxpr_find_denormalize_mapping__linear_op_propagating__proper_mapping(case):
typed_jaxpr = jax.make_jaxpr(case["fn"])(1.0)
denorm_map = jaxpr_find_denormalize_mapping(
typed_jaxpr.jaxpr, typed_jaxpr.jaxpr.constvars
)
invar = typed_jaxpr.jaxpr.invars[0]
constvars = {v: ConstVarInfo(False, True) for v in typed_jaxpr.jaxpr.constvars}
constvar_state = jaxpr_find_constvars(typed_jaxpr.jaxpr, constvars)

print(typed_jaxpr)
denorm_rec_state = jaxpr_find_denormalize_mapping(typed_jaxpr.jaxpr, constvar_state)
denorm_map = denorm_rec_state[0][0]

invar = typed_jaxpr.jaxpr.invars[0]
# Proper mapping of the output to the input.
assert len(denorm_map) == 1
map_op, map_invar = list(denorm_map.values())[0]