Skip to content

Commit d88aa15

Browse files
author
C.A.P. Linssen
committed
transform kernels and convolutions using a transformer before code generation
1 parent abba1b4 commit d88aa15

File tree

5 files changed

+148
-86
lines changed

5 files changed

+148
-86
lines changed

pynestml/codegeneration/resources_nest/point_neuron/common/NeuronClass.jinja2

Lines changed: 0 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -717,28 +717,6 @@ void {{neuronName}}::update(nest::Time const & origin,const long from, const lon
717717
update_delay_variables();
718718
{%- endif %}
719719

720-
/**
721-
* subthreshold updates of the convolution variables
722-
*
723-
* step 1: regardless of whether and how integrate_odes() will be called, update variables due to convolutions
724-
**/
725-
726-
{%- if uses_analytic_solver %}
727-
{%- for variable_name in analytic_state_variables: %}
728-
{%- if "__X__" in variable_name %}
729-
{%- set update_expr = update_expressions[variable_name] %}
730-
{%- set var_ast = utils.get_variable_by_name(astnode, variable_name)%}
731-
{%- set var_symbol = var_ast.get_scope().resolve_to_symbol(variable_name, SymbolKind.VARIABLE)%}
732-
{%- if use_gap_junctions %}
733-
const {{ type_symbol_printer.print(var_symbol.type_symbol) }} {{variable_name}}__tmp_ = {{ printer.print(update_expr) | replace("B_." + gap_junction_port + "_grid_sum_", "(B_." + gap_junction_port + "_grid_sum_ + __I_gap)") }};
734-
{%- else %}
735-
const {{ type_symbol_printer.print(var_symbol.type_symbol) }} {{variable_name}}__tmp_ = {{ printer.print(update_expr) }};
736-
{%- endif %}
737-
{%- endif %}
738-
{%- endfor %}
739-
{%- endif %}
740-
741-
742720
/**
743721
* Begin NESTML generated code for the update block(s)
744722
**/
@@ -768,22 +746,6 @@ const {{ type_symbol_printer.print(var_symbol.type_symbol) }} {{variable_name}}_
768746
}
769747
{%- endfor %}
770748

771-
/**
772-
* subthreshold updates of the convolution variables
773-
*
774-
* step 2: regardless of whether and how integrate_odes() was called, update variables due to convolutions. Set to the updated values at the end of the timestep.
775-
**/
776-
{% if uses_analytic_solver %}
777-
{%- for variable_name in analytic_state_variables: %}
778-
{%- if "__X__" in variable_name %}
779-
{%- set update_expr = update_expressions[variable_name] %}
780-
{%- set var_ast = utils.get_variable_by_name(astnode, variable_name)%}
781-
{%- set var_symbol = var_ast.get_scope().resolve_to_symbol(variable_name, SymbolKind.VARIABLE)%}
782-
{{ printer.print(var_ast) }} = {{variable_name}}__tmp_;
783-
{%- endif %}
784-
{%- endfor %}
785-
{%- endif %}
786-
787749
/**
788750
* Begin NESTML generated code for the onCondition block(s)
789751
**/

pynestml/codegeneration/resources_python_standalone/point_neuron/@[email protected]

Lines changed: 1 addition & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,7 @@ class Neuron_{{neuronName}}(Neuron):
191191
{%- endif %}
192192
{%- endfor %}
193193
{%- endfilter %}
194+
pass
194195
else:
195196
# internals V_
196197
{%- filter indent(6) %}
@@ -262,13 +263,6 @@ class Neuron_{{neuronName}}(Neuron):
262263
{%- set analytic_state_variables_ = utils.filter_variables_list(analytic_state_variables_, ast.get_args()) %}
263264
{%- endif %}
264265

