|
35 | 35 | from pynestml.codegeneration.printers.unitless_sympy_simple_expression_printer import UnitlessSympySimpleExpressionPrinter |
36 | 36 | from pynestml.frontend.frontend_configuration import FrontendConfiguration |
37 | 37 | from pynestml.meta_model.ast_assignment import ASTAssignment |
| 38 | +from pynestml.meta_model.ast_block import ASTBlock |
38 | 39 | from pynestml.meta_model.ast_data_type import ASTDataType |
39 | 40 | from pynestml.meta_model.ast_declaration import ASTDeclaration |
40 | 41 | from pynestml.meta_model.ast_equations_block import ASTEquationsBlock |
|
46 | 47 | from pynestml.meta_model.ast_node import ASTNode |
47 | 48 | from pynestml.meta_model.ast_node_factory import ASTNodeFactory |
48 | 49 | from pynestml.meta_model.ast_simple_expression import ASTSimpleExpression |
| 50 | +from pynestml.meta_model.ast_small_stmt import ASTSmallStmt |
49 | 51 | from pynestml.meta_model.ast_variable import ASTVariable |
50 | 52 | from pynestml.symbols.predefined_functions import PredefinedFunctions |
51 | 53 | from pynestml.symbols.real_type_symbol import RealTypeSymbol |
@@ -86,6 +88,53 @@ def __init__(self, options: Optional[Mapping[str, Any]] = None): |
86 | 88 | self._ode_toolbox_variable_printer._expression_printer = self._ode_toolbox_printer |
87 | 89 | self._ode_toolbox_function_call_printer._expression_printer = self._ode_toolbox_printer |
88 | 90 |
|
| 91 | + def add_restore_kernel_variables_to_start_of_timestep(self, model, solvers_json): |
| 92 | + r"""For each integrate_odes() call in the model, append statements restoring the kernel variables to the values at the start of the timestep""" |
| 93 | + |
| 94 | + var_names = [] |
| 95 | + for solver_dict in solvers_json: |
| 96 | + if solver_dict is None: |
| 97 | + continue |
| 98 | + |
| 99 | + for var_name, expr in solver_dict["initial_values"].items(): |
| 100 | + var_names.append(var_name) |
| 101 | + |
| 102 | + class IntegrateODEsFunctionCallVisitor(ASTVisitor): |
| 103 | + all_args = None |
| 104 | + |
| 105 | + def __init__(self): |
| 106 | + super().__init__() |
| 107 | + |
| 108 | + def visit_small_stmt(self, node: ASTSmallStmt): |
| 109 | + self._visit(node) |
| 110 | + |
| 111 | + def visit_simple_expression(self, node: ASTSimpleExpression): |
| 112 | + self._visit(node) |
| 113 | + |
| 114 | + def _visit(self, node): |
| 115 | + if node.is_function_call() and node.get_function_call().get_name() == PredefinedFunctions.INTEGRATE_ODES: |
| 116 | + parent_stmt = node.get_parent() |
| 117 | + parent_block = parent_stmt.get_parent() |
| 118 | + assert isinstance(parent_block, ASTBlock) |
| 119 | + idx = parent_block.stmts.index(parent_stmt) |
| 120 | + |
| 121 | + for i, var_name in enumerate(var_names): |
| 122 | + var = ASTNodeFactory.create_ast_variable(var_name + "__at_start_of_timestep", type_symbol=RealTypeSymbol) |
| 123 | + var.update_scope(parent_block.get_scope()) |
| 124 | + expr = ASTNodeFactory.create_ast_simple_expression(variable=var) |
| 125 | + ast_assignment = ASTNodeFactory.create_ast_assignment(lhs=ASTUtils.get_variable_by_name(model, var_name), |
| 126 | + is_direct_assignment=True, |
| 127 | + expression=expr, source_position=ASTSourceLocation.get_added_source_position()) |
| 128 | + ast_assignment.update_scope(parent_block.get_scope()) |
| 129 | + ast_small_stmt = ASTNodeFactory.create_ast_small_stmt(assignment=ast_assignment) |
| 130 | + ast_small_stmt.update_scope(parent_block.get_scope()) |
| 131 | + ast_stmt = ASTNodeFactory.create_ast_stmt(small_stmt=ast_small_stmt) |
| 132 | + ast_stmt.update_scope(parent_block.get_scope()) |
| 133 | + |
| 134 | + parent_block.stmts.insert(idx + i + 1, ast_stmt) |
| 135 | + |
| 136 | + model.accept(IntegrateODEsFunctionCallVisitor()) |
| 137 | + |
89 | 138 | def add_kernel_variables_to_integrate_odes_calls(self, model, solvers_json): |
90 | 139 | for solver_dict in solvers_json: |
91 | 140 | if solver_dict is None: |
@@ -124,7 +173,7 @@ def add_temporary_kernel_variables_copy(self, model, solvers_json): |
124 | 173 | scope = model.get_update_blocks()[0].scope |
125 | 174 |
|
126 | 175 | for var_name in var_names: |
127 | | - var = ASTNodeFactory.create_ast_variable(var_name + "__tmp", type_symbol=RealTypeSymbol) |
| 176 | + var = ASTNodeFactory.create_ast_variable(var_name + "__at_start_of_timestep", type_symbol=RealTypeSymbol) |
128 | 177 | var.scope = scope |
129 | 178 | expr = ASTNodeFactory.create_ast_simple_expression(variable=ASTUtils.get_variable_by_name(model, var_name)) |
130 | 179 | ast_declaration = ASTNodeFactory.create_ast_declaration(variables=[var], |
@@ -163,6 +212,7 @@ def transform(self, models: Union[ASTNode, Sequence[ASTNode]]) -> Union[ASTNode, |
163 | 212 | self.replace_convolve_calls_with_buffers_(model) |
164 | 213 | self.remove_kernel_definitions_from_equations_blocks(model) |
165 | 214 | self.add_kernel_variables_to_integrate_odes_calls(model, solvers_json) |
| 215 | + self.add_restore_kernel_variables_to_start_of_timestep(model, solvers_json) |
166 | 216 | self.add_temporary_kernel_variables_copy(model, solvers_json) |
167 | 217 | self.add_integrate_odes_call_for_kernel_variables(model, solvers_json) |
168 | 218 | self.add_kernel_equations(model, solvers_json) |
@@ -438,6 +488,32 @@ def generate_kernel_buffers(self, model: ASTModel) -> Mapping[ASTKernel, ASTInpu |
438 | 488 |
|
439 | 489 | return kernel_buffers |
440 | 490 |
|
| 491 | + def add_kernel_equations(self, model, solver_dicts): |
| 492 | + if not model.get_equations_blocks(): |
| 493 | + ASTUtils.create_equations_block() |
| 494 | + |
| 495 | + assert len(model.get_equations_blocks()) <= 1 |
| 496 | + |
| 497 | + equations_block = model.get_equations_blocks()[0] |
| 498 | + |
| 499 | + for solver_dict in solver_dicts: |
| 500 | + if solver_dict is None: |
| 501 | + continue |
| 502 | + |
| 503 | + for var_name, expr_str in solver_dict["update_expressions"].items(): |
| 504 | + expr = ModelParser.parse_expression(expr_str) |
| 505 | + expr.update_scope(model.get_scope()) |
| 506 | + expr.accept(ASTSymbolTableVisitor()) |
| 507 | + |
| 508 | + var = ASTNodeFactory.create_ast_variable(var_name, differential_order=1, source_position=ASTSourceLocation.get_added_source_position()) |
| 509 | + var.update_scope(equations_block.get_scope()) |
| 510 | + ast_ode_equation = ASTNodeFactory.create_ast_ode_equation(lhs=var, rhs=expr, source_position=ASTSourceLocation.get_added_source_position()) |
| 511 | + ast_ode_equation.update_scope(equations_block.get_scope()) |
| 512 | + equations_block.declarations.append(ast_ode_equation) |
| 513 | + |
| 514 | + model.accept(ASTParentVisitor()) |
| 515 | + model.accept(ASTSymbolTableVisitor()) |
| 516 | + |
441 | 517 | def remove_kernel_definitions_from_equations_blocks(self, model: ASTModel) -> ASTDeclaration: |
442 | 518 | r""" |
443 | 519 | Removes all kernels in equations blocks. |
|
0 commit comments