Skip to content

Commit de22d17

Browse files
author
C.A.P. Linssen
committed
transform kernels and convolutions using a transformer before code generation [noci]
1 parent d88aa15 commit de22d17

File tree

2 files changed

+80
-1
lines changed

2 files changed

+80
-1
lines changed

pynestml/transformers/convolutions_transformer.py

Lines changed: 77 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from pynestml.codegeneration.printers.unitless_sympy_simple_expression_printer import UnitlessSympySimpleExpressionPrinter
3636
from pynestml.frontend.frontend_configuration import FrontendConfiguration
3737
from pynestml.meta_model.ast_assignment import ASTAssignment
38+
from pynestml.meta_model.ast_block import ASTBlock
3839
from pynestml.meta_model.ast_data_type import ASTDataType
3940
from pynestml.meta_model.ast_declaration import ASTDeclaration
4041
from pynestml.meta_model.ast_equations_block import ASTEquationsBlock
@@ -46,6 +47,7 @@
4647
from pynestml.meta_model.ast_node import ASTNode
4748
from pynestml.meta_model.ast_node_factory import ASTNodeFactory
4849
from pynestml.meta_model.ast_simple_expression import ASTSimpleExpression
50+
from pynestml.meta_model.ast_small_stmt import ASTSmallStmt
4951
from pynestml.meta_model.ast_variable import ASTVariable
5052
from pynestml.symbols.predefined_functions import PredefinedFunctions
5153
from pynestml.symbols.real_type_symbol import RealTypeSymbol
@@ -86,6 +88,53 @@ def __init__(self, options: Optional[Mapping[str, Any]] = None):
8688
self._ode_toolbox_variable_printer._expression_printer = self._ode_toolbox_printer
8789
self._ode_toolbox_function_call_printer._expression_printer = self._ode_toolbox_printer
8890

91+
def add_restore_kernel_variables_to_start_of_timestep(self, model, solvers_json):
92+
r"""For each integrate_odes() call in the model, append statements restoring the kernel variables to the values at the start of the timestep"""
93+
94+
var_names = []
95+
for solver_dict in solvers_json:
96+
if solver_dict is None:
97+
continue
98+
99+
for var_name, expr in solver_dict["initial_values"].items():
100+
var_names.append(var_name)
101+
102+
class IntegrateODEsFunctionCallVisitor(ASTVisitor):
103+
all_args = None
104+
105+
def __init__(self):
106+
super().__init__()
107+
108+
def visit_small_stmt(self, node: ASTSmallStmt):
109+
self._visit(node)
110+
111+
def visit_simple_expression(self, node: ASTSimpleExpression):
112+
self._visit(node)
113+
114+
def _visit(self, node):
115+
if node.is_function_call() and node.get_function_call().get_name() == PredefinedFunctions.INTEGRATE_ODES:
116+
parent_stmt = node.get_parent()
117+
parent_block = parent_stmt.get_parent()
118+
assert isinstance(parent_block, ASTBlock)
119+
idx = parent_block.stmts.index(parent_stmt)
120+
121+
for i, var_name in enumerate(var_names):
122+
var = ASTNodeFactory.create_ast_variable(var_name + "__at_start_of_timestep", type_symbol=RealTypeSymbol)
123+
var.update_scope(parent_block.get_scope())
124+
expr = ASTNodeFactory.create_ast_simple_expression(variable=var)
125+
ast_assignment = ASTNodeFactory.create_ast_assignment(lhs=ASTUtils.get_variable_by_name(model, var_name),
126+
is_direct_assignment=True,
127+
expression=expr, source_position=ASTSourceLocation.get_added_source_position())
128+
ast_assignment.update_scope(parent_block.get_scope())
129+
ast_small_stmt = ASTNodeFactory.create_ast_small_stmt(assignment=ast_assignment)
130+
ast_small_stmt.update_scope(parent_block.get_scope())
131+
ast_stmt = ASTNodeFactory.create_ast_stmt(small_stmt=ast_small_stmt)
132+
ast_stmt.update_scope(parent_block.get_scope())
133+
134+
parent_block.stmts.insert(idx + i + 1, ast_stmt)
135+
136+
model.accept(IntegrateODEsFunctionCallVisitor())
137+
89138
def add_kernel_variables_to_integrate_odes_calls(self, model, solvers_json):
90139
for solver_dict in solvers_json:
91140
if solver_dict is None:
@@ -124,7 +173,7 @@ def add_temporary_kernel_variables_copy(self, model, solvers_json):
124173
scope = model.get_update_blocks()[0].scope
125174