265-
{#- always integrate convolutions in time #}
266-
{%- for var in analytic_state_variables %}
267-
{%- if "__X__" in var %}
268-
{%- set tmp = analytic_state_variables_.append(var) %}
269-
{%- endif %}
270-
{%- endfor %}
271-
272266
{%- include "directives_py/AnalyticIntegrationStep_begin.jinja2" %}
273267

274268
{%- if uses_numeric_solver %}
@@ -283,14 +277,6 @@ class Neuron_{{neuronName}}(Neuron):
283277
def step(self, origin: float, timestep: float) -> None:
284278
__resolution: float = timestep # do not remove, this is necessary for the resolution() function
285279

286-
# -------------------------------------------------------------------------
287-
# integrate variables related to convolutions
288-
# -------------------------------------------------------------------------
289-
290-
{%- with analytic_state_variables_ = analytic_state_variables_from_convolutions %}
291-
{%- include "directives_py/AnalyticIntegrationStep_begin.jinja2" %}
292-
{%- endwith %}
293-
294280
# -------------------------------------------------------------------------
295281
# NESTML generated code for the update block
296282
# -------------------------------------------------------------------------
@@ -304,14 +290,6 @@ class Neuron_{{neuronName}}(Neuron):
304290
{%- endfilter %}
305291
{%- endif %}
306292

307-
# -------------------------------------------------------------------------
308-
# integrate variables related to convolutions
309-
# -------------------------------------------------------------------------
310-
311-
{%- with analytic_state_variables_ = analytic_state_variables_from_convolutions %}
312-
{%- include "directives_py/AnalyticIntegrationStep_end.jinja2" %}
313-
{%- endwith %}
314-
315293
# -------------------------------------------------------------------------
316294
# begin NESTML generated code for the onReceive block(s)
317295
# -------------------------------------------------------------------------

pynestml/meta_model/ast_node_factory.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -360,8 +360,8 @@ def create_ast_update_block(cls, block, source_position):
360360
return ASTUpdateBlock(block, source_position=source_position)
361361

362362
@classmethod
363-
def create_ast_variable(cls, name: str, differential_order: int = 0, vector_parameter=None, is_homogeneous=False, source_position: Optional[ASTSourceLocation] = None, scope: Optional[Scope] = None) -> ASTVariable:
364-
var = ASTVariable(name, differential_order, vector_parameter=vector_parameter, is_homogeneous=is_homogeneous, source_position=source_position)
363+
def create_ast_variable(cls, name: str, differential_order: int = 0, vector_parameter=None, is_homogeneous=False, type_symbol: Optional[str] = None, source_position: Optional[ASTSourceLocation] = None, scope: Optional[Scope] = None) -> ASTVariable:
364+
var = ASTVariable(name, differential_order, type_symbol=type_symbol, vector_parameter=vector_parameter, is_homogeneous=is_homogeneous, source_position=source_position)
365365
if scope:
366366
var.scope = scope
367367

pynestml/transformers/convolutions_transformer.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from pynestml.codegeneration.printers.unitless_sympy_simple_expression_printer import UnitlessSympySimpleExpressionPrinter
3636
from pynestml.frontend.frontend_configuration import FrontendConfiguration
3737
from pynestml.meta_model.ast_assignment import ASTAssignment
38+
from pynestml.meta_model.ast_data_type import ASTDataType
3839
from pynestml.meta_model.ast_declaration import ASTDeclaration
3940
from pynestml.meta_model.ast_equations_block import ASTEquationsBlock
4041
from pynestml.meta_model.ast_expression import ASTExpression
@@ -47,6 +48,7 @@
4748
from pynestml.meta_model.ast_simple_expression import ASTSimpleExpression
4849
from pynestml.meta_model.ast_variable import ASTVariable
4950
from pynestml.symbols.predefined_functions import PredefinedFunctions
51+
from pynestml.symbols.real_type_symbol import RealTypeSymbol
5052
from pynestml.symbols.symbol import SymbolKind
5153
from pynestml.symbols.variable_symbol import BlockType
5254
from pynestml.transformers.transformer import Transformer
@@ -84,6 +86,61 @@ def __init__(self, options: Optional[Mapping[str, Any]] = None):
8486
self._ode_toolbox_variable_printer._expression_printer = self._ode_toolbox_printer
8587
self._ode_toolbox_function_call_printer._expression_printer = self._ode_toolbox_printer
8688

89+
def add_kernel_variables_to_integrate_odes_calls(self, model, solvers_json):
90+
for solver_dict in solvers_json:
91+
if solver_dict is None:
92+
continue
93+
94+
for var_name, expr in solver_dict["initial_values"].items():
95+
var = ASTUtils.get_variable_by_name(model, var_name)
96+
ASTUtils.add_state_var_to_integrate_odes_calls(model, var)
97+
98+
model.accept(ASTParentVisitor())
99+
100+
101+
def add_integrate_odes_call_for_kernel_variables(self, model, solvers_json):
102+
var_names = []
103+
for solver_dict in solvers_json:
104+
if solver_dict is None:
105+
continue
106+
107+
for var_name, expr in solver_dict["initial_values"].items():
108+
var_names.append(var_name)
109+
110+
args = ASTUtils.resolve_variables_to_simple_expressions(model, var_names)
111+
ast_function_call = ASTNodeFactory.create_ast_function_call("integrate_odes", args)
112+
ASTUtils.add_function_call_to_update_block(ast_function_call, model)
113+
model.accept(ASTParentVisitor())
114+
115+
def add_temporary_kernel_variables_copy(self, model, solvers_json):
116+
var_names = []
117+
for solver_dict in solvers_json:
118+
if solver_dict is None:
119+
continue
120+
121+
for var_name, expr in solver_dict["initial_values"].items():
122+
var_names.append(var_name)
123+
124+
scope = model.get_update_blocks()[0].scope
125+
126+
for var_name in var_names:
127+
var = ASTNodeFactory.create_ast_variable(var_name + "__tmp", type_symbol=RealTypeSymbol)
128+
var.scope = scope
129+
expr = ASTNodeFactory.create_ast_simple_expression(variable=ASTUtils.get_variable_by_name(model, var_name))
130+
ast_declaration = ASTNodeFactory.create_ast_declaration(variables=[var],
131+
data_type=ASTDataType(is_real=True),
132+
expression=expr, source_position=ASTSourceLocation.get_added_source_position())
133+
ast_declaration.update_scope(scope)
134+
ast_small_stmt = ASTNodeFactory.create_ast_small_stmt(declaration=ast_declaration)
135+
ast_small_stmt.update_scope(scope)
136+
ast_stmt = ASTNodeFactory.create_ast_stmt(small_stmt=ast_small_stmt)
137+
ast_stmt.update_scope(scope)
138+
139+
model.get_update_blocks()[0].get_block().stmts.insert(0, ast_stmt)
140+
141+
model.accept(ASTParentVisitor())
142+
model.accept(ASTSymbolTableVisitor())
143+
87144
def transform(self, models: Union[ASTNode, Sequence[ASTNode]]) -> Union[ASTNode, Sequence[ASTNode]]:
88145
r"""Transform a model or a list of models. Return an updated model or list of models."""
89146
for model in models:
@@ -105,6 +162,11 @@ def transform(self, models: Union[ASTNode, Sequence[ASTNode]]) -> Union[ASTNode,
105162
self.create_spike_update_event_handlers(model, solvers_json, kernel_buffers)
106163
self.replace_convolve_calls_with_buffers_(model)
107164
self.remove_kernel_definitions_from_equations_blocks(model)
165+
self.add_kernel_variables_to_integrate_odes_calls(model, solvers_json)
166+
self.add_temporary_kernel_variables_copy(model, solvers_json)
167+
self.add_integrate_odes_call_for_kernel_variables(model, solvers_json)
168+
self.add_kernel_equations(model, solvers_json)
169+
108170
print("-------- MODEL AFTER TRANSFORM ------------")
109171
print(model)
110172
print("-------------------------------------------")

pynestml/utils/ast_utils.py

Lines changed: 83 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -556,6 +556,18 @@ def visit_function_call(self, node: ASTFunctionCall):
556556
remove_state_var_from_integrate_odes_calls_visitor = RemoveStateVarFromIntegrateODEsCallsVisitor()
557557
model.accept(remove_state_var_from_integrate_odes_calls_visitor)
558558

559+
@classmethod
560+
def add_state_var_to_integrate_odes_calls(cls, model: ASTModel, var: ASTExpression):
561+
r"""Add a state variable to the arguments to each integrate_odes() calls in the model."""
562+
563+
class AddStateVarToIntegrateODEsCallsVisitor(ASTVisitor):
564+
def visit_function_call(self, node: ASTFunctionCall):
565+
if node.get_name() == PredefinedFunctions.INTEGRATE_ODES:
566+
expr = ASTNodeFactory.create_ast_simple_expression(variable=var.clone())
567+
node.args.append(expr)
568+
569+
model.accept(AddStateVarToIntegrateODEsCallsVisitor())
570+
559571
@classmethod
560572
def resolve_variables_to_expressions(cls, astnode, analytic_state_variables_moved):
561573
"""receives a list of variable names (as strings) and returns a list of ASTExpressions containing each ASTVariable"""
@@ -564,7 +576,19 @@ def resolve_variables_to_expressions(cls, astnode, analytic_state_variables_move
564576
for var_name in analytic_state_variables_moved:
565577
node = ASTUtils.get_variable_by_name(astnode, var_name)
566578
assert node is not None
567-
expressions.append(ASTNodeFactory.create_ast_expression(False, None, False, ASTNodeFactory.create_ast_simple_expression(variable=node)))
579+
expressions.append(ASTNodeFactory.create_ast_expression(expression=ASTNodeFactory.create_ast_simple_expression(variable=node)))
580+
581+
return expressions
582+
583+
@classmethod
584+
def resolve_variables_to_simple_expressions(cls, model, vars):
585+
"""receives a list of variable names (as strings) and returns a list of ASTSimpleExpressions containing each ASTVariable"""
586+
expressions = []
587+
588+
for var_name in vars:
589+
node = ASTUtils.get_variable_by_name(model, var_name)
590+
assert node is not None
591+
expressions.append(ASTNodeFactory.create_ast_simple_expression(variable=node))
568592

569593
return expressions
570594

@@ -1113,45 +1137,80 @@ def declaration_in_state_block(cls, neuron: ASTModel, variable_name: str) -> boo
11131137
return False
11141138

11151139
@classmethod
1116-
def add_assignment_to_update_block(cls, assignment: ASTAssignment, neuron: ASTModel) -> ASTModel:
1140+
def add_assignment_to_update_block(cls, assignment: ASTAssignment, model: ASTModel) -> ASTModel:
11171141
"""
1118-
Adds a single assignment to the end of the update block of the handed over neuron. At most one update block should be present.
1142+
Adds a single assignment to the end of the update block of the handed over model. At most one update block should be present.
11191143
11201144
:param assignment: a single assignment
1121-
:param neuron: a single neuron instance
1122-
:return: the modified neuron
1145+
:param model: a single model instance
1146+
:return: the modified model
11231147
"""
1124-
assert len(neuron.get_update_blocks()) <= 1, "At most one update block should be present"
1148+
assert len(model.get_update_blocks()) <= 1, "At most one update block should be present"
11251149
small_stmt = ASTNodeFactory.create_ast_small_stmt(assignment=assignment,
11261150
source_position=ASTSourceLocation.get_added_source_position())
11271151
stmt = ASTNodeFactory.create_ast_stmt(small_stmt=small_stmt,
11281152
source_position=ASTSourceLocation.get_added_source_position())
1129-
if not neuron.get_update_blocks():
1130-
neuron.create_empty_update_block()
1131-
neuron.get_update_blocks()[0].get_block().get_stmts().append(stmt)
1132-
small_stmt.update_scope(neuron.get_update_blocks()[0].get_block().get_scope())
1133-
stmt.update_scope(neuron.get_update_blocks()[0].get_block().get_scope())
1134-
return neuron
1153+
if not model.get_update_blocks():
1154+
model.create_empty_update_block()
1155+
model.get_update_blocks()[0].get_block().get_stmts().append(stmt)
1156+
small_stmt.update_scope(model.get_update_blocks()[0].get_block().get_scope())
1157+
stmt.update_scope(model.get_update_blocks()[0].get_block().get_scope())
1158+
1159+
from pynestml.visitors.ast_parent_visitor import ASTParentVisitor
1160+
model.accept(ASTParentVisitor())
1161+
1162+
return model
1163+
1164+
@classmethod
1165+
def add_function_call_to_update_block(cls, function_call: ASTFunctionCall, model: ASTModel) -> ASTModel:
1166+
"""
1167+
Adds a single assignment to the end of the update block of the handed over model.
1168+
1169+
:param function_call: a single function call
1170+
:param neuron: a single model instance
1171+
:return: the modified model
1172+
"""
1173+
assert len(model.get_update_blocks()) <= 1, "At most one update block should be present"
1174+
1175+
if not model.get_update_blocks():
1176+
model.create_empty_update_block()
1177+
1178+
small_stmt = ASTNodeFactory.create_ast_small_stmt(function_call=function_call,
1179+
source_position=ASTSourceLocation.get_added_source_position())
1180+
stmt = ASTNodeFactory.create_ast_stmt(small_stmt=small_stmt,
1181+
source_position=ASTSourceLocation.get_added_source_position())
1182+
model.get_update_blocks()[0].get_block().get_stmts().append(stmt)
1183+
small_stmt.update_scope(model.get_update_blocks()[0].get_block().get_scope())
1184+
stmt.update_scope(model.get_update_blocks()[0].get_block().get_scope())
1185+
1186+
from pynestml.visitors.ast_parent_visitor import ASTParentVisitor
1187+
model.accept(ASTParentVisitor())
1188+
1189+
return model
11351190

11361191
@classmethod
1137-
def add_declaration_to_update_block(cls, declaration: ASTDeclaration, neuron: ASTModel) -> ASTModel:
1192+
def add_declaration_to_update_block(cls, declaration: ASTDeclaration, model: ASTModel) -> ASTModel:
11381193
"""
1139-
Adds a single declaration to the end of the update block of the handed over neuron.
1194+
Adds a single declaration to the end of the update block of the handed over model.
11401195
:param declaration: ASTDeclaration node to add
1141-
:param neuron: a single neuron instance
1142-
:return: a modified neuron
1196+
:param model: a single model instance
1197+
:return: a modified model
11431198
"""
1144-
assert len(neuron.get_update_blocks()) <= 1, "At most one update block should be present"
1199+
assert len(model.get_update_blocks()) <= 1, "At most one update block should be present"
11451200
small_stmt = ASTNodeFactory.create_ast_small_stmt(declaration=declaration,
11461201
source_position=ASTSourceLocation.get_added_source_position())
11471202
stmt = ASTNodeFactory.create_ast_stmt(small_stmt=small_stmt,
11481203
source_position=ASTSourceLocation.get_added_source_position())
1149-
if not neuron.get_update_blocks():
1150-
neuron.create_empty_update_block()
1151-
neuron.get_update_blocks()[0].get_block().get_stmts().append(stmt)
1152-
small_stmt.update_scope(neuron.get_update_blocks()[0].get_block().get_scope())
1153-
stmt.update_scope(neuron.get_update_blocks()[0].get_block().get_scope())
1154-
return neuron
1204+
if not model.get_update_blocks():
1205+
model.create_empty_update_block()
1206+
model.get_update_blocks()[0].get_block().get_stmts().append(stmt)
1207+
small_stmt.update_scope(model.get_update_blocks()[0].get_block().get_scope())
1208+
stmt.update_scope(model.get_update_blocks()[0].get_block().get_scope())
1209+
1210+
from pynestml.visitors.ast_parent_visitor import ASTParentVisitor
1211+
model.accept(ASTParentVisitor())
1212+
1213+
return model
11551214

11561215
@classmethod
11571216
def add_state_updates(cls, neuron: ASTModel, update_expressions: Mapping[str, str]) -> ASTModel:
@@ -1165,6 +1224,7 @@ def add_state_updates(cls, neuron: ASTModel, update_expressions: Mapping[str, st
11651224
for variable, update_expression in update_expressions.items():
11661225
declaration_statement = variable + '__tmp real = ' + update_expression
11671226
cls.add_declaration_to_update_block(ModelParser.parse_declaration(declaration_statement), neuron)
1227+
11681228
for variable, update_expression in update_expressions.items():
11691229
cls.add_assignment_to_update_block(ModelParser.parse_assignment(variable + ' = ' + variable + '__tmp'),
11701230
neuron)

0 commit comments

Comments
 (0)