@@ -556,6 +556,18 @@ def visit_function_call(self, node: ASTFunctionCall):
556556 remove_state_var_from_integrate_odes_calls_visitor = RemoveStateVarFromIntegrateODEsCallsVisitor ()
557557 model .accept (remove_state_var_from_integrate_odes_calls_visitor )
558558
559+ @classmethod
560+ def add_state_var_to_integrate_odes_calls (cls , model : ASTModel , var : ASTExpression ):
561+ r"""Add a state variable to the arguments to each integrate_odes() calls in the model."""
562+
563+ class AddStateVarToIntegrateODEsCallsVisitor (ASTVisitor ):
564+ def visit_function_call (self , node : ASTFunctionCall ):
565+ if node .get_name () == PredefinedFunctions .INTEGRATE_ODES :
566+ expr = ASTNodeFactory .create_ast_simple_expression (variable = var .clone ())
567+ node .args .append (expr )
568+
569+ model .accept (AddStateVarToIntegrateODEsCallsVisitor ())
570+
559571 @classmethod
560572 def resolve_variables_to_expressions (cls , astnode , analytic_state_variables_moved ):
561573 """receives a list of variable names (as strings) and returns a list of ASTExpressions containing each ASTVariable"""
@@ -564,7 +576,19 @@ def resolve_variables_to_expressions(cls, astnode, analytic_state_variables_move
564576 for var_name in analytic_state_variables_moved :
565577 node = ASTUtils .get_variable_by_name (astnode , var_name )
566578 assert node is not None
567- expressions .append (ASTNodeFactory .create_ast_expression (False , None , False , ASTNodeFactory .create_ast_simple_expression (variable = node )))
579+ expressions .append (ASTNodeFactory .create_ast_expression (expression = ASTNodeFactory .create_ast_simple_expression (variable = node )))
580+
581+ return expressions
582+
583+ @classmethod
584+ def resolve_variables_to_simple_expressions (cls , model , vars ):
585+ """receives a list of variable names (as strings) and returns a list of ASTSimpleExpressions containing each ASTVariable"""
586+ expressions = []
587+
588+ for var_name in vars :
589+ node = ASTUtils .get_variable_by_name (model , var_name )
590+ assert node is not None
591+ expressions .append (ASTNodeFactory .create_ast_simple_expression (variable = node ))
568592
569593 return expressions
570594
@@ -1113,45 +1137,80 @@ def declaration_in_state_block(cls, neuron: ASTModel, variable_name: str) -> boo
11131137 return False
11141138
11151139 @classmethod
1116- def add_assignment_to_update_block (cls , assignment : ASTAssignment , neuron : ASTModel ) -> ASTModel :
1140+ def add_assignment_to_update_block (cls , assignment : ASTAssignment , model : ASTModel ) -> ASTModel :
11171141 """
1118- Adds a single assignment to the end of the update block of the handed over neuron . At most one update block should be present.
1142+ Adds a single assignment to the end of the update block of the handed over model . At most one update block should be present.
11191143
11201144 :param assignment: a single assignment
1121- :param neuron : a single neuron instance
1122- :return: the modified neuron
1145+ :param model : a single model instance
1146+ :return: the modified model
11231147 """
1124- assert len (neuron .get_update_blocks ()) <= 1 , "At most one update block should be present"
1148+ assert len (model .get_update_blocks ()) <= 1 , "At most one update block should be present"
11251149 small_stmt = ASTNodeFactory .create_ast_small_stmt (assignment = assignment ,
11261150 source_position = ASTSourceLocation .get_added_source_position ())
11271151 stmt = ASTNodeFactory .create_ast_stmt (small_stmt = small_stmt ,
11281152 source_position = ASTSourceLocation .get_added_source_position ())
1129- if not neuron .get_update_blocks ():
1130- neuron .create_empty_update_block ()
1131- neuron .get_update_blocks ()[0 ].get_block ().get_stmts ().append (stmt )
1132- small_stmt .update_scope (neuron .get_update_blocks ()[0 ].get_block ().get_scope ())
1133- stmt .update_scope (neuron .get_update_blocks ()[0 ].get_block ().get_scope ())
1134- return neuron
1153+ if not model .get_update_blocks ():
1154+ model .create_empty_update_block ()
1155+ model .get_update_blocks ()[0 ].get_block ().get_stmts ().append (stmt )
1156+ small_stmt .update_scope (model .get_update_blocks ()[0 ].get_block ().get_scope ())
1157+ stmt .update_scope (model .get_update_blocks ()[0 ].get_block ().get_scope ())
1158+
1159+ from pynestml .visitors .ast_parent_visitor import ASTParentVisitor
1160+ model .accept (ASTParentVisitor ())
1161+
1162+ return model
1163+
1164+ @classmethod
1165+ def add_function_call_to_update_block (cls , function_call : ASTFunctionCall , model : ASTModel ) -> ASTModel :
1166+ """
1167+ Adds a single assignment to the end of the update block of the handed over model.
1168+
1169+ :param function_call: a single function call
1170+ :param neuron: a single model instance
1171+ :return: the modified model
1172+ """
1173+ assert len (model .get_update_blocks ()) <= 1 , "At most one update block should be present"
1174+
1175+ if not model .get_update_blocks ():
1176+ model .create_empty_update_block ()
1177+
1178+ small_stmt = ASTNodeFactory .create_ast_small_stmt (function_call = function_call ,
1179+ source_position = ASTSourceLocation .get_added_source_position ())
1180+ stmt = ASTNodeFactory .create_ast_stmt (small_stmt = small_stmt ,
1181+ source_position = ASTSourceLocation .get_added_source_position ())
1182+ model .get_update_blocks ()[0 ].get_block ().get_stmts ().append (stmt )
1183+ small_stmt .update_scope (model .get_update_blocks ()[0 ].get_block ().get_scope ())
1184+ stmt .update_scope (model .get_update_blocks ()[0 ].get_block ().get_scope ())
1185+
1186+ from pynestml .visitors .ast_parent_visitor import ASTParentVisitor
1187+ model .accept (ASTParentVisitor ())
1188+
1189+ return model
11351190
11361191 @classmethod
1137- def add_declaration_to_update_block (cls , declaration : ASTDeclaration , neuron : ASTModel ) -> ASTModel :
1192+ def add_declaration_to_update_block (cls , declaration : ASTDeclaration , model : ASTModel ) -> ASTModel :
11381193 """
1139- Adds a single declaration to the end of the update block of the handed over neuron .
1194+ Adds a single declaration to the end of the update block of the handed over model .
11401195 :param declaration: ASTDeclaration node to add
1141- :param neuron : a single neuron instance
1142- :return: a modified neuron
1196+ :param model : a single model instance
1197+ :return: a modified model
11431198 """
1144- assert len (neuron .get_update_blocks ()) <= 1 , "At most one update block should be present"
1199+ assert len (model .get_update_blocks ()) <= 1 , "At most one update block should be present"
11451200 small_stmt = ASTNodeFactory .create_ast_small_stmt (declaration = declaration ,
11461201 source_position = ASTSourceLocation .get_added_source_position ())
11471202 stmt = ASTNodeFactory .create_ast_stmt (small_stmt = small_stmt ,
11481203 source_position = ASTSourceLocation .get_added_source_position ())
1149- if not neuron .get_update_blocks ():
1150- neuron .create_empty_update_block ()
1151- neuron .get_update_blocks ()[0 ].get_block ().get_stmts ().append (stmt )
1152- small_stmt .update_scope (neuron .get_update_blocks ()[0 ].get_block ().get_scope ())
1153- stmt .update_scope (neuron .get_update_blocks ()[0 ].get_block ().get_scope ())
1154- return neuron
1204+ if not model .get_update_blocks ():
1205+ model .create_empty_update_block ()
1206+ model .get_update_blocks ()[0 ].get_block ().get_stmts ().append (stmt )
1207+ small_stmt .update_scope (model .get_update_blocks ()[0 ].get_block ().get_scope ())
1208+ stmt .update_scope (model .get_update_blocks ()[0 ].get_block ().get_scope ())
1209+
1210+ from pynestml .visitors .ast_parent_visitor import ASTParentVisitor
1211+ model .accept (ASTParentVisitor ())
1212+
1213+ return model
11551214
11561215 @classmethod
11571216 def add_state_updates (cls , neuron : ASTModel , update_expressions : Mapping [str , str ]) -> ASTModel :
@@ -1165,6 +1224,7 @@ def add_state_updates(cls, neuron: ASTModel, update_expressions: Mapping[str, st
11651224 for variable , update_expression in update_expressions .items ():
11661225 declaration_statement = variable + '__tmp real = ' + update_expression
11671226 cls .add_declaration_to_update_block (ModelParser .parse_declaration (declaration_statement ), neuron )
1227+
11681228 for variable , update_expression in update_expressions .items ():
11691229 cls .add_assignment_to_update_block (ModelParser .parse_assignment (variable + ' = ' + variable + '__tmp' ),
11701230 neuron )
0 commit comments