126175
for var_name in var_names:
127-
var = ASTNodeFactory.create_ast_variable(var_name + "__tmp", type_symbol=RealTypeSymbol)
176+
var = ASTNodeFactory.create_ast_variable(var_name + "__at_start_of_timestep", type_symbol=RealTypeSymbol)
128177
var.scope = scope
129178
expr = ASTNodeFactory.create_ast_simple_expression(variable=ASTUtils.get_variable_by_name(model, var_name))
130179
ast_declaration = ASTNodeFactory.create_ast_declaration(variables=[var],
@@ -163,6 +212,7 @@ def transform(self, models: Union[ASTNode, Sequence[ASTNode]]) -> Union[ASTNode,
163212
self.replace_convolve_calls_with_buffers_(model)
164213
self.remove_kernel_definitions_from_equations_blocks(model)
165214
self.add_kernel_variables_to_integrate_odes_calls(model, solvers_json)
215+
self.add_restore_kernel_variables_to_start_of_timestep(model, solvers_json)
166216
self.add_temporary_kernel_variables_copy(model, solvers_json)
167217
self.add_integrate_odes_call_for_kernel_variables(model, solvers_json)
168218
self.add_kernel_equations(model, solvers_json)
@@ -438,6 +488,32 @@ def generate_kernel_buffers(self, model: ASTModel) -> Mapping[ASTKernel, ASTInpu
438488

439489
return kernel_buffers
440490

491+
def add_kernel_equations(self, model, solver_dicts):
492+
if not model.get_equations_blocks():
493+
ASTUtils.create_equations_block()
494+
495+
assert len(model.get_equations_blocks()) <= 1
496+
497+
equations_block = model.get_equations_blocks()[0]
498+
499+
for solver_dict in solver_dicts:
500+
if solver_dict is None:
501+
continue
502+
503+
for var_name, expr_str in solver_dict["update_expressions"].items():
504+
expr = ModelParser.parse_expression(expr_str)
505+
expr.update_scope(model.get_scope())
506+
expr.accept(ASTSymbolTableVisitor())
507+
508+
var = ASTNodeFactory.create_ast_variable(var_name, differential_order=1, source_position=ASTSourceLocation.get_added_source_position())
509+
var.update_scope(equations_block.get_scope())
510+
ast_ode_equation = ASTNodeFactory.create_ast_ode_equation(lhs=var, rhs=expr, source_position=ASTSourceLocation.get_added_source_position())
511+
ast_ode_equation.update_scope(equations_block.get_scope())
512+
equations_block.declarations.append(ast_ode_equation)
513+
514+
model.accept(ASTParentVisitor())
515+
model.accept(ASTSymbolTableVisitor())
516+
441517
def remove_kernel_definitions_from_equations_blocks(self, model: ASTModel) -> ASTDeclaration:
442518
r"""
443519
Removes all kernels in equations blocks.

pynestml/utils/ast_utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -456,6 +456,9 @@ def create_equations_block(cls, model: ASTModel) -> ASTModel:
456456
block = ASTNodeFactory.create_ast_equations_block(list(),
457457
ASTSourceLocation.get_added_source_position())
458458
model.get_body().get_body_elements().append(block)
459+
460+
model.accept(ASTParentVisitor())
461+
459462
return model
460463

461464
@classmethod

0 commit comments

Comments
 (0)