Skip to content

Commit

Permalink
llvm: Return initialized values as output from "reset" function execu…
Browse files Browse the repository at this point in the history
…tion variants (#3188)

Return the value of "previous_value" Parameter in Function "reset" variant for Functions that don't have other Parameters with initializers.
Return the values of "previous_value" and "previous_time" in DriftDiffusionIntegrator "reset" variant".
Update output ports in Mechanism "reset" execution variants, using the value returned from the Function's "reset" variant.

Update compiled test helpers to allow the selection of execution variants for Mechanisms and Functions.
Add reproducer from #3142 as a regression test.

Closes: #3142
  • Loading branch information
jvesely authored Feb 11, 2025
2 parents f441887 + 9b41a73 commit 5e78a74
Show file tree
Hide file tree
Showing 8 changed files with 205 additions and 51 deletions.
16 changes: 8 additions & 8 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,29 +207,29 @@ def cuda_param(val):
return pytest.param(val, marks=[pytest.mark.llvm, pytest.mark.cuda])

@pytest.helpers.register
def get_func_execution(func, func_mode):
def get_func_execution(func, func_mode, *, tags:frozenset=frozenset(), member='function'):
if func_mode == 'LLVM':
return pnlvm.execution.FuncExecution(func).execute
return pnlvm.execution.FuncExecution(func, tags=tags).execute

elif func_mode == 'PTX':
return pnlvm.execution.FuncExecution(func).cuda_execute
return pnlvm.execution.FuncExecution(func, tags=tags).cuda_execute

elif func_mode == 'Python':
return func.function
return getattr(func, member)
else:
assert False, "Unknown function mode: {}".format(func_mode)

@pytest.helpers.register
def get_mech_execution(mech, mech_mode):
def get_mech_execution(mech, mech_mode, *, tags:frozenset=frozenset(), member='execute'):
if mech_mode == 'LLVM':
return pnlvm.execution.MechExecution(mech).execute
return pnlvm.execution.MechExecution(mech, tags=tags).execute

elif mech_mode == 'PTX':
return pnlvm.execution.MechExecution(mech).cuda_execute
return pnlvm.execution.MechExecution(mech, tags=tags).cuda_execute

elif mech_mode == 'Python':
def mech_wrapper(x):
mech.execute(x)
getattr(mech, member)(x)
return mech.output_values

return mech_wrapper
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2590,8 +2590,7 @@ def _gen_llvm_integrate(self, builder, index, ctx, vi, vo, params, state):
val = pnlvm.helpers.fclamp(builder, val, neg_threshold, threshold)

# Store value result
data_vo_ptr = builder.gep(vo, [ctx.int32_ty(0),
ctx.int32_ty(0), index])
data_vo_ptr = builder.gep(vo, [ctx.int32_ty(0), ctx.int32_ty(0), index])
builder.store(val, data_vo_ptr)
builder.store(val, prev_val_ptr)

Expand All @@ -2604,6 +2603,22 @@ def _gen_llvm_integrate(self, builder, index, ctx, vi, vo, params, state):
time_vo_ptr = builder.gep(vo, [ctx.int32_ty(0), ctx.int32_ty(1), index])
builder.store(curr_time, time_vo_ptr)

def _gen_llvm_function_reset(self, ctx, builder, params, state, arg_in, arg_out, *, tags:frozenset):
assert "reset" in tags

builder = super()._gen_llvm_function_reset(ctx, builder, params, state, arg_in, arg_out, tags=tags)

# Return the reconstructed combination of previous value and previous tim
prev_value_ptr = ctx.get_param_or_state_ptr(builder, self, PREVIOUS_VALUE, state_struct_ptr=state)
value_out_ptr = builder.gep(arg_out, [ctx.int32_ty(0), ctx.int32_ty(0)])
builder.store(builder.load(prev_value_ptr), value_out_ptr)

prev_time_ptr = ctx.get_param_or_state_ptr(builder, self, "previous_time", state_struct_ptr=state)
time_out_ptr = builder.gep(arg_out, [ctx.int32_ty(0), ctx.int32_ty(1)])
builder.store(builder.load(prev_time_ptr), time_out_ptr)

return builder

def reset(self, previous_value=None, previous_time=None, context=None):
return super().reset(
previous_value=previous_value,
Expand Down
22 changes: 17 additions & 5 deletions psyneulink/core/components/functions/stateful/statefulfunction.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,10 +538,12 @@ def reset(self, *args, context=None, **kwargs):
value = []
for attr, v in kwargs.items():
# FIXME: HACK: Do not reinitialize random_state
if attr != "random_state":
getattr(self.parameters, attr).set(kwargs[attr],
context, override=True)
value.append(getattr(self.parameters, attr)._get(context))
if attr == "random_state":
continue

param = getattr(self.parameters, attr)
param.set(kwargs[attr], context, override=True)
value.append(param._get(context))

self.parameters.value.set(value, context, override=True)
return value
Expand All @@ -552,7 +554,17 @@ def _gen_llvm_function_reset(self, ctx, builder, params, state, arg_in, arg_out,
initializer = getattr(self.parameters, a).initializer
source_ptr = ctx.get_param_or_state_ptr(builder, self, initializer, param_struct_ptr=params)
dest_ptr = ctx.get_param_or_state_ptr(builder, self, a, state_struct_ptr=state)
builder.store(builder.load(source_ptr), dest_ptr)
initial_value = builder.load(source_ptr)
builder.store(initial_value, dest_ptr)

# previous_value is the only output of the reset function
if a == "previous_value" and len(self.stateful_attributes) == 1:
unwrapped_ptr = arg_out
if initial_value.type != unwrapped_ptr.type.pointee:
unwrapped_ptr = pnlvm.helpers.unwrap_2d_array(builder, arg_out)

assert initial_value.type == unwrapped_ptr.type.pointee, "{}: {} vs. {}".format(self.name, initial_value.type, arg_out.type.pointee)
builder.store(initial_value, unwrapped_ptr)

return builder

Expand Down
37 changes: 20 additions & 17 deletions psyneulink/core/components/mechanisms/mechanism.py
Original file line number Diff line number Diff line change
Expand Up @@ -3241,6 +3241,9 @@ def _gen_llvm_function_reset(self, ctx, builder, m_base_params, m_state, m_arg_i

builder.call(reinit_func, [reinit_params, reinit_state, reinit_in, reinit_out])

# update output ports after getting the reinitialized value
builder = self._gen_llvm_output_ports(ctx, builder, reinit_out, m_base_params, m_state, m_arg_in, m_arg_out)

return builder

def _gen_llvm_function(self, *, extra_args=[], ctx:pnlvm.LLVMBuilderContext, tags:frozenset):
Expand All @@ -3251,23 +3254,23 @@ def _gen_llvm_function(self, *, extra_args=[], ctx:pnlvm.LLVMBuilderContext, tag
on top of the variants supported by Component.
"""

# Call parent "_gen_llvm_function", this should result in calling
# "_gen_llvm_function_body" below
if "is_finished" not in tags:
return super()._gen_llvm_function(extra_args=extra_args, ctx=ctx, tags=tags)

# Keep all 4 standard arguments to ease invocation
args = [ctx.get_param_struct_type(self).as_pointer(),
ctx.get_state_struct_type(self).as_pointer(),
ctx.get_input_struct_type(self).as_pointer(),
ctx.get_output_struct_type(self).as_pointer()]

builder = ctx.create_llvm_function(args, self, return_type=ctx.bool_ty,
tags=tags)
params, state, inputs = builder.function.args[:3]
finished = self._gen_llvm_is_finished_cond(ctx, builder, params, state, inputs)
builder.ret(finished)
return builder.function
if "is_finished" in tags:

# Keep all 4 standard arguments to ease invocation
args = [ctx.get_param_struct_type(self).as_pointer(),
ctx.get_state_struct_type(self).as_pointer(),
ctx.get_input_struct_type(self).as_pointer(),
ctx.get_output_struct_type(self).as_pointer()]

builder = ctx.create_llvm_function(args, self, return_type=ctx.bool_ty, tags=tags)
params, state, inputs = builder.function.args[:3]
finished = self._gen_llvm_is_finished_cond(ctx, builder, params, state, inputs)
builder.ret(finished)
return builder.function

# Call parent "_gen_llvm_function". This handles standard variants like
# no tags, or the "reset" tag.
return super()._gen_llvm_function(extra_args=extra_args, ctx=ctx, tags=tags)

def _gen_llvm_function_body(self, ctx, builder, base_params, state, arg_in, arg_out, *, tags:frozenset):
"""
Expand Down
64 changes: 45 additions & 19 deletions tests/functions/test_integrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,9 +165,9 @@ def DriftOnASphereFun(init, value, iterations, noise, **kwargs):

else:
if "initializer" not in kwargs:
return [ 0.23690849, 0.00140115, 0.0020072, -0.00128063,
-0.00096267, -0.01620475, -0.02644836, 0.46090672,
0.82875571, -0.31584261, -0.00132534]
return [ 0.23690849474294814, 0.0014011543771184686, 0.0020071969614023914, -0.0012806262650772564,
-0.0009626666466757963, -0.016204753263919822, -0.026448355473615546, 0.4609067174067295,
0.828755706263852, -0.3158426068946889, -0.0013253357638719173]

else:
return [-3.72900858e-03, -3.38148799e-04, -6.43154678e-04, 4.36274120e-05,
Expand All @@ -186,16 +186,20 @@ def DriftOnASphereFun(init, value, iterations, noise, **kwargs):
@pytest.mark.parametrize("noise", [RAND2, test_noise_arr, pnl.NormalDist],
ids=["SNOISE", "VNOISE", "FNOISE"])
@pytest.mark.parametrize("func", [
(pnl.AdaptiveIntegrator, AdaptiveIntFun),
(pnl.SimpleIntegrator, SimpleIntFun),
(pnl.DriftDiffusionIntegrator, DriftIntFun),
(pnl.LeakyCompetingIntegrator, LeakyFun),
(pnl.AccumulatorIntegrator, AccumulatorFun),
pytest.param((pnl.DriftOnASphereIntegrator, DriftOnASphereFun), marks=pytest.mark.llvm_not_implemented),
(pnl.AdaptiveIntegrator, AdaptiveIntFun, {}),
(pnl.SimpleIntegrator, SimpleIntFun, {}),
(pnl.DriftDiffusionIntegrator, DriftIntFun, {'time_step_size': 1.0}),
(pnl.LeakyCompetingIntegrator, LeakyFun, {}),
(pnl.AccumulatorIntegrator, AccumulatorFun, {'increment': RAND0_1}),
pytest.param((pnl.DriftOnASphereIntegrator,
DriftOnASphereFun,
{'dimension': len(test_var) + 1},
), marks=pytest.mark.llvm_not_implemented),
], ids=lambda x: x[0])
@pytest.mark.parametrize("mode", ["test_execution", "test_reset"])
@pytest.mark.benchmark
def test_execute(func, func_mode, variable, noise, params, benchmark):
func_class, func_res = func
def test_execute(func, func_mode, variable, noise, params, mode, benchmark):
func_class, func_res, func_params = func
benchmark.group = GROUP_PREFIX + func_class.componentName

try:
Expand All @@ -207,29 +211,51 @@ def test_execute(func, func_mode, variable, noise, params, benchmark):
if issubclass(func_class, (pnl.DriftDiffusionIntegrator, pnl.DriftOnASphereIntegrator)):
pytest.skip("{} doesn't support functional noise".format(func_class.componentName))

if issubclass(func_class, pnl.DriftOnASphereIntegrator):
params = {**params, 'dimension': len(variable) + 1}
params = {**params, **func_params}

elif issubclass(func_class, pnl.AccumulatorIntegrator):
params = {**params, 'increment': RAND0_1}
if issubclass(func_class, pnl.AccumulatorIntegrator):
params.pop('offset', None)

elif issubclass(func_class, pnl.DriftDiffusionIntegrator):
# If we are dealing with a DriftDiffusionIntegrator, noise and
# time_step_size defaults have changed since this test was created.
# Hard code their old values.
params = {**params, 'time_step_size': 1.0}
noise = np.sqrt(noise)

f = func_class(default_variable=variable, noise=noise, **params)
ex = pytest.helpers.get_func_execution(f, func_mode)

# Execute few times to update the internal state
ex(variable)
ex(variable)
res = benchmark(ex, variable)

expected = func_res(f.initializer, variable, 3, noise, **params)
np.testing.assert_allclose(res, expected, rtol=1e-5, atol=1e-8)
if mode == "test_execution":
res = benchmark(ex, variable)

expected = func_res(f.initializer, variable, 3, noise, **params)

tolerance = {} if pytest.helpers.llvm_current_fp_precision() == 'fp64' else {'rtol':1e-5, 'atol':1e-8}
np.testing.assert_allclose(res, expected, **tolerance)

elif mode == "test_reset":
ex_res = pytest.helpers.get_func_execution(f, func_mode, tags=frozenset({'reset'}), member='reset')

# Compiled mode ignores input variable, but python uses it if it's provided
post_reset = ex_res(None if func_mode == "Python" else variable)

# Python implementations return 2d arrays,
# while most compiled variants return 1d
if func_mode != "Python":
post_reset = np.atleast_2d(post_reset)

# The order in which the reinitialized values are returned
# is hardcoded in kwargs of the reset() methods of the respective
# Function classes. The first one is 'initializer' in all cases.
# The other ones are reset to 0 in the test cases.
reset_expected = np.zeros_like(post_reset)
reset_expected[0] = f.parameters.initializer.get()

np.testing.assert_allclose(post_reset, reset_expected)


def test_integrator_function_no_default_variable_and_params_len_more_than_1():
Expand Down
22 changes: 22 additions & 0 deletions tests/mechanisms/test_ddm_mechanism.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,6 +617,28 @@ def test_WhenFinished_DDM_Analytical():
c.is_satisfied()


@pytest.mark.mechanism
@pytest.mark.parametrize("initializer_param", [{}, {pnl.INITIALIZER: 5.0}], ids=["default_initializer", "custom_initializer"])
@pytest.mark.parametrize("non_decision_time_param", [{}, {pnl.NON_DECISION_TIME: 13.0}], ids=["default_non_decision_time", "custom_non_decision_time"])
def test_DDM_reset(mech_mode, initializer_param, non_decision_time_param):
D = pnl.DDM(function=pnl.DriftDiffusionIntegrator(**initializer_param, **non_decision_time_param))

ex = pytest.helpers.get_mech_execution(D, mech_mode)

initializer_value = initializer_param.get(pnl.INITIALIZER, 0)
non_decision_time_value = non_decision_time_param.get(pnl.NON_DECISION_TIME, 0)

ex([1])
ex([2])
result = ex([3])
np.testing.assert_array_equal(result, [[100], [102 - initializer_value + non_decision_time_value]])

reset_ex = pytest.helpers.get_mech_execution(D, mech_mode, tags=frozenset({"reset"}), member="reset")

reset_result = reset_ex(None if mech_mode == "Python" else [0])
np.testing.assert_array_equal(reset_result, [[initializer_value], [non_decision_time_value]])


@pytest.mark.composition
@pytest.mark.ddm_mechanism
@pytest.mark.mechanism
Expand Down
24 changes: 24 additions & 0 deletions tests/mechanisms/test_processing_mechanism.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numpy as np
import pytest

import psyneulink as pnl
from psyneulink.core.components.functions.function import FunctionError
from psyneulink.core.components.functions.nonstateful.learningfunctions import Hebbian, Reinforcement, TDLearning
from psyneulink.core.components.functions.nonstateful.objectivefunctions import Distance
Expand Down Expand Up @@ -343,3 +344,26 @@ def test_output_ports2(self, op, expected):
var = [1, 2, 4] if op in {MEAN, MEDIAN, STANDARD_DEVIATION, VARIANCE} else [1, 2, -4]
PM1.execute(var)
np.testing.assert_allclose(PM1.output_ports[0].value, expected)

@pytest.mark.mechanism
@pytest.mark.parametrize("initializer_param", [{}, {pnl.INITIALIZER: 5.0}], ids=["default_initializer", "custom_initializer"])
def test_processing_mechanism_reset(mech_mode, initializer_param):
T = pnl.ProcessingMechanism(function=pnl.AdaptiveIntegrator(**initializer_param, rate=0.5))

ex = pytest.helpers.get_mech_execution(T, mech_mode)
initializer_value = initializer_param.get(pnl.INITIALIZER, 0)

ex([1])
ex([2])
result = ex([3])
np.testing.assert_array_equal(result, [[2.125 + initializer_value / 8]])

reset_ex = pytest.helpers.get_mech_execution(T, mech_mode, tags=frozenset({"reset"}), member="reset")

reset_result = reset_ex(None if mech_mode == "Python" else [0])

# FIXME: Python returns 3d value with default initializer
if mech_mode == "Python" and pnl.INITIALIZER not in initializer_param:
reset_result = reset_result[0]

np.testing.assert_array_equal(reset_result, [[initializer_value]])
52 changes: 52 additions & 0 deletions tests/mechanisms/test_transfer_mechanism.py
Original file line number Diff line number Diff line change
Expand Up @@ -1793,3 +1793,55 @@ def test_combine_standard_output_port(self):
# np.testing.assert_allclose(T.output_ports[1].value, [2.0])
# np.testing.assert_allclose(T.output_ports[2].value, [3.0])
# np.testing.assert_allclose(T.output_ports[3].value, [4.0])

@pytest.mark.mechanism
@pytest.mark.parametrize("initializer_param", [{}, {pnl.INITIALIZER: 5.0}], ids=["default_initializer", "custom_initializer"])
def test_integrator_mode_reset(mech_mode, initializer_param):
T = pnl.TransferMechanism(integrator_mode=True,
integrator_function=pnl.AdaptiveIntegrator(**initializer_param))

ex = pytest.helpers.get_mech_execution(T, mech_mode)
initializer_value = initializer_param.get(pnl.INITIALIZER, 0)

ex([1])
ex([2])
result = ex([3])
np.testing.assert_array_equal(result, [[2.125 + initializer_value / 8]])

reset_ex = pytest.helpers.get_mech_execution(T, mech_mode, tags=frozenset({"reset"}), member="reset")

reset_result = reset_ex(None if mech_mode == "Python" else [0])
np.testing.assert_array_equal(reset_result, [[initializer_value]])

@pytest.mark.composition
@pytest.mark.usefixtures("comp_mode_no_per_node")
def test_integrator_mode_reset_in_composition(comp_mode):

# Note, input_mech is scheduled to only execute on pass 5!
input_mech = pnl.TransferMechanism(
input_shapes=2,
integrator_mode=True,
integration_rate=1,
reset_stateful_function_when=pnl.AtTrialStart()
)

lca = pnl.LCAMechanism(
input_shapes=2,
termination_threshold=10,
termination_measure=pnl.TimeScale.TRIAL,
execute_until_finished=False
)

gate = pnl.ProcessingMechanism(input_shapes=2)

comp = pnl.Composition()
comp.add_linear_processing_pathway([input_mech, lca, gate])

comp.scheduler.add_condition(input_mech, pnl.AtPass(5))

comp.scheduler.add_condition(lca, pnl.Always())

comp.scheduler.add_condition(gate, pnl.WhenFinished(lca))

comp.run([[1, 0], [0, 1]], execution_mode=comp_mode)
np.testing.assert_allclose(comp.results, [[[0.52293998, 0.40526519]], [[0.4336115, 0.46026939]]])

0 comments on commit 5e78a74

Please sign in to comment.