Skip to content

Commit

Permalink
llvm, Mechanism: Update output ports in "reset" execution variant
Browse files Browse the repository at this point in the history
The "reset" variants of functions now return a value.

Fixes: PrincetonUniversity#3142
Signed-off-by: Jan Vesely <[email protected]>
  • Loading branch information
jvesely committed Feb 9, 2025
1 parent 879f12a commit 206a868
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 0 deletions.
3 changes: 3 additions & 0 deletions psyneulink/core/components/mechanisms/mechanism.py
Original file line number Diff line number Diff line change
Expand Up @@ -3241,6 +3241,9 @@ def _gen_llvm_function_reset(self, ctx, builder, m_base_params, m_state, m_arg_i

builder.call(reinit_func, [reinit_params, reinit_state, reinit_in, reinit_out])

# update output ports after getting the reinitialized value
builder = self._gen_llvm_output_ports(ctx, builder, reinit_out, m_base_params, m_state, m_arg_in, m_arg_out)

return builder

def _gen_llvm_function(self, *, extra_args=[], ctx:pnlvm.LLVMBuilderContext, tags:frozenset):
Expand Down
22 changes: 22 additions & 0 deletions tests/mechanisms/test_ddm_mechanism.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,6 +617,28 @@ def test_WhenFinished_DDM_Analytical():
c.is_satisfied()


@pytest.mark.mechanism
@pytest.mark.parametrize("initializer_param", [{}, {pnl.INITIALIZER: 5.0}], ids=["default_initializer", "custom_initializer"])
@pytest.mark.parametrize("non_decision_time_param", [{}, {pnl.NON_DECISION_TIME: 13.0}], ids=["default_non_decision_time", "custom_non_decision_time"])
def test_DDM_reset(mech_mode, initializer_param, non_decision_time_param):
D = pnl.DDM(function=pnl.DriftDiffusionIntegrator(**initializer_param, **non_decision_time_param))

ex = pytest.helpers.get_mech_execution(D, mech_mode)

initializer_value = initializer_param.get(pnl.INITIALIZER, 0)
non_decision_time_value = non_decision_time_param.get(pnl.NON_DECISION_TIME, 0)

ex([1])
ex([2])
result = ex([3])
np.testing.assert_array_equal(result, [[100], [102 - initializer_value + non_decision_time_value]])

reset_ex = pytest.helpers.get_mech_execution(D, mech_mode, tags=frozenset({"reset"}), member="reset")

reset_result = reset_ex(None if mech_mode == "Python" else [0])
np.testing.assert_array_equal(reset_result, [[initializer_value], [non_decision_time_value]])


@pytest.mark.composition
@pytest.mark.ddm_mechanism
@pytest.mark.mechanism
Expand Down

0 comments on commit 206a868

Please sign in to comment